Source code for tests.utils

"""Test utilities."""
# (C) Pywikibot team, 2013-2024
# Distributed under the terms of the MIT license.
from __future__ import annotations

import inspect
import os
import sys
import unittest
import warnings
from contextlib import contextmanager, suppress
from subprocess import PIPE, Popen, TimeoutExpired
from typing import Any

import pywikibot
from pywikibot import config
from pywikibot.backports import Sequence
from import CachedRequest
from import Request as _original_Request
from pywikibot.exceptions import APIError
from pywikibot.login import LoginStatus
from import Namespace
from import EMPTY_DEFAULT
from tests import _pwb_py

OSWIN32 = (sys.platform == 'win32')

[docs] def expected_failure_if(expect): """ Unit test decorator to expect failure under conditions. :param expect: Flag to check if failure is expected :type expect: bool """ if expect: return unittest.expectedFailure return lambda orig: orig
[docs] def fixed_generator(iterable): """Return a dummy generator ignoring all parameters. This can be used to overwrite a generator method and yield predefined items: >>> from tests.utils import fixed_generator >>> site = pywikibot.Site() >>> page = pywikibot.Page(site, 'Any page') >>> list(page.linkedPages(total=1)) [] >>> gen = fixed_generator([ ... pywikibot.Page(site, 'User:BobBot/Redir'), ... pywikibot.Page(site, 'Main Page')]) >>> page.linkedPages = gen >>> list(page.linkedPages(total=1)) [Page('Benutzer:BobBot/Redir'), Page('Main Page')] """ def gen(*args, **kwargs): yield from iterable return gen
[docs] def entered_loop(iterable): """Return True if iterable contains items.""" for _ in iterable: return True return False
[docs] class WarningSourceSkipContextManager(warnings.catch_warnings): """ Warning context manager that adjusts source of warning. The source of the warning will be moved further down the stack to skip a list of objects that have been monkey patched into the call stack. """ def __init__(self, skip_list): """ Initializer. :param skip_list: List of objects to be skipped. The source of any warning that matches the skip_list won't be adjusted. :type skip_list: list of object or (obj, str, int, int) """ super().__init__(record=True) self.skip_list = skip_list @property def skip_list(self): """ Return list of filename and line ranges to skip. :rtype: list of (obj, str, int, int) """ return self._skip_list @skip_list.setter def skip_list(self, value): """ Set list of objects to be skipped. :param value: List of objects to be skipped :type value: list of object or (obj, str, int, int) """ self._skip_list = [] for item in value: if isinstance(item, tuple): self._skip_list.append(item) else: filename = inspect.getsourcefile(item) code, first_line = inspect.getsourcelines(item) last_line = first_line + len(code) self._skip_list.append( (item, filename, first_line, last_line)) def __enter__(self): """Enter the context manager.""" def detailed_show_warning(*args, **kwargs): """Replacement handler for warnings.showwarning.""" warn_msg = warnings.WarningMessage(*args, **kwargs) skip_frames = 0 a_frame_has_matched_warn_msg = False # The following for-loop will adjust the warn_msg only if the # warning does not match the skip_list. for _, frame_filename, frame_lineno, *_ in inspect.stack(): if any(start <= frame_lineno <= end for (_, skip_filename, start, end) in self.skip_list if skip_filename == frame_filename): # this frame matches to one of the items in the skip_list if a_frame_has_matched_warn_msg: continue skip_frames += 1 if frame_filename == warn_msg.filename \ and frame_lineno == warn_msg.lineno: if not skip_frames: break a_frame_has_matched_warn_msg = True if a_frame_has_matched_warn_msg: if not skip_frames: # adjust the warn_msg warn_msg.filename = frame_filename warn_msg.lineno = frame_lineno break skip_frames -= 1 # Ignore socket IO warnings (T183696, T184996) if issubclass(warn_msg.category, ResourceWarning) \ and str(warn_msg.message).startswith( ('unclosed <ssl.SSLSocket', 'unclosed <socket.socket')): return log.append(warn_msg) log = super().__enter__() self._module.showwarning = detailed_show_warning return log
[docs] class AssertAPIErrorContextManager: """ Context manager to assert certain APIError exceptions. This is build similar to the :py:obj:`unittest.TestCase.assertError` implementation which creates a context manager. It then calls :py:obj:`handle` which either returns this manager if no executing object given or calls the callable object. """ def __init__(self, code, info, msg, test_case, regex=None): """Create instance expecting the code and info.""" self.code = code = info self.msg = msg self.test_case = test_case self.regex = regex def __enter__(self): """Enter this context manager and the unittest's context manager.""" if self.regex: = self.test_case.assertRaisesRegex(APIError, self.regex, msg=self.msg) else: = self.test_case.assertRaises(APIError, msg=self.msg) return def __exit__(self, exc_type, exc_value, tb): """Exit the context manager and assert code and optionally info.""" result =, exc_value, tb) assert result is isinstance(exc_value, APIError) if result: self.test_case.assertEqual(exc_value.code, self.code) if self.test_case.assertEqual(, return result
[docs] def handle(self, callable_obj, args, kwargs): """Handle the callable object by returning itself or using itself.""" if callable_obj is None: return self with self: callable_obj(*args, **kwargs) return None
[docs] class DryParamInfo(dict): """Dummy class to use instead of :py:obj:`data.api.ParamInfo`.""" def __init__(self, *args, **kwargs): """Initializer.""" super().__init__(*args, **kwargs) self.action_modules = set() self.query_modules = set()
[docs] def fetch(self, modules, _init=False): """Load dry data.""" return [self[mod] for mod in modules]
[docs] def parameter(self, module, param_name): """Load dry data.""" return self[module].get(param_name)
def __getitem__(self, name): """Return dry data or a dummy parameter block.""" with suppress(KeyError): return super().__getitem__(name) result = {'name': name, 'prefix': '', 'limit': {}} if name == 'query+templates': result['limit'] = {'max': 1} return result
[docs] class DummySiteinfo: """Dummy Siteinfo class. To be used instead of :class:` <>`. """ def __init__(self, cache): """Initializer.""" self._cache = {key: (item, False) for key, item in cache.items()} def __getitem__(self, key): """Get item.""" return self.get(key, False) def __setitem__(self, key, value): """Set item.""" self._cache[key] = (value, False)
[docs] def get(self, key, get_default=True, cache=True, expiry=False): """Return dry data.""" # Default values are always expired, so only expiry=False doesn't force # a reload force = expiry is not False if not force and key in self._cache: loaded = self._cache[key] if not loaded[1] and not get_default: raise KeyError(key) return loaded[0] if get_default: default = EMPTY_DEFAULT if cache: self._cache[key] = (default, False) return default raise KeyError(key)
def __contains__(self, key): """Return False.""" return False
[docs] def is_cached(self, key: str) -> bool: """Return whether the key is cached. .. versionadded:: 8.3 """ return key in self._cache
[docs] def is_recognised(self, key): """Return None.""" return None
[docs] def get_requested_time(self, key): """Return False.""" return False
[docs] class DryRequest(CachedRequest): """Dummy class to use instead of :py:obj:`data.api.Request`.""" def __init__(self, *args, **kwargs): """Initializer.""" _original_Request.__init__(self, *args, **kwargs)
[docs] @classmethod def create_simple(cls, req_site, **kwargs): """Skip CachedRequest implementation.""" return _original_Request.create_simple(req_site, **kwargs)
def _expired(self, dt): """Never invalidate cached data.""" return False def _write_cache(self, data) -> None: """Never write data but just do nothing."""
[docs] def submit(self): """Prevented method.""" raise Exception(f'DryRequest rejecting request: {self._params!r}')
[docs] class DrySite( """Dummy class to use instead of :py:obj:``.""" _loginstatus = LoginStatus.NOT_ATTEMPTED def __init__(self, code, fam, user): """Initializer.""" super().__init__(code, fam, user) self._userinfo = self._paraminfo = DryParamInfo() # setup a default siteinfo used by dry tests default_siteinfo = {'fileextensions': [{'ext': 'jpg'}]} self._siteinfo = DummySiteinfo(default_siteinfo) self._siteinfo._cache['lang'] = (code, True) self._siteinfo._cache['case'] = ( 'case-sensitive' if == 'wiktionary' else 'first-letter', True) self._siteinfo._cache['mainpage'] = 'Main Page' extensions = [] if == 'wikisource': extensions.append({'name': 'ProofreadPage'}) self._siteinfo._cache['extensions'] = (extensions, True) aliases = [] for alias in ('PrefixIndex', ): # TODO: Not all follow that scheme (e.g. "BrokenRedirects") aliases.append( {'realname': alias.capitalize(), 'aliases': [alias]}) self._siteinfo._cache['specialpagealiases'] = (aliases, True) self._msgcache = {'*': 'dummy entry', 'hello': 'world'} def _build_namespaces(self): ns_dict = Namespace.builtin_namespaces(case=self.siteinfo['case']) if hasattr(, 'authornamespaces'): assert len([self.code]) <= 1 if[self.code]: author_ns =[self.code][0] assert author_ns not in ns_dict ns_dict[author_ns] = Namespace( author_ns, 'Author', case=self.siteinfo['case']) return ns_dict
[docs] def linktrail(self): """Return default linkrail.""" return '[a-z]*'
@property def userinfo(self): """Return dry data.""" return self._userinfo
[docs] def version(self): """Return a big dummy version string.""" return '999.999'
[docs] def image_repository(self): """Return Site object for image repository e.g. commons.""" code, fam = self.shared_image_repository() if code or fam: return pywikibot.Site(code, fam, self.username(), interface=self.__class__) return None
[docs] def data_repository(self): """Return Site object for data repository e.g. Wikidata.""" if self.hostname().endswith(''): # TODO: Use definition for beta cluster's wikidata code, fam = None, None fam_name = self.hostname().split('.')[-4] else: code, fam = 'wikidata', 'wikidata' fam_name = # Only let through valid entries if fam_name not in ('commons', 'wikibooks', 'wikidata', 'wikinews', 'wikipedia', 'wikiquote', 'wikisource', 'wikivoyage'): code, fam = None, None if code or fam: return pywikibot.Site(code, fam, self.username(), interface=DryDataSite) return None
[docs] def login(self, *args, cookie_only=False, **kwargs): """Overwrite login which is called when a site is initialized. .. versionadded:: 8.0.4 """ if cookie_only: return raise Exception(f'Attempting to login with {type(self).__name__}')
[docs] class DryDataSite(DrySite, """Dummy class to use instead of :py:obj:``.""" def _build_namespaces(self): namespaces = super()._build_namespaces() namespaces[0].defaultcontentmodel = 'wikibase-item' namespaces[120] = Namespace(id=120, case='first-letter', canonical_name='Property', defaultcontentmodel='wikibase-property') return namespaces
[docs] class DryPage(pywikibot.Page): """Dummy class that acts like a Page but avoids network activity.""" _pageid = 1 _disambig = False _isredir = False
[docs] def isDisambig(self): # noqa: N802 """Return disambig status stored in _disambig.""" return self._disambig
[docs] class FakeLoginManager(pywikibot.login.ClientLoginManager): """Loads a fake password.""" @property def password(self): """Get the fake password.""" return 'foo' @password.setter def password(self, value): """Ignore password changes."""
[docs] def execute(command: list[str], *, data_in=None, timeout=None): """Execute a command and capture outputs. .. versionchanged:: 8.2 *error* parameter was removed. .. versionchanged:: 9.1 parameters except *command* are keyword only. :param command: executable to run and arguments to use """ env = os.environ.copy() # Prevent output by test package; e.g. 'max_retries reduced from x to y' env['PYWIKIBOT_TEST_QUIET'] = '1' # sys.path may have been modified by the test runner to load dependencies. pythonpath = os.pathsep.join(sys.path) env['PYTHONPATH'] = pythonpath env['PYTHONIOENCODING'] = config.console_encoding # PYWIKIBOT_USERINTERFACE_LANG will be assigned to # config.userinterface_lang if pywikibot.config.userinterface_lang: env['PYWIKIBOT_USERINTERFACE_LANG'] \ = pywikibot.config.userinterface_lang # Set EDITOR to an executable that ignores all arguments and does nothing. env['EDITOR'] = 'break' if OSWIN32 else 'true' p = Popen(command, env=env, stdout=PIPE, stderr=PIPE, stdin=PIPE if data_in is not None else None) if data_in is not None: data_in = data_in.encode(config.console_encoding) try: stdout_data, stderr_data = p.communicate(input=data_in, timeout=timeout) except TimeoutExpired: p.kill() stdout_data, stderr_data = p.communicate() return {'exit_code': p.returncode, 'stdout': stdout_data.decode(config.console_encoding), 'stderr': stderr_data.decode(config.console_encoding)}
[docs] def execute_pwb(args: list[str], *, data_in: Sequence[str] | None = None, timeout: int | float | None = None, overrides: dict[str, str] | None = None) -> dict[str, Any]: """Execute the script and capture outputs. .. versionchanged:: 8.2 the *error* parameter was removed. .. versionchanged:: 9.1 parameters except *args* are keyword only. :param args: list of arguments for :param overrides: mapping of pywikibot symbols to test replacements """ command = [sys.executable] if overrides: command.append('-c') overrides = '; '.join( f'{key} = {value}' for key, value in overrides.items()) command.append( f'import pwb; import pywikibot; {overrides}; pwb.main()') else: command.append(_pwb_py) return execute(command=command + args, data_in=data_in, timeout=timeout)
[docs] @contextmanager def empty_sites(): """Empty pywikibot _sites and _code_fam_from_url cache on entry point.""" pywikibot._sites = {} pywikibot._code_fam_from_url.cache_clear() yield
[docs] @contextmanager def skipping(*exceptions: BaseException, msg: str | None = None): """Context manager to skip test on specified exceptions. For example Eventstreams raises ``NotImplementedError`` if no ``streams`` parameter was given. Skip the following tests in that case:: with skipping(NotImplementedError): = comms.eventstreams.EventStreams(streams=None) self.assertIsInstance(, tools.collections.GeneratorWrapper) The exception message is used for the ``SkipTest`` reason. To use a custom message, add a ``msg`` parameter:: with skipping(AssertionError, msg='T304786'): self.assertEqual(self.get_mainpage().oldest_revision.text, text) Multiple context expressions may also be used:: with ( skipping(OtherPageSaveError), self.assertRaisesRegex(SpamblacklistError, ''), ): .. note:: The last sample uses Python 3.10 syntax. .. versionadded:: 6.2 :param msg: Optional skipping reason :param exceptions: Exceptions to let test skip """ try: yield except exceptions as e: if msg is None: msg = e raise unittest.SkipTest(msg)