"""Test utilities."""
#
# (C) Pywikibot team, 2013-2024
#
# Distributed under the terms of the MIT license.
#
from __future__ import annotations
import inspect
import os
import sys
import unittest
import warnings
from contextlib import contextmanager, suppress
from subprocess import PIPE, Popen, TimeoutExpired
from typing import Any
import pywikibot
from pywikibot import config
from pywikibot.backports import Sequence
from pywikibot.data.api import CachedRequest
from pywikibot.data.api import Request as _original_Request
from pywikibot.exceptions import APIError
from pywikibot.login import LoginStatus
from pywikibot.site import Namespace
from pywikibot.tools import PYTHON_VERSION
from pywikibot.tools.collections import EMPTY_DEFAULT
from tests import _pwb_py
OSWIN32 = (sys.platform == 'win32')
[docs]
def expected_failure_if(expect):
"""Unit test decorator to expect failure under conditions.
:param expect: Flag to check if failure is expected
:type expect: bool
"""
if expect:
return unittest.expectedFailure
return lambda orig: orig
[docs]
def fixed_generator(iterable):
"""Return a dummy generator ignoring all parameters.
This can be used to overwrite a generator method and yield
predefined items:
>>> from tests.utils import fixed_generator
>>> site = pywikibot.Site()
>>> page = pywikibot.Page(site, 'Any page')
>>> list(page.linkedPages(total=1))
[]
>>> gen = fixed_generator([
... pywikibot.Page(site, 'User:BobBot/Redir'),
... pywikibot.Page(site, 'Main Page')])
>>> page.linkedPages = gen
>>> list(page.linkedPages(total=1))
[Page('Benutzer:BobBot/Redir'), Page('Main Page')]
"""
def gen(*args, **kwargs):
yield from iterable
return gen
[docs]
def entered_loop(iterable):
"""Return True if iterable contains items."""
for _ in iterable:
return True
return False
[docs]
class WarningSourceSkipContextManager(warnings.catch_warnings):
"""Warning context manager that adjusts source of warning.
The source of the warning will be moved further down the
stack to skip a list of objects that have been monkey
patched into the call stack.
"""
def __init__(self, skip_list):
"""Initializer.
:param skip_list: List of objects to be skipped. The source of any
warning that matches the skip_list won't be adjusted.
:type skip_list: list of object or (obj, str, int, int)
"""
super().__init__(record=True)
self.skip_list = skip_list
@property
def skip_list(self):
"""Return list of filename and line ranges to skip.
:rtype: list of (obj, str, int, int)
"""
return self._skip_list
@skip_list.setter
def skip_list(self, value):
"""Set list of objects to be skipped.
:param value: List of objects to be skipped
:type value: list of object or (obj, str, int, int)
"""
self._skip_list = []
for item in value:
if isinstance(item, tuple):
self._skip_list.append(item)
else:
filename = inspect.getsourcefile(item)
code, first_line = inspect.getsourcelines(item)
last_line = first_line + len(code)
self._skip_list.append(
(item, filename, first_line, last_line))
def __enter__(self):
"""Enter the context manager."""
def detailed_show_warning(*args, **kwargs):
"""Replacement handler for warnings.showwarning."""
warn_msg = warnings.WarningMessage(*args, **kwargs)
skip_frames = 0
a_frame_has_matched_warn_msg = False
# The following for-loop will adjust the warn_msg only if the
# warning does not match the skip_list.
for _, frame_filename, frame_lineno, *_ in inspect.stack():
if any(start <= frame_lineno <= end
for (_, skip_filename, start, end) in self.skip_list
if skip_filename == frame_filename):
# this frame matches to one of the items in the skip_list
if a_frame_has_matched_warn_msg:
continue
skip_frames += 1
if frame_filename == warn_msg.filename \
and frame_lineno == warn_msg.lineno:
if not skip_frames:
break
a_frame_has_matched_warn_msg = True
if a_frame_has_matched_warn_msg:
if not skip_frames:
# adjust the warn_msg
warn_msg.filename = frame_filename
warn_msg.lineno = frame_lineno
break
skip_frames -= 1
# Ignore socket IO warnings (T183696, T184996)
if issubclass(warn_msg.category, ResourceWarning) \
and str(warn_msg.message).startswith(
('unclosed <ssl.SSLSocket', 'unclosed <socket.socket')):
return
log.append(warn_msg)
log = super().__enter__()
self._module.showwarning = detailed_show_warning
return log
[docs]
class AssertAPIErrorContextManager:
"""Context manager to assert certain APIError exceptions.
This is build similar to the :py:obj:`unittest.TestCase.assertError`
implementation which creates a context manager. It then calls
:py:obj:`handle` which either returns this manager if no executing
object given or calls the callable object.
"""
def __init__(self, code, info, msg, test_case, regex=None):
"""Create instance expecting the code and info."""
self.code = code
self.info = info
self.msg = msg
self.test_case = test_case
self.regex = regex
def __enter__(self):
"""Enter this context manager and the unittest's context manager."""
if self.regex:
self.cm = self.test_case.assertRaisesRegex(APIError, self.regex,
msg=self.msg)
else:
self.cm = self.test_case.assertRaises(APIError, msg=self.msg)
self.cm.__enter__()
return self.cm
def __exit__(self, exc_type, exc_value, tb):
"""Exit the context manager and assert code and optionally info."""
result = self.cm.__exit__(exc_type, exc_value, tb)
assert result is isinstance(exc_value, APIError)
if result:
self.test_case.assertEqual(exc_value.code, self.code)
if self.info:
self.test_case.assertEqual(exc_value.info, self.info)
return result
[docs]
def handle(self, callable_obj, args, kwargs):
"""Handle the callable object by returning itself or using itself."""
if callable_obj is None:
return self
with self:
callable_obj(*args, **kwargs)
return None
[docs]
class DryParamInfo(dict):
"""Dummy class to use instead of :py:obj:`data.api.ParamInfo`."""
def __init__(self, *args, **kwargs):
"""Initializer."""
super().__init__(*args, **kwargs)
self.action_modules = set()
self.query_modules = set()
[docs]
def fetch(self, modules, _init=False):
"""Load dry data."""
return [self[mod] for mod in modules]
[docs]
def parameter(self, module, param_name):
"""Load dry data."""
return self[module].get(param_name)
def __getitem__(self, name):
"""Return dry data or a dummy parameter block."""
with suppress(KeyError):
return super().__getitem__(name)
result = {'name': name, 'prefix': '', 'limit': {}}
if name == 'query+templates':
result['limit'] = {'max': 1}
return result
[docs]
class DummySiteinfo:
"""Dummy Siteinfo class.
To be used instead of :class:`pywikibot.site.Siteinfo
<pywikibot.site._siteinfo.Siteinfo>`.
"""
def __init__(self, cache):
"""Initializer."""
self._cache = {key: (item, False) for key, item in cache.items()}
def __getitem__(self, key):
"""Get item."""
return self.get(key, False)
def __setitem__(self, key, value):
"""Set item."""
self._cache[key] = (value, False)
[docs]
def get(self, key, get_default=True, cache=True, expiry=False):
"""Return dry data."""
# Default values are always expired, so only expiry=False doesn't force
# a reload
force = expiry is not False
if not force and key in self._cache:
loaded = self._cache[key]
if not loaded[1] and not get_default:
raise KeyError(key)
return loaded[0]
if get_default:
default = EMPTY_DEFAULT
if cache:
self._cache[key] = (default, False)
return default
raise KeyError(key)
def __contains__(self, key):
"""Return False."""
return False
[docs]
def is_cached(self, key: str) -> bool:
"""Return whether the key is cached.
.. versionadded:: 8.3
"""
return key in self._cache
[docs]
def is_recognised(self, key):
"""Return None."""
return
[docs]
def get_requested_time(self, key):
"""Return False."""
return False
[docs]
class DryRequest(CachedRequest):
"""Dummy class to use instead of :py:obj:`data.api.Request`."""
def __init__(self, *args, **kwargs):
"""Initializer."""
_original_Request.__init__(self, *args, **kwargs)
[docs]
@classmethod
def create_simple(cls, req_site, **kwargs):
"""Skip CachedRequest implementation."""
return _original_Request.create_simple(req_site, **kwargs)
def _expired(self, dt):
"""Never invalidate cached data."""
return False
def _write_cache(self, data) -> None:
"""Never write data but just do nothing."""
[docs]
def submit(self):
"""Prevented method."""
raise Exception(f'DryRequest rejecting request: {self._params!r}')
[docs]
class DrySite(pywikibot.site.APISite):
"""Dummy class to use instead of :py:obj:`pywikibot.site.APISite`."""
_loginstatus = LoginStatus.NOT_ATTEMPTED
def __init__(self, code, fam, user):
"""Initializer."""
super().__init__(code, fam, user)
self._userinfo = pywikibot.tools.collections.EMPTY_DEFAULT
self._paraminfo = DryParamInfo()
# setup a default siteinfo used by dry tests
default_siteinfo = {'fileextensions': [{'ext': 'jpg'}]}
self._siteinfo = DummySiteinfo(default_siteinfo)
self._siteinfo._cache['lang'] = (code, True)
self._siteinfo._cache['case'] = (
'case-sensitive' if self.family.name == 'wiktionary' else
'first-letter', True)
self._siteinfo._cache['mainpage'] = 'Main Page'
extensions = []
if self.family.name == 'wikisource':
extensions.append({'name': 'ProofreadPage'})
self._siteinfo._cache['extensions'] = (extensions, True)
# TODO: Not all follow that scheme (e.g. "BrokenRedirects")
aliases = [{'realname': alias.capitalize(), 'aliases': [alias]}
for alias in ('PrefixIndex', )]
self._siteinfo._cache['specialpagealiases'] = (aliases, True)
self._msgcache = {'*': 'dummy entry', 'hello': 'world'}
def _build_namespaces(self):
ns_dict = Namespace.builtin_namespaces(case=self.siteinfo['case'])
if hasattr(self.family, 'authornamespaces'):
assert len(self.family.authornamespaces[self.code]) <= 1
if self.family.authornamespaces[self.code]:
author_ns = self.family.authornamespaces[self.code][0]
assert author_ns not in ns_dict
ns_dict[author_ns] = Namespace(
author_ns, 'Author', case=self.siteinfo['case'])
return ns_dict
[docs]
def linktrail(self):
"""Return default linkrail."""
return '[a-z]*'
@property
def userinfo(self):
"""Return dry data."""
return self._userinfo
[docs]
def version(self):
"""Return a big dummy version string."""
return '999.999'
[docs]
def image_repository(self):
"""Return Site object for image repository e.g. commons."""
code, fam = self.shared_image_repository()
if code or fam:
return pywikibot.Site(code, fam, self.username(),
interface=self.__class__)
return None
[docs]
def data_repository(self):
"""Return Site object for data repository e.g. Wikidata."""
if self.hostname().endswith('.beta.wmflabs.org'):
# TODO: Use definition for beta cluster's wikidata
code, fam = None, None
fam_name = self.hostname().split('.')[-4]
else:
code, fam = 'wikidata', 'wikidata'
fam_name = self.family.name
# Only let through valid entries
if fam_name not in ('commons', 'wikibooks', 'wikidata', 'wikinews',
'wikipedia', 'wikiquote', 'wikisource',
'wikivoyage'):
code, fam = None, None
if code or fam:
return pywikibot.Site(code, fam, self.username(),
interface=DryDataSite)
return None
[docs]
def login(self, *args, cookie_only=False, **kwargs):
"""Overwrite login which is called when a site is initialized.
.. versionadded:: 8.0.4
"""
if cookie_only:
return
raise Exception(f'Attempting to login with {type(self).__name__}')
[docs]
class DryDataSite(DrySite, pywikibot.site.DataSite):
"""Dummy class to use instead of :py:obj:`pywikibot.site.DataSite`."""
def _build_namespaces(self):
namespaces = super()._build_namespaces()
namespaces[0].defaultcontentmodel = 'wikibase-item'
namespaces[120] = Namespace(id=120,
case='first-letter',
canonical_name='Property',
defaultcontentmodel='wikibase-property')
return namespaces
[docs]
class DryPage(pywikibot.Page):
"""Dummy class that acts like a Page but avoids network activity."""
_pageid = 1
_disambig = False
_isredir = False
[docs]
def isDisambig(self): # noqa: N802
"""Return disambig status stored in _disambig."""
return self._disambig
[docs]
class FakeLoginManager(pywikibot.login.ClientLoginManager):
"""Loads a fake password."""
@property
def password(self):
"""Get the fake password."""
return 'foo'
@password.setter
def password(self, value):
"""Ignore password changes."""
[docs]
def execute(command: list[str], *, data_in=None, timeout=None):
"""Execute a command and capture outputs.
.. versionchanged:: 8.2
*error* parameter was removed.
.. versionchanged:: 9.1
parameters except *command* are keyword only.
:param command: executable to run and arguments to use
"""
if PYTHON_VERSION < (3, 8):
command.insert(1, '-W ignore::FutureWarning:pywikibot:97')
env = os.environ.copy()
# Prevent output by test package; e.g. 'max_retries reduced from x to y'
env['PYWIKIBOT_TEST_QUIET'] = '1'
# sys.path may have been modified by the test runner to load dependencies.
pythonpath = os.pathsep.join(sys.path)
env['PYTHONPATH'] = pythonpath
env['PYTHONIOENCODING'] = config.console_encoding
# PYWIKIBOT_USERINTERFACE_LANG will be assigned to
# config.userinterface_lang
if pywikibot.config.userinterface_lang:
env['PYWIKIBOT_USERINTERFACE_LANG'] \
= pywikibot.config.userinterface_lang
# Set EDITOR to an executable that ignores all arguments and does nothing.
env['EDITOR'] = 'break' if OSWIN32 else 'true'
p = Popen(command, env=env, stdout=PIPE, stderr=PIPE,
stdin=PIPE if data_in is not None else None)
if data_in is not None:
data_in = data_in.encode(config.console_encoding)
try:
stdout_data, stderr_data = p.communicate(input=data_in,
timeout=timeout)
except TimeoutExpired:
p.kill()
stdout_data, stderr_data = p.communicate()
return {'exit_code': p.returncode,
'stdout': stdout_data.decode(config.console_encoding),
'stderr': stderr_data.decode(config.console_encoding)}
[docs]
def execute_pwb(args: list[str], *,
data_in: Sequence[str] | None = None,
timeout: int | float | None = None,
overrides: dict[str, str] | None = None) -> dict[str, Any]:
"""Execute the pwb.py script and capture outputs.
.. versionchanged:: 8.2
the *error* parameter was removed.
.. versionchanged:: 9.1
parameters except *args* are keyword only.
:param args: list of arguments for pwb.py
:param overrides: mapping of pywikibot symbols to test replacements
"""
command = [sys.executable]
if overrides:
command.append('-c')
overrides = '; '.join(
f'{key} = {value}' for key, value in overrides.items())
command.append(
f'import pwb; import pywikibot; {overrides}; pwb.main()')
else:
command.append(_pwb_py)
return execute(command=command + args, data_in=data_in, timeout=timeout)
[docs]
@contextmanager
def empty_sites():
"""Empty pywikibot _sites and _code_fam_from_url cache on entry point."""
pywikibot._sites = {}
pywikibot._code_fam_from_url.cache_clear()
yield
[docs]
@contextmanager
def skipping(*exceptions: BaseException,
msg: str | None = None,
code: str | None = None):
"""Context manager to skip test on specified exceptions.
For example Eventstreams raises ``NotImplementedError`` if no
``streams`` parameter was given. Skip the following tests in that
case::
with skipping(NotImplementedError):
self.es = comms.eventstreams.EventStreams(streams=None)
self.assertIsInstance(self.es, tools.collections.GeneratorWrapper)
The exception message is used for the ``SkipTest`` reason. To use a
custom message, add a ``msg`` parameter::
with skipping(AssertionError, msg='T304786'):
self.assertEqual(self.get_mainpage().oldest_revision.text, text)
Multiple context expressions may also be used::
with (
skipping(OtherPageSaveError),
self.assertRaisesRegex(SpamblacklistError, 'badsite.com'),
):
page.save()
.. note:: The last sample uses Python 3.10 syntax.
.. versionadded:: 6.2
.. versionchanged:: 9.3
*code* parameter was added
:param exceptions: Exceptions to let test skip
:param msg: Optional skipping reason
:param code: if *exceptions* is a single :exc:`APIError` you can
specify the :attr:`APIError.code` for the right match to be
skipped.
"""
try:
yield
except exceptions as e:
if code:
if len(exceptions) != 1 \
or not hasattr(e, 'code') \
or e.code != code:
raise
msg = e.info
if msg is None:
msg = e
raise unittest.SkipTest(msg)