diff --git a/ConfigTools.py b/ConfigTools.py new file mode 100644 index 0000000..2eabcb6 --- /dev/null +++ b/ConfigTools.py @@ -0,0 +1,49 @@ +from logging import Manager +import os +from lib.cryptography.fernet import Fernet + +class CredentialManager: + def __init__(self, key_file="secret_key"): + self.key_file = key_file + self.key = self._load_or_generate_key() + self.cipher = Fernet(self.key) + + def _load_or_generate_key(self): + if not os.path.exists(self.key_file): + key = Fernet.generate_key() + with open(self.key_file, "wb") as file: + file.write(key) + os.chmod(self.key_file, 0o600) + return key + with open(self.key_file, "rb") as file: + return file.read() + + def encrypt_password(self, password): + return self.cipher.encrypt(password.encode()).decode() + + def dencrypt_password(self, encrypted_password): + return self.cipher.decrypt(encrypted_password.encode()).decode() + + +manager = CredentialManager() + +password = "test1" +encrypted = manager.encrypt_password(password) +print(f"encrpted: {encrypted}") + + +decripted = manager.dencrypt_password(encrypted) +print(f"Decripted: {decripted}") + + + + + + + + + + + + + diff --git a/PKGBUILD b/PKGBUILD new file mode 100644 index 0000000..e4d15c7 --- /dev/null +++ b/PKGBUILD @@ -0,0 +1,22 @@ +pkgname=ServerSync +pkgver=1.0.0 +pkgrel=1 +pkgdesc="A tool to simply manage & Sync files and directories to remotes with configs!" +arch=('any') +url="https://github.com/youruser/your-tool" +license=('MIT') +depends=('bash, python3') # Add runtime dependencies here +makedepends=('git' 'gcc') # Add build-time dependencies here +source=("https://github.com/youruser/$pkgname/archive/v$pkgver.tar.gz") +sha256sums=('SKIP') # We will fix this in the next step + +build() { + cd "$pkgname-$pkgver" + make # Or your specific build command +} + +package() { + cd "$pkgname-$pkgver" + # This installs the binary to /usr/bin inside the package + install -Dm755 your-binary-name "$pkgdir/usr/bin/your-binary-name" +} diff --git a/lib/jaraco.classes-3.4.0.dist-info/INSTALLER b/lib/jaraco.classes-3.4.0.dist-info/INSTALLER new file mode 100644 index 0000000..a1b589e --- /dev/null +++ b/lib/jaraco.classes-3.4.0.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/lib/jaraco.classes-3.4.0.dist-info/LICENSE b/lib/jaraco.classes-3.4.0.dist-info/LICENSE new file mode 100644 index 0000000..1bb5a44 --- /dev/null +++ b/lib/jaraco.classes-3.4.0.dist-info/LICENSE @@ -0,0 +1,17 @@ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to +deal in the Software without restriction, including without limitation the +rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +sell copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +IN THE SOFTWARE. diff --git a/lib/jaraco.classes-3.4.0.dist-info/METADATA b/lib/jaraco.classes-3.4.0.dist-info/METADATA new file mode 100644 index 0000000..6b11499 --- /dev/null +++ b/lib/jaraco.classes-3.4.0.dist-info/METADATA @@ -0,0 +1,60 @@ +Metadata-Version: 2.1 +Name: jaraco.classes +Version: 3.4.0 +Summary: Utility functions for Python class constructs +Home-page: https://github.com/jaraco/jaraco.classes +Author: Jason R. Coombs +Author-email: jaraco@jaraco.com +Classifier: Development Status :: 5 - Production/Stable +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: MIT License +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3 :: Only +Requires-Python: >=3.8 +License-File: LICENSE +Requires-Dist: more-itertools +Provides-Extra: docs +Requires-Dist: sphinx >=3.5 ; extra == 'docs' +Requires-Dist: jaraco.packaging >=9.3 ; extra == 'docs' +Requires-Dist: rst.linker >=1.9 ; extra == 'docs' +Requires-Dist: furo ; extra == 'docs' +Requires-Dist: sphinx-lint ; extra == 'docs' +Requires-Dist: jaraco.tidelift >=1.4 ; extra == 'docs' +Provides-Extra: testing +Requires-Dist: pytest >=6 ; extra == 'testing' +Requires-Dist: pytest-checkdocs >=2.4 ; extra == 'testing' +Requires-Dist: pytest-cov ; extra == 'testing' +Requires-Dist: pytest-mypy ; extra == 'testing' +Requires-Dist: pytest-enabler >=2.2 ; extra == 'testing' +Requires-Dist: pytest-ruff >=0.2.1 ; extra == 'testing' + +.. image:: https://img.shields.io/pypi/v/jaraco.classes.svg + :target: https://pypi.org/project/jaraco.classes + +.. image:: https://img.shields.io/pypi/pyversions/jaraco.classes.svg + +.. image:: https://github.com/jaraco/jaraco.classes/actions/workflows/main.yml/badge.svg + :target: https://github.com/jaraco/jaraco.classes/actions?query=workflow%3A%22tests%22 + :alt: tests + +.. image:: https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/charliermarsh/ruff/main/assets/badge/v2.json + :target: https://github.com/astral-sh/ruff + :alt: Ruff + +.. image:: https://readthedocs.org/projects/jaracoclasses/badge/?version=latest + :target: https://jaracoclasses.readthedocs.io/en/latest/?badge=latest + +.. image:: https://img.shields.io/badge/skeleton-2024-informational + :target: https://blog.jaraco.com/skeleton + +.. image:: https://tidelift.com/badges/package/pypi/jaraco.classes + :target: https://tidelift.com/subscription/pkg/pypi-jaraco.classes?utm_source=pypi-jaraco.classes&utm_medium=readme + +For Enterprise +============== + +Available as part of the Tidelift Subscription. + +This project and the maintainers of thousands of other packages are working with Tidelift to deliver one enterprise subscription that covers all of the open source you use. + +`Learn more `_. diff --git a/lib/jaraco.classes-3.4.0.dist-info/RECORD b/lib/jaraco.classes-3.4.0.dist-info/RECORD new file mode 100644 index 0000000..4d09383 --- /dev/null +++ b/lib/jaraco.classes-3.4.0.dist-info/RECORD @@ -0,0 +1,15 @@ +jaraco.classes-3.4.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +jaraco.classes-3.4.0.dist-info/LICENSE,sha256=htoPAa6uRjSKPD1GUZXcHOzN55956HdppkuNoEsqR0E,1023 +jaraco.classes-3.4.0.dist-info/METADATA,sha256=LmsQIjLt1Frhu4prQJH9QM8yAaa7b9S8l8XozXZaRLg,2623 +jaraco.classes-3.4.0.dist-info/RECORD,, +jaraco.classes-3.4.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92 +jaraco.classes-3.4.0.dist-info/top_level.txt,sha256=0JnN3LfXH4LIRfXL-QFOGCJzQWZO3ELx4R1d_louoQM,7 +jaraco/classes/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +jaraco/classes/__pycache__/__init__.cpython-314.pyc,, +jaraco/classes/__pycache__/ancestry.cpython-314.pyc,, +jaraco/classes/__pycache__/meta.cpython-314.pyc,, +jaraco/classes/__pycache__/properties.cpython-314.pyc,, +jaraco/classes/ancestry.py,sha256=FkU7kyOO-TOMgwR3obcpqB93Ht-f0yxjGnTxcvfBLB0,1787 +jaraco/classes/meta.py,sha256=uz1zmtse_0n7cs2M2hfz8iIqoe2_2vZI-_JiFvQuDwE,2198 +jaraco/classes/properties.py,sha256=f-88KCSBeeCliwxfXOwe7Uqk9_elEmi9ZSwOh6_yBq4,6191 +jaraco/classes/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 diff --git a/lib/jaraco.classes-3.4.0.dist-info/WHEEL b/lib/jaraco.classes-3.4.0.dist-info/WHEEL new file mode 100644 index 0000000..bab98d6 --- /dev/null +++ b/lib/jaraco.classes-3.4.0.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: bdist_wheel (0.43.0) +Root-Is-Purelib: true +Tag: py3-none-any + diff --git a/lib/jaraco.classes-3.4.0.dist-info/top_level.txt b/lib/jaraco.classes-3.4.0.dist-info/top_level.txt new file mode 100644 index 0000000..f6205a5 --- /dev/null +++ b/lib/jaraco.classes-3.4.0.dist-info/top_level.txt @@ -0,0 +1 @@ +jaraco diff --git a/lib/jaraco/classes/__init__.py b/lib/jaraco/classes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lib/jaraco/classes/__pycache__/__init__.cpython-314.pyc b/lib/jaraco/classes/__pycache__/__init__.cpython-314.pyc new file mode 100644 index 0000000..53e4f6a Binary files /dev/null and b/lib/jaraco/classes/__pycache__/__init__.cpython-314.pyc differ diff --git a/lib/jaraco/classes/__pycache__/ancestry.cpython-314.pyc b/lib/jaraco/classes/__pycache__/ancestry.cpython-314.pyc new file mode 100644 index 0000000..8611bc7 Binary files /dev/null and b/lib/jaraco/classes/__pycache__/ancestry.cpython-314.pyc differ diff --git a/lib/jaraco/classes/__pycache__/meta.cpython-314.pyc b/lib/jaraco/classes/__pycache__/meta.cpython-314.pyc new file mode 100644 index 0000000..b4e63fb Binary files /dev/null and b/lib/jaraco/classes/__pycache__/meta.cpython-314.pyc differ diff --git a/lib/jaraco/classes/__pycache__/properties.cpython-314.pyc b/lib/jaraco/classes/__pycache__/properties.cpython-314.pyc new file mode 100644 index 0000000..d344ae3 Binary files /dev/null and b/lib/jaraco/classes/__pycache__/properties.cpython-314.pyc differ diff --git a/lib/jaraco/classes/ancestry.py b/lib/jaraco/classes/ancestry.py new file mode 100644 index 0000000..5c8c5de --- /dev/null +++ b/lib/jaraco/classes/ancestry.py @@ -0,0 +1,76 @@ +""" +Routines for obtaining the class names +of an object and its parent classes. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, cast + +from more_itertools import unique_everseen + +if TYPE_CHECKING: + from collections.abc import Iterator + from typing import Any + + +def all_bases(c: type[object]) -> list[type[Any]]: + """ + return a tuple of all base classes the class c has as a parent. + >>> object in all_bases(list) + True + """ + return c.mro()[1:] + + +def all_classes(c: type[object]) -> list[type[Any]]: + """ + return a tuple of all classes to which c belongs + >>> list in all_classes(list) + True + """ + return c.mro() + + +# borrowed from +# http://code.activestate.com/recipes/576949-find-all-subclasses-of-a-given-class/ + + +def iter_subclasses(cls: type[object]) -> Iterator[type[Any]]: + """ + Generator over all subclasses of a given class, in depth-first order. + + >>> bool in list(iter_subclasses(int)) + True + >>> class A(object): pass + >>> class B(A): pass + >>> class C(A): pass + >>> class D(B,C): pass + >>> class E(D): pass + >>> + >>> for cls in iter_subclasses(A): + ... print(cls.__name__) + B + D + E + C + >>> # get ALL classes currently defined + >>> res = [cls.__name__ for cls in iter_subclasses(object)] + >>> 'type' in res + True + >>> 'tuple' in res + True + >>> len(res) > 100 + True + """ + return unique_everseen(_iter_all_subclasses(cls)) + + +def _iter_all_subclasses(cls: type[object]) -> Iterator[type[Any]]: + try: + subs = cls.__subclasses__() + except TypeError: # fails only when cls is type + subs = cast('type[type]', cls).__subclasses__(cls) + for sub in subs: + yield sub + yield from iter_subclasses(sub) diff --git a/lib/jaraco/classes/meta.py b/lib/jaraco/classes/meta.py new file mode 100644 index 0000000..27d03a0 --- /dev/null +++ b/lib/jaraco/classes/meta.py @@ -0,0 +1,85 @@ +""" +meta.py + +Some useful metaclasses. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Any + + +class LeafClassesMeta(type): + """ + A metaclass for classes that keeps track of all of them that + aren't base classes. + + >>> Parent = LeafClassesMeta('MyParentClass', (), {}) + >>> Parent in Parent._leaf_classes + True + >>> Child = LeafClassesMeta('MyChildClass', (Parent,), {}) + >>> Child in Parent._leaf_classes + True + >>> Parent in Parent._leaf_classes + False + + >>> Other = LeafClassesMeta('OtherClass', (), {}) + >>> Parent in Other._leaf_classes + False + >>> len(Other._leaf_classes) + 1 + """ + + _leaf_classes: set[type[Any]] + + def __init__( + cls, + name: str, + bases: tuple[type[object], ...], + attrs: dict[str, object], + ) -> None: + if not hasattr(cls, '_leaf_classes'): + cls._leaf_classes = set() + leaf_classes = getattr(cls, '_leaf_classes') + leaf_classes.add(cls) + # remove any base classes + leaf_classes -= set(bases) + + +class TagRegistered(type): + """ + As classes of this metaclass are created, they keep a registry in the + base class of all classes by a class attribute, indicated by attr_name. + + >>> FooObject = TagRegistered('FooObject', (), dict(tag='foo')) + >>> FooObject._registry['foo'] is FooObject + True + >>> BarObject = TagRegistered('Barobject', (FooObject,), dict(tag='bar')) + >>> FooObject._registry is BarObject._registry + True + >>> len(FooObject._registry) + 2 + + '...' below should be 'jaraco.classes' but for pytest-dev/pytest#3396 + >>> FooObject._registry['bar'] + + """ + + attr_name = 'tag' + + def __init__( + cls, + name: str, + bases: tuple[type[object], ...], + namespace: dict[str, object], + ) -> None: + super(TagRegistered, cls).__init__(name, bases, namespace) + if not hasattr(cls, '_registry'): + cls._registry = {} + meta = cls.__class__ + attr = getattr(cls, meta.attr_name, None) + if attr: + cls._registry[attr] = cls diff --git a/lib/jaraco/classes/properties.py b/lib/jaraco/classes/properties.py new file mode 100644 index 0000000..2447041 --- /dev/null +++ b/lib/jaraco/classes/properties.py @@ -0,0 +1,241 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Generic, TypeVar, cast, overload + +_T = TypeVar('_T') +_U = TypeVar('_U') + +if TYPE_CHECKING: + from collections.abc import Callable + from typing import Any, Protocol + + from typing_extensions import Self, TypeAlias + + # TODO(coherent-oss/granary#4): Migrate to PEP 695 by 2027-10. + _GetterCallable: TypeAlias = Callable[..., _T] + _GetterClassMethod: TypeAlias = classmethod[Any, [], _T] + + _SetterCallable: TypeAlias = Callable[[type[Any], _T], None] + _SetterClassMethod: TypeAlias = classmethod[Any, [_T], None] + + class _ClassPropertyAttribute(Protocol[_T]): + def __get__(self, obj: object, objtype: type[Any] | None = None) -> _T: ... + + def __set__(self, obj: object, value: _T) -> None: ... + + +class NonDataProperty(Generic[_T, _U]): + """Much like the property builtin, but only implements __get__, + making it a non-data property, and can be subsequently reset. + + See http://users.rcn.com/python/download/Descriptor.htm for more + information. + + >>> class X(object): + ... @NonDataProperty + ... def foo(self): + ... return 3 + >>> x = X() + >>> x.foo + 3 + >>> x.foo = 4 + >>> x.foo + 4 + + '...' below should be 'jaraco.classes' but for pytest-dev/pytest#3396 + >>> X.foo + <....properties.NonDataProperty object at ...> + """ + + def __init__(self, fget: Callable[[_T], _U]) -> None: + assert fget is not None, "fget cannot be none" + assert callable(fget), "fget must be callable" + self.fget = fget + + @overload + def __get__( + self, + obj: None, + objtype: None, + ) -> Self: ... + + @overload + def __get__( + self, + obj: _T, + objtype: type[_T] | None = None, + ) -> _U: ... + + def __get__( + self, + obj: _T | None, + objtype: type[_T] | None = None, + ) -> Self | _U: + if obj is None: + return self + return self.fget(obj) + + +class classproperty(Generic[_T]): + """ + Like @property but applies at the class level. + + + >>> class X(metaclass=classproperty.Meta): + ... val = None + ... @classproperty + ... def foo(cls): + ... return cls.val + ... @foo.setter + ... def foo(cls, val): + ... cls.val = val + >>> X.foo + >>> X.foo = 3 + >>> X.foo + 3 + >>> x = X() + >>> x.foo + 3 + >>> X.foo = 4 + >>> x.foo + 4 + + Setting the property on an instance affects the class. + + >>> x.foo = 5 + >>> x.foo + 5 + >>> X.foo + 5 + >>> vars(x) + {} + >>> X().foo + 5 + + Attempting to set an attribute where no setter was defined + results in an AttributeError: + + >>> class GetOnly(metaclass=classproperty.Meta): + ... @classproperty + ... def foo(cls): + ... return 'bar' + >>> GetOnly.foo = 3 + Traceback (most recent call last): + ... + AttributeError: can't set attribute + + It is also possible to wrap a classmethod or staticmethod in + a classproperty. + + >>> class Static(metaclass=classproperty.Meta): + ... @classproperty + ... @classmethod + ... def foo(cls): + ... return 'foo' + ... @classproperty + ... @staticmethod + ... def bar(): + ... return 'bar' + >>> Static.foo + 'foo' + >>> Static.bar + 'bar' + + *Legacy* + + For compatibility, if the metaclass isn't specified, the + legacy behavior will be invoked. + + >>> class X: + ... val = None + ... @classproperty + ... def foo(cls): + ... return cls.val + ... @foo.setter + ... def foo(cls, val): + ... cls.val = val + >>> X.foo + >>> X.foo = 3 + >>> X.foo + 3 + >>> x = X() + >>> x.foo + 3 + >>> X.foo = 4 + >>> x.foo + 4 + + Note, because the metaclass was not specified, setting + a value on an instance does not have the intended effect. + + >>> x.foo = 5 + >>> x.foo + 5 + >>> X.foo # should be 5 + 4 + >>> vars(x) # should be empty + {'foo': 5} + >>> X().foo # should be 5 + 4 + """ + + fget: _ClassPropertyAttribute[_GetterClassMethod[_T]] + fset: _ClassPropertyAttribute[_SetterClassMethod[_T] | None] + + class Meta(type): + def __setattr__(self, key: str, value: object) -> None: + obj = self.__dict__.get(key, None) + if type(obj) is classproperty: + return obj.__set__(self, value) + return super().__setattr__(key, value) + + def __init__( + self, + fget: _GetterCallable[_T] | _GetterClassMethod[_T], + fset: _SetterCallable[_T] | _SetterClassMethod[_T] | None = None, + ) -> None: + self.fget = self._ensure_method(fget) + self.fset = fset # type: ignore[assignment] # Corrected in the next line. + fset and self.setter(fset) + + def __get__(self, instance: object, owner: type[object] | None = None) -> _T: + return self.fget.__get__(None, owner)() + + def __set__(self, owner: object, value: _T) -> None: + if not self.fset: + raise AttributeError("can't set attribute") + if type(owner) is not classproperty.Meta: + owner = type(owner) + return self.fset.__get__(None, cast('type[object]', owner))(value) + + def setter(self, fset: _SetterCallable[_T] | _SetterClassMethod[_T]) -> Self: + self.fset = self._ensure_method(fset) + return self + + @overload + @classmethod + def _ensure_method( + cls, + fn: _GetterCallable[_T] | _GetterClassMethod[_T], + ) -> _GetterClassMethod[_T]: ... + + @overload + @classmethod + def _ensure_method( + cls, + fn: _SetterCallable[_T] | _SetterClassMethod[_T], + ) -> _SetterClassMethod[_T]: ... + + @classmethod + def _ensure_method( + cls, + fn: _GetterCallable[_T] + | _GetterClassMethod[_T] + | _SetterCallable[_T] + | _SetterClassMethod[_T], + ) -> _GetterClassMethod[_T] | _SetterClassMethod[_T]: + """ + Ensure fn is a classmethod or staticmethod. + """ + needs_method = not isinstance(fn, (classmethod, staticmethod)) + return classmethod(fn) if needs_method else fn # type: ignore[arg-type,return-value] diff --git a/lib/jaraco/classes/py.typed b/lib/jaraco/classes/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/lib/jaraco/context/__init__.py b/lib/jaraco/context/__init__.py new file mode 100644 index 0000000..41ad609 --- /dev/null +++ b/lib/jaraco/context/__init__.py @@ -0,0 +1,367 @@ +from __future__ import annotations + +import contextlib +import errno +import functools +import operator +import os +import platform +import shutil +import stat +import subprocess +import sys +import tempfile +import urllib.request +from collections.abc import Iterator + +if sys.version_info < (3, 12): + from backports import tarfile +else: + import tarfile + + +@contextlib.contextmanager +def pushd(dir: str | os.PathLike) -> Iterator[str | os.PathLike]: + """ + >>> tmp_path = getfixture('tmp_path') + >>> with pushd(tmp_path): + ... assert os.getcwd() == os.fspath(tmp_path) + >>> assert os.getcwd() != os.fspath(tmp_path) + """ + + orig = os.getcwd() + os.chdir(dir) + try: + yield dir + finally: + os.chdir(orig) + + +@contextlib.contextmanager +def tarball( + url, target_dir: str | os.PathLike | None = None +) -> Iterator[str | os.PathLike]: + """ + Get a URL to a tarball, download, extract, yield, then clean up. + + Assumes everything in the tarball is prefixed with a common + directory. That common path is stripped and the contents + are extracted to ``target_dir``, similar to passing + ``-C {target} --strip-components 1`` to the ``tar`` command. + + Uses the streaming protocol to extract the contents from a + stream in a single pass without loading the whole file into + memory. + + >>> import urllib.request + >>> url = getfixture('tarfile_served') + >>> target = getfixture('tmp_path') / 'out' + >>> tb = tarball(url, target_dir=target) + >>> import pathlib + >>> with tb as extracted: + ... contents = pathlib.Path(extracted, 'contents.txt').read_text(encoding='utf-8') + >>> assert not os.path.exists(extracted) + + If the target is not specified, contents are extracted to a + directory relative to the current working directory named after + the name of the file as extracted from the URL. + + >>> target = getfixture('tmp_path') + >>> with pushd(target), tarball(url): + ... target.joinpath('served').is_dir() + True + """ + if target_dir is None: + target_dir = os.path.basename(url).replace('.tar.gz', '').replace('.tgz', '') + os.mkdir(target_dir) + try: + req = urllib.request.urlopen(url) + with tarfile.open(fileobj=req, mode='r|*') as tf: + tf.extractall(path=target_dir, filter=_default_filter) + yield target_dir + finally: + shutil.rmtree(target_dir) + + +def _compose_tarfile_filters(*filters): + def compose_two(f1, f2): + return lambda member, path: f1(f2(member, path), path) + + return functools.reduce(compose_two, filters, lambda member, path: member) + + +def strip_first_component( + member: tarfile.TarInfo, + path, +) -> tarfile.TarInfo: + _, member.name = member.name.split('/', 1) + return member + + +_default_filter = _compose_tarfile_filters(tarfile.data_filter, strip_first_component) + + +def _compose(*cmgrs): + """ + Compose any number of dependent context managers into a single one. + + The last, innermost context manager may take arbitrary arguments, but + each successive context manager should accept the result from the + previous as a single parameter. + + Like :func:`jaraco.functools.compose`, behavior works from right to + left, so the context manager should be indicated from outermost to + innermost. + + Example, to create a context manager to change to a temporary + directory: + + >>> temp_dir_as_cwd = _compose(pushd, temp_dir) + >>> with temp_dir_as_cwd() as dir: + ... assert os.path.samefile(os.getcwd(), dir) + """ + + def compose_two(inner, outer): + def composed(*args, **kwargs): + with inner(*args, **kwargs) as saved, outer(saved) as res: + yield res + + return contextlib.contextmanager(composed) + + return functools.reduce(compose_two, reversed(cmgrs)) + + +tarball_cwd = _compose(pushd, tarball) +""" +A tarball context with the current working directory pointing to the contents. +""" + + +def remove_readonly(func, path, exc_info): + """ + Add support for removing read-only files on Windows. + """ + _, exc, _ = exc_info + if func in (os.rmdir, os.remove, os.unlink) and exc.errno == errno.EACCES: + # change the file to be readable,writable,executable: 0777 + os.chmod(path, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO) + # retry + func(path) + else: + raise + + +def robust_remover(): + return ( + functools.partial(shutil.rmtree, onerror=remove_readonly) + if platform.system() == 'Windows' + else shutil.rmtree + ) + + +@contextlib.contextmanager +def temp_dir(remover=shutil.rmtree): + """ + Create a temporary directory context. Pass a custom remover + to override the removal behavior. + + >>> import pathlib + >>> with temp_dir() as the_dir: + ... assert os.path.isdir(the_dir) + >>> assert not os.path.exists(the_dir) + """ + temp_dir = tempfile.mkdtemp() + try: + yield temp_dir + finally: + remover(temp_dir) + + +robust_temp_dir = functools.partial(temp_dir, remover=robust_remover()) + + +@contextlib.contextmanager +def repo_context( + url, branch: str | None = None, quiet: bool = True, dest_ctx=robust_temp_dir +): + """ + Check out the repo indicated by url. + + If dest_ctx is supplied, it should be a context manager + to yield the target directory for the check out. + + >>> getfixture('ensure_git') + >>> getfixture('needs_internet') + >>> repo = repo_context('https://github.com/jaraco/jaraco.context') + >>> with repo as dest: + ... listing = os.listdir(dest) + >>> 'README.rst' in listing + True + """ + exe = 'git' if 'git' in url else 'hg' + with dest_ctx() as repo_dir: + cmd = [exe, 'clone', url, repo_dir] + cmd.extend(['--branch', branch] * bool(branch)) + stream = subprocess.DEVNULL if quiet else None + subprocess.check_call(cmd, stdout=stream, stderr=stream) + yield repo_dir + + +class ExceptionTrap: + """ + A context manager that will catch certain exceptions and provide an + indication they occurred. + + >>> with ExceptionTrap() as trap: + ... raise Exception() + >>> bool(trap) + True + + >>> with ExceptionTrap() as trap: + ... pass + >>> bool(trap) + False + + >>> with ExceptionTrap(ValueError) as trap: + ... raise ValueError("1 + 1 is not 3") + >>> bool(trap) + True + >>> trap.value + ValueError('1 + 1 is not 3') + >>> trap.tb + + + >>> with ExceptionTrap(ValueError) as trap: + ... raise Exception() + Traceback (most recent call last): + ... + Exception + + >>> bool(trap) + False + """ + + exc_info = None, None, None + + def __init__(self, exceptions=(Exception,)): + self.exceptions = exceptions + + def __enter__(self): + return self + + @property + def type(self): + return self.exc_info[0] + + @property + def value(self): + return self.exc_info[1] + + @property + def tb(self): + return self.exc_info[2] + + def __exit__(self, *exc_info): + type = exc_info[0] + matches = type and issubclass(type, self.exceptions) + if matches: + self.exc_info = exc_info + return matches + + def __bool__(self): + return bool(self.type) + + def raises(self, func, *, _test=bool): + """ + Wrap func and replace the result with the truth + value of the trap (True if an exception occurred). + + First, give the decorator an alias to support Python 3.8 + Syntax. + + >>> raises = ExceptionTrap(ValueError).raises + + Now decorate a function that always fails. + + >>> @raises + ... def fail(): + ... raise ValueError('failed') + >>> fail() + True + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + with ExceptionTrap(self.exceptions) as trap: + func(*args, **kwargs) + return _test(trap) + + return wrapper + + def passes(self, func): + """ + Wrap func and replace the result with the truth + value of the trap (True if no exception). + + First, give the decorator an alias to support Python 3.8 + Syntax. + + >>> passes = ExceptionTrap(ValueError).passes + + Now decorate a function that always fails. + + >>> @passes + ... def fail(): + ... raise ValueError('failed') + + >>> fail() + False + """ + return self.raises(func, _test=operator.not_) + + +class suppress(contextlib.suppress, contextlib.ContextDecorator): + """ + A version of contextlib.suppress with decorator support. + + >>> @suppress(KeyError) + ... def key_error(): + ... {}[''] + >>> key_error() + """ + + +class on_interrupt(contextlib.ContextDecorator): + """ + Replace a KeyboardInterrupt with SystemExit(1). + + Useful in conjunction with console entry point functions. + + >>> def do_interrupt(): + ... raise KeyboardInterrupt() + >>> on_interrupt('error')(do_interrupt)() + Traceback (most recent call last): + ... + SystemExit: 1 + >>> on_interrupt('error', code=255)(do_interrupt)() + Traceback (most recent call last): + ... + SystemExit: 255 + >>> on_interrupt('suppress')(do_interrupt)() + >>> with __import__('pytest').raises(KeyboardInterrupt): + ... on_interrupt('ignore')(do_interrupt)() + """ + + def __init__(self, action='error', /, code=1): + self.action = action + self.code = code + + def __enter__(self): + return self + + def __exit__(self, exctype, excinst, exctb): + if exctype is not KeyboardInterrupt or self.action == 'ignore': + return + elif self.action == 'error': + raise SystemExit(self.code) from excinst + return self.action == 'suppress' diff --git a/lib/jaraco/context/__pycache__/__init__.cpython-314.pyc b/lib/jaraco/context/__pycache__/__init__.cpython-314.pyc new file mode 100644 index 0000000..0f51726 Binary files /dev/null and b/lib/jaraco/context/__pycache__/__init__.cpython-314.pyc differ diff --git a/lib/jaraco/context/py.typed b/lib/jaraco/context/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/lib/jaraco/functools/__init__.py b/lib/jaraco/functools/__init__.py new file mode 100644 index 0000000..df32e2e --- /dev/null +++ b/lib/jaraco/functools/__init__.py @@ -0,0 +1,722 @@ +from __future__ import annotations + +import collections.abc +import functools +import inspect +import itertools +import operator +import time +import types +import warnings +from typing import Callable, TypeVar + +import more_itertools + + +def compose(*funcs): + """ + Compose any number of unary functions into a single unary function. + + Comparable to + `function composition `_ + in mathematics: + + ``h = g ∘ f`` implies ``h(x) = g(f(x))``. + + In Python, ``h = compose(g, f)``. + + >>> import textwrap + >>> expected = str.strip(textwrap.dedent(compose.__doc__)) + >>> strip_and_dedent = compose(str.strip, textwrap.dedent) + >>> strip_and_dedent(compose.__doc__) == expected + True + + Compose also allows the innermost function to take arbitrary arguments. + + >>> round_three = lambda x: round(x, ndigits=3) + >>> f = compose(round_three, int.__truediv__) + >>> [f(3*x, x+1) for x in range(1,10)] + [1.5, 2.0, 2.25, 2.4, 2.5, 2.571, 2.625, 2.667, 2.7] + """ + + def compose_two(f1, f2): + return lambda *args, **kwargs: f1(f2(*args, **kwargs)) + + return functools.reduce(compose_two, funcs) + + +def once(func): + """ + Decorate func so it's only ever called the first time. + + This decorator can ensure that an expensive or non-idempotent function + will not be expensive on subsequent calls and is idempotent. + + >>> add_three = once(lambda a: a+3) + >>> add_three(3) + 6 + >>> add_three(9) + 6 + >>> add_three('12') + 6 + + To reset the stored value, simply clear the property ``saved_result``. + + >>> del add_three.saved_result + >>> add_three(9) + 12 + >>> add_three(8) + 12 + + Or invoke 'reset()' on it. + + >>> add_three.reset() + >>> add_three(-3) + 0 + >>> add_three(0) + 0 + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + if not hasattr(wrapper, 'saved_result'): + wrapper.saved_result = func(*args, **kwargs) + return wrapper.saved_result + + wrapper.reset = lambda: vars(wrapper).__delitem__('saved_result') + return wrapper + + +def method_cache(method, cache_wrapper=functools.lru_cache()): + """ + Wrap lru_cache to support storing the cache data in the object instances. + + Abstracts the common paradigm where the method explicitly saves an + underscore-prefixed protected property on first call and returns that + subsequently. + + >>> class MyClass: + ... calls = 0 + ... + ... @method_cache + ... def method(self, value): + ... self.calls += 1 + ... return value + + >>> a = MyClass() + >>> a.method(3) + 3 + >>> for x in range(75): + ... res = a.method(x) + >>> a.calls + 75 + + Note that the apparent behavior will be exactly like that of lru_cache + except that the cache is stored on each instance, so values in one + instance will not flush values from another, and when an instance is + deleted, so are the cached values for that instance. + + >>> b = MyClass() + >>> for x in range(35): + ... res = b.method(x) + >>> b.calls + 35 + >>> a.method(0) + 0 + >>> a.calls + 75 + + Note that if method had been decorated with ``functools.lru_cache()``, + a.calls would have been 76 (due to the cached value of 0 having been + flushed by the 'b' instance). + + Clear the cache with ``.cache_clear()`` + + >>> a.method.cache_clear() + + Same for a method that hasn't yet been called. + + >>> c = MyClass() + >>> c.method.cache_clear() + + Another cache wrapper may be supplied: + + >>> cache = functools.lru_cache(maxsize=2) + >>> MyClass.method2 = method_cache(lambda self: 3, cache_wrapper=cache) + >>> a = MyClass() + >>> a.method2() + 3 + + Caution - do not subsequently wrap the method with another decorator, such + as ``@property``, which changes the semantics of the function. + + See also + http://code.activestate.com/recipes/577452-a-memoize-decorator-for-instance-methods/ + for another implementation and additional justification. + """ + + def wrapper(self, *args, **kwargs): + # it's the first call, replace the method with a cached, bound method + bound_method = types.MethodType(method, self) + cached_method = cache_wrapper(bound_method) + setattr(self, method.__name__, cached_method) + return cached_method(*args, **kwargs) + + # Support cache clear even before cache has been created. + wrapper.cache_clear = lambda: None + + return _special_method_cache(method, cache_wrapper) or wrapper + + +def _special_method_cache(method, cache_wrapper): + """ + Because Python treats special methods differently, it's not + possible to use instance attributes to implement the cached + methods. + + Instead, install the wrapper method under a different name + and return a simple proxy to that wrapper. + + https://github.com/jaraco/jaraco.functools/issues/5 + """ + name = method.__name__ + special_names = '__getattr__', '__getitem__' + + if name not in special_names: + return None + + wrapper_name = '__cached' + name + + def proxy(self, /, *args, **kwargs): + if wrapper_name not in vars(self): + bound = types.MethodType(method, self) + cache = cache_wrapper(bound) + setattr(self, wrapper_name, cache) + else: + cache = getattr(self, wrapper_name) + return cache(*args, **kwargs) + + return proxy + + +def apply(transform): + """ + Decorate a function with a transform function that is + invoked on results returned from the decorated function. + + >>> @apply(reversed) + ... def get_numbers(start): + ... "doc for get_numbers" + ... return range(start, start+3) + >>> list(get_numbers(4)) + [6, 5, 4] + >>> get_numbers.__doc__ + 'doc for get_numbers' + """ + + def wrap(func): + return functools.wraps(func)(compose(transform, func)) + + return wrap + + +def result_invoke(action): + r""" + Decorate a function with an action function that is + invoked on the results returned from the decorated + function (for its side effect), then return the original + result. + + >>> @result_invoke(print) + ... def add_two(a, b): + ... return a + b + >>> x = add_two(2, 3) + 5 + >>> x + 5 + """ + + def wrap(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + result = func(*args, **kwargs) + action(result) + return result + + return wrapper + + return wrap + + +def invoke(f, /, *args, **kwargs): + """ + Call a function for its side effect after initialization. + + The benefit of using the decorator instead of simply invoking a function + after defining it is that it makes explicit the author's intent for the + function to be called immediately. Whereas if one simply calls the + function immediately, it's less obvious if that was intentional or + incidental. It also avoids repeating the name - the two actions, defining + the function and calling it immediately are modeled separately, but linked + by the decorator construct. + + The benefit of having a function construct (opposed to just invoking some + behavior inline) is to serve as a scope in which the behavior occurs. It + avoids polluting the global namespace with local variables, provides an + anchor on which to attach documentation (docstring), keeps the behavior + logically separated (instead of conceptually separated or not separated at + all), and provides potential to re-use the behavior for testing or other + purposes. + + This function is named as a pithy way to communicate, "call this function + primarily for its side effect", or "while defining this function, also + take it aside and call it". It exists because there's no Python construct + for "define and call" (nor should there be, as decorators serve this need + just fine). The behavior happens immediately and synchronously. + + >>> @invoke + ... def func(): print("called") + called + >>> func() + called + + Use functools.partial to pass parameters to the initial call + + >>> @functools.partial(invoke, name='bingo') + ... def func(name): print('called with', name) + called with bingo + """ + f(*args, **kwargs) + return f + + +_T = TypeVar('_T') + + +def passthrough(func: Callable[..., object]) -> Callable[[_T], _T]: + """ + Wrap the function to always return the first parameter. + + >>> passthrough(print)('3') + 3 + '3' + """ + + @functools.wraps(func) + def wrapper(first: _T, *args, **kwargs) -> _T: + func(first, *args, **kwargs) + return first + + return wrapper + + +class Throttler: + """Rate-limit a function (or other callable).""" + + def __init__(self, func, max_rate=float('Inf')): + if isinstance(func, Throttler): + func = func.func + self.func = func + self.max_rate = max_rate + self.reset() + + def reset(self): + self.last_called = 0 + + def __call__(self, *args, **kwargs): + self._wait() + return self.func(*args, **kwargs) + + def _wait(self): + """Ensure at least 1/max_rate seconds from last call.""" + elapsed = time.time() - self.last_called + must_wait = 1 / self.max_rate - elapsed + time.sleep(max(0, must_wait)) + self.last_called = time.time() + + def __get__(self, obj, owner=None): + return first_invoke(self._wait, functools.partial(self.func, obj)) + + +def first_invoke(func1, func2): + """ + Return a function that when invoked will invoke func1 without + any parameters (for its side effect) and then invoke func2 + with whatever parameters were passed, returning its result. + """ + + def wrapper(*args, **kwargs): + func1() + return func2(*args, **kwargs) + + return wrapper + + +method_caller = first_invoke( + lambda: warnings.warn( + '`jaraco.functools.method_caller` is deprecated, ' + 'use `operator.methodcaller` instead', + DeprecationWarning, + stacklevel=3, + ), + operator.methodcaller, +) + + +def retry_call(func, cleanup=lambda: None, retries=0, trap=()): + """ + Given a callable func, trap the indicated exceptions + for up to 'retries' times, invoking cleanup on the + exception. On the final attempt, allow any exceptions + to propagate. + """ + attempts = itertools.count() if retries == float('inf') else range(retries) + for _ in attempts: + try: + return func() + except trap: + cleanup() + + return func() + + +def retry(*r_args, **r_kwargs): + """ + Decorator wrapper for retry_call. Accepts arguments to retry_call + except func and then returns a decorator for the decorated function. + + Ex: + + >>> @retry(retries=3) + ... def my_func(a, b): + ... "this is my funk" + ... print(a, b) + >>> my_func.__doc__ + 'this is my funk' + """ + + def decorate(func): + @functools.wraps(func) + def wrapper(*f_args, **f_kwargs): + bound = functools.partial(func, *f_args, **f_kwargs) + return retry_call(bound, *r_args, **r_kwargs) + + return wrapper + + return decorate + + +def print_yielded(func): + """ + Convert a generator into a function that prints all yielded elements. + + >>> @print_yielded + ... def x(): + ... yield 3; yield None + >>> x() + 3 + None + """ + print_all = functools.partial(map, print) + print_results = compose(more_itertools.consume, print_all, func) + return functools.wraps(func)(print_results) + + +def pass_none(func): + """ + Wrap func so it's not called if its first param is None. + + >>> print_text = pass_none(print) + >>> print_text('text') + text + >>> print_text(None) + """ + + @functools.wraps(func) + def wrapper(param, /, *args, **kwargs): + if param is not None: + return func(param, *args, **kwargs) + return None + + return wrapper + + +def none_as(value, replacement=None): + """ + >>> none_as(None, 'foo') + 'foo' + >>> none_as('bar', 'foo') + 'bar' + """ + return replacement if value is None else value + + +def assign_params(func, namespace): + """ + Assign parameters from namespace where func solicits. + + >>> def func(x, y=3): + ... print(x, y) + >>> assigned = assign_params(func, dict(x=2, z=4)) + >>> assigned() + 2 3 + + The usual errors are raised if a function doesn't receive + its required parameters: + + >>> assigned = assign_params(func, dict(y=3, z=4)) + >>> assigned() + Traceback (most recent call last): + TypeError: func() ...argument... + + It even works on methods: + + >>> class Handler: + ... def meth(self, arg): + ... print(arg) + >>> assign_params(Handler().meth, dict(arg='crystal', foo='clear'))() + crystal + """ + sig = inspect.signature(func) + params = sig.parameters.keys() + call_ns = {k: namespace[k] for k in params if k in namespace} + return functools.partial(func, **call_ns) + + +def save_method_args(method): + """ + Wrap a method such that when it is called, the args and kwargs are + saved on the method. + + >>> class MyClass: + ... @save_method_args + ... def method(self, a, b): + ... print(a, b) + >>> my_ob = MyClass() + >>> my_ob.method(1, 2) + 1 2 + >>> my_ob._saved_method.args + (1, 2) + >>> my_ob._saved_method.kwargs + {} + >>> my_ob.method(a=3, b='foo') + 3 foo + >>> my_ob._saved_method.args + () + >>> my_ob._saved_method.kwargs == dict(a=3, b='foo') + True + + The arguments are stored on the instance, allowing for + different instance to save different args. + + >>> your_ob = MyClass() + >>> your_ob.method({str('x'): 3}, b=[4]) + {'x': 3} [4] + >>> your_ob._saved_method.args + ({'x': 3},) + >>> my_ob._saved_method.args + () + """ + args_and_kwargs = collections.namedtuple('args_and_kwargs', 'args kwargs') # noqa: PYI024 # Internal; stubs used for typing + + @functools.wraps(method) + def wrapper(self, /, *args, **kwargs): + attr_name = '_saved_' + method.__name__ + attr = args_and_kwargs(args, kwargs) + setattr(self, attr_name, attr) + return method(self, *args, **kwargs) + + return wrapper + + +def except_(*exceptions, replace=None, use=None): + """ + Replace the indicated exceptions, if raised, with the indicated + literal replacement or evaluated expression (if present). + + >>> safe_int = except_(ValueError)(int) + >>> safe_int('five') + >>> safe_int('5') + 5 + + Specify a literal replacement with ``replace``. + + >>> safe_int_r = except_(ValueError, replace=0)(int) + >>> safe_int_r('five') + 0 + + Provide an expression to ``use`` to pass through particular parameters. + + >>> safe_int_pt = except_(ValueError, use='args[0]')(int) + >>> safe_int_pt('five') + 'five' + + """ + + def decorate(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except exceptions: + try: + return eval(use) + except TypeError: + return replace + + return wrapper + + return decorate + + +def identity(x): + """ + Return the argument. + + >>> o = object() + >>> identity(o) is o + True + """ + return x + + +def bypass_when(check, *, _op=identity): + """ + Decorate a function to return its parameter when ``check``. + + >>> bypassed = [] # False + + >>> @bypass_when(bypassed) + ... def double(x): + ... return x * 2 + >>> double(2) + 4 + >>> bypassed[:] = [object()] # True + >>> double(2) + 2 + """ + + def decorate(func): + @functools.wraps(func) + def wrapper(param, /): + return param if _op(check) else func(param) + + return wrapper + + return decorate + + +def bypass_unless(check): + """ + Decorate a function to return its parameter unless ``check``. + + >>> enabled = [object()] # True + + >>> @bypass_unless(enabled) + ... def double(x): + ... return x * 2 + >>> double(2) + 4 + >>> del enabled[:] # False + >>> double(2) + 2 + """ + return bypass_when(check, _op=operator.not_) + + +@functools.singledispatch +def _splat_inner(args, func): + """Splat args to func.""" + return func(*args) + + +@_splat_inner.register +def _(args: collections.abc.Mapping, func): + """Splat kargs to func as kwargs.""" + return func(**args) + + +def splat(func): + """ + Wrap func to expect its parameters to be passed positionally in a tuple. + + Has a similar effect to that of ``itertools.starmap`` over + simple ``map``. + + >>> pairs = [(-1, 1), (0, 2)] + >>> more_itertools.consume(itertools.starmap(print, pairs)) + -1 1 + 0 2 + >>> more_itertools.consume(map(splat(print), pairs)) + -1 1 + 0 2 + + The approach generalizes to other iterators that don't have a "star" + equivalent, such as a "starfilter". + + >>> list(filter(splat(operator.add), pairs)) + [(0, 2)] + + Splat also accepts a mapping argument. + + >>> def is_nice(msg, code): + ... return "smile" in msg or code == 0 + >>> msgs = [ + ... dict(msg='smile!', code=20), + ... dict(msg='error :(', code=1), + ... dict(msg='unknown', code=0), + ... ] + >>> for msg in filter(splat(is_nice), msgs): + ... print(msg) + {'msg': 'smile!', 'code': 20} + {'msg': 'unknown', 'code': 0} + """ + return functools.wraps(func)(functools.partial(_splat_inner, func=func)) + + +_T = TypeVar('_T') + + +def chainable(method: Callable[[_T, ...], None]) -> Callable[[_T, ...], _T]: + """ + Wrap an instance method to always return self. + + + >>> class Dingus: + ... @chainable + ... def set_attr(self, name, val): + ... setattr(self, name, val) + >>> d = Dingus().set_attr('a', 'eh!') + >>> d.a + 'eh!' + >>> d2 = Dingus().set_attr('a', 'eh!').set_attr('b', 'bee!') + >>> d2.a + d2.b + 'eh!bee!' + + Enforces that the return value is null. + + >>> class BorkedDingus: + ... @chainable + ... def set_attr(self, name, val): + ... setattr(self, name, val) + ... return len(name) + >>> BorkedDingus().set_attr('a', 'eh!') + Traceback (most recent call last): + ... + AssertionError + """ + + @functools.wraps(method) + def wrapper(self, *args, **kwargs): + assert method(self, *args, **kwargs) is None + return self + + return wrapper + + +def noop(*args, **kwargs): + """ + A no-operation function that does nothing. + + >>> noop(1, 2, three=3) + """ diff --git a/lib/jaraco/functools/__init__.pyi b/lib/jaraco/functools/__init__.pyi new file mode 100644 index 0000000..6f834bf --- /dev/null +++ b/lib/jaraco/functools/__init__.pyi @@ -0,0 +1,123 @@ +from collections.abc import Callable, Hashable, Iterator +from functools import partial +from operator import methodcaller +from typing import ( + Any, + Generic, + Protocol, + TypeVar, + overload, +) + +from typing_extensions import Concatenate, ParamSpec, TypeVarTuple, Unpack + +_P = ParamSpec('_P') +_R = TypeVar('_R') +_T = TypeVar('_T') +_Ts = TypeVarTuple('_Ts') +_R1 = TypeVar('_R1') +_R2 = TypeVar('_R2') +_V = TypeVar('_V') +_S = TypeVar('_S') +_R_co = TypeVar('_R_co', covariant=True) + +class _OnceCallable(Protocol[_P, _R]): + saved_result: _R + reset: Callable[[], None] + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: ... + +class _ProxyMethodCacheWrapper(Protocol[_R_co]): + cache_clear: Callable[[], None] + def __call__(self, *args: Hashable, **kwargs: Hashable) -> _R_co: ... + +class _MethodCacheWrapper(Protocol[_R_co]): + def cache_clear(self) -> None: ... + def __call__(self, *args: Hashable, **kwargs: Hashable) -> _R_co: ... + +# `compose()` overloads below will cover most use cases. + +@overload +def compose( + __func1: Callable[[_R], _T], + __func2: Callable[_P, _R], + /, +) -> Callable[_P, _T]: ... +@overload +def compose( + __func1: Callable[[_R], _T], + __func2: Callable[[_R1], _R], + __func3: Callable[_P, _R1], + /, +) -> Callable[_P, _T]: ... +@overload +def compose( + __func1: Callable[[_R], _T], + __func2: Callable[[_R2], _R], + __func3: Callable[[_R1], _R2], + __func4: Callable[_P, _R1], + /, +) -> Callable[_P, _T]: ... +def once(func: Callable[_P, _R]) -> _OnceCallable[_P, _R]: ... +def method_cache( + method: Callable[..., _R], + cache_wrapper: Callable[[Callable[..., _R]], _MethodCacheWrapper[_R]] = ..., +) -> _MethodCacheWrapper[_R] | _ProxyMethodCacheWrapper[_R]: ... +def apply( + transform: Callable[[_R], _T], +) -> Callable[[Callable[_P, _R]], Callable[_P, _T]]: ... +def result_invoke( + action: Callable[[_R], Any], +) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]: ... +def invoke( + f: Callable[_P, _R], /, *args: _P.args, **kwargs: _P.kwargs +) -> Callable[_P, _R]: ... + +class Throttler(Generic[_R]): + last_called: float + func: Callable[..., _R] + max_rate: float + def __init__( + self, func: Callable[..., _R] | Throttler[_R], max_rate: float = ... + ) -> None: ... + def reset(self) -> None: ... + def __call__(self, *args: Any, **kwargs: Any) -> _R: ... + def __get__(self, obj: Any, owner: type[Any] | None = ...) -> Callable[..., _R]: ... + +def first_invoke( + func1: Callable[..., Any], func2: Callable[_P, _R] +) -> Callable[_P, _R]: ... + +method_caller: Callable[..., methodcaller] + +def retry_call( + func: Callable[..., _R], + cleanup: Callable[..., None] = ..., + retries: float = ..., + trap: type[BaseException] | tuple[type[BaseException], ...] = ..., +) -> _R: ... +def retry( + cleanup: Callable[..., None] = ..., + retries: float = ..., + trap: type[BaseException] | tuple[type[BaseException], ...] = ..., +) -> Callable[[Callable[..., _R]], Callable[..., _R]]: ... +def print_yielded(func: Callable[_P, Iterator[Any]]) -> Callable[_P, None]: ... +def pass_none( + func: Callable[Concatenate[_T, _P], _R], +) -> Callable[Concatenate[_T, _P], _R]: ... +def assign_params( + func: Callable[..., _R], namespace: dict[str, Any] +) -> partial[_R]: ... +def save_method_args( + method: Callable[Concatenate[_S, _P], _R], +) -> Callable[Concatenate[_S, _P], _R]: ... +def except_( + *exceptions: type[BaseException], replace: Any = ..., use: Any = ... +) -> Callable[[Callable[_P, Any]], Callable[_P, Any]]: ... +def identity(x: _T) -> _T: ... +def bypass_when( + check: _V, *, _op: Callable[[_V], Any] = ... +) -> Callable[[Callable[[_T], _R]], Callable[[_T], _T | _R]]: ... +def bypass_unless( + check: Any, +) -> Callable[[Callable[[_T], _R]], Callable[[_T], _T | _R]]: ... +def splat(func: Callable[[Unpack[_Ts]], _R]) -> Callable[[tuple[Unpack[_Ts]]], _R]: ... diff --git a/lib/jaraco/functools/__pycache__/__init__.cpython-314.pyc b/lib/jaraco/functools/__pycache__/__init__.cpython-314.pyc new file mode 100644 index 0000000..c8a9443 Binary files /dev/null and b/lib/jaraco/functools/__pycache__/__init__.cpython-314.pyc differ diff --git a/lib/jaraco/functools/py.typed b/lib/jaraco/functools/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/lib/jaraco_context-6.1.0.dist-info/INSTALLER b/lib/jaraco_context-6.1.0.dist-info/INSTALLER new file mode 100644 index 0000000..a1b589e --- /dev/null +++ b/lib/jaraco_context-6.1.0.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/lib/jaraco_context-6.1.0.dist-info/METADATA b/lib/jaraco_context-6.1.0.dist-info/METADATA new file mode 100644 index 0000000..8fb5e53 --- /dev/null +++ b/lib/jaraco_context-6.1.0.dist-info/METADATA @@ -0,0 +1,82 @@ +Metadata-Version: 2.4 +Name: jaraco.context +Version: 6.1.0 +Summary: Useful decorators and context managers +Author-email: "Jason R. Coombs" +License-Expression: MIT +Project-URL: Source, https://github.com/jaraco/jaraco.context +Classifier: Development Status :: 5 - Production/Stable +Classifier: Intended Audience :: Developers +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3 :: Only +Requires-Python: >=3.9 +Description-Content-Type: text/x-rst +License-File: LICENSE +Requires-Dist: backports.tarfile; python_version < "3.12" +Provides-Extra: test +Requires-Dist: pytest!=8.1.*,>=6; extra == "test" +Requires-Dist: jaraco.test>=5.6.0; extra == "test" +Requires-Dist: portend; extra == "test" +Provides-Extra: doc +Requires-Dist: sphinx>=3.5; extra == "doc" +Requires-Dist: jaraco.packaging>=9.3; extra == "doc" +Requires-Dist: rst.linker>=1.9; extra == "doc" +Requires-Dist: furo; extra == "doc" +Requires-Dist: sphinx-lint; extra == "doc" +Requires-Dist: jaraco.tidelift>=1.4; extra == "doc" +Provides-Extra: check +Requires-Dist: pytest-checkdocs>=2.4; extra == "check" +Requires-Dist: pytest-ruff>=0.2.1; sys_platform != "cygwin" and extra == "check" +Provides-Extra: cover +Requires-Dist: pytest-cov; extra == "cover" +Provides-Extra: enabler +Requires-Dist: pytest-enabler>=3.4; extra == "enabler" +Provides-Extra: type +Requires-Dist: pytest-mypy>=1.0.1; extra == "type" +Requires-Dist: mypy<1.19; platform_python_implementation == "PyPy" and extra == "type" +Dynamic: license-file + +.. image:: https://img.shields.io/pypi/v/jaraco.context.svg + :target: https://pypi.org/project/jaraco.context + +.. image:: https://img.shields.io/pypi/pyversions/jaraco.context.svg + +.. image:: https://github.com/jaraco/jaraco.context/actions/workflows/main.yml/badge.svg + :target: https://github.com/jaraco/jaraco.context/actions?query=workflow%3A%22tests%22 + :alt: tests + +.. image:: https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json + :target: https://github.com/astral-sh/ruff + :alt: Ruff + +.. image:: https://readthedocs.org/projects/jaracocontext/badge/?version=latest + :target: https://jaracocontext.readthedocs.io/en/latest/?badge=latest + +.. image:: https://img.shields.io/badge/skeleton-2025-informational + :target: https://blog.jaraco.com/skeleton + +.. image:: https://tidelift.com/badges/package/pypi/jaraco.context + :target: https://tidelift.com/subscription/pkg/pypi-jaraco.context?utm_source=pypi-jaraco.context&utm_medium=readme + + +Highlights +========== + +See the docs linked from the badge above for the full details, but here are some features that may be of interest. + +- ``ExceptionTrap`` provides a general-purpose wrapper for trapping exceptions and then acting on the outcome. Includes ``passes`` and ``raises`` decorators to replace the result of a wrapped function by a boolean indicating the outcome of the exception trap. See `this keyring commit `_ for an example of it in production. +- ``suppress`` simply enables ``contextlib.suppress`` as a decorator. +- ``on_interrupt`` is a decorator used by CLI entry points to affect the handling of a ``KeyboardInterrupt``. Inspired by `Lucretiel/autocommand#18 `_. +- ``pushd`` is similar to pytest's ``monkeypatch.chdir`` or path's `default context `_, changes the current working directory for the duration of the context. +- ``tarball`` will download a tarball, extract it, change directory, yield, then clean up after. Convenient when working with web assets. +- ``null`` is there for those times when one code branch needs a context and the other doesn't; this null context provides symmetry across those branches. + + +For Enterprise +============== + +Available as part of the Tidelift Subscription. + +This project and the maintainers of thousands of other packages are working with Tidelift to deliver one enterprise subscription that covers all of the open source you use. + +`Learn more `_. diff --git a/lib/jaraco_context-6.1.0.dist-info/RECORD b/lib/jaraco_context-6.1.0.dist-info/RECORD new file mode 100644 index 0000000..c82da66 --- /dev/null +++ b/lib/jaraco_context-6.1.0.dist-info/RECORD @@ -0,0 +1,9 @@ +jaraco/context/__init__.py,sha256=br1ydYGo1Xr_Pu1anuEdd-QrjUiz_EY5L_5E4C03L4w,9809 +jaraco/context/__pycache__/__init__.cpython-314.pyc,, +jaraco/context/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +jaraco_context-6.1.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +jaraco_context-6.1.0.dist-info/METADATA,sha256=BDXr_FIFXFqZdO0gwXG2RUOD6vnbsVCIFLp62XxZ1xI,4270 +jaraco_context-6.1.0.dist-info/RECORD,, +jaraco_context-6.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91 +jaraco_context-6.1.0.dist-info/licenses/LICENSE,sha256=l1WhhRlmbl8PTK49qtPXASvK5IpgCzEjfXXp_hNOZoM,1076 +jaraco_context-6.1.0.dist-info/top_level.txt,sha256=0JnN3LfXH4LIRfXL-QFOGCJzQWZO3ELx4R1d_louoQM,7 diff --git a/lib/jaraco_context-6.1.0.dist-info/WHEEL b/lib/jaraco_context-6.1.0.dist-info/WHEEL new file mode 100644 index 0000000..e7fa31b --- /dev/null +++ b/lib/jaraco_context-6.1.0.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: setuptools (80.9.0) +Root-Is-Purelib: true +Tag: py3-none-any + diff --git a/lib/jaraco_context-6.1.0.dist-info/licenses/LICENSE b/lib/jaraco_context-6.1.0.dist-info/licenses/LICENSE new file mode 100644 index 0000000..c891f41 --- /dev/null +++ b/lib/jaraco_context-6.1.0.dist-info/licenses/LICENSE @@ -0,0 +1,18 @@ +MIT License + +Copyright (c) 2026 + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and +associated documentation files (the "Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the +following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial +portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT +LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO +EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE +USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/lib/jaraco_context-6.1.0.dist-info/top_level.txt b/lib/jaraco_context-6.1.0.dist-info/top_level.txt new file mode 100644 index 0000000..f6205a5 --- /dev/null +++ b/lib/jaraco_context-6.1.0.dist-info/top_level.txt @@ -0,0 +1 @@ +jaraco diff --git a/lib/jaraco_functools-4.4.0.dist-info/INSTALLER b/lib/jaraco_functools-4.4.0.dist-info/INSTALLER new file mode 100644 index 0000000..a1b589e --- /dev/null +++ b/lib/jaraco_functools-4.4.0.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/lib/jaraco_functools-4.4.0.dist-info/METADATA b/lib/jaraco_functools-4.4.0.dist-info/METADATA new file mode 100644 index 0000000..f2150dd --- /dev/null +++ b/lib/jaraco_functools-4.4.0.dist-info/METADATA @@ -0,0 +1,69 @@ +Metadata-Version: 2.4 +Name: jaraco.functools +Version: 4.4.0 +Summary: Functools like those found in stdlib +Author-email: "Jason R. Coombs" +License-Expression: MIT +Project-URL: Source, https://github.com/jaraco/jaraco.functools +Classifier: Development Status :: 5 - Production/Stable +Classifier: Intended Audience :: Developers +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3 :: Only +Requires-Python: >=3.9 +Description-Content-Type: text/x-rst +License-File: LICENSE +Requires-Dist: more_itertools +Provides-Extra: test +Requires-Dist: pytest!=8.1.*,>=6; extra == "test" +Requires-Dist: jaraco.classes; extra == "test" +Provides-Extra: doc +Requires-Dist: sphinx>=3.5; extra == "doc" +Requires-Dist: jaraco.packaging>=9.3; extra == "doc" +Requires-Dist: rst.linker>=1.9; extra == "doc" +Requires-Dist: furo; extra == "doc" +Requires-Dist: sphinx-lint; extra == "doc" +Requires-Dist: jaraco.tidelift>=1.4; extra == "doc" +Provides-Extra: check +Requires-Dist: pytest-checkdocs>=2.4; extra == "check" +Requires-Dist: pytest-ruff>=0.2.1; sys_platform != "cygwin" and extra == "check" +Provides-Extra: cover +Requires-Dist: pytest-cov; extra == "cover" +Provides-Extra: enabler +Requires-Dist: pytest-enabler>=3.4; extra == "enabler" +Provides-Extra: type +Requires-Dist: pytest-mypy>=1.0.1; extra == "type" +Requires-Dist: mypy<1.19; platform_python_implementation == "PyPy" and extra == "type" +Dynamic: license-file + +.. image:: https://img.shields.io/pypi/v/jaraco.functools.svg + :target: https://pypi.org/project/jaraco.functools + +.. image:: https://img.shields.io/pypi/pyversions/jaraco.functools.svg + +.. image:: https://github.com/jaraco/jaraco.functools/actions/workflows/main.yml/badge.svg + :target: https://github.com/jaraco/jaraco.functools/actions?query=workflow%3A%22tests%22 + :alt: tests + +.. image:: https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json + :target: https://github.com/astral-sh/ruff + :alt: Ruff + +.. image:: https://readthedocs.org/projects/jaracofunctools/badge/?version=latest + :target: https://jaracofunctools.readthedocs.io/en/latest/?badge=latest + +.. image:: https://img.shields.io/badge/skeleton-2025-informational + :target: https://blog.jaraco.com/skeleton + +.. image:: https://tidelift.com/badges/package/pypi/jaraco.functools + :target: https://tidelift.com/subscription/pkg/pypi-jaraco.functools?utm_source=pypi-jaraco.functools&utm_medium=readme + +Additional functools in the spirit of stdlib's functools. + +For Enterprise +============== + +Available as part of the Tidelift Subscription. + +This project and the maintainers of thousands of other packages are working with Tidelift to deliver one enterprise subscription that covers all of the open source you use. + +`Learn more `_. diff --git a/lib/jaraco_functools-4.4.0.dist-info/RECORD b/lib/jaraco_functools-4.4.0.dist-info/RECORD new file mode 100644 index 0000000..2b53e4d --- /dev/null +++ b/lib/jaraco_functools-4.4.0.dist-info/RECORD @@ -0,0 +1,10 @@ +jaraco/functools/__init__.py,sha256=ZJx9cMs2Nvk2xGUl8OjVGkpjdOaNlSzJrN4dGglgX2g,18599 +jaraco/functools/__init__.pyi,sha256=K4DcbnYIHE5QlMxqf9-cVp-WhycrhuTao4J7O7TMq4Y,3907 +jaraco/functools/__pycache__/__init__.cpython-314.pyc,, +jaraco/functools/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +jaraco_functools-4.4.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +jaraco_functools-4.4.0.dist-info/METADATA,sha256=LnnajcNGmSSr46yLIqP-tWkqeb-fR7vIa2U11hhkGEk,2960 +jaraco_functools-4.4.0.dist-info/RECORD,, +jaraco_functools-4.4.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91 +jaraco_functools-4.4.0.dist-info/licenses/LICENSE,sha256=WlfLTbheKi3YjCkGKJCK3VfjRRRJ4KmnH9-zh3b9dZ0,1076 +jaraco_functools-4.4.0.dist-info/top_level.txt,sha256=0JnN3LfXH4LIRfXL-QFOGCJzQWZO3ELx4R1d_louoQM,7 diff --git a/lib/jaraco_functools-4.4.0.dist-info/WHEEL b/lib/jaraco_functools-4.4.0.dist-info/WHEEL new file mode 100644 index 0000000..e7fa31b --- /dev/null +++ b/lib/jaraco_functools-4.4.0.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: setuptools (80.9.0) +Root-Is-Purelib: true +Tag: py3-none-any + diff --git a/lib/jaraco_functools-4.4.0.dist-info/licenses/LICENSE b/lib/jaraco_functools-4.4.0.dist-info/licenses/LICENSE new file mode 100644 index 0000000..f60bd57 --- /dev/null +++ b/lib/jaraco_functools-4.4.0.dist-info/licenses/LICENSE @@ -0,0 +1,18 @@ +MIT License + +Copyright (c) 2025 + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and +associated documentation files (the "Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the +following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial +portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT +LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO +EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE +USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/lib/jaraco_functools-4.4.0.dist-info/top_level.txt b/lib/jaraco_functools-4.4.0.dist-info/top_level.txt new file mode 100644 index 0000000..f6205a5 --- /dev/null +++ b/lib/jaraco_functools-4.4.0.dist-info/top_level.txt @@ -0,0 +1 @@ +jaraco diff --git a/lib/jeepney-0.9.0.dist-info/INSTALLER b/lib/jeepney-0.9.0.dist-info/INSTALLER new file mode 100644 index 0000000..a1b589e --- /dev/null +++ b/lib/jeepney-0.9.0.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/lib/jeepney-0.9.0.dist-info/METADATA b/lib/jeepney-0.9.0.dist-info/METADATA new file mode 100644 index 0000000..c9a9e6c --- /dev/null +++ b/lib/jeepney-0.9.0.dist-info/METADATA @@ -0,0 +1,35 @@ +Metadata-Version: 2.4 +Name: jeepney +Version: 0.9.0 +Summary: Low-level, pure Python DBus protocol wrapper. +Author-email: Thomas Kluyver +Requires-Python: >=3.7 +Description-Content-Type: text/x-rst +License-Expression: MIT +Classifier: Programming Language :: Python :: 3 +Classifier: Topic :: Desktop Environment +License-File: LICENSE +Requires-Dist: pytest ; extra == "test" +Requires-Dist: pytest-trio ; extra == "test" +Requires-Dist: pytest-asyncio >=0.17 ; extra == "test" +Requires-Dist: testpath ; extra == "test" +Requires-Dist: trio ; extra == "test" +Requires-Dist: async-timeout ; extra == "test" and ( python_version < '3.11') +Requires-Dist: trio ; extra == "trio" +Project-URL: Documentation, https://jeepney.readthedocs.io/en/latest/ +Project-URL: Source, https://gitlab.com/takluyver/jeepney +Provides-Extra: test +Provides-Extra: trio + +Jeepney is a pure Python implementation of D-Bus messaging. It has an `I/O-free +`__ core, and integration modules for different +event loops. + +D-Bus is an inter-process communication system, mainly used in Linux. + +To install Jeepney:: + + pip install jeepney + +`Jeepney docs on Readthedocs `__ + diff --git a/lib/jeepney-0.9.0.dist-info/RECORD b/lib/jeepney-0.9.0.dist-info/RECORD new file mode 100644 index 0000000..74db428 --- /dev/null +++ b/lib/jeepney-0.9.0.dist-info/RECORD @@ -0,0 +1,64 @@ +jeepney-0.9.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +jeepney-0.9.0.dist-info/METADATA,sha256=uObDU-mq7Q7QFEApVWQX_aI7ZPNE5xgzGQcFS6BbjGI,1230 +jeepney-0.9.0.dist-info/RECORD,, +jeepney-0.9.0.dist-info/WHEEL,sha256=_2ozNFCLWc93bK4WKHCO-eDUENDlo-dgc9cU3qokYO4,82 +jeepney-0.9.0.dist-info/licenses/LICENSE,sha256=GyKwSbUmfW38I6Z79KhNjsBLn9-xpR02DkK0NCyLQVQ,1081 +jeepney/__init__.py,sha256=ULhIr444tY81PUkEHRWNlTXlqRbFmqMTbWBsktKgmoI,408 +jeepney/__pycache__/__init__.cpython-314.pyc,, +jeepney/__pycache__/auth.cpython-314.pyc,, +jeepney/__pycache__/bindgen.cpython-314.pyc,, +jeepney/__pycache__/bus.cpython-314.pyc,, +jeepney/__pycache__/bus_messages.cpython-314.pyc,, +jeepney/__pycache__/fds.cpython-314.pyc,, +jeepney/__pycache__/low_level.cpython-314.pyc,, +jeepney/__pycache__/wrappers.cpython-314.pyc,, +jeepney/auth.py,sha256=ZW0HMX6Vfwx28P-jNrzVVgEn1ipjO-KJrNJ2SG90V3U,5409 +jeepney/bindgen.py,sha256=yPDJFt_WjKoFUp08r-_upsqu0L8Rmv8gNKr-MA4T4bI,6085 +jeepney/bus.py,sha256=KUiSr3ECzdbe-S9tNKm6kvf3oZi4RYnJWkZUXK7tE2k,1817 +jeepney/bus_messages.py,sha256=uUCc_1Xllzth4F95aghpDLmlv5Gz0are2FpKg7D_gqc,8239 +jeepney/fds.py,sha256=ZYzN_c_7rkBT0wU7dYUmQRijpSzCv-DATCYEklpXxUU,5056 +jeepney/io/__init__.py,sha256=inJI_1U-ATymLcAVYs-LD2aUwgl-tihW8-oVFUxYRgA,33 +jeepney/io/__pycache__/__init__.cpython-314.pyc,, +jeepney/io/__pycache__/asyncio.cpython-314.pyc,, +jeepney/io/__pycache__/blocking.cpython-314.pyc,, +jeepney/io/__pycache__/common.cpython-314.pyc,, +jeepney/io/__pycache__/threading.cpython-314.pyc,, +jeepney/io/__pycache__/trio.cpython-314.pyc,, +jeepney/io/asyncio.py,sha256=qfWi_1pWCXSP1LNRafHBuvrxHx4tX96b52KBa4sUFMc,7622 +jeepney/io/blocking.py,sha256=I_rw90IY_EesBZmkfUqk7UniyVkQAngz7jyQmzju680,11940 +jeepney/io/common.py,sha256=l8lbFUgQmBxfqSC-hqHYmPUYCVFMKbOGB1k5ZWPKXfs,2696 +jeepney/io/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +jeepney/io/tests/__pycache__/__init__.cpython-314.pyc,, +jeepney/io/tests/__pycache__/conftest.cpython-314.pyc,, +jeepney/io/tests/__pycache__/test_asyncio.cpython-314.pyc,, +jeepney/io/tests/__pycache__/test_blocking.cpython-314.pyc,, +jeepney/io/tests/__pycache__/test_threading.cpython-314.pyc,, +jeepney/io/tests/__pycache__/test_trio.cpython-314.pyc,, +jeepney/io/tests/__pycache__/utils.cpython-314.pyc,, +jeepney/io/tests/conftest.py,sha256=o7JrYypYE-0jNFUndsQ4Ek5dNYM0ofh1sYcIVeCZMj0,2730 +jeepney/io/tests/test_asyncio.py,sha256=JJtnX5HiRRZjjuGIDoI8LvzfbaSNg-ljiX95yUvd9xk,2720 +jeepney/io/tests/test_blocking.py,sha256=ETLnoivenN8Dzp0JB4wPOb9PNbpSuiocuP_IDeNRlI4,2804 +jeepney/io/tests/test_threading.py,sha256=RALwy-aI64TBoFmBnSU63HLcwRnStLVtnewOtoaBl3o,2699 +jeepney/io/tests/test_trio.py,sha256=DPY1V_K2qLTyBTrbrxZeLTA5dmca3Ye3e6pz08UxbO8,3892 +jeepney/io/tests/utils.py,sha256=i7VJYT-axefzS8mWcvv-9DeHEB6LdP9M82H3Hx6fyC4,79 +jeepney/io/threading.py,sha256=mwGCNlun_baX8Y4eienCGDKdZD4SKdTMvBTkIE0EMKo,9391 +jeepney/io/trio.py,sha256=IdZIJnQcPjVOBA9KooFn0nTBEz3BuBDkz56qLYhGR1M,15088 +jeepney/low_level.py,sha256=m4wGY-quPnzylgKlBdBccmkuOXF_hQ1gbtT25qPX2GM,19949 +jeepney/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +jeepney/tests/__pycache__/__init__.cpython-314.pyc,, +jeepney/tests/__pycache__/test_auth.cpython-314.pyc,, +jeepney/tests/__pycache__/test_bindgen.cpython-314.pyc,, +jeepney/tests/__pycache__/test_bus.cpython-314.pyc,, +jeepney/tests/__pycache__/test_bus_messages.cpython-314.pyc,, +jeepney/tests/__pycache__/test_fds.cpython-314.pyc,, +jeepney/tests/__pycache__/test_low_level.cpython-314.pyc,, +jeepney/tests/__pycache__/test_wrappers.cpython-314.pyc,, +jeepney/tests/secrets_introspect.xml,sha256=9cfNs1aGLtIAykcQVsycwIwCLmEeorKkFjqJCLAknRQ,4575 +jeepney/tests/test_auth.py,sha256=Ee79vsedCwveukudAZTwqYTXHWV3PYnXkmMl0MBMZEE,611 +jeepney/tests/test_bindgen.py,sha256=Ez99zr9TIV3mlZdH-2A_dz4LbvxCqzWDIadhOCbbaoc,1098 +jeepney/tests/test_bus.py,sha256=ApOxd3AcYQB14G1XsiFGBYtQ4xSKw52y9YvmPz700gc,847 +jeepney/tests/test_bus_messages.py,sha256=elwS7odY9RDsjg9jL4tN0O7uCxUqSYHsWShWXn_WPOQ,3338 +jeepney/tests/test_fds.py,sha256=-gyvQpfsXtPaIEeqbwhrNPOcIAN0DsrQ7MXZu4nMvvQ,1821 +jeepney/tests/test_low_level.py,sha256=2SC-wKKGr0yfEguswfHzCojSTwsYlTVLPyuzQbGS3L4,3000 +jeepney/tests/test_wrappers.py,sha256=NSY6LblWeU2kToISjpi9YHgrd_Y6PVyFwXqnbY93ygU,2202 +jeepney/wrappers.py,sha256=5zM_v1jFqEGDSaPh0f06SDxCF6JmWVhyXjfYR6KHum4,9605 diff --git a/lib/jeepney-0.9.0.dist-info/WHEEL b/lib/jeepney-0.9.0.dist-info/WHEEL new file mode 100644 index 0000000..23d2d7e --- /dev/null +++ b/lib/jeepney-0.9.0.dist-info/WHEEL @@ -0,0 +1,4 @@ +Wheel-Version: 1.0 +Generator: flit 3.11.0 +Root-Is-Purelib: true +Tag: py3-none-any diff --git a/lib/jeepney-0.9.0.dist-info/licenses/LICENSE b/lib/jeepney-0.9.0.dist-info/licenses/LICENSE new file mode 100644 index 0000000..b0ae9db --- /dev/null +++ b/lib/jeepney-0.9.0.dist-info/licenses/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2017 Thomas Kluyver + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/lib/jeepney/__init__.py b/lib/jeepney/__init__.py new file mode 100644 index 0000000..b314820 --- /dev/null +++ b/lib/jeepney/__init__.py @@ -0,0 +1,13 @@ +"""Low-level, pure Python DBus protocol wrapper. +""" +from .auth import AuthenticationError, FDNegotiationError +from .low_level import ( + Endianness, Header, HeaderFields, Message, MessageFlag, MessageType, + Parser, SizeLimitError, +) +from .bus import find_session_bus, find_system_bus +from .bus_messages import * +from .fds import FileDescriptor, NoFDError +from .wrappers import * + +__version__ = '0.9.0' diff --git a/lib/jeepney/__pycache__/__init__.cpython-314.pyc b/lib/jeepney/__pycache__/__init__.cpython-314.pyc new file mode 100644 index 0000000..1440bcd Binary files /dev/null and b/lib/jeepney/__pycache__/__init__.cpython-314.pyc differ diff --git a/lib/jeepney/__pycache__/auth.cpython-314.pyc b/lib/jeepney/__pycache__/auth.cpython-314.pyc new file mode 100644 index 0000000..f6ee81c Binary files /dev/null and b/lib/jeepney/__pycache__/auth.cpython-314.pyc differ diff --git a/lib/jeepney/__pycache__/bindgen.cpython-314.pyc b/lib/jeepney/__pycache__/bindgen.cpython-314.pyc new file mode 100644 index 0000000..07af7a3 Binary files /dev/null and b/lib/jeepney/__pycache__/bindgen.cpython-314.pyc differ diff --git a/lib/jeepney/__pycache__/bus.cpython-314.pyc b/lib/jeepney/__pycache__/bus.cpython-314.pyc new file mode 100644 index 0000000..7ae5ff8 Binary files /dev/null and b/lib/jeepney/__pycache__/bus.cpython-314.pyc differ diff --git a/lib/jeepney/__pycache__/bus_messages.cpython-314.pyc b/lib/jeepney/__pycache__/bus_messages.cpython-314.pyc new file mode 100644 index 0000000..5a9d27b Binary files /dev/null and b/lib/jeepney/__pycache__/bus_messages.cpython-314.pyc differ diff --git a/lib/jeepney/__pycache__/fds.cpython-314.pyc b/lib/jeepney/__pycache__/fds.cpython-314.pyc new file mode 100644 index 0000000..ac65671 Binary files /dev/null and b/lib/jeepney/__pycache__/fds.cpython-314.pyc differ diff --git a/lib/jeepney/__pycache__/low_level.cpython-314.pyc b/lib/jeepney/__pycache__/low_level.cpython-314.pyc new file mode 100644 index 0000000..385239c Binary files /dev/null and b/lib/jeepney/__pycache__/low_level.cpython-314.pyc differ diff --git a/lib/jeepney/__pycache__/wrappers.cpython-314.pyc b/lib/jeepney/__pycache__/wrappers.cpython-314.pyc new file mode 100644 index 0000000..8661918 Binary files /dev/null and b/lib/jeepney/__pycache__/wrappers.cpython-314.pyc differ diff --git a/lib/jeepney/auth.py b/lib/jeepney/auth.py new file mode 100644 index 0000000..2c4153f --- /dev/null +++ b/lib/jeepney/auth.py @@ -0,0 +1,144 @@ +from binascii import hexlify +from enum import Enum +import os +from typing import Optional + +def make_auth_external() -> bytes: + """Prepare an AUTH command line with the current effective user ID. + + This is the preferred authentication method for typical D-Bus connections + over a Unix domain socket. + """ + hex_uid = hexlify(str(os.geteuid()).encode('ascii')) + return b'AUTH EXTERNAL %b\r\n' % hex_uid + +def make_auth_anonymous() -> bytes: + """Format an AUTH command line for the ANONYMOUS mechanism + + Jeepney's higher-level wrappers don't currently use this mechanism, + but third-party code may choose to. + + See for details. + """ + from . import __version__ + trace = hexlify(('Jeepney %s' % __version__).encode('ascii')) + return b'AUTH ANONYMOUS %s\r\n' % trace + +BEGIN = b'BEGIN\r\n' +NEGOTIATE_UNIX_FD = b'NEGOTIATE_UNIX_FD\r\n' + +class ClientState(Enum): + # States from the D-Bus spec (plus 'Success'). Not all used in Jeepney. + WaitingForData = 1 + WaitingForOk = 2 + WaitingForReject = 3 + WaitingForAgreeUnixFD = 4 + Success = 5 + +class AuthenticationError(ValueError): + """Raised when DBus authentication fails""" + def __init__(self, data, msg="Authentication failed"): + self.msg = msg + self.data = data + + def __str__(self): + return f"{self.msg}. Bus sent: {self.data!r}" + +class FDNegotiationError(AuthenticationError): + """Raised when file descriptor support is requested but not available""" + def __init__(self, data): + super().__init__(data, msg="File descriptor support not available") + + +class Authenticator: + """Process data for the SASL authentication conversation + + If enable_fds is True, this includes negotiating support for passing + file descriptors. If inc_null_byte is True, sends the '\0' byte + at the beginning of the negotiations, which was the past behavior, + but which prevents sending the SCM_CREDS ancillary data over the socket, + breaking authentication on *BSD; the caller should rather send that + null byte and ancillary data and pass inc_null_byte=False to prevent + it being done here. + """ + def __init__(self, enable_fds=False, inc_null_byte=True): + self.enable_fds = enable_fds + self.buffer = bytearray() + if inc_null_byte: + self._to_send = b'\0' + make_auth_external() + else: + self._to_send = make_auth_external() + self.state = ClientState.WaitingForOk + self.error = None + + @property + def authenticated(self): + return self.state is ClientState.Success + + def __iter__(self): + return iter(self.data_to_send, None) + + def data_to_send(self) -> Optional[bytes]: + """Get a line of data to send to the server + + The data returned should be sent before waiting to receive data. + Returns empty bytes if waiting for more data from the server, and None + if authentication is finished (success or error). + + Iterating over the Authenticator object will also yield these lines; + :meth:`feed` should be called with received data inside the loop. + """ + if self.authenticated or self.error: + return None + self._to_send, to_send = b'', self._to_send + return to_send + + def process_line(self, line): + if self.state is ClientState.WaitingForOk: + if line.startswith(b'OK '): + if self.enable_fds: + return NEGOTIATE_UNIX_FD, ClientState.WaitingForAgreeUnixFD + else: + return BEGIN, ClientState.Success + # We only support EXTERNAL authentication, but if we allow others, + # 'REJECTED ' would tell us to try another one. + + elif self.state is ClientState.WaitingForAgreeUnixFD: + if line.startswith(b'AGREE_UNIX_FD'): + return BEGIN, ClientState.Success + # The protocol allows us to continue if FD passing is rejected, + # but Jeepney assumes that if you enable FD support you need it, + # so we fail rather + self.error = line + raise FDNegotiationError(line) + + self.error = line + raise AuthenticationError(line) + + def feed(self, data: bytes): + """Process received data + + Raises AuthenticationError if the incoming data is not as expected for + successful authentication. The connection should then be abandoned. + """ + self.buffer += data + if b'\r\n' in self.buffer: + line, self.buffer = self.buffer.split(b'\r\n', 1) + if self.buffer: + # We only expect one line before we reply + raise AuthenticationError(self.buffer, "Unexpected data received") + + self._to_send, self.state = self.process_line(line) + + # Avoid consuming lots of memory if the server is not sending what we + # expect. There doesn't appear to be a specified maximum line length, + # but 8192 bytes leaves a sizeable margin over all the examples in the + # spec (all < 100 bytes per line). + elif len(self.buffer) > 8192: + raise AuthenticationError( + self.buffer, "Too much data received without line ending" + ) + + +# Old name (behaviour on errors has changed, but should work for standard case) +SASLParser = Authenticator diff --git a/lib/jeepney/bindgen.py b/lib/jeepney/bindgen.py new file mode 100644 index 0000000..4695eb6 --- /dev/null +++ b/lib/jeepney/bindgen.py @@ -0,0 +1,170 @@ +"""Generate a wrapper class from DBus introspection data""" +import argparse +import os.path +import sys +import xml.etree.ElementTree as ET +from textwrap import indent + +from jeepney.wrappers import Introspectable +from jeepney.io.blocking import open_dbus_connection, Proxy +from jeepney import __version__ + +class Method: + def __init__(self, xml_node): + self.name = xml_node.attrib['name'] + self.in_args = [] + self.signature = [] + for arg in xml_node.findall("arg[@direction='in']"): + try: + name = arg.attrib['name'] + except KeyError: + name = 'arg{}'.format(len(self.in_args)) + self.in_args.append(name) + self.signature.append(arg.attrib['type']) + + def _make_code_noargs(self): + return ("def {name}(self):\n" + " return new_method_call(self, '{name}')\n").format( + name=self.name) + + def make_code(self): + if not self.in_args: + return self._make_code_noargs() + + args = ', '.join(self.in_args) + signature = ''.join(self.signature) + tuple = ('({},)' if len(self.in_args) == 1 else '({})').format(args) + return ("def {name}(self, {args}):\n" + " return new_method_call(self, '{name}', '{signature}',\n" + " {tuple})\n").format( + name=self.name, args=args, signature=signature, tuple=tuple + ) + +INTERFACE_CLASS_TEMPLATE = """ +class {cls_name}(MessageGenerator): + interface = {interface!r} + + def __init__(self, object_path{path_default}, + bus_name{name_default}): + super().__init__(object_path=object_path, bus_name=bus_name) +""" + +class Interface: + def __init__(self, xml_node, path, bus_name): + self.name = xml_node.attrib['name'] + self.path = path + self.bus_name = bus_name + self.methods = [Method(node) for node in xml_node.findall('method')] + + def make_code(self): + cls_name = self.name.split('.')[-1] + chunks = [INTERFACE_CLASS_TEMPLATE.format( + cls_name=cls_name, + interface=self.name, + path_default='' if self.path is None else f'={self.path!r}', + name_default='' if self.bus_name is None else f'={self.bus_name!r}' + )] + for method in self.methods: + chunks.append(indent(method.make_code(), ' ' * 4)) + return '\n'.join(chunks) + +MODULE_TEMPLATE = '''\ +"""Auto-generated DBus bindings + +Generated by jeepney version {version} + +Object path: {path} +Bus name : {bus_name} +""" + +from jeepney.wrappers import MessageGenerator, new_method_call + +''' + +# Jeepney already includes bindings for these common interfaces +IGNORE_INTERFACES = { + 'org.freedesktop.DBus.Introspectable', + 'org.freedesktop.DBus.Properties', + 'org.freedesktop.DBus.Peer', +} + +def code_from_xml(xml, path, bus_name, fh): + if isinstance(fh, (bytes, str)): + with open(fh, 'w') as f: + return code_from_xml(xml, path, bus_name, f) + + root = ET.fromstring(xml) + fh.write(MODULE_TEMPLATE.format(version=__version__, path=path, + bus_name=bus_name)) + + i = 0 + for interface_node in root.findall('interface'): + if interface_node.attrib['name'] in IGNORE_INTERFACES: + continue + fh.write(Interface(interface_node, path, bus_name).make_code()) + i += 1 + + return i + +def generate_from_introspection(path, name, output_file, bus='SESSION'): + # Many D-Bus services have a main object at a predictable name, e.g. + # org.freedesktop.Notifications -> /org/freedesktop/Notifications + if not path: + path = '/' + name.replace('.', '/') + + conn = open_dbus_connection(bus) + introspectable = Proxy(Introspectable(path, name), conn) + xml, = introspectable.Introspect() + # print(xml) + + n_interfaces = code_from_xml(xml, path, name, output_file) + print("Written {} interface wrappers to {}".format(n_interfaces, output_file)) + +def generate_from_file(input_file, path, name, output_file): + with open(input_file, encoding='utf-8') as f: + xml = f.read() + + n_interfaces = code_from_xml(xml, path, name, output_file) + print("Written {} interface wrappers to {}".format(n_interfaces, output_file)) + +def main(): + ap = argparse.ArgumentParser( + description="Generate a simple wrapper module to call D-Bus methods.", + epilog="If you don't use --file, this will connect to D-Bus and introspect the " + "given name and path. --name and --path can also be used with --file, " + "to give defaults for the generated class." + ) + ap.add_argument('-n', '--name', + help='Bus name to introspect, required unless using file') + ap.add_argument('-p', '--path', + help='Object path to introspect. If not specified, a path matching ' + 'the name will be used, e.g. /org/freedesktop/Notifications for org.freedesktop.Notifications') + ap.add_argument('--bus', default='SESSION', + help='Bus to connect to for introspection (SESSION/SYSTEM), default SESSION') + ap.add_argument('-f', '--file', + help='XML file to use instead of D-Bus introspection') + ap.add_argument('-o', '--output', + help='Output filename') + args = ap.parse_args() + + if not (args.file or args.name): + sys.exit("Either --name or --file is required") + + # If no --output, guess a (hopefully) reasonable name. + if args.output: + output = args.output + elif args.file: + output = os.path.splitext(os.path.basename(args.file))[0] + '.py' + elif args.path and len(args.path) > 1: + output = args.path[1:].replace('/', '_') + '.py' + else: # e.g. path is '/' + output = args.name.replace('.', '_') + '.py' + + if args.file: + generate_from_file(args.file, args.path, args.name, output) + else: + generate_from_introspection(args.path, args.name, output, args.bus) + + +if __name__ == '__main__': + main() diff --git a/lib/jeepney/bus.py b/lib/jeepney/bus.py new file mode 100644 index 0000000..dfc10ee --- /dev/null +++ b/lib/jeepney/bus.py @@ -0,0 +1,62 @@ +import os +import re + +_escape_pat = re.compile(r'%([0-9A-Fa-f]{2})') +def unescape(v): + def repl(match): + n = int(match.group(1), base=16) + return chr(n) + return _escape_pat.sub(repl, v) + +def parse_addresses(s): + for addr in s.split(';'): + transport, info = addr.split(':', 1) + kv = {} + for x in info.split(','): + k, v = x.split('=', 1) + kv[k] = unescape(v) + yield (transport, kv) + +SUPPORTED_TRANSPORTS = ('unix',) + +def get_connectable_addresses(addr): + unsupported_transports = set() + found = False + for transport, kv in parse_addresses(addr): + if transport not in SUPPORTED_TRANSPORTS: + unsupported_transports.add(transport) + + elif transport == 'unix': + if 'abstract' in kv: + yield '\0' + kv['abstract'] + found = True + elif 'path' in kv: + yield kv['path'] + found = True + + if not found: + raise RuntimeError("DBus transports ({}) not supported. Supported: {}" + .format(unsupported_transports, SUPPORTED_TRANSPORTS)) + +def find_session_bus(): + addr = os.environ['DBUS_SESSION_BUS_ADDRESS'] + return next(get_connectable_addresses(addr)) + # TODO: fallbacks to X, filesystem + +def find_system_bus(): + addr = os.environ.get('DBUS_SYSTEM_BUS_ADDRESS', '') \ + or 'unix:path=/var/run/dbus/system_bus_socket' + return next(get_connectable_addresses(addr)) + +def get_bus(addr): + if addr == 'SESSION': + return find_session_bus() + elif addr == 'SYSTEM': + return find_system_bus() + else: + return next(get_connectable_addresses(addr)) + + +if __name__ == '__main__': + print('System bus at:', find_system_bus()) + print('Session bus at:', find_session_bus()) diff --git a/lib/jeepney/bus_messages.py b/lib/jeepney/bus_messages.py new file mode 100644 index 0000000..67fdf7a --- /dev/null +++ b/lib/jeepney/bus_messages.py @@ -0,0 +1,238 @@ +"""Messages for talking to the DBus daemon itself + +Generated by jeepney.bindgen and modified by hand. +""" +from .low_level import Message, MessageType, HeaderFields +from .wrappers import MessageGenerator, new_method_call + +__all__ = [ + 'DBusNameFlags', + 'DBus', + 'message_bus', + 'Monitoring', + 'Stats', + 'MatchRule', +] + +class DBusNameFlags: + allow_replacement = 1 + replace_existing = 2 + do_not_queue = 4 + +class DBus(MessageGenerator): + """Messages to talk to the message bus + """ + interface = 'org.freedesktop.DBus' + + def __init__(self, object_path='/org/freedesktop/DBus', + bus_name='org.freedesktop.DBus'): + super().__init__(object_path=object_path, bus_name=bus_name) + + def Hello(self): + return new_method_call(self, 'Hello') + + def RequestName(self, name, flags=0): + return new_method_call(self, 'RequestName', 'su', (name, flags)) + + def ReleaseName(self, name): + return new_method_call(self, 'ReleaseName', 's', (name,)) + + def StartServiceByName(self, name): + return new_method_call(self, 'StartServiceByName', 'su', + (name, 0)) + + def UpdateActivationEnvironment(self, env): + return new_method_call(self, 'UpdateActivationEnvironment', 'a{ss}', + (env,)) + + def NameHasOwner(self, name): + return new_method_call(self, 'NameHasOwner', 's', (name,)) + + def ListNames(self): + return new_method_call(self, 'ListNames') + + def ListActivatableNames(self): + return new_method_call(self, 'ListActivatableNames') + + def AddMatch(self, rule): + """*rule* can be a str or a :class:`MatchRule` instance""" + if isinstance(rule, MatchRule): + rule = rule.serialise() + return new_method_call(self, 'AddMatch', 's', (rule,)) + + def RemoveMatch(self, rule): + if isinstance(rule, MatchRule): + rule = rule.serialise() + return new_method_call(self, 'RemoveMatch', 's', (rule,)) + + def GetNameOwner(self, name): + return new_method_call(self, 'GetNameOwner', 's', (name,)) + + def ListQueuedOwners(self, name): + return new_method_call(self, 'ListQueuedOwners', 's', (name,)) + + def GetConnectionUnixUser(self, name): + return new_method_call(self, 'GetConnectionUnixUser', 's', (name,)) + + def GetConnectionUnixProcessID(self, name): + return new_method_call(self, 'GetConnectionUnixProcessID', 's', (name,)) + + def GetAdtAuditSessionData(self, name): + return new_method_call(self, 'GetAdtAuditSessionData', 's', (name,)) + + def GetConnectionSELinuxSecurityContext(self, name): + return new_method_call(self, 'GetConnectionSELinuxSecurityContext', 's', + (name,)) + + def ReloadConfig(self): + return new_method_call(self, 'ReloadConfig') + + def GetId(self): + return new_method_call(self, 'GetId') + + def GetConnectionCredentials(self, name): + return new_method_call(self, 'GetConnectionCredentials', 's', (name,)) + +message_bus = DBus() + +class Monitoring(MessageGenerator): + interface = 'org.freedesktop.DBus.Monitoring' + + def __init__(self, object_path='/org/freedesktop/DBus', + bus_name='org.freedesktop.DBus'): + super().__init__(object_path=object_path, bus_name=bus_name) + + def BecomeMonitor(self, rules): + """Convert this connection to a monitor connection (advanced)""" + return new_method_call(self, 'BecomeMonitor', 'asu', (rules, 0)) + +class Stats(MessageGenerator): + interface = 'org.freedesktop.DBus.Debug.Stats' + + def __init__(self, object_path='/org/freedesktop/DBus', + bus_name='org.freedesktop.DBus'): + super().__init__(object_path=object_path, bus_name=bus_name) + + def GetStats(self): + return new_method_call(self, 'GetStats') + + def GetConnectionStats(self, arg0): + return new_method_call(self, 'GetConnectionStats', 's', + (arg0,)) + + def GetAllMatchRules(self): + return new_method_call(self, 'GetAllMatchRules') + + +class MatchRule: + """Construct a match rule to subscribe to DBus messages. + + e.g.:: + + mr = MatchRule( + interface='org.freedesktop.DBus', + member='NameOwnerChanged', + type='signal' + ) + msg = message_bus.AddMatch(mr) + # Send this message to subscribe to the signal + """ + def __init__(self, *, type=None, sender=None, interface=None, member=None, + path=None, path_namespace=None, destination=None, + eavesdrop=False): + if isinstance(type, str): + type = MessageType[type] + self.message_type = type + fields = { + 'sender': sender, + 'interface': interface, + 'member': member, + 'path': path, + 'destination': destination, + } + self.header_fields = { + k: v for (k, v) in fields.items() if (v is not None) + } + self.path_namespace = path_namespace + self.eavesdrop = eavesdrop + self.arg_conditions = {} + + def add_arg_condition(self, argno: int, value: str, kind='string'): + """Add a condition for a particular argument + + argno: int, 0-63 + kind: 'string', 'path', 'namespace' + """ + if kind not in {'string', 'path', 'namespace'}: + raise ValueError("kind={!r}".format(kind)) + if kind == 'namespace' and argno != 0: + raise ValueError("argno must be 0 for kind='namespace'") + self.arg_conditions[argno] = (value, kind) + + def serialise(self) -> str: + """Convert to a string to use in an AddMatch call to the message bus""" + pairs = list(self.header_fields.items()) + + if self.message_type: + pairs.append(('type', self.message_type.name)) + + if self.path_namespace: + pairs.append(('path_namespace', self.path_namespace)) + + if self.eavesdrop: + pairs.append(('eavesdrop', 'true')) + + for argno, (val, kind) in self.arg_conditions.items(): + if kind == 'string': + kind = '' + pairs.append((f'arg{argno}{kind}', val)) + + # Quoting rules: single quotes ('') needed if the value contains a comma. + # A literal ' can only be represented outside single quotes, by + # backslash-escaping it. No escaping inside the quotes. + # The simplest way to handle this is to use '' around every value, and + # use '\'' (end quote, escaped ', restart quote) for literal ' . + return ','.join( + "{}='{}'".format(k, v.replace("'", r"'\''")) for (k, v) in pairs + ) + + def matches(self, msg: Message) -> bool: + """Returns True if msg matches this rule""" + h = msg.header + if (self.message_type is not None) and h.message_type != self.message_type: + return False + + for field, expected in self.header_fields.items(): + if h.fields.get(HeaderFields[field], None) != expected: + return False + + if self.path_namespace is not None: + path = h.fields.get(HeaderFields.path, '\0') + path_ns = self.path_namespace.rstrip('/') + if not ((path == path_ns) or path.startswith(path_ns + '/')): + return False + + for argno, (expected, kind) in self.arg_conditions.items(): + if argno >= len(msg.body): + return False + arg = msg.body[argno] + if not isinstance(arg, str): + return False + if kind == 'string': + if arg != expected: + return False + elif kind == 'path': + if not ( + (arg == expected) + or (expected.endswith('/') and arg.startswith(expected)) + or (arg.endswith('/') and expected.startswith(arg)) + ): + return False + elif kind == 'namespace': + if not ( + (arg == expected) + or arg.startswith(expected + '.') + ): + return False + + return True diff --git a/lib/jeepney/fds.py b/lib/jeepney/fds.py new file mode 100644 index 0000000..233c3aa --- /dev/null +++ b/lib/jeepney/fds.py @@ -0,0 +1,158 @@ +import array +import os +import socket +from warnings import warn + + +class NoFDError(RuntimeError): + """Raised by :class:`FileDescriptor` methods if it was already closed/converted + """ + pass + + +class FileDescriptor: + """A file descriptor received in a D-Bus message + + This wrapper helps ensure that the file descriptor is closed exactly once. + If you don't explicitly convert or close the FileDescriptor object, it will + close its file descriptor when it goes out of scope, and emit a + ResourceWarning. + """ + __slots__ = ('_fd',) + _CLOSED = -1 + _CONVERTED = -2 + + def __init__(self, fd): + self._fd = fd + + def __repr__(self): + detail = self._fd + if self._fd == self._CLOSED: + detail = 'closed' + elif self._fd == self._CONVERTED: + detail = 'converted' + return f"" + + def close(self): + """Close the file descriptor + + This can safely be called multiple times, but will raise RuntimeError + if called after converting it with one of the ``to_*`` methods. + + This object can also be used in a ``with`` block, to close it on + leaving the block. + """ + if self._fd == self._CLOSED: + pass + elif self._fd == self._CONVERTED: + raise NoFDError("Can't close FileDescriptor after converting it") + else: + self._fd, fd = self._CLOSED, self._fd + os.close(fd) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def __del__(self): + if self._fd >= 0: + warn( + f'FileDescriptor ({self._fd}) was neither closed nor converted', + ResourceWarning, stacklevel=2, source=self + ) + self.close() + + def _check(self): + if self._fd < 0: + detail = 'closed' if self._fd == self._CLOSED else 'converted' + raise NoFDError(f'FileDescriptor object was already {detail}') + + def fileno(self): + """Get the integer file descriptor + + This does not change the state of the :class:`FileDescriptor` object, + unlike the ``to_*`` methods. + """ + self._check() + return self._fd + + def to_raw_fd(self): + """Convert to the low-level integer file descriptor:: + + raw_fd = fd.to_raw_fd() + os.write(raw_fd, b'xyz') + os.close(raw_fd) + + The :class:`FileDescriptor` can't be used after calling this. The caller + is responsible for closing the file descriptor. + """ + self._check() + self._fd, fd = self._CONVERTED, self._fd + return fd + + def to_file(self, mode, buffering=-1, encoding=None, errors=None, newline=None): + """Convert to a Python file object:: + + with fd.to_file('w') as f: + f.write('xyz') + + The arguments are the same as for the builtin :func:`open` function. + + The :class:`FileDescriptor` can't be used after calling this. Closing + the file object will also close the file descriptor. + """ + self._check() + f = open( + self._fd, mode, buffering=buffering, + encoding=encoding, errors=errors, newline=newline + ) + self._fd = self._CONVERTED + return f + + def to_socket(self): + """Convert to a socket object + + This returns a standard library :func:`socket.socket` object:: + + with fd.to_socket() as sock: + b = sock.sendall(b'xyz') + + The wrapper object can't be used after calling this. Closing the socket + object will also close the file descriptor. + """ + from socket import socket + + self._check() + s = socket(fileno=self._fd) + self._fd = self._CONVERTED + return s + + @classmethod + def from_ancdata(cls, ancdata) -> ['FileDescriptor']: + """Make a list of FileDescriptor from received file descriptors + + ancdata is a list of ancillary data tuples as returned by socket.recvmsg() + """ + fds = array.array("i") # Array of ints + for cmsg_level, cmsg_type, data in ancdata: + if cmsg_level == socket.SOL_SOCKET and cmsg_type == socket.SCM_RIGHTS: + # Append data, ignoring any truncated integers at the end. + fds.frombytes(data[:len(data) - (len(data) % fds.itemsize)]) + return [cls(i) for i in fds] + + +_fds_buf_size_cache = None + +def fds_buf_size(): + # If there may be file descriptors, we try to read 1 message at a time. + # The reference implementation of D-Bus defaults to allowing 16 FDs per + # message, and the Linux kernel currently allows 253 FDs per sendmsg() + # call. So hopefully allowing 256 FDs per recvmsg() will always suffice. + global _fds_buf_size_cache + if _fds_buf_size_cache is None: + maxfds = 256 + fd_size = array.array('i').itemsize + _fds_buf_size_cache = socket.CMSG_SPACE(maxfds * fd_size) + return _fds_buf_size_cache diff --git a/lib/jeepney/io/__init__.py b/lib/jeepney/io/__init__.py new file mode 100644 index 0000000..d346b6c --- /dev/null +++ b/lib/jeepney/io/__init__.py @@ -0,0 +1 @@ +from .common import RouterClosed diff --git a/lib/jeepney/io/__pycache__/__init__.cpython-314.pyc b/lib/jeepney/io/__pycache__/__init__.cpython-314.pyc new file mode 100644 index 0000000..84b2312 Binary files /dev/null and b/lib/jeepney/io/__pycache__/__init__.cpython-314.pyc differ diff --git a/lib/jeepney/io/__pycache__/asyncio.cpython-314.pyc b/lib/jeepney/io/__pycache__/asyncio.cpython-314.pyc new file mode 100644 index 0000000..a5c7a09 Binary files /dev/null and b/lib/jeepney/io/__pycache__/asyncio.cpython-314.pyc differ diff --git a/lib/jeepney/io/__pycache__/blocking.cpython-314.pyc b/lib/jeepney/io/__pycache__/blocking.cpython-314.pyc new file mode 100644 index 0000000..2b3de04 Binary files /dev/null and b/lib/jeepney/io/__pycache__/blocking.cpython-314.pyc differ diff --git a/lib/jeepney/io/__pycache__/common.cpython-314.pyc b/lib/jeepney/io/__pycache__/common.cpython-314.pyc new file mode 100644 index 0000000..be2e623 Binary files /dev/null and b/lib/jeepney/io/__pycache__/common.cpython-314.pyc differ diff --git a/lib/jeepney/io/__pycache__/threading.cpython-314.pyc b/lib/jeepney/io/__pycache__/threading.cpython-314.pyc new file mode 100644 index 0000000..7b70a13 Binary files /dev/null and b/lib/jeepney/io/__pycache__/threading.cpython-314.pyc differ diff --git a/lib/jeepney/io/__pycache__/trio.cpython-314.pyc b/lib/jeepney/io/__pycache__/trio.cpython-314.pyc new file mode 100644 index 0000000..d1a2538 Binary files /dev/null and b/lib/jeepney/io/__pycache__/trio.cpython-314.pyc differ diff --git a/lib/jeepney/io/asyncio.py b/lib/jeepney/io/asyncio.py new file mode 100644 index 0000000..2c6ade6 --- /dev/null +++ b/lib/jeepney/io/asyncio.py @@ -0,0 +1,233 @@ +import asyncio +import contextlib +from itertools import count +from typing import Optional + +from jeepney.auth import Authenticator, BEGIN +from jeepney.bus import get_bus +from jeepney import Message, MessageType, Parser +from jeepney.wrappers import ProxyBase, unwrap_msg +from jeepney.bus_messages import message_bus +from .common import ( + MessageFilters, FilterHandle, ReplyMatcher, RouterClosed, check_replyable, +) + + +class DBusConnection: + """A plain D-Bus connection with no matching of replies. + + This doesn't run any separate tasks: sending and receiving are done in + the task that calls those methods. It's suitable for implementing servers: + several worker tasks can receive requests and send replies. + For a typical client pattern, see :class:`DBusRouter`. + """ + def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter): + self.reader = reader + self.writer = writer + self.parser = Parser() + self.outgoing_serial = count(start=1) + self.unique_name = None + self.send_lock = asyncio.Lock() + + async def send(self, message: Message, *, serial=None): + """Serialise and send a :class:`~.Message` object""" + async with self.send_lock: + if serial is None: + serial = next(self.outgoing_serial) + self.writer.write(message.serialise(serial)) + await self.writer.drain() + + async def receive(self) -> Message: + """Return the next available message from the connection""" + while True: + msg = self.parser.get_next_message() + if msg is not None: + return msg + + b = await self.reader.read(4096) + if not b: + raise EOFError + self.parser.add_data(b) + + async def close(self): + """Close the D-Bus connection""" + self.writer.close() + await self.writer.wait_closed() + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + + +async def open_dbus_connection(bus='SESSION'): + """Open a plain D-Bus connection + + :return: :class:`DBusConnection` + """ + bus_addr = get_bus(bus) + reader, writer = await asyncio.open_unix_connection(bus_addr) + + # Authentication flow + authr = Authenticator() + for req_data in authr: + writer.write(req_data) + await writer.drain() + b = await reader.read(1024) + if not b: + raise EOFError("Socket closed before authentication") + authr.feed(b) + + writer.write(BEGIN) + await writer.drain() + # Authentication finished + + conn = DBusConnection(reader, writer) + + # Say *Hello* to the message bus - this must be the first message, and the + # reply gives us our unique name. + async with DBusRouter(conn) as router: + reply_body = await asyncio.wait_for(Proxy(message_bus, router).Hello(), 10) + conn.unique_name = reply_body[0] + + return conn + +class DBusRouter: + """A 'client' D-Bus connection which can wait for a specific reply. + + This runs a background receiver task, and makes it possible to send a + request and wait for the relevant reply. + """ + _nursery_mgr = None + _send_cancel_scope = None + _rcv_cancel_scope = None + + def __init__(self, conn: DBusConnection): + self._conn = conn + self._replies = ReplyMatcher() + self._filters = MessageFilters() + self._rcv_task = asyncio.create_task(self._receiver()) + + @property + def unique_name(self): + return self._conn.unique_name + + async def send(self, message, *, serial=None): + """Send a message, don't wait for a reply""" + await self._conn.send(message, serial=serial) + + async def send_and_get_reply(self, message) -> Message: + """Send a method call message and wait for the reply + + Returns the reply message (method return or error message type). + """ + check_replyable(message) + if self._rcv_task.done(): + raise RouterClosed("This DBusRouter has stopped") + + serial = next(self._conn.outgoing_serial) + + with self._replies.catch(serial, asyncio.Future()) as reply_fut: + await self.send(message, serial=serial) + return (await reply_fut) + + def filter(self, rule, *, queue: Optional[asyncio.Queue] =None, bufsize=1): + """Create a filter for incoming messages + + Usage:: + + with router.filter(rule) as queue: + matching_msg = await queue.get() + + :param MatchRule rule: Catch messages matching this rule + :param asyncio.Queue queue: Send matching messages here + :param int bufsize: If no queue is passed in, create one with this size + """ + return FilterHandle(self._filters, rule, queue or asyncio.Queue(bufsize)) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if self._rcv_task.done(): + self._rcv_task.result() # Throw exception if receive task failed + else: + self._rcv_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._rcv_task + return False + + # Code to run in receiver task ------------------------------------ + + def _dispatch(self, msg: Message): + """Handle one received message""" + if self._replies.dispatch(msg): + return + + for filter in list(self._filters.matches(msg)): + try: + filter.queue.put_nowait(msg) + except asyncio.QueueFull: + pass + + async def _receiver(self): + """Receiver loop - runs in a separate task""" + try: + while True: + msg = await self._conn.receive() + self._dispatch(msg) + finally: + # Send errors to any tasks still waiting for a message. + self._replies.drop_all() + +class open_dbus_router: + """Open a D-Bus 'router' to send and receive messages + + Use as an async context manager:: + + async with open_dbus_router() as router: + ... + """ + conn = None + req_ctx = None + + def __init__(self, bus='SESSION'): + self.bus = bus + + async def __aenter__(self): + self.conn = await open_dbus_connection(self.bus) + self.req_ctx = DBusRouter(self.conn) + return await self.req_ctx.__aenter__() + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.req_ctx.__aexit__(exc_type, exc_val, exc_tb) + await self.conn.close() + + +class Proxy(ProxyBase): + """An asyncio proxy for calling D-Bus methods + + You can call methods on the proxy object, such as ``await bus_proxy.Hello()`` + to make a method call over D-Bus and wait for a reply. It will either + return a tuple of returned data, or raise :exc:`.DBusErrorResponse`. + The methods available are defined by the message generator you wrap. + + :param msggen: A message generator object. + :param ~asyncio.DBusRouter router: Router to send and receive messages. + """ + def __init__(self, msggen, router): + super().__init__(msggen) + self._router = router + + def __repr__(self): + return 'Proxy({}, {})'.format(self._msggen, self._router) + + def _method_call(self, make_msg): + async def inner(*args, **kwargs): + msg = make_msg(*args, **kwargs) + assert msg.header.message_type is MessageType.method_call + reply = await self._router.send_and_get_reply(msg) + return unwrap_msg(reply) + + return inner diff --git a/lib/jeepney/io/blocking.py b/lib/jeepney/io/blocking.py new file mode 100644 index 0000000..d2d9b54 --- /dev/null +++ b/lib/jeepney/io/blocking.py @@ -0,0 +1,337 @@ +"""Synchronous IO wrappers around jeepney +""" +import array +from collections import deque +from errno import ECONNRESET +import functools +from itertools import count +import os +from selectors import DefaultSelector, EVENT_READ +import socket +import time +from typing import Optional + +from jeepney import Parser, Message, MessageType, HeaderFields +from jeepney.auth import Authenticator, BEGIN +from jeepney.bus import get_bus +from jeepney.fds import FileDescriptor, fds_buf_size +from jeepney.wrappers import ProxyBase, unwrap_msg +from jeepney.bus_messages import message_bus +from .common import MessageFilters, FilterHandle, check_replyable + +__all__ = [ + 'open_dbus_connection', + 'DBusConnection', + 'Proxy', +] + + +class _Future: + def __init__(self): + self._result = None + + def done(self): + return bool(self._result) + + def set_exception(self, exception): + self._result = (False, exception) + + def set_result(self, result): + self._result = (True, result) + + def result(self): + success, value = self._result + if success: + return value + raise value + + +def timeout_to_deadline(timeout): + if timeout is not None: + return time.monotonic() + timeout + return None + +def deadline_to_timeout(deadline): + if deadline is not None: + return max(deadline - time.monotonic(), 0.) + return None + + +class DBusConnectionBase: + """Connection machinery shared by this module and threading""" + def __init__(self, sock: socket.socket, enable_fds=False): + self.sock = sock + self.enable_fds = enable_fds + self.parser = Parser() + self.outgoing_serial = count(start=1) + self.selector = DefaultSelector() + self.select_key = self.selector.register(sock, EVENT_READ) + self.unique_name = None + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + return False + + def _serialise(self, message: Message, serial) -> (bytes, Optional[array.array]): + if serial is None: + serial = next(self.outgoing_serial) + fds = array.array('i') if self.enable_fds else None + data = message.serialise(serial=serial, fds=fds) + return data, fds + + def _send_with_fds(self, data, fds): + bytes_sent = self.sock.sendmsg( + [data], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fds)] + ) + # If sendmsg succeeds, I think ancillary data has been sent atomically? + # So now we just need to send any leftover normal data. + if bytes_sent < len(data): + self.sock.sendall(data[bytes_sent:]) + + def _receive(self, deadline): + while True: + msg = self.parser.get_next_message() + if msg is not None: + return msg + + b, fds = self._read_some_data(timeout=deadline_to_timeout(deadline)) + self.parser.add_data(b, fds=fds) + + def _read_some_data(self, timeout=None): + for key, ev in self.selector.select(timeout): + if key == self.select_key: + if self.enable_fds: + return self._read_with_fds() + else: + return unwrap_read(self.sock.recv(4096)), [] + + raise TimeoutError + + def _read_with_fds(self): + nbytes = self.parser.bytes_desired() + data, ancdata, flags, _ = self.sock.recvmsg(nbytes, fds_buf_size()) + if flags & getattr(socket, 'MSG_CTRUNC', 0): + self.close() + raise RuntimeError("Unable to receive all file descriptors") + return unwrap_read(data), FileDescriptor.from_ancdata(ancdata) + + def close(self): + """Close the connection""" + self.selector.close() + self.sock.close() + + +class DBusConnection(DBusConnectionBase): + def __init__(self, sock: socket.socket, enable_fds=False): + super().__init__(sock, enable_fds) + + # Message routing machinery + self._filters = MessageFilters() + + # Say Hello, get our unique name + self.bus_proxy = Proxy(message_bus, self) + hello_reply = self.bus_proxy.Hello() + self.unique_name = hello_reply[0] + + def send(self, message: Message, serial=None): + """Serialise and send a :class:`~.Message` object""" + data, fds = self._serialise(message, serial) + if fds: + self._send_with_fds(data, fds) + else: + self.sock.sendall(data) + + send_message = send # Backwards compatibility + + def receive(self, *, timeout=None) -> Message: + """Return the next available message from the connection + + If the data is ready, this will return immediately, even if timeout<=0. + Otherwise, it will wait for up to timeout seconds, or indefinitely if + timeout is None. If no message comes in time, it raises TimeoutError. + """ + return self._receive(timeout_to_deadline(timeout)) + + def recv_messages(self, *, timeout=None): + """Receive one message and apply filters + + See :meth:`filter`. Returns nothing. + """ + msg = self.receive(timeout=timeout) + for filter in self._filters.matches(msg): + filter.queue.append(msg) + + def send_and_get_reply(self, message, *, timeout=None): + """Send a message, wait for the reply and return it + + Filters are applied to other messages received before the reply - + see :meth:`add_filter`. + """ + check_replyable(message) + deadline = timeout_to_deadline(timeout) + + serial = next(self.outgoing_serial) + self.send_message(message, serial=serial) + while True: + msg_in = self.receive(timeout=deadline_to_timeout(deadline)) + reply_to = msg_in.header.fields.get(HeaderFields.reply_serial, -1) + if reply_to == serial: + return msg_in + + # Not the reply + for filter in self._filters.matches(msg_in): + filter.queue.append(msg_in) + + def filter(self, rule, *, queue: Optional[deque] =None, bufsize=1): + """Create a filter for incoming messages + + Usage:: + + with conn.filter(rule) as matches: + # matches is a deque containing matched messages + matching_msg = conn.recv_until_filtered(matches) + + :param jeepney.MatchRule rule: Catch messages matching this rule + :param collections.deque queue: Matched messages will be added to this + :param int bufsize: If no deque is passed in, create one with this size + """ + if queue is None: + queue = deque(maxlen=bufsize) + return FilterHandle(self._filters, rule, queue) + + def recv_until_filtered(self, queue, *, timeout=None) -> Message: + """Process incoming messages until one is filtered into queue + + Pops the message from queue and returns it, or raises TimeoutError if + the optional timeout expires. Without a timeout, this is equivalent to:: + + while len(queue) == 0: + conn.recv_messages() + return queue.popleft() + + In the other I/O modules, there is no need for this, because messages + are placed in queues by a separate task. + + :param collections.deque queue: A deque connected by :meth:`filter` + :param float timeout: Maximum time to wait in seconds + """ + deadline = timeout_to_deadline(timeout) + while len(queue) == 0: + self.recv_messages(timeout=deadline_to_timeout(deadline)) + return queue.popleft() + + +class Proxy(ProxyBase): + """A blocking proxy for calling D-Bus methods + + You can call methods on the proxy object, such as ``bus_proxy.Hello()`` + to make a method call over D-Bus and wait for a reply. It will either + return a tuple of returned data, or raise :exc:`.DBusErrorResponse`. + The methods available are defined by the message generator you wrap. + + You can set a time limit on a call by passing ``_timeout=`` in the method + call, or set a default when creating the proxy. The ``_timeout`` argument + is not passed to the message generator. + All timeouts are in seconds, and :exc:`TimeoutErrror` is raised if it + expires before a reply arrives. + + :param msggen: A message generator object + :param ~blocking.DBusConnection connection: Connection to send and receive messages + :param float timeout: Default seconds to wait for a reply, or None for no limit + """ + def __init__(self, msggen, connection, *, timeout=None): + super().__init__(msggen) + self._connection = connection + self._timeout = timeout + + def __repr__(self): + extra = '' if (self._timeout is None) else f', timeout={self._timeout}' + return f"Proxy({self._msggen}, {self._connection}{extra})" + + def _method_call(self, make_msg): + @functools.wraps(make_msg) + def inner(*args, **kwargs): + timeout = kwargs.pop('_timeout', self._timeout) + msg = make_msg(*args, **kwargs) + assert msg.header.message_type is MessageType.method_call + return unwrap_msg(self._connection.send_and_get_reply( + msg, timeout=timeout + )) + + return inner + + +def unwrap_read(b): + """Raise ConnectionResetError from an empty read. + + Sometimes the socket raises an error itself, sometimes it gives no data. + I haven't worked out when it behaves each way. + """ + if not b: + raise ConnectionResetError(ECONNRESET, os.strerror(ECONNRESET)) + return b + + +def prep_socket(addr, enable_fds=False, timeout=2.0) -> socket.socket: + """Create a socket and authenticate ready to send D-Bus messages""" + sock = socket.socket(family=socket.AF_UNIX) + + # To impose the overall auth timeout, we'll update the timeout on the socket + # before each send/receive. This is ugly, but we can't use the socket for + # anything else until this has succeeded, so this should be safe. + deadline = timeout_to_deadline(timeout) + def with_sock_deadline(meth, *args): + sock.settimeout(deadline_to_timeout(deadline)) + return meth(*args) + + try: + with_sock_deadline(sock.connect, addr) + authr = Authenticator(enable_fds=enable_fds, inc_null_byte=False) + if hasattr(socket, 'SCM_CREDS'): + # BSD: send credentials message to authenticate (kernel fills in data) + sock.sendmsg([b'\0'], [(socket.SOL_SOCKET, socket.SCM_CREDS, bytes(512))]) + else: + # Linux: no ancillary data needed, bus checks with SO_PEERCRED + sock.send(b'\0') + for req_data in authr: + with_sock_deadline(sock.sendall, req_data) + authr.feed(unwrap_read(with_sock_deadline(sock.recv, 1024))) + with_sock_deadline(sock.sendall, BEGIN) + except socket.timeout as e: + sock.close() + raise TimeoutError(f"Did not authenticate in {timeout} seconds") from e + except: + sock.close() + raise + + sock.settimeout(None) # Put the socket back in blocking mode + return sock + + +def open_dbus_connection( + bus='SESSION', enable_fds=False, auth_timeout=1., +) -> DBusConnection: + """Connect to a D-Bus message bus + + Pass ``enable_fds=True`` to allow sending & receiving file descriptors. + An error will be raised if the bus does not allow this. For simplicity, + it's advisable to leave this disabled unless you need it. + + D-Bus has an authentication step before sending or receiving messages. + This takes < 1 ms in normal operation, but there is a timeout so that client + code won't get stuck if the server doesn't reply. *auth_timeout* configures + this timeout in seconds. + """ + bus_addr = get_bus(bus) + sock = prep_socket(bus_addr, enable_fds, timeout=auth_timeout) + + conn = DBusConnection(sock, enable_fds) + return conn + + +if __name__ == '__main__': + conn = open_dbus_connection() + print("Unique name:", conn.unique_name) diff --git a/lib/jeepney/io/common.py b/lib/jeepney/io/common.py new file mode 100644 index 0000000..f74d460 --- /dev/null +++ b/lib/jeepney/io/common.py @@ -0,0 +1,88 @@ +from contextlib import contextmanager +from itertools import count + +from jeepney import HeaderFields, Message, MessageFlag, MessageType + +class MessageFilters: + def __init__(self): + self.filters = {} + self.filter_ids = count() + + def matches(self, message): + for handle in self.filters.values(): + if handle.rule.matches(message): + yield handle + + +class FilterHandle: + def __init__(self, filters: MessageFilters, rule, queue): + self._filters = filters + self._filter_id = next(filters.filter_ids) + self.rule = rule + self.queue = queue + + self._filters.filters[self._filter_id] = self + + def close(self): + del self._filters.filters[self._filter_id] + + def __enter__(self): + return self.queue + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + return False + + +class ReplyMatcher: + def __init__(self): + self._futures = {} + + @contextmanager + def catch(self, serial, future): + """Context manager to capture a reply for the given serial number""" + self._futures[serial] = future + + try: + yield future + finally: + del self._futures[serial] + + def dispatch(self, msg): + """Dispatch an incoming message which may be a reply + + Returns True if a task was waiting for it, otherwise False. + """ + rep_serial = msg.header.fields.get(HeaderFields.reply_serial, -1) + if rep_serial in self._futures: + self._futures[rep_serial].set_result(msg) + return True + else: + return False + + def drop_all(self, exc: Exception = None): + """Throw an error in any task still waiting for a reply""" + if exc is None: + exc = RouterClosed("D-Bus router closed before reply arrived") + futures, self._futures = self._futures, {} + for fut in futures.values(): + fut.set_exception(exc) + + +class RouterClosed(Exception): + """Raised in tasks waiting for a reply when the router is closed + + This will also be raised if the receiver task crashes, so tasks are not + stuck waiting for a reply that can never come. The router object will not + be usable after this is raised. + """ + pass + + +def check_replyable(msg: Message): + """Raise an error if we wouldn't expect a reply for msg""" + if msg.header.message_type != MessageType.method_call: + raise TypeError("Only method call messages have replies " + f"(not {msg.header.message_type})") + if MessageFlag.no_reply_expected & msg.header.flags: + raise ValueError("This message has the no_reply_expected flag set") diff --git a/lib/jeepney/io/tests/__init__.py b/lib/jeepney/io/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lib/jeepney/io/tests/__pycache__/__init__.cpython-314.pyc b/lib/jeepney/io/tests/__pycache__/__init__.cpython-314.pyc new file mode 100644 index 0000000..0b0c690 Binary files /dev/null and b/lib/jeepney/io/tests/__pycache__/__init__.cpython-314.pyc differ diff --git a/lib/jeepney/io/tests/__pycache__/conftest.cpython-314.pyc b/lib/jeepney/io/tests/__pycache__/conftest.cpython-314.pyc new file mode 100644 index 0000000..e9c3d80 Binary files /dev/null and b/lib/jeepney/io/tests/__pycache__/conftest.cpython-314.pyc differ diff --git a/lib/jeepney/io/tests/__pycache__/test_asyncio.cpython-314.pyc b/lib/jeepney/io/tests/__pycache__/test_asyncio.cpython-314.pyc new file mode 100644 index 0000000..88062b7 Binary files /dev/null and b/lib/jeepney/io/tests/__pycache__/test_asyncio.cpython-314.pyc differ diff --git a/lib/jeepney/io/tests/__pycache__/test_blocking.cpython-314.pyc b/lib/jeepney/io/tests/__pycache__/test_blocking.cpython-314.pyc new file mode 100644 index 0000000..0495445 Binary files /dev/null and b/lib/jeepney/io/tests/__pycache__/test_blocking.cpython-314.pyc differ diff --git a/lib/jeepney/io/tests/__pycache__/test_threading.cpython-314.pyc b/lib/jeepney/io/tests/__pycache__/test_threading.cpython-314.pyc new file mode 100644 index 0000000..2ba2ab9 Binary files /dev/null and b/lib/jeepney/io/tests/__pycache__/test_threading.cpython-314.pyc differ diff --git a/lib/jeepney/io/tests/__pycache__/test_trio.cpython-314.pyc b/lib/jeepney/io/tests/__pycache__/test_trio.cpython-314.pyc new file mode 100644 index 0000000..6d492af Binary files /dev/null and b/lib/jeepney/io/tests/__pycache__/test_trio.cpython-314.pyc differ diff --git a/lib/jeepney/io/tests/__pycache__/utils.cpython-314.pyc b/lib/jeepney/io/tests/__pycache__/utils.cpython-314.pyc new file mode 100644 index 0000000..493eec8 Binary files /dev/null and b/lib/jeepney/io/tests/__pycache__/utils.cpython-314.pyc differ diff --git a/lib/jeepney/io/tests/conftest.py b/lib/jeepney/io/tests/conftest.py new file mode 100644 index 0000000..1087467 --- /dev/null +++ b/lib/jeepney/io/tests/conftest.py @@ -0,0 +1,81 @@ +from tempfile import TemporaryFile +import threading + +import pytest + +from jeepney import ( + DBusAddress, HeaderFields, message_bus, MessageType, new_error, + new_method_return, +) +from jeepney.io.threading import open_dbus_connection, DBusRouter, Proxy + +@pytest.fixture() +def respond_with_fd(): + name = "io.gitlab.takluyver.jeepney.tests.respond_with_fd" + addr = DBusAddress(bus_name=name, object_path='/') + + with open_dbus_connection(bus='SESSION', enable_fds=True) as conn: + with DBusRouter(conn) as router: + status, = Proxy(message_bus, router).RequestName(name) + assert status == 1 # DBUS_REQUEST_NAME_REPLY_PRIMARY_OWNER + + def _reply_once(): + while True: + msg = conn.receive() + if msg.header.message_type is MessageType.method_call: + if msg.header.fields[HeaderFields.member] == 'GetFD': + with TemporaryFile('w+') as tf: + tf.write('readme') + tf.seek(0) + rep = new_method_return(msg, 'h', (tf,)) + conn.send(rep) + return + else: + conn.send(new_error(msg, 'NoMethod')) + + reply_thread = threading.Thread(target=_reply_once, daemon=True) + reply_thread.start() + yield addr + + reply_thread.join() + + +@pytest.fixture() +def read_from_fd(): + name = "io.gitlab.takluyver.jeepney.tests.read_from_fd" + addr = DBusAddress(bus_name=name, object_path='/') + + with open_dbus_connection(bus='SESSION', enable_fds=True) as conn: + with DBusRouter(conn) as router: + status, = Proxy(message_bus, router).RequestName(name) + assert status == 1 # DBUS_REQUEST_NAME_REPLY_PRIMARY_OWNER + + def _reply_once(): + while True: + msg = conn.receive() + if msg.header.message_type is MessageType.method_call: + if msg.header.fields[HeaderFields.member] == 'ReadFD': + with msg.body[0].to_file('rb') as f: + f.seek(0) + b = f.read() + conn.send(new_method_return(msg, 'ay', (b,))) + return + else: + conn.send(new_error(msg, 'NoMethod')) + + reply_thread = threading.Thread(target=_reply_once, daemon=True) + reply_thread.start() + yield addr + + reply_thread.join() + + +@pytest.fixture() +def temp_file_and_contents(): + data = b'abc123' + with TemporaryFile('w+b') as tf: + tf.write(data) + tf.flush() + tf.seek(0) + yield tf, data + diff --git a/lib/jeepney/io/tests/test_asyncio.py b/lib/jeepney/io/tests/test_asyncio.py new file mode 100644 index 0000000..c738105 --- /dev/null +++ b/lib/jeepney/io/tests/test_asyncio.py @@ -0,0 +1,95 @@ +import asyncio +import sys + +if sys.version_info >= (3, 11): + from asyncio import timeout +else: + from async_timeout import timeout +import pytest +import pytest_asyncio + +from jeepney import DBusAddress, new_method_call +from jeepney.bus_messages import message_bus, MatchRule +from jeepney.io.asyncio import ( + open_dbus_connection, open_dbus_router, Proxy +) +from .utils import have_session_bus + +pytestmark = [ + pytest.mark.asyncio, + pytest.mark.skipif( + not have_session_bus, reason="Tests require DBus session bus" + ), +] + +bus_peer = DBusAddress( + bus_name='org.freedesktop.DBus', + object_path='/org/freedesktop/DBus', + interface='org.freedesktop.DBus.Peer' +) + + +@pytest_asyncio.fixture() +async def connection(): + async with (await open_dbus_connection(bus='SESSION')) as conn: + yield conn + +async def test_connect(connection): + assert connection.unique_name.startswith(':') + +@pytest_asyncio.fixture() +async def router(): + async with open_dbus_router(bus='SESSION') as router: + yield router + +async def test_send_and_get_reply(router): + ping_call = new_method_call(bus_peer, 'Ping') + reply = await asyncio.wait_for( + router.send_and_get_reply(ping_call), timeout=5 + ) + assert reply.body == () + +async def test_proxy(router): + proxy = Proxy(message_bus, router) + name = "io.gitlab.takluyver.jeepney.examples.Server" + res = await proxy.RequestName(name) + assert res in {(1,), (2,)} # 1: got the name, 2: queued + + has_owner, = await proxy.NameHasOwner(name) + assert has_owner is True + +async def test_filter(router): + bus = Proxy(message_bus, router) + name = "io.gitlab.takluyver.jeepney.tests.asyncio_test_filter" + + match_rule = MatchRule( + type="signal", + sender=message_bus.bus_name, + interface=message_bus.interface, + member="NameOwnerChanged", + path=message_bus.object_path, + ) + match_rule.add_arg_condition(0, name) + + # Ask the message bus to subscribe us to this signal + await bus.AddMatch(match_rule) + + with router.filter(match_rule) as queue: + res, = await bus.RequestName(name) + assert res == 1 # 1: got the name + + signal_msg = await asyncio.wait_for(queue.get(), timeout=2.0) + assert signal_msg.body == (name, '', router.unique_name) + +async def test_recv_after_connect(): + # Can't use here: + # 1. 'connection' fixture + # 2. asyncio.wait_for() + # If (1) and/or (2) is used, the error won't be triggered. + conn = await open_dbus_connection(bus='SESSION') + try: + with pytest.raises(asyncio.TimeoutError): + async with timeout(0): + await conn.receive() + finally: + await conn.close() diff --git a/lib/jeepney/io/tests/test_blocking.py b/lib/jeepney/io/tests/test_blocking.py new file mode 100644 index 0000000..fedd95e --- /dev/null +++ b/lib/jeepney/io/tests/test_blocking.py @@ -0,0 +1,84 @@ +import pytest + +from jeepney import new_method_call, MessageType, DBusAddress +from jeepney.bus_messages import message_bus, MatchRule +from jeepney.io.blocking import open_dbus_connection, Proxy +from .utils import have_session_bus + +pytestmark = pytest.mark.skipif( + not have_session_bus, reason="Tests require DBus session bus" +) + +@pytest.fixture +def session_conn(): + with open_dbus_connection(bus='SESSION') as conn: + yield conn + + +def test_connect(session_conn): + assert session_conn.unique_name.startswith(':') + +bus_peer = DBusAddress( + bus_name='org.freedesktop.DBus', + object_path='/org/freedesktop/DBus', + interface='org.freedesktop.DBus.Peer' +) + +def test_send_and_get_reply(session_conn): + ping_call = new_method_call(bus_peer, 'Ping') + reply = session_conn.send_and_get_reply(ping_call, timeout=5) + assert reply.header.message_type == MessageType.method_return + assert reply.body == () + +def test_proxy(session_conn): + proxy = Proxy(message_bus, session_conn, timeout=5) + name = "io.gitlab.takluyver.jeepney.examples.Server" + res = proxy.RequestName(name) + assert res in {(1,), (2,)} # 1: got the name, 2: queued + + has_owner, = proxy.NameHasOwner(name, _timeout=3) + assert has_owner is True + +def test_filter(session_conn): + bus = Proxy(message_bus, session_conn) + name = "io.gitlab.takluyver.jeepney.tests.blocking_test_filter" + + match_rule = MatchRule( + type="signal", + sender=message_bus.bus_name, + interface=message_bus.interface, + member="NameOwnerChanged", + path=message_bus.object_path, + ) + match_rule.add_arg_condition(0, name) + + # Ask the message bus to subscribe us to this signal + bus.AddMatch(match_rule) + + with session_conn.filter(match_rule) as matches: + res, = bus.RequestName(name) + assert res == 1 # 1: got the name + + signal_msg = session_conn.recv_until_filtered(matches, timeout=2) + + assert signal_msg.body == (name, '', session_conn.unique_name) + + +def test_recv_fd(respond_with_fd): + getfd_call = new_method_call(respond_with_fd, 'GetFD') + with open_dbus_connection(bus='SESSION', enable_fds=True) as conn: + reply = conn.send_and_get_reply(getfd_call, timeout=5) + + assert reply.header.message_type is MessageType.method_return + with reply.body[0].to_file('w+') as f: + assert f.read() == 'readme' + + +def test_send_fd(temp_file_and_contents, read_from_fd): + temp_file, data = temp_file_and_contents + readfd_call = new_method_call(read_from_fd, 'ReadFD', 'h', (temp_file,)) + with open_dbus_connection(bus='SESSION', enable_fds=True) as conn: + reply = conn.send_and_get_reply(readfd_call, timeout=5) + + assert reply.header.message_type is MessageType.method_return + assert reply.body[0] == data diff --git a/lib/jeepney/io/tests/test_threading.py b/lib/jeepney/io/tests/test_threading.py new file mode 100644 index 0000000..d408497 --- /dev/null +++ b/lib/jeepney/io/tests/test_threading.py @@ -0,0 +1,83 @@ +import pytest + +from jeepney import new_method_call, MessageType, DBusAddress +from jeepney.bus_messages import message_bus, MatchRule +from jeepney.io.threading import open_dbus_router, Proxy +from .utils import have_session_bus + +pytestmark = pytest.mark.skipif( + not have_session_bus, reason="Tests require DBus session bus" +) + +@pytest.fixture +def router(): + with open_dbus_router(bus='SESSION') as conn: + yield conn + + +def test_connect(router): + assert router.unique_name.startswith(':') + +bus_peer = DBusAddress( + bus_name='org.freedesktop.DBus', + object_path='/org/freedesktop/DBus', + interface='org.freedesktop.DBus.Peer' +) + +def test_send_and_get_reply(router): + ping_call = new_method_call(bus_peer, 'Ping') + reply = router.send_and_get_reply(ping_call, timeout=5) + assert reply.header.message_type == MessageType.method_return + assert reply.body == () + +def test_proxy(router): + proxy = Proxy(message_bus, router, timeout=5) + name = "io.gitlab.takluyver.jeepney.examples.Server" + res = proxy.RequestName(name) + assert res in {(1,), (2,)} # 1: got the name, 2: queued + + has_owner, = proxy.NameHasOwner(name, _timeout=3) + assert has_owner is True + +def test_filter(router): + bus = Proxy(message_bus, router) + name = "io.gitlab.takluyver.jeepney.tests.threading_test_filter" + + match_rule = MatchRule( + type="signal", + sender=message_bus.bus_name, + interface=message_bus.interface, + member="NameOwnerChanged", + path=message_bus.object_path, + ) + match_rule.add_arg_condition(0, name) + + # Ask the message bus to subscribe us to this signal + bus.AddMatch(match_rule) + + with router.filter(match_rule) as queue: + res, = bus.RequestName(name) + assert res == 1 # 1: got the name + + signal_msg = queue.get(timeout=2.0) + assert signal_msg.body == (name, '', router.unique_name) + + +def test_recv_fd(respond_with_fd): + getfd_call = new_method_call(respond_with_fd, 'GetFD') + with open_dbus_router(bus='SESSION', enable_fds=True) as router: + reply = router.send_and_get_reply(getfd_call, timeout=5) + + assert reply.header.message_type is MessageType.method_return + with reply.body[0].to_file('w+') as f: + assert f.read() == 'readme' + + +def test_send_fd(temp_file_and_contents, read_from_fd): + temp_file, data = temp_file_and_contents + readfd_call = new_method_call(read_from_fd, 'ReadFD', 'h', (temp_file,)) + with open_dbus_router(bus='SESSION', enable_fds=True) as router: + reply = router.send_and_get_reply(readfd_call, timeout=5) + + assert reply.header.message_type is MessageType.method_return + assert reply.body[0] == data diff --git a/lib/jeepney/io/tests/test_trio.py b/lib/jeepney/io/tests/test_trio.py new file mode 100644 index 0000000..d426993 --- /dev/null +++ b/lib/jeepney/io/tests/test_trio.py @@ -0,0 +1,114 @@ +import trio +import pytest + +from jeepney import DBusAddress, DBusErrorResponse, MessageType, new_method_call +from jeepney.bus_messages import message_bus, MatchRule +from jeepney.io.trio import ( + open_dbus_connection, open_dbus_router, Proxy, +) +from .utils import have_session_bus + +pytestmark = [ + pytest.mark.trio, + pytest.mark.skipif( + not have_session_bus, reason="Tests require DBus session bus" + ), +] + +# Can't use any async fixtures here, because pytest-asyncio tries to handle +# all of them: https://github.com/pytest-dev/pytest-asyncio/issues/124 + +async def test_connect(): + conn = await open_dbus_connection(bus='SESSION') + async with conn: + assert conn.unique_name.startswith(':') + +bus_peer = DBusAddress( + bus_name='org.freedesktop.DBus', + object_path='/org/freedesktop/DBus', + interface='org.freedesktop.DBus.Peer' +) + +async def test_send_and_get_reply(): + ping_call = new_method_call(bus_peer, 'Ping') + async with open_dbus_router(bus='SESSION') as req: + with trio.fail_after(5): + reply = await req.send_and_get_reply(ping_call) + + assert reply.header.message_type == MessageType.method_return + assert reply.body == () + + +async def test_send_and_get_reply_error(): + ping_call = new_method_call(bus_peer, 'Snart') # No such method + async with open_dbus_router(bus='SESSION') as req: + with trio.fail_after(5): + reply = await req.send_and_get_reply(ping_call) + + assert reply.header.message_type == MessageType.error + + +async def test_proxy(): + async with open_dbus_router(bus='SESSION') as req: + proxy = Proxy(message_bus, req) + name = "io.gitlab.takluyver.jeepney.examples.Server" + res = await proxy.RequestName(name) + assert res in {(1,), (2,)} # 1: got the name, 2: queued + + has_owner, = await proxy.NameHasOwner(name) + assert has_owner is True + + +async def test_proxy_error(): + async with open_dbus_router(bus='SESSION') as req: + proxy = Proxy(message_bus, req) + with pytest.raises(DBusErrorResponse): + await proxy.RequestName(":123") # Invalid name + + +async def test_filter(): + name = "io.gitlab.takluyver.jeepney.tests.trio_test_filter" + async with open_dbus_router(bus='SESSION') as router: + bus = Proxy(message_bus, router) + + match_rule = MatchRule( + type="signal", + sender=message_bus.bus_name, + interface=message_bus.interface, + member="NameOwnerChanged", + path=message_bus.object_path, + ) + match_rule.add_arg_condition(0, name) + + # Ask the message bus to subscribe us to this signal + await bus.AddMatch(match_rule) + + async with router.filter(match_rule) as chan: + res, = await bus.RequestName(name) + assert res == 1 # 1: got the name + + with trio.fail_after(2.0): + signal_msg = await chan.receive() + assert signal_msg.body == (name, '', router.unique_name) + + +async def test_recv_fd(respond_with_fd): + getfd_call = new_method_call(respond_with_fd, 'GetFD') + with trio.fail_after(5): + async with open_dbus_router(bus='SESSION', enable_fds=True) as router: + reply = await router.send_and_get_reply(getfd_call) + + assert reply.header.message_type is MessageType.method_return + with reply.body[0].to_file('w+') as f: + assert f.read() == 'readme' + + +async def test_send_fd(temp_file_and_contents, read_from_fd): + temp_file, data = temp_file_and_contents + readfd_call = new_method_call(read_from_fd, 'ReadFD', 'h', (temp_file,)) + with trio.fail_after(5): + async with open_dbus_router(bus='SESSION', enable_fds=True) as router: + reply = await router.send_and_get_reply(readfd_call) + + assert reply.header.message_type is MessageType.method_return + assert reply.body[0] == data diff --git a/lib/jeepney/io/tests/utils.py b/lib/jeepney/io/tests/utils.py new file mode 100644 index 0000000..6db0f86 --- /dev/null +++ b/lib/jeepney/io/tests/utils.py @@ -0,0 +1,3 @@ +import os + +have_session_bus = bool(os.environ.get('DBUS_SESSION_BUS_ADDRESS')) diff --git a/lib/jeepney/io/threading.py b/lib/jeepney/io/threading.py new file mode 100644 index 0000000..5649299 --- /dev/null +++ b/lib/jeepney/io/threading.py @@ -0,0 +1,273 @@ +"""Synchronous IO wrappers with thread safety +""" +from concurrent.futures import Future +from contextlib import contextmanager +import functools +import os +from selectors import EVENT_READ +import socket +from queue import Queue, Full as QueueFull +from threading import Lock, Thread +from typing import Optional + +from jeepney import Message, MessageType +from jeepney.bus import get_bus +from jeepney.bus_messages import message_bus +from jeepney.wrappers import ProxyBase, unwrap_msg +from .blocking import ( + unwrap_read, prep_socket, DBusConnectionBase, timeout_to_deadline, +) +from .common import ( + MessageFilters, FilterHandle, ReplyMatcher, RouterClosed, check_replyable, +) + +__all__ = [ + 'open_dbus_connection', + 'open_dbus_router', + 'DBusConnection', + 'DBusRouter', + 'Proxy', + 'ReceiveStopped', +] + + +class ReceiveStopped(Exception): + pass + + +class DBusConnection(DBusConnectionBase): + def __init__(self, sock: socket.socket, enable_fds=False): + super().__init__(sock, enable_fds=enable_fds) + self._stop_r, self._stop_w = os.pipe() + self.stop_key = self.selector.register(self._stop_r, EVENT_READ) + self.send_lock = Lock() + self.rcv_lock = Lock() + + def send(self, message: Message, serial=None): + """Serialise and send a :class:`~.Message` object""" + data, fds = self._serialise(message, serial) + with self.send_lock: + if fds: + self._send_with_fds(data, fds) + else: + self.sock.sendall(data) + + def receive(self, *, timeout=None) -> Message: + """Return the next available message from the connection + + If the data is ready, this will return immediately, even if timeout<=0. + Otherwise, it will wait for up to timeout seconds, or indefinitely if + timeout is None. If no message comes in time, it raises TimeoutError. + + If the connection is closed from another thread, this will raise + ReceiveStopped. + """ + deadline = timeout_to_deadline(timeout) + + if not self.rcv_lock.acquire(timeout=(timeout or -1)): + raise TimeoutError(f"Did not get receive lock in {timeout} seconds") + try: + return self._receive(deadline) + finally: + self.rcv_lock.release() + + def _read_some_data(self, timeout=None): + # Wait for data or a signal on the stop pipe + for key, ev in self.selector.select(timeout): + if key == self.select_key: + if self.enable_fds: + return self._read_with_fds() + else: + return unwrap_read(self.sock.recv(4096)), [] + elif key == self.stop_key: + raise ReceiveStopped("DBus receive stopped from another thread") + + raise TimeoutError + + def interrupt(self): + """Make any threads waiting for a message raise ReceiveStopped""" + os.write(self._stop_w, b'a') + + def reset_interrupt(self): + """Allow calls to .receive() again after .interrupt() + + To avoid race conditions, you should typically wait for threads to + respond (e.g. by joining them) between interrupting and resetting. + """ + # Clear any data on the stop pipe + while (self.stop_key, EVENT_READ) in self.selector.select(timeout=0): + os.read(self._stop_r, 1024) + + def close(self): + """Close the connection""" + self.interrupt() + super().close() + + +def open_dbus_connection(bus='SESSION', enable_fds=False, auth_timeout=1.): + """Open a plain D-Bus connection + + D-Bus has an authentication step before sending or receiving messages. + This takes < 1 ms in normal operation, but there is a timeout so that client + code won't get stuck if the server doesn't reply. *auth_timeout* configures + this timeout in seconds. + + :return: :class:`DBusConnection` + """ + bus_addr = get_bus(bus) + sock = prep_socket(bus_addr, enable_fds, timeout=auth_timeout) + + conn = DBusConnection(sock, enable_fds) + + with DBusRouter(conn) as router: + reply_body = Proxy(message_bus, router, timeout=10).Hello() + conn.unique_name = reply_body[0] + + return conn + + +class DBusRouter: + """A client D-Bus connection which can wait for replies. + + This runs a separate receiver thread and dispatches received messages. + + It's possible to wrap a :class:`DBusConnection` in a router temporarily. + Using the connection directly while it is wrapped is not supported, + but you can use it again after the router is closed. + """ + def __init__(self, conn: DBusConnection): + self.conn = conn + self._replies = ReplyMatcher() + self._filters = MessageFilters() + self._rcv_thread = Thread(target=self._receiver, daemon=True) + self._rcv_thread.start() + + @property + def unique_name(self): + return self.conn.unique_name + + def send(self, message, *, serial=None): + """Serialise and send a :class:`~.Message` object""" + self.conn.send(message, serial=serial) + + def send_and_get_reply(self, msg: Message, *, timeout=None) -> Message: + """Send a method call message, wait for and return a reply""" + check_replyable(msg) + if not self._rcv_thread.is_alive(): + raise RouterClosed("This D-Bus router has stopped") + + serial = next(self.conn.outgoing_serial) + + with self._replies.catch(serial, Future()) as reply_fut: + self.conn.send(msg, serial=serial) + return reply_fut.result(timeout=timeout) + + def close(self): + """Close this router + + This does not close the underlying connection. + """ + self.conn.interrupt() + self._rcv_thread.join(timeout=10) + self.conn.reset_interrupt() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + return False + + def filter(self, rule, *, queue: Optional[Queue] =None, bufsize=1): + """Create a filter for incoming messages + + Usage:: + + with router.filter(rule) as queue: + matching_msg = queue.get() + + :param jeepney.MatchRule rule: Catch messages matching this rule + :param queue.Queue queue: Matched messages will be added to this + :param int bufsize: If no queue is passed in, create one with this size + """ + return FilterHandle(self._filters, rule, queue or Queue(maxsize=bufsize)) + + # Code to run in receiver thread ------------------------------------ + + def _dispatch(self, msg: Message): + if self._replies.dispatch(msg): + return + + for filter in self._filters.matches(msg): + try: + filter.queue.put_nowait(msg) + except QueueFull: + pass + + def _receiver(self): + try: + while True: + msg = self.conn.receive() + self._dispatch(msg) + except ReceiveStopped: + pass + finally: + # Send errors to any tasks still waiting for a message. + self._replies.drop_all() + +class Proxy(ProxyBase): + """A blocking proxy for calling D-Bus methods via a :class:`DBusRouter`. + + You can call methods on the proxy object, such as ``bus_proxy.Hello()`` + to make a method call over D-Bus and wait for a reply. It will either + return a tuple of returned data, or raise :exc:`.DBusErrorResponse`. + The methods available are defined by the message generator you wrap. + + You can set a time limit on a call by passing ``_timeout=`` in the method + call, or set a default when creating the proxy. The ``_timeout`` argument + is not passed to the message generator. + All timeouts are in seconds, and :exc:`TimeoutErrror` is raised if it + expires before a reply arrives. + + :param msggen: A message generator object + :param ~threading.DBusRouter router: Router to send and receive messages + :param float timeout: Default seconds to wait for a reply, or None for no limit + """ + def __init__(self, msggen, router, *, timeout=None): + super().__init__(msggen) + self._router = router + self._timeout = timeout + + def __repr__(self): + extra = '' if (self._timeout is None) else f', timeout={self._timeout}' + return f"Proxy({self._msggen}, {self._router}{extra})" + + def _method_call(self, make_msg): + @functools.wraps(make_msg) + def inner(*args, **kwargs): + timeout = kwargs.pop('_timeout', self._timeout) + msg = make_msg(*args, **kwargs) + assert msg.header.message_type is MessageType.method_call + reply = self._router.send_and_get_reply(msg, timeout=timeout) + return unwrap_msg(reply) + + return inner + +@contextmanager +def open_dbus_router(bus='SESSION', enable_fds=False): + """Open a D-Bus 'router' to send and receive messages. + + Use as a context manager:: + + with open_dbus_router() as router: + ... + + On leaving the ``with`` block, the connection will be closed. + + :param str bus: 'SESSION' or 'SYSTEM' or a supported address. + :param bool enable_fds: Whether to enable passing file descriptors. + :return: :class:`DBusRouter` + """ + with open_dbus_connection(bus=bus, enable_fds=enable_fds) as conn: + with DBusRouter(conn) as router: + yield router diff --git a/lib/jeepney/io/trio.py b/lib/jeepney/io/trio.py new file mode 100644 index 0000000..99acb91 --- /dev/null +++ b/lib/jeepney/io/trio.py @@ -0,0 +1,424 @@ +import array +import errno +import logging +import socket +from contextlib import asynccontextmanager, contextmanager +from itertools import count +from typing import Optional + +from outcome import Value, Error +import trio +from trio.abc import Channel + +from jeepney.auth import Authenticator, BEGIN +from jeepney.bus import get_bus +from jeepney.fds import FileDescriptor, fds_buf_size +from jeepney.low_level import Parser, MessageType, Message +from jeepney.wrappers import ProxyBase, unwrap_msg +from jeepney.bus_messages import message_bus +from .common import ( + MessageFilters, FilterHandle, ReplyMatcher, RouterClosed, check_replyable, +) + +log = logging.getLogger(__name__) + +__all__ = [ + 'open_dbus_connection', + 'open_dbus_router', + 'Proxy', +] + + +# The function below is copied from trio, which is under the MIT license: + +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +@contextmanager +def _translate_socket_errors_to_stream_errors(): + try: + yield + except OSError as exc: + if exc.errno in {errno.EBADF, errno.ENOTSOCK}: + # EBADF on Unix, ENOTSOCK on Windows + raise trio.ClosedResourceError("this socket was already closed") from None + else: + raise trio.BrokenResourceError( + "socket connection broken: {}".format(exc) + ) from exc + + + +class DBusConnection(Channel): + """A plain D-Bus connection with no matching of replies. + + This doesn't run any separate tasks: sending and receiving are done in + the task that calls those methods. It's suitable for implementing servers: + several worker tasks can receive requests and send replies. + For a typical client pattern, see :class:`DBusRouter`. + + Implements trio's channel interface for Message objects. + """ + def __init__(self, socket, enable_fds=False): + self.socket = socket + self.enable_fds = enable_fds + self.parser = Parser() + self.outgoing_serial = count(start=1) + self.unique_name = None + self.send_lock = trio.Lock() + self.recv_lock = trio.Lock() + self._leftover_to_send = None # type: Optional[memoryview] + + async def send(self, message: Message, *, serial=None): + """Serialise and send a :class:`~.Message` object""" + async with self.send_lock: + if serial is None: + serial = next(self.outgoing_serial) + fds = array.array('i') if self.enable_fds else None + data = message.serialise(serial, fds=fds) + await self._send_data(data, fds) + + # _send_data is copied & modified from trio's SocketStream.send_all() . + # See above for the MIT license. + async def _send_data(self, data: bytes, fds): + if self.socket.did_shutdown_SHUT_WR: + raise trio.ClosedResourceError("can't send data after sending EOF") + + with _translate_socket_errors_to_stream_errors(): + if self._leftover_to_send: + # A previous message was partly sent - finish sending it now. + await self._send_remainder(self._leftover_to_send) + + with memoryview(data) as data: + if fds: + sent = await self.socket.sendmsg([data], [( + trio.socket.SOL_SOCKET, trio.socket.SCM_RIGHTS, fds + )]) + else: + sent = await self.socket.send(data) + + await self._send_remainder(data, sent) + + async def _send_remainder(self, data: memoryview, already_sent=0): + try: + while already_sent < len(data): + with data[already_sent:] as remaining: + sent = await self.socket.send(remaining) + already_sent += sent + self._leftover_to_send = None + except trio.Cancelled: + # Sending cancelled mid-message. Keep track of the remaining data + # so it can be sent before the next message, otherwise the next + # message won't be recognised. + self._leftover_to_send = data[already_sent:] + raise + + async def receive(self) -> Message: + """Return the next available message from the connection""" + async with self.recv_lock: + while True: + msg = self.parser.get_next_message() + if msg is not None: + return msg + + # Once data is read, it must be given to the parser with no + # checkpoints (where the task could be cancelled). + b, fds = await self._read_data() + if not b: + raise trio.EndOfChannel("Socket closed at the other end") + self.parser.add_data(b, fds) + + async def _read_data(self): + if self.enable_fds: + nbytes = self.parser.bytes_desired() + with _translate_socket_errors_to_stream_errors(): + data, ancdata, flags, _ = await self.socket.recvmsg( + nbytes, fds_buf_size() + ) + if flags & getattr(trio.socket, 'MSG_CTRUNC', 0): + self._close() + raise RuntimeError("Unable to receive all file descriptors") + return data, FileDescriptor.from_ancdata(ancdata) + + else: # not self.enable_fds + with _translate_socket_errors_to_stream_errors(): + data = await self.socket.recv(4096) + return data, [] + + def _close(self): + self.socket.close() + self._leftover_to_send = None + + # Our closing is currently sync, but AsyncResource objects must have aclose + async def aclose(self): + """Close the D-Bus connection""" + self._close() + + @asynccontextmanager + async def router(self): + """Temporarily wrap this connection as a :class:`DBusRouter` + + To be used like:: + + async with conn.router() as req: + reply = await req.send_and_get_reply(msg) + + While the router is running, you shouldn't use :meth:`receive`. + Once the router is closed, you can use the plain connection again. + """ + async with trio.open_nursery() as nursery: + router = DBusRouter(self) + await router.start(nursery) + try: + yield router + finally: + await router.aclose() + + +async def open_dbus_connection(bus='SESSION', *, enable_fds=False) -> DBusConnection: + """Open a plain D-Bus connection + + :return: :class:`DBusConnection` + """ + bus_addr = get_bus(bus) + sock : trio.SocketStream = await trio.open_unix_socket(bus_addr) + + # Authentication + authr = Authenticator(enable_fds=enable_fds, inc_null_byte=False) + if hasattr(socket, 'SCM_CREDS'): + # BSD: send credentials message to authenticate (kernel fills in data) + await sock.socket.sendmsg( + [b'\0'], [(socket.SOL_SOCKET, socket.SCM_CREDS, bytes(512))] + ) + else: + # Linux: no ancillary data needed, bus checks with SO_PEERCRED + await sock.send_all(b'\0') + for req_data in authr: + await sock.send_all(req_data) + authr.feed(await sock.receive_some()) + await sock.send_all(BEGIN) + + conn = DBusConnection(sock.socket, enable_fds=enable_fds) + + # Say *Hello* to the message bus - this must be the first message, and the + # reply gives us our unique name. + async with conn.router() as router: + reply = await router.send_and_get_reply(message_bus.Hello()) + conn.unique_name = reply.body[0] + + return conn + + +class TrioFilterHandle(FilterHandle): + def __init__(self, filters: MessageFilters, rule, send_chn, recv_chn): + super().__init__(filters, rule, recv_chn) + self.send_channel = send_chn + + @property + def receive_channel(self): + return self.queue + + async def aclose(self): + self.close() + await self.send_channel.aclose() + + async def __aenter__(self): + return self.queue + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.aclose() + + +class Future: + """A very simple Future for trio based on `trio.Event`.""" + def __init__(self): + self._outcome = None + self._event = trio.Event() + + def set_result(self, result): + self._outcome = Value(result) + self._event.set() + + def set_exception(self, exc): + self._outcome = Error(exc) + self._event.set() + + async def get(self): + await self._event.wait() + return self._outcome.unwrap() + + +class DBusRouter: + """A client D-Bus connection which can wait for replies. + + This runs a separate receiver task and dispatches received messages. + """ + _nursery_mgr = None + _rcv_cancel_scope = None + + def __init__(self, conn: DBusConnection): + self._conn = conn + self._replies = ReplyMatcher() + self._filters = MessageFilters() + + @property + def unique_name(self): + return self._conn.unique_name + + async def send(self, message, *, serial=None): + """Send a message, don't wait for a reply + """ + await self._conn.send(message, serial=serial) + + async def send_and_get_reply(self, message) -> Message: + """Send a method call message and wait for the reply + + Returns the reply message (method return or error message type). + """ + check_replyable(message) + if self._rcv_cancel_scope is None: + raise RouterClosed("This DBusRouter has stopped") + + serial = next(self._conn.outgoing_serial) + + with self._replies.catch(serial, Future()) as reply_fut: + await self.send(message, serial=serial) + return (await reply_fut.get()) + + def filter(self, rule, *, channel: Optional[trio.MemorySendChannel]=None, bufsize=1): + """Create a filter for incoming messages + + Usage:: + + async with router.filter(rule) as receive_channel: + matching_msg = await receive_channel.receive() + + # OR: + send_chan, recv_chan = trio.open_memory_channel(1) + async with router.filter(rule, channel=send_chan): + matching_msg = await recv_chan.receive() + + If the channel fills up, + The sending end of the channel is closed when leaving the ``async with`` + block, whether or not it was passed in. + + :param jeepney.MatchRule rule: Catch messages matching this rule + :param trio.MemorySendChannel channel: Send matching messages here + :param int bufsize: If no channel is passed in, create one with this size + """ + if channel is None: + channel, recv_channel = trio.open_memory_channel(bufsize) + else: + recv_channel = None + return TrioFilterHandle(self._filters, rule, channel, recv_channel) + + # Task management ------------------------------------------- + + async def start(self, nursery: trio.Nursery): + if self._rcv_cancel_scope is not None: + raise RuntimeError("DBusRouter receiver task is already running") + self._rcv_cancel_scope = await nursery.start(self._receiver) + + async def aclose(self): + """Stop the sender & receiver tasks""" + # It doesn't matter if we receive a partial message - the connection + # should ensure that whatever is received is fed to the parser. + if self._rcv_cancel_scope is not None: + self._rcv_cancel_scope.cancel() + self._rcv_cancel_scope = None + + # Ensure trio checkpoint + await trio.sleep(0) + + # Code to run in receiver task ------------------------------------ + + def _dispatch(self, msg: Message): + """Handle one received message""" + if self._replies.dispatch(msg): + return + + for filter in self._filters.matches(msg): + try: + filter.send_channel.send_nowait(msg) + except trio.WouldBlock: + pass + + async def _receiver(self, task_status=trio.TASK_STATUS_IGNORED): + """Receiver loop - runs in a separate task""" + with trio.CancelScope() as cscope: + self.is_running = True + task_status.started(cscope) + try: + while True: + msg = await self._conn.receive() + self._dispatch(msg) + finally: + self.is_running = False + # Send errors to any tasks still waiting for a message. + self._replies.drop_all() + + # Closing a memory channel can't block, but it only has an + # async close method, so we need to shield it from cancellation. + with trio.move_on_after(3) as cleanup_scope: + for filter in self._filters.filters.values(): + cleanup_scope.shield = True + await filter.send_channel.aclose() + + +class Proxy(ProxyBase): + """A trio proxy for calling D-Bus methods + + You can call methods on the proxy object, such as ``await bus_proxy.Hello()`` + to make a method call over D-Bus and wait for a reply. It will either + return a tuple of returned data, or raise :exc:`.DBusErrorResponse`. + The methods available are defined by the message generator you wrap. + + :param msggen: A message generator object. + :param ~trio.DBusRouter router: Router to send and receive messages. + """ + def __init__(self, msggen, router): + super().__init__(msggen) + if not isinstance(router, DBusRouter): + raise TypeError("Proxy can only be used with DBusRequester") + self._router = router + + def _method_call(self, make_msg): + async def inner(*args, **kwargs): + msg = make_msg(*args, **kwargs) + assert msg.header.message_type is MessageType.method_call + reply = await self._router.send_and_get_reply(msg) + return unwrap_msg(reply) + + return inner + + +@asynccontextmanager +async def open_dbus_router(bus='SESSION', *, enable_fds=False): + """Open a D-Bus 'router' to send and receive messages. + + Use as an async context manager:: + + async with open_dbus_router() as req: + ... + + :param str bus: 'SESSION' or 'SYSTEM' or a supported address. + :return: :class:`DBusRouter` + + This is a shortcut for:: + + conn = await open_dbus_connection() + async with conn: + async with conn.router() as req: + ... + """ + conn = await open_dbus_connection(bus, enable_fds=enable_fds) + async with conn: + async with conn.router() as rtr: + yield rtr diff --git a/lib/jeepney/low_level.py b/lib/jeepney/low_level.py new file mode 100644 index 0000000..1b1463d --- /dev/null +++ b/lib/jeepney/low_level.py @@ -0,0 +1,608 @@ +import string +import struct +from collections import deque +from enum import Enum, IntEnum, IntFlag +from typing import Optional + +class SizeLimitError(ValueError): + """Raised when trying to (de-)serialise data exceeding D-Bus' size limit. + + This is currently only implemented for arrays, where the maximum size is + 64 MiB. + """ + pass + +class Endianness(Enum): + little = 1 + big = 2 + + def struct_code(self): + return '<' if (self is Endianness.little) else '>' + + def dbus_code(self): + return b'l' if (self is Endianness.little) else b'B' + + +endian_map = {b'l': Endianness.little, b'B': Endianness.big} + + +class MessageType(Enum): + method_call = 1 + method_return = 2 + error = 3 + signal = 4 + + +class MessageFlag(IntFlag): + no_reply_expected = 1 + no_auto_start = 2 + allow_interactive_authorization = 4 + + +class HeaderFields(IntEnum): + path = 1 + interface = 2 + member = 3 + error_name = 4 + reply_serial = 5 + destination = 6 + sender = 7 + signature = 8 + unix_fds = 9 + + +def padding(pos, step): + pad = step - (pos % step) + if pad == step: + return 0 + return pad + + +class FixedType: + def __init__(self, size, struct_code): + self.size = self.alignment = size + self.struct_code = struct_code + + def parse_data(self, buf, pos, endianness, fds=()): + pos += padding(pos, self.alignment) + code = endianness.struct_code() + self.struct_code + val = struct.unpack(code, buf[pos:pos + self.size])[0] + return val, pos + self.size + + def serialise(self, data, pos, endianness, fds=None): + pad = b'\0' * padding(pos, self.alignment) + code = endianness.struct_code() + self.struct_code + return pad + struct.pack(code, data) + + def __repr__(self): + return 'FixedType({!r}, {!r})'.format(self.size, self.struct_code) + + def __eq__(self, other): + return (type(other) is FixedType) and (self.size == other.size) \ + and (self.struct_code == other.struct_code) + + +class Boolean(FixedType): + def __init__(self): + super().__init__(4, 'I') # D-Bus booleans take 4 bytes + + def parse_data(self, buf, pos, endianness, fds=()): + val, new_pos = super().parse_data(buf, pos, endianness) + return bool(val), new_pos + + def __repr__(self): + return 'Boolean()' + + def __eq__(self, other): + return type(other) is Boolean + + +class FileDescriptor(FixedType): + def __init__(self): + super().__init__(4, 'I') + + def parse_data(self, buf, pos, endianness, fds=()): + idx, new_pos = super().parse_data(buf, pos, endianness) + return fds[idx], new_pos + + def serialise(self, data, pos, endianness, fds=None): + if fds is None: + raise RuntimeError("Sending FDs is not supported or not enabled") + + if hasattr(data, 'fileno'): + data = data.fileno() + if isinstance(data, bool) or not isinstance(data, int): + raise TypeError("Cannot use {data!r} as file descriptor. Expected " + "an int or an object with fileno() method") + + if data < 0: + raise ValueError(f"File descriptor can't be negative ({data})") + + fds.append(data) + return super().serialise(len(fds) - 1, pos, endianness) + + def __repr__(self): + return 'FileDescriptor()' + + def __eq__(self, other): + return type(other) is FileDescriptor + + +simple_types = { + 'y': FixedType(1, 'B'), # unsigned 8 bit + 'n': FixedType(2, 'h'), # signed 16 bit + 'q': FixedType(2, 'H'), # unsigned 16 bit + 'b': Boolean(), # bool (32-bit) + 'i': FixedType(4, 'i'), # signed 32-bit + 'u': FixedType(4, 'I'), # unsigned 32-bit + 'x': FixedType(8, 'q'), # signed 64-bit + 't': FixedType(8, 'Q'), # unsigned 64-bit + 'd': FixedType(8, 'd'), # double + 'h': FileDescriptor(), # file descriptor (uint32 index in a separate list) +} + + +class StringType: + def __init__(self, length_type): + self.length_type = length_type + + @property + def alignment(self): + return self.length_type.size + + def parse_data(self, buf, pos, endianness, fds=()): + length, pos = self.length_type.parse_data(buf, pos, endianness) + end = pos + length + val = buf[pos:end].decode('utf-8') + assert buf[end:end + 1] == b'\0' + return val, end + 1 + + def check_data(self, data): + if not isinstance(data, str): + raise TypeError("Expected str, not {!r}".format(data)) + + def serialise(self, data, pos, endianness, fds=None): + self.check_data(data) + encoded = data.encode('utf-8') + len_data = self.length_type.serialise(len(encoded), pos, endianness) + return len_data + encoded + b'\0' + + def __repr__(self): + return 'StringType({!r})'.format(self.length_type) + + def __eq__(self, other): + return (type(other) is StringType) \ + and (self.length_type == other.length_type) + + +class ObjectPathType(StringType): + def __init__(self): + super().__init__(simple_types['u']) + + def check_data(self, data): + super().check_data(data) + if not data.startswith('/'): + raise ValueError(f"Object path ({data!r}) must start with /") + if data.endswith('/') and len(data) > 1: + raise ValueError(f"Object path ({data!r}) cannot end with /") + if '//' in data: + raise ValueError(f"Object path ({data!r}) cannot contain double /") + valid_chars = string.ascii_letters + string.digits + '/_' + if any(c not in valid_chars for c in data): + raise ValueError( + f"Object path ({data!r}) can only contain A-Z, a-z, 0-9, / and _" + ) + + +simple_types.update({ + 's': StringType(simple_types['u']), # String + 'o': ObjectPathType(), # Object path + 'g': StringType(simple_types['y']), # Signature +}) + + +class Struct: + alignment = 8 + + def __init__(self, fields): + if any(isinstance(f, DictEntry) for f in fields): + raise TypeError("Found dict entry outside array") + self.fields = fields + + def parse_data(self, buf, pos, endianness, fds=()): + pos += padding(pos, 8) + res = [] + for field in self.fields: + v, pos = field.parse_data(buf, pos, endianness, fds=fds) + res.append(v) + return tuple(res), pos + + def serialise(self, data, pos, endianness, fds=None): + if not isinstance(data, tuple): + raise TypeError("Expected tuple, not {!r}".format(data)) + if len(data) != len(self.fields): + raise ValueError("{} entries for {} fields".format( + len(data), len(self.fields) + )) + pad = b'\0' * padding(pos, self.alignment) + pos += len(pad) + res_pieces = [] + for item, field in zip(data, self.fields): + res_pieces.append(field.serialise(item, pos, endianness, fds=fds)) + pos += len(res_pieces[-1]) + return pad + b''.join(res_pieces) + + def __repr__(self): + return "{}({!r})".format(type(self).__name__, self.fields) + + def __eq__(self, other): + return (type(other) is type(self)) and (self.fields == other.fields) + + +class DictEntry(Struct): + def __init__(self, fields): + if len(fields) != 2: + raise TypeError( + "Dict entry must have 2 fields, not %d" % len(fields)) + if not isinstance(fields[0], (FixedType, StringType)): + raise TypeError( + "First field in dict entry must be simple type, not {}" + .format(type(fields[0]))) + super().__init__(fields) + +class Array: + alignment = 4 + length_type = FixedType(4, 'I') + + def __init__(self, elt_type): + self.elt_type = elt_type + + def parse_data(self, buf, pos, endianness, fds=()): + # print('Array start', pos) + length, pos = self.length_type.parse_data(buf, pos, endianness) + pos += padding(pos, self.elt_type.alignment) + end = pos + length + if self.elt_type == simple_types['y']: # Array of bytes + return buf[pos:end], end + + res = [] + while pos < end: + # print('Array elem', pos) + v, pos = self.elt_type.parse_data(buf, pos, endianness, fds=fds) + res.append(v) + if isinstance(self.elt_type, DictEntry): + # Convert list of 2-tuples to dict + res = dict(res) + return res, pos + + def serialise(self, data, pos, endianness, fds=None): + data_is_bytes = False + if isinstance(self.elt_type, DictEntry) and isinstance(data, dict): + data = data.items() + elif (self.elt_type == simple_types['y']) and isinstance(data, bytes): + data_is_bytes = True + elif not isinstance(data, list): + raise TypeError("Not suitable for array: {!r}".format(data)) + + # Fail fast if we know in advance that the data is too big: + if isinstance(self.elt_type, FixedType): + if (self.elt_type.size * len(data)) > 2**26: + raise SizeLimitError("Array size exceeds 64 MiB limit") + + pad1 = padding(pos, self.alignment) + pos_after_length = pos + pad1 + 4 + pad2 = padding(pos_after_length, self.elt_type.alignment) + + if data_is_bytes: + buf = data + else: + data_pos = pos_after_length + pad2 + limit_pos = data_pos + 2 ** 26 + chunks = [] + for item in data: + chunks.append(self.elt_type.serialise( + item, data_pos, endianness, fds=fds + )) + data_pos += len(chunks[-1]) + if data_pos > limit_pos: + raise SizeLimitError("Array size exceeds 64 MiB limit") + buf = b''.join(chunks) + + len_data = self.length_type.serialise(len(buf), pos+pad1, endianness) + # print('Array ser: pad1={!r}, len_data={!r}, pad2={!r}, buf={!r}'.format( + # pad1, len_data, pad2, buf)) + return (b'\0' * pad1) + len_data + (b'\0' * pad2) + buf + + def __repr__(self): + return 'Array({!r})'.format(self.elt_type) + + def __eq__(self, other): + return (type(other) is Array) and (self.elt_type == other.elt_type) + + +class Variant: + alignment = 1 + + def parse_data(self, buf, pos, endianness, fds=()): + # print('variant', pos) + sig, pos = simple_types['g'].parse_data(buf, pos, endianness) + # print('variant sig:', repr(sig), pos) + valtype = parse_signature(list(sig)) + val, pos = valtype.parse_data(buf, pos, endianness, fds=fds) + # print('variant done', (sig, val), pos) + return (sig, val), pos + + def serialise(self, data, pos, endianness, fds=None): + sig, data = data + valtype = parse_signature(list(sig)) + sig_buf = simple_types['g'].serialise(sig, pos, endianness) + return sig_buf + valtype.serialise( + data, pos + len(sig_buf), endianness, fds=fds + ) + + def __repr__(self): + return 'Variant()' + + def __eq__(self, other): + return type(other) is Variant + +def parse_signature(sig): + """Parse a symbolic signature into objects. + """ + # Based on http://norvig.com/lispy.html + token = sig.pop(0) + if token == 'a': + return Array(parse_signature(sig)) + if token == 'v': + return Variant() + elif token == '(': + fields = [] + while sig[0] != ')': + fields.append(parse_signature(sig)) + sig.pop(0) # ) + return Struct(fields) + elif token == '{': + de = [] + while sig[0] != '}': + de.append(parse_signature(sig)) + sig.pop(0) # } + return DictEntry(de) + elif token in ')}': + raise ValueError('Unexpected end of struct') + else: + return simple_types[token] + + +def calc_msg_size(buf): + endian, = struct.unpack('c', buf[:1]) + endian = endian_map[endian] + body_length, = struct.unpack(endian.struct_code() + 'I', buf[4:8]) + fields_array_len, = struct.unpack(endian.struct_code() + 'I', buf[12:16]) + header_len = 16 + fields_array_len + return header_len + padding(header_len, 8) + body_length + + +_header_fields_type = Array(Struct([simple_types['y'], Variant()])) + + +def parse_header_fields(buf, endianness): + l, pos = _header_fields_type.parse_data(buf, 12, endianness) + return {HeaderFields(k): v[1] for (k, v) in l}, pos + + +header_field_codes = { + 1: 'o', + 2: 's', + 3: 's', + 4: 's', + 5: 'u', + 6: 's', + 7: 's', + 8: 'g', + 9: 'u', +} + + +def serialise_header_fields(d, endianness): + l = [(i.value, (header_field_codes[i], v)) for (i, v) in sorted(d.items())] + return _header_fields_type.serialise(l, 12, endianness) + + +class Header: + def __init__(self, endianness, message_type, flags, protocol_version, + body_length, serial, fields): + """A D-Bus message header + + It's not normally necessary to construct this directly: use higher level + functions and methods instead. + """ + self.endianness = endianness + self.message_type = MessageType(message_type) + self.flags = MessageFlag(flags) + self.protocol_version = protocol_version + self.body_length = body_length + self.serial = serial + self.fields = fields + + def __repr__(self): + return 'Header({!r}, {!r}, {!r}, {!r}, {!r}, {!r}, fields={!r})'.format( + self.endianness, self.message_type, self.flags, + self.protocol_version, self.body_length, self.serial, self.fields) + + def serialise(self, serial=None): + s = self.endianness.struct_code() + 'cBBBII' + if serial is None: + serial = self.serial + return struct.pack(s, self.endianness.dbus_code(), + self.message_type.value, self.flags, + self.protocol_version, + self.body_length, serial) \ + + serialise_header_fields(self.fields, self.endianness) + + @classmethod + def from_buffer(cls, buf): + endian, msgtype, flags, pv = struct.unpack('cBBB', buf[:4]) + endian = endian_map[endian] + bodylen, serial = struct.unpack(endian.struct_code() + 'II', buf[4:12]) + fields, pos = parse_header_fields(buf, endian) + return cls(endian, msgtype, flags, pv, bodylen, serial, fields), pos + + +class Message: + """Object representing a DBus message. + + It's not normally necessary to construct this directly: use higher level + functions and methods instead. + """ + def __init__(self, header, body): + self.header = header + self.body = body + + def __repr__(self): + return "{}({!r}, {!r})".format(type(self).__name__, self.header, self.body) + + @classmethod + def from_buffer(cls, buf: bytes, fds=()) -> 'Message': + header, pos = Header.from_buffer(buf) + n_fds = header.fields.get(HeaderFields.unix_fds, 0) + if n_fds > len(fds): + raise ValueError( + f"Message expects {n_fds} FDs, but only {len(fds)} were received" + ) + fds = fds[:n_fds] + body = () + if HeaderFields.signature in header.fields: + sig = header.fields[HeaderFields.signature] + body_type = parse_signature(list('(%s)' % sig)) + body = body_type.parse_data(buf, pos, header.endianness, fds=fds)[0] + return Message(header, body) + + def serialise(self, serial=None, fds=None) -> bytes: + """Convert this message to bytes. + + Specifying *serial* overrides the ``msg.header.serial`` field, so a + connection can use its own serial number without modifying the message. + + If file-descriptor support is in use, *fds* should be a + :class:`array.array` object with type ``'i'``. Any file descriptors in + the message will be added to the array. If the message contains FDs, + it can't be serialised without this array. + """ + endian = self.header.endianness + + if HeaderFields.signature in self.header.fields: + sig = self.header.fields[HeaderFields.signature] + body_type = parse_signature(list('(%s)' % sig)) + body_buf = body_type.serialise(self.body, 0, endian, fds=fds) + else: + body_buf = b'' + + self.header.body_length = len(body_buf) + if fds: + self.header.fields[HeaderFields.unix_fds] = len(fds) + + header_buf = self.header.serialise(serial=serial) + pad = b'\0' * padding(len(header_buf), 8) + return header_buf + pad + body_buf + + +class Parser: + """Parse DBus messages from a stream of incoming data. + """ + def __init__(self): + self.buf = BufferPipe() + self.fds = [] + self.next_msg_size = None + + def add_data(self, data: bytes, fds=()): + """Provide newly received data to the parser""" + self.buf.write(data) + self.fds.extend(fds) + + def feed(self, data): + """Feed the parser newly read data. + + Returns a list of messages completed by the new data. + """ + self.add_data(data) + return list(iter(self.get_next_message, None)) + + def bytes_desired(self): + """How many bytes can be received without going beyond the next message? + + This is only used with file-descriptor passing, so we don't get too many + FDs in a single recvmsg call. + """ + got = self.buf.bytes_buffered + if got < 16: # The first 16 bytes tell us the message size + return 16 - got + + if self.next_msg_size is None: + self.next_msg_size = calc_msg_size(self.buf.peek(16)) + return self.next_msg_size - got + + def get_next_message(self) -> Optional[Message]: + """Parse one message, if there is enough data. + + Returns None if it doesn't have a complete message. + """ + if self.next_msg_size is None: + if self.buf.bytes_buffered >= 16: + self.next_msg_size = calc_msg_size(self.buf.peek(16)) + nms = self.next_msg_size + if (nms is not None) and self.buf.bytes_buffered >= nms: + raw_msg = self.buf.read(nms) + msg = Message.from_buffer(raw_msg, fds=self.fds) + self.next_msg_size = None + fds_consumed = msg.header.fields.get(HeaderFields.unix_fds, 0) + self.fds = self.fds[fds_consumed:] + return msg + + +class BufferPipe: + """A place to store received data until we can parse a complete message + + The main difference from io.BytesIO is that read & write operate at + opposite ends, like a pipe. + """ + def __init__(self): + self.chunks = deque() + self.bytes_buffered = 0 + + def write(self, b: bytes): + self.chunks.append(b) + self.bytes_buffered += len(b) + + def _peek_iter(self, nbytes: int): + assert nbytes <= self.bytes_buffered + for chunk in self.chunks: + chunk = chunk[:nbytes] + nbytes -= len(chunk) + yield chunk + if nbytes <= 0: + break + + def peek(self, nbytes: int) -> bytes: + """Get exactly nbytes bytes from the front without removing them""" + return b''.join(self._peek_iter(nbytes)) + + def _read_iter(self, nbytes: int): + assert nbytes <= self.bytes_buffered + while True: + chunk = self.chunks.popleft() + self.bytes_buffered -= len(chunk) + if nbytes <= len(chunk): + break + nbytes -= len(chunk) + yield chunk + + # Final chunk + chunk, rem = chunk[:nbytes], chunk[nbytes:] + if rem: + self.chunks.appendleft(rem) + self.bytes_buffered += len(rem) + yield chunk + + def read(self, nbytes: int) -> bytes: + """Take & return exactly nbytes bytes from the front""" + return b''.join(self._read_iter(nbytes)) diff --git a/lib/jeepney/tests/__init__.py b/lib/jeepney/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lib/jeepney/tests/__pycache__/__init__.cpython-314.pyc b/lib/jeepney/tests/__pycache__/__init__.cpython-314.pyc new file mode 100644 index 0000000..41f4a85 Binary files /dev/null and b/lib/jeepney/tests/__pycache__/__init__.cpython-314.pyc differ diff --git a/lib/jeepney/tests/__pycache__/test_auth.cpython-314.pyc b/lib/jeepney/tests/__pycache__/test_auth.cpython-314.pyc new file mode 100644 index 0000000..48c99f9 Binary files /dev/null and b/lib/jeepney/tests/__pycache__/test_auth.cpython-314.pyc differ diff --git a/lib/jeepney/tests/__pycache__/test_bindgen.cpython-314.pyc b/lib/jeepney/tests/__pycache__/test_bindgen.cpython-314.pyc new file mode 100644 index 0000000..c0e33d4 Binary files /dev/null and b/lib/jeepney/tests/__pycache__/test_bindgen.cpython-314.pyc differ diff --git a/lib/jeepney/tests/__pycache__/test_bus.cpython-314.pyc b/lib/jeepney/tests/__pycache__/test_bus.cpython-314.pyc new file mode 100644 index 0000000..9a0956a Binary files /dev/null and b/lib/jeepney/tests/__pycache__/test_bus.cpython-314.pyc differ diff --git a/lib/jeepney/tests/__pycache__/test_bus_messages.cpython-314.pyc b/lib/jeepney/tests/__pycache__/test_bus_messages.cpython-314.pyc new file mode 100644 index 0000000..a9dc1cf Binary files /dev/null and b/lib/jeepney/tests/__pycache__/test_bus_messages.cpython-314.pyc differ diff --git a/lib/jeepney/tests/__pycache__/test_fds.cpython-314.pyc b/lib/jeepney/tests/__pycache__/test_fds.cpython-314.pyc new file mode 100644 index 0000000..3c22b76 Binary files /dev/null and b/lib/jeepney/tests/__pycache__/test_fds.cpython-314.pyc differ diff --git a/lib/jeepney/tests/__pycache__/test_low_level.cpython-314.pyc b/lib/jeepney/tests/__pycache__/test_low_level.cpython-314.pyc new file mode 100644 index 0000000..ac4a1a7 Binary files /dev/null and b/lib/jeepney/tests/__pycache__/test_low_level.cpython-314.pyc differ diff --git a/lib/jeepney/tests/__pycache__/test_wrappers.cpython-314.pyc b/lib/jeepney/tests/__pycache__/test_wrappers.cpython-314.pyc new file mode 100644 index 0000000..1b73210 Binary files /dev/null and b/lib/jeepney/tests/__pycache__/test_wrappers.cpython-314.pyc differ diff --git a/lib/jeepney/tests/secrets_introspect.xml b/lib/jeepney/tests/secrets_introspect.xml new file mode 100644 index 0000000..edabf81 --- /dev/null +++ b/lib/jeepney/tests/secrets_introspect.xml @@ -0,0 +1,116 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/lib/jeepney/tests/test_auth.py b/lib/jeepney/tests/test_auth.py new file mode 100644 index 0000000..62d9228 --- /dev/null +++ b/lib/jeepney/tests/test_auth.py @@ -0,0 +1,24 @@ +import pytest + +from jeepney import auth + +def test_make_auth_external(): + b = auth.make_auth_external() + assert b.startswith(b'AUTH EXTERNAL') + +def test_make_auth_anonymous(): + b = auth.make_auth_anonymous() + assert b.startswith(b'AUTH ANONYMOUS') + +def test_parser(): + p = auth.SASLParser() + p.feed(b'OK 728d62bc2eb394') + assert not p.authenticated + p.feed(b'1ebbb0b42958b1e0d6\r\n') + assert p.authenticated + +def test_parser_rejected(): + p = auth.SASLParser() + with pytest.raises(auth.AuthenticationError): + p.feed(b'REJECTED EXTERNAL\r\n') + assert not p.authenticated diff --git a/lib/jeepney/tests/test_bindgen.py b/lib/jeepney/tests/test_bindgen.py new file mode 100644 index 0000000..ef9571b --- /dev/null +++ b/lib/jeepney/tests/test_bindgen.py @@ -0,0 +1,28 @@ +from io import StringIO +import os.path + +from jeepney.low_level import MessageType, HeaderFields +from jeepney.bindgen import code_from_xml + +sample_file = os.path.join(os.path.dirname(__file__), 'secrets_introspect.xml') + +def test_bindgen(): + with open(sample_file) as f: + xml = f.read() + sio = StringIO() + n_interfaces = code_from_xml(xml, path='/org/freedesktop/secrets', + bus_name='org.freedesktop.secrets', + fh=sio) + # 5 interfaces defined, but we ignore Properties, Introspectable, Peer + assert n_interfaces == 2 + + # Run the generated code, defining the message generator classes. + binding_ns = {} + exec(sio.getvalue(), binding_ns) + Service = binding_ns['Service'] + + # Check basic functionality of the Service class + assert Service.interface == 'org.freedesktop.Secret.Service' + msg = Service().SearchItems({"service": "foo", "user": "bar"}) + assert msg.header.message_type is MessageType.method_call + assert msg.header.fields[HeaderFields.destination] == 'org.freedesktop.secrets' diff --git a/lib/jeepney/tests/test_bus.py b/lib/jeepney/tests/test_bus.py new file mode 100644 index 0000000..70dfa36 --- /dev/null +++ b/lib/jeepney/tests/test_bus.py @@ -0,0 +1,24 @@ +import pytest +from testpath import modified_env + +from jeepney import bus + +def test_get_connectable_addresses(): + a = list(bus.get_connectable_addresses('unix:path=/run/user/1000/bus')) + assert a == ['/run/user/1000/bus'] + + a = list(bus.get_connectable_addresses('unix:abstract=/tmp/foo')) + assert a == ['\0/tmp/foo'] + + with pytest.raises(RuntimeError): + list(bus.get_connectable_addresses('unix:tmpdir=/tmp')) + +def test_get_bus(): + with modified_env({ + 'DBUS_SESSION_BUS_ADDRESS':'unix:path=/run/user/1000/bus', + 'DBUS_SYSTEM_BUS_ADDRESS': 'unix:path=/var/run/dbus/system_bus_socket' + }): + assert bus.get_bus('SESSION') == '/run/user/1000/bus' + assert bus.get_bus('SYSTEM') == '/var/run/dbus/system_bus_socket' + + assert bus.get_bus('unix:path=/run/user/1002/bus') == '/run/user/1002/bus' diff --git a/lib/jeepney/tests/test_bus_messages.py b/lib/jeepney/tests/test_bus_messages.py new file mode 100644 index 0000000..50069b5 --- /dev/null +++ b/lib/jeepney/tests/test_bus_messages.py @@ -0,0 +1,112 @@ +from jeepney import DBusAddress, new_signal, new_method_call +from jeepney.bus_messages import MatchRule, message_bus + +portal = DBusAddress( + object_path='/org/freedesktop/portal/desktop', + bus_name='org.freedesktop.portal.Desktop', +) +portal_req_iface = portal.with_interface('org.freedesktop.portal.Request') + + +def test_match_rule_simple(): + rule = MatchRule( + type='signal', interface='org.freedesktop.portal.Request', + ) + assert rule.matches(new_signal(portal_req_iface, 'Response')) + + # Wrong message type + assert not rule.matches(new_method_call(portal_req_iface, 'Boo')) + + # Wrong interface + assert not rule.matches(new_signal( + portal.with_interface('org.freedesktop.portal.FileChooser'), 'Response' + )) + + +def test_match_rule_path_namespace(): + assert MatchRule(path_namespace='/org/freedesktop/portal').matches( + new_signal(portal_req_iface, 'Response') + ) + assert "/freedesktop/" in ( + MatchRule(path_namespace='/org/freedesktop/portal').serialise() + ) + + # Prefix but not a parent in the path hierarchy + assert not MatchRule(path_namespace='/org/freedesktop/por').matches( + new_signal(portal_req_iface, 'Response') + ) + + +def test_match_rule_arg(): + rule = MatchRule(type='method_call') + rule.add_arg_condition(0, 'foo') + + assert rule.matches(new_method_call( + portal_req_iface, 'Boo', signature='s', body=('foo',) + )) + + assert not rule.matches(new_method_call( + portal_req_iface, 'Boo', signature='s', body=('foobar',) + )) + + # No such argument + assert not rule.matches(new_method_call(portal_req_iface, 'Boo')) + + +def test_match_rule_arg_path(): + rule = MatchRule(type='method_call') + rule.add_arg_condition(0, '/aa/bb/', kind='path') + + # Exact match + assert rule.matches(new_method_call( + portal_req_iface, 'Boo', signature='s', body=('/aa/bb/',) + )) + + # Match a prefix + assert rule.matches(new_method_call( + portal_req_iface, 'Boo', signature='s', body=('/aa/bb/cc',) + )) + + # Argument is a prefix, ending with / + assert rule.matches(new_method_call( + portal_req_iface, 'Boo', signature='s', body=('/aa/',) + )) + + # Argument is a prefix, but NOT ending with / + assert not rule.matches(new_method_call( + portal_req_iface, 'Boo', signature='s', body=('/aa',) + )) + + assert not rule.matches(new_method_call( + portal_req_iface, 'Boo', signature='s', body=('/aa/bb',) + )) + + # Not a string + assert not rule.matches(new_method_call( + portal_req_iface, 'Boo', signature='u', body=(12,) + )) + + +def test_match_rule_arg_namespace(): + rule = MatchRule(member='NameOwnerChanged') + rule.add_arg_condition(0, 'com.example.backend1', kind='namespace') + + # Exact match + assert rule.matches(new_signal( + message_bus, 'NameOwnerChanged', 's', ('com.example.backend1',) + )) + + # Parent of the name + assert rule.matches(new_signal( + message_bus, 'NameOwnerChanged', 's', ('com.example.backend1.foo.bar',) + )) + + # Prefix but not a parent in the namespace + assert not rule.matches(new_signal( + message_bus, 'NameOwnerChanged', 's', ('com.example.backend12',) + )) + + # Not a string + assert not rule.matches(new_signal( + message_bus, 'NameOwnerChanged', 'u', (1,) + )) diff --git a/lib/jeepney/tests/test_fds.py b/lib/jeepney/tests/test_fds.py new file mode 100644 index 0000000..5d197fb --- /dev/null +++ b/lib/jeepney/tests/test_fds.py @@ -0,0 +1,80 @@ +import errno +import os +import socket + +import pytest + +from jeepney import FileDescriptor, NoFDError + +def assert_not_fd(fd: int): + """Check that the given number is not open as a file descriptor""" + with pytest.raises(OSError) as exc_info: + os.stat(fd) + assert exc_info.value.errno == errno.EBADF + + +def test_close(tmp_path): + fd = os.open(tmp_path / 'a', os.O_CREAT | os.O_RDWR) + + with FileDescriptor(fd) as wfd: + assert wfd.fileno() == fd + # Leaving the with block is equivalent to calling .close() + + assert 'closed' in repr(wfd) + with pytest.raises(NoFDError): + wfd.fileno() + + assert_not_fd(fd) + + +def test_to_raw_fd(tmp_path): + fd = os.open(tmp_path / 'a', os.O_CREAT) + wfd = FileDescriptor(fd) + assert wfd.fileno() == fd + + assert wfd.to_raw_fd() == fd + + try: + assert 'converted' in repr(wfd) + with pytest.raises(NoFDError): + wfd.fileno() + finally: + os.close(fd) + + +def test_to_file(tmp_path): + fd = os.open(tmp_path / 'a', os.O_CREAT | os.O_RDWR) + wfd = FileDescriptor(fd) + + with wfd.to_file('w') as f: + assert f.write('abc') + + assert 'converted' in repr(wfd) + with pytest.raises(NoFDError): + wfd.fileno() + + assert_not_fd(fd) # Check FD was closed by file object + + assert (tmp_path / 'a').read_text() == 'abc' + + +def test_to_socket(): + s1, s2 = socket.socketpair() + try: + s1.sendall(b'abcd') + sfd = s2.detach() + wfd = FileDescriptor(sfd) + + with wfd.to_socket() as sock: + b = sock.recv(16) + assert b and b'abcd'.startswith(b) + + assert 'converted' in repr(wfd) + with pytest.raises(NoFDError): + wfd.fileno() + + assert_not_fd(sfd) # Check FD was closed by socket object + finally: + s1.close() + + diff --git a/lib/jeepney/tests/test_low_level.py b/lib/jeepney/tests/test_low_level.py new file mode 100644 index 0000000..ce8b4ee --- /dev/null +++ b/lib/jeepney/tests/test_low_level.py @@ -0,0 +1,101 @@ +import pytest +from jeepney.low_level import * + +HELLO_METHOD_CALL = ( + b'l\x01\x00\x01\x00\x00\x00\x00\x01\x00\x00\x00m\x00\x00\x00\x01\x01o\x00\x15' + b'\x00\x00\x00/org/freedesktop/DBus\x00\x00\x00\x02\x01s\x00\x14\x00\x00\x00' + b'org.freedesktop.DBus\x00\x00\x00\x00\x03\x01s\x00\x05\x00\x00\x00Hello\x00' + b'\x00\x00\x06\x01s\x00\x14\x00\x00\x00org.freedesktop.DBus\x00\x00\x00\x00') + + +def test_parser_simple(): + msg = Parser().feed(HELLO_METHOD_CALL)[0] + assert msg.header.fields[HeaderFields.member] == 'Hello' + +def chunks(src, size): + pos = 0 + while pos < len(src): + end = pos + size + yield src[pos:end] + pos = end + +def test_parser_chunks(): + p = Parser() + chunked = list(chunks(HELLO_METHOD_CALL, 16)) + for c in chunked[:-1]: + assert p.feed(c) == [] + msg = p.feed(chunked[-1])[0] + assert msg.header.fields[HeaderFields.member] == 'Hello' + +def test_multiple(): + msgs = Parser().feed(HELLO_METHOD_CALL * 6) + assert len(msgs) == 6 + for msg in msgs: + assert msg.header.fields[HeaderFields.member] == 'Hello' + +def test_roundtrip(): + msg = Parser().feed(HELLO_METHOD_CALL)[0] + assert msg.serialise() == HELLO_METHOD_CALL + +def test_serialise_dict(): + data = { + 'a': 'b', + 'de': 'f', + } + string_type = simple_types['s'] + sig = Array(DictEntry([string_type, string_type])) + print(sig.serialise(data, 0, Endianness.little)) + assert sig.serialise(data, 0, Endianness.little) == ( + b'\x1e\0\0\0' + # Length + b'\0\0\0\0' + # Padding + b'\x01\0\0\0a\0\0\0' + + b'\x01\0\0\0b\0\0\0' + + b'\x02\0\0\0de\0\0' + + b'\x01\0\0\0f\0' + ) + +def test_parse_signature(): + sig = parse_signature(list('(a{sv}(oayays)b)')) + print(sig) + assert sig == Struct([ + Array(DictEntry([simple_types['s'], Variant()])), + Struct([ + simple_types['o'], + Array(simple_types['y']), + Array(simple_types['y']), + simple_types['s'] + ]), + simple_types['b'], + ]) + +class fake_list(list): + def __init__(self, n): + super().__init__() + self._n = n + + def __len__(self): + return self._n + + def __iter__(self): + return iter(range(self._n)) + +def test_array_limit(): + # The spec limits arrays to 64 MiB + a = Array(FixedType(8, 'Q')) # 'at' - array of uint64 + a.serialise(fake_list(100), 0, Endianness.little) + with pytest.raises(SizeLimitError): + a.serialise(fake_list(2**23 + 1), 0, Endianness.little) + + +def test_bad_object_path(): + with pytest.raises(ValueError): + ObjectPathType().check_data('org/freedesktop/DBus') + + with pytest.raises(ValueError): + ObjectPathType().check_data('/org/freedesktop/DBus/') + + with pytest.raises(ValueError): + ObjectPathType().check_data('/org//freedesktop/DBus') + + with pytest.raises(ValueError): + ObjectPathType().check_data('/org/freedesktop/DBüs') # Non-ASCII character diff --git a/lib/jeepney/tests/test_wrappers.py b/lib/jeepney/tests/test_wrappers.py new file mode 100644 index 0000000..636feef --- /dev/null +++ b/lib/jeepney/tests/test_wrappers.py @@ -0,0 +1,74 @@ +import pytest + +from jeepney.wrappers import * + +def test_bad_bus_name(): + obj = '/com/example/foo' + DBusAddress(obj, 'com.example.a') # Valid (well known name) + DBusAddress(obj, 'com.example.a-b') # Valid but discouraged + DBusAddress(obj, ':1.13') # Valid (unique name) + + with pytest.raises(ValueError, match='too long'): + DBusAddress(obj, 'com.example.' + ('a' * 256)) + + with pytest.raises(ValueError): + DBusAddress(obj, '.com.example.a') + + with pytest.raises(ValueError): + DBusAddress(obj, 'com..example.a') + + with pytest.raises(ValueError): + DBusAddress(obj, 'com.2example.a') + + with pytest.raises(ValueError): + DBusAddress(obj, 'cöm.example.a') # Non-ASCII character + + with pytest.raises(ValueError): + DBusAddress(obj, 'com') + +def test_bad_interface(): + obj = '/com/example/foo' + busname = 'com.example.foo' + DBusAddress(obj, 'com.example.a', 'com.example.a_b') # Valid + + with pytest.raises(ValueError, match='too long'): + DBusAddress(obj, 'com.example.a', 'com.example.' + ('a' * 256)) + + with pytest.raises(ValueError): + DBusAddress(obj, 'com.example.a', 'com.example.a-b') # No hyphens + + with pytest.raises(ValueError): + DBusAddress(obj, busname, '.com.example.a') + + with pytest.raises(ValueError): + DBusAddress(obj, busname, 'com..example.a') + + with pytest.raises(ValueError): + DBusAddress(obj, busname, 'com.2example.a') + + with pytest.raises(ValueError): + DBusAddress(obj, busname, 'cöm.example.a') # Non-ASCII character + + with pytest.raises(ValueError): + DBusAddress(obj, busname, 'com') + + +def test_bad_member_name(): + addr = DBusAddress( + '/org/freedesktop/DBus', + bus_name='org.freedesktop.DBus', + interface='org.freedesktop.DBus', + ) + new_method_call(addr, 'Hello') + + with pytest.raises(ValueError, match='too long'): + new_method_call(addr, 'Hell' + ('o' * 256)) + + with pytest.raises(ValueError): + new_method_call(addr, 'org.Hello') + + with pytest.raises(ValueError): + new_method_call(addr, '9Hello') + + with pytest.raises(ValueError): + new_method_call(addr, '') diff --git a/lib/jeepney/wrappers.py b/lib/jeepney/wrappers.py new file mode 100644 index 0000000..a3b62c5 --- /dev/null +++ b/lib/jeepney/wrappers.py @@ -0,0 +1,265 @@ +import re +from typing import Union +from warnings import warn + +from .low_level import * + +__all__ = [ + 'DBusAddress', + 'new_method_call', + 'new_method_return', + 'new_error', + 'new_signal', + 'MessageGenerator', + 'Properties', + 'Introspectable', + 'DBusErrorResponse', +] + +bus_name_pat = re.compile( + r'([A-Za-z_-][A-Za-z0-9_-]*(\.[A-Za-z_-][A-Za-z0-9_-]*)+' # Well known name + r'|:[A-Za-z0-9_-]+(\.[A-Za-z0-9_-]+))$', # Unique name +) + +def check_bus_name(name): + if len(name) > 255: + abbr = name[:8] + '...' + raise ValueError(f"Bus name ({abbr!r}) is too long (> 255 characters)") + if not bus_name_pat.match(name): + raise ValueError(f"Bus name ({name!r}) is not valid") + +interface_pat = re.compile(r'[A-Za-z_][A-Za-z0-9_]*(\.[A-Za-z_][A-Za-z0-9_]*)+$') + +def check_interface(name): + if len(name) > 255: + abbr = name[:8] + '...' + raise ValueError(f"Interface name ({abbr!r}) is too long (> 255 characters)") + if not interface_pat.match(name): + raise ValueError(f"Interface name ({name!r}) is not valid") + +member_name_pat = re.compile(r'[A-Za-z_][A-Za-z0-9_]*$') + +def check_member_name(name): + if len(name) > 255: + abbr = name[:8] + '...' + raise ValueError(f"Member name ({abbr!r}) is too long (> 255 characters)") + if not member_name_pat.match(name): + raise ValueError(f"Member name ({name!r} is not valid") + + +class DBusAddress: + """This identifies the object and interface a message is for. + + e.g. messages to display desktop notifications would have this address:: + + DBusAddress('/org/freedesktop/Notifications', + bus_name='org.freedesktop.Notifications', + interface='org.freedesktop.Notifications') + """ + def __init__(self, object_path, bus_name=None, interface=None): + ObjectPathType().check_data(object_path) + self.object_path = object_path + + if bus_name is not None: + check_bus_name(bus_name) + self.bus_name = bus_name + + if interface is not None: + check_interface(interface) + self.interface = interface + + def __repr__(self): + return '{}({!r}, bus_name={!r}, interface={!r})'.format(type(self).__name__, + self.object_path, self.bus_name, self.interface) + + def with_interface(self, interface): + check_interface(interface) + return type(self)(self.object_path, self.bus_name, interface) + +class DBusObject(DBusAddress): + def __init__(self, object_path, bus_name=None, interface=None): + super().__init__(object_path, bus_name, interface) + warn('Deprecated alias, use DBusAddress instead', stacklevel=2) + +def new_header(msg_type): + return Header(Endianness.little, msg_type, flags=0, protocol_version=1, + body_length=-1, serial=-1, fields={}) + +def new_method_call(remote_obj, method, signature=None, body=()): + """Construct a new method call message + + This is a relatively low-level method. In many cases, this will be called + from a :class:`MessageGenerator` subclass which provides a more convenient + API. + + :param DBusAddress remote_obj: The object to call a method on + :param str method: The name of the method to call + :param str signature: The DBus signature of the body data + :param tuple body: Body data (i.e. method parameters) + """ + check_member_name(method) + header = new_header(MessageType.method_call) + header.fields[HeaderFields.path] = remote_obj.object_path + if remote_obj.bus_name is None: + raise ValueError("remote_obj.bus_name cannot be None for method calls") + header.fields[HeaderFields.destination] = remote_obj.bus_name + if remote_obj.interface is not None: + header.fields[HeaderFields.interface] = remote_obj.interface + header.fields[HeaderFields.member] = method + if signature is not None: + header.fields[HeaderFields.signature] = signature + + return Message(header, body) + +def new_method_return(parent_msg, signature=None, body=()): + """Construct a new response message + + :param Message parent_msg: The method call this is a reply to + :param str signature: The DBus signature of the body data + :param tuple body: Body data + """ + header = new_header(MessageType.method_return) + header.fields[HeaderFields.reply_serial] = parent_msg.header.serial + sender = parent_msg.header.fields.get(HeaderFields.sender, None) + if sender is not None: + header.fields[HeaderFields.destination] = sender + if signature is not None: + header.fields[HeaderFields.signature] = signature + return Message(header, body) + +def new_error(parent_msg, error_name, signature=None, body=()): + """Construct a new error response message + + :param Message parent_msg: The method call this is a reply to + :param str error_name: The name of the error + :param str signature: The DBus signature of the body data + :param tuple body: Body data + """ + header = new_header(MessageType.error) + header.fields[HeaderFields.reply_serial] = parent_msg.header.serial + header.fields[HeaderFields.error_name] = error_name + sender = parent_msg.header.fields.get(HeaderFields.sender, None) + if sender is not None: + header.fields[HeaderFields.destination] = sender + if signature is not None: + header.fields[HeaderFields.signature] = signature + return Message(header, body) + +def new_signal(emitter, signal, signature=None, body=()): + """Construct a new signal message + + :param DBusAddress emitter: The object sending the signal + :param str signal: The name of the signal + :param str signature: The DBus signature of the body data + :param tuple body: Body data + """ + check_member_name(signal) + header = new_header(MessageType.signal) + header.fields[HeaderFields.path] = emitter.object_path + if emitter.interface is None: + raise ValueError("emitter.interface cannot be None for signals") + header.fields[HeaderFields.interface] = emitter.interface + header.fields[HeaderFields.member] = signal + if signature is not None: + header.fields[HeaderFields.signature] = signature + return Message(header, body) + + +class MessageGenerator: + """Subclass this to define the methods available on a DBus interface. + + jeepney.bindgen can automatically create subclasses using introspection. + """ + interface: Optional[str] = None + + def __init__(self, object_path, bus_name): + ObjectPathType().check_data(object_path) + check_bus_name(bus_name) + if self.interface is not None: + check_interface(self.interface) + + self.object_path = object_path + self.bus_name = bus_name + + def __repr__(self): + return "{}({!r}, bus_name={!r})".format(type(self).__name__, + self.object_path, self.bus_name) + + +class ProxyBase: + """A proxy is an IO-aware wrapper around a MessageGenerator + + Calling methods on a proxy object will send a message and wait for the + reply. This is a base class for proxy implementations in jeepney.io. + """ + def __init__(self, msggen): + self._msggen = msggen + + def __getattr__(self, item): + if item.startswith('__'): + raise AttributeError(item) + + make_msg = getattr(self._msggen, item, None) + if callable(make_msg): + return self._method_call(make_msg) + + raise AttributeError(item) + + def _method_call(self, make_msg): + raise NotImplementedError("Needs to be implemented in subclass") + +class Properties: + """Build messages for accessing object properties + + If a D-Bus object has multiple interfaces, each interface has its own + set of properties. + + This uses the standard DBus interface ``org.freedesktop.DBus.Properties`` + """ + def __init__(self, obj: Union[DBusAddress, MessageGenerator]): + self.obj = obj + self.props_if = DBusAddress(obj.object_path, bus_name=obj.bus_name, + interface='org.freedesktop.DBus.Properties') + + def get(self, name): + """Get the value of the property *name*""" + return new_method_call(self.props_if, 'Get', 'ss', + (self.obj.interface, name)) + + def get_all(self): + """Get all property values for this interface""" + return new_method_call(self.props_if, 'GetAll', 's', + (self.obj.interface,)) + + def set(self, name, signature, value): + """Set the property *name* to *value* (with appropriate signature)""" + return new_method_call(self.props_if, 'Set', 'ssv', + (self.obj.interface, name, (signature, value))) + +class Introspectable(MessageGenerator): + interface = 'org.freedesktop.DBus.Introspectable' + + def Introspect(self): + """Request D-Bus introspection XML for a remote object""" + return new_method_call(self, 'Introspect') + +class DBusErrorResponse(Exception): + """Raised by proxy method calls when the reply is an error message""" + def __init__(self, msg): + self.name = msg.header.fields.get(HeaderFields.error_name) + self.data = msg.body + + def __str__(self): + return '[{}] {}'.format(self.name, self.data) + + +def unwrap_msg(msg: Message): + """Get the body of a message, raising DBusErrorResponse for error messages + + This is to be used with replies to method_call messages, which may be + method_return or error. + """ + if msg.header.message_type == MessageType.error: + raise DBusErrorResponse(msg) + + return msg.body diff --git a/lib/keyring-25.7.0.dist-info/INSTALLER b/lib/keyring-25.7.0.dist-info/INSTALLER new file mode 100644 index 0000000..a1b589e --- /dev/null +++ b/lib/keyring-25.7.0.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/lib/keyring-25.7.0.dist-info/METADATA b/lib/keyring-25.7.0.dist-info/METADATA new file mode 100644 index 0000000..76689bd --- /dev/null +++ b/lib/keyring-25.7.0.dist-info/METADATA @@ -0,0 +1,554 @@ +Metadata-Version: 2.4 +Name: keyring +Version: 25.7.0 +Summary: Store and access your passwords safely. +Author-email: Kang Zhang +Maintainer-email: "Jason R. Coombs" +License-Expression: MIT +Project-URL: Source, https://github.com/jaraco/keyring +Classifier: Development Status :: 5 - Production/Stable +Classifier: Intended Audience :: Developers +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3 :: Only +Requires-Python: >=3.9 +Description-Content-Type: text/x-rst +License-File: LICENSE +Requires-Dist: pywin32-ctypes>=0.2.0; sys_platform == "win32" +Requires-Dist: SecretStorage>=3.2; sys_platform == "linux" +Requires-Dist: jeepney>=0.4.2; sys_platform == "linux" +Requires-Dist: importlib_metadata>=4.11.4; python_version < "3.12" +Requires-Dist: jaraco.classes +Requires-Dist: jaraco.functools +Requires-Dist: jaraco.context +Provides-Extra: test +Requires-Dist: pytest!=8.1.*,>=6; extra == "test" +Requires-Dist: pyfakefs; extra == "test" +Provides-Extra: doc +Requires-Dist: sphinx>=3.5; extra == "doc" +Requires-Dist: jaraco.packaging>=9.3; extra == "doc" +Requires-Dist: rst.linker>=1.9; extra == "doc" +Requires-Dist: furo; extra == "doc" +Requires-Dist: sphinx-lint; extra == "doc" +Requires-Dist: jaraco.tidelift>=1.4; extra == "doc" +Provides-Extra: check +Requires-Dist: pytest-checkdocs>=2.4; extra == "check" +Requires-Dist: pytest-ruff>=0.2.1; sys_platform != "cygwin" and extra == "check" +Provides-Extra: cover +Requires-Dist: pytest-cov; extra == "cover" +Provides-Extra: enabler +Requires-Dist: pytest-enabler>=3.4; extra == "enabler" +Provides-Extra: type +Requires-Dist: pytest-mypy>=1.0.1; extra == "type" +Requires-Dist: pygobject-stubs; extra == "type" +Requires-Dist: shtab; extra == "type" +Requires-Dist: types-pywin32; extra == "type" +Provides-Extra: completion +Requires-Dist: shtab>=1.1.0; extra == "completion" +Dynamic: license-file + +.. image:: https://img.shields.io/pypi/v/keyring.svg + :target: https://pypi.org/project/keyring + +.. image:: https://img.shields.io/pypi/pyversions/keyring.svg + +.. image:: https://github.com/jaraco/keyring/actions/workflows/main.yml/badge.svg + :target: https://github.com/jaraco/keyring/actions?query=workflow%3A%22tests%22 + :alt: tests + +.. image:: https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json + :target: https://github.com/astral-sh/ruff + :alt: Ruff + +.. image:: https://readthedocs.org/projects/keyring/badge/?version=latest + :target: https://keyring.readthedocs.io/en/latest/?badge=latest + +.. image:: https://img.shields.io/badge/skeleton-2025-informational + :target: https://blog.jaraco.com/skeleton + +.. image:: https://tidelift.com/badges/package/pypi/keyring + :target: https://tidelift.com/subscription/pkg/pypi-keyring?utm_source=pypi-keyring&utm_medium=readme + +.. image:: https://badges.gitter.im/jaraco/keyring.svg + :alt: Join the chat at https://gitter.im/jaraco/keyring + :target: https://gitter.im/jaraco/keyring?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge + +The Python keyring library provides an easy way to access the +system keyring service from python. It can be used in any +application that needs safe password storage. + +These recommended keyring backends are supported: + +* macOS `Keychain + `_ +* Freedesktop `Secret Service + `_ supports many DE including + GNOME (requires `secretstorage `_) +* KDE4 & KDE5 `KWallet `_ + (requires `dbus `_) +* `Windows Credential Locker + `_ + +Other keyring implementations are available through `Third-Party Backends`_. + +Installation - Linux +==================== + +On Linux, the KWallet backend relies on dbus-python_, which does not always +install correctly when using pip (compilation is needed). For best results, +install dbus-python as a system package. + +.. _dbus-python: https://gitlab.freedesktop.org/dbus/dbus-python + +Compatibility - macOS +===================== + +macOS keychain supports macOS 11 (Big Sur) and later requires Python 3.8.7 +or later with the "universal2" binary. See +`#525 `_ for details. + +Using Keyring +============= + +The basic usage of keyring is pretty simple: just call +``keyring.set_password`` and ``keyring.get_password``:: + + >>> import keyring + >>> keyring.set_password("system", "username", "password") + >>> keyring.get_password("system", "username") + 'password' + +Command-line Utility +-------------------- + +Keyring supplies a ``keyring`` command which is installed with the +package. After installing keyring in most environments, the +command should be available for setting, getting, and deleting +passwords. For more usage information, invoke with no arguments +or with ``--help`` as so:: + + $ keyring --help + $ keyring set system username + Password for 'username' in 'system': + $ keyring get system username + password + +The command-line functionality is also exposed as an executable +package, suitable for invoking from Python like so:: + + $ python -m keyring --help + $ python -m keyring set system username + Password for 'username' in 'system': + $ python -m keyring get system username + password + +Tab Completion +-------------- + +If installed via a package manager (apt, pacman, nix, homebrew, etc), +these shell completions may already have been distributed with the package +(no action required). + +Keyring provides tab completion if the ``completion`` extra is installed:: + + $ pip install 'keyring[completion]' + +Then, generate shell completions, something like:: + + $ keyring --print-completion bash | sudo tee /usr/share/bash-completion/completions/keyring + $ keyring --print-completion zsh | sudo tee /usr/share/zsh/site-functions/_keyring + $ keyring --print-completion tcsh | sudo tee /etc/profile.d/keyring.csh + +**Note**: the path of `/usr/share` is mainly for GNU/Linux. For other OSs, +consider: + +- macOS (Homebrew x86): /usr/local/share +- macOS (Homebrew ARM): /opt/homebrew/share +- Android (Termux): /data/data/com.termux/files/usr/share +- Windows (mingw64 of msys2): /mingw64/share +- ... + +After installing the shell completions, enable them following your shell's +recommended instructions. e.g.: + +- bash: install `bash-completion `_, + and ensure ``. /usr/share/bash-completion/bash_completion`` in ``~/.bashrc``. +- zsh: ensure ``autoload -Uz compinit && compinit`` appears in ``~/.zshrc``, + then ``grep -w keyring ~/.zcompdump`` to verify keyring appears, indicating + it was installed correctly. + +Configuring +=========== + +The python keyring lib contains implementations for several backends. The +library will attempt to +automatically choose the most suitable backend for the current +environment. Users may also specify the preferred keyring in a +config file or by calling the ``set_keyring()`` function. + +Config file path +---------------- + +The configuration is stored in a file named "keyringrc.cfg" +found in a platform-specific location. To determine +where the config file is stored, run ``keyring diagnose``. + +Config file content +------------------- + +To specify a keyring backend, set the **default-keyring** option to the +full path of the class for that backend, such as +``keyring.backends.macOS.Keyring``. + +If **keyring-path** is indicated, keyring will add that path to the Python +module search path before loading the backend. + +For example, this config might be used to load the +``SimpleKeyring`` from the ``simplekeyring`` module in +the ``./demo`` directory (not implemented):: + + [backend] + default-keyring=simplekeyring.SimpleKeyring + keyring-path=demo + +Third-Party Backends +==================== + +In addition to the backends provided by the core keyring package for +the most common and secure use cases, there +are additional keyring backend implementations available for other +use cases. Simply install them to make them available: + +- `keyrings.cryptfile `_ + - Encrypted text file storage. +- `keyrings.alt `_ - "alternate", + possibly-insecure backends, originally part of the core package, but + available for opt-in. +- `gsheet-keyring `_ + - a backend that stores secrets in a Google Sheet. For use with + `ipython-secrets `_. +- `bitwarden-keyring `_ + - a backend that stores secrets in the `BitWarden `_ + password manager. +- `onepassword-keyring `_ + - a backend that stores secrets in the `1Password `_ password manager. +- `sagecipher `_ - an encryption + backend which uses the ssh agent protocol's signature operation to + derive the cipher key. +- `keyrings.osx_keychain_keys `_ + - OSX keychain key-management, for private, public, and symmetric keys. +- `keyring_pass.PasswordStoreBackend `_ + - Password Store (pass) backend for python's keyring +- `keyring_jeepney `__ - a + pure Python backend using the secret service DBus API for desktop + Linux (requires ``keyring<24``). + + +Write your own keyring backend +============================== + +The interface for the backend is defined by ``keyring.backend.KeyringBackend``. +Every backend should derive from that base class and define a ``priority`` +attribute and three functions: ``get_password()``, ``set_password()``, and +``delete_password()``. The ``get_credential()`` function may be defined if +desired. + +See the ``backend`` module for more detail on the interface of this class. + +Keyring employs entry points to allow any third-party package to implement +backends without any modification to the keyring itself. Those interested in +creating new backends are encouraged to create new, third-party packages +in the ``keyrings`` namespace, in a manner modeled by the `keyrings.alt +package `_. See the +``setup.cfg`` file +in that project for hints on how to create the requisite entry points. +Backends that prove essential may be considered for inclusion in the core +library, although the ease of installing these third-party packages should +mean that extensions may be readily available. + +To create an extension for Keyring, please submit a pull request to +have your extension mentioned as an available extension. + +Runtime Configuration +===================== + +Keyring additionally allows programmatic configuration of the +backend calling the api ``set_keyring()``. The indicated backend +will subsequently be used to store and retrieve passwords. + +To invoke ``set_keyring``:: + + # define a new keyring class which extends the KeyringBackend + import keyring.backend + + class TestKeyring(keyring.backend.KeyringBackend): + """A test keyring which always outputs the same password + """ + priority = 1 + + def set_password(self, servicename, username, password): + pass + + def get_password(self, servicename, username): + return "password from TestKeyring" + + def delete_password(self, servicename, username): + pass + + # set the keyring for keyring lib + keyring.set_keyring(TestKeyring()) + + # invoke the keyring lib + try: + keyring.set_password("demo-service", "tarek", "passexample") + print("password stored successfully") + except keyring.errors.PasswordSetError: + print("failed to store password") + print("password", keyring.get_password("demo-service", "tarek")) + + +Disabling Keyring +================= + +In many cases, uninstalling keyring will never be necessary. +Especially on Windows and macOS, the behavior of keyring is +usually degenerate, meaning it will return empty values to +the caller, allowing the caller to fall back to some other +behavior. + +In some cases, the default behavior of keyring is undesirable and +it would be preferable to disable the keyring behavior altogether. +There are several mechanisms to disable keyring: + +- Uninstall keyring. Most applications are tolerant to keyring + not being installed. Uninstalling keyring should cause those + applications to fall back to the behavior without keyring. + This approach affects the Python environment where keyring + would otherwise have been installed. + +- Configure the Null keyring in the environment. Set + ``PYTHON_KEYRING_BACKEND=keyring.backends.null.Keyring`` + in the environment, and the ``Null`` (degenerate) backend + will be used. This approach affects all uses of Keyring where + that variable is set. + +- Permanently configure the Null keyring for the user by running + ``keyring --disable`` or ``python -m keyring --disable``. + This approach affects all uses of keyring for that user. + + +Altering Keyring Behavior +========================= + +Keyring provides a mechanism to alter the keyring's behavior through +environment variables. Each backend implements a +``KeyringBackend.set_properties_from_env``, which +when invoked will find all environment variables beginning with +``KEYRING_PROPERTY_{NAME}`` and will set a property for each +``{NAME.lower()}`` on the keyring. This method is invoked during +initialization for the default/configured keyring. + +This mechanism may be used to set some useful values on various +keyrings, including: + +- keychain; macOS, path to an alternate keychain file +- appid; Linux/SecretService, alternate ID for the application + + +Using Keyring on Ubuntu 16.04 +============================= + +The following is a complete transcript for installing keyring in a +virtual environment on Ubuntu 16.04. No config file was used:: + + $ sudo apt install python3-venv libdbus-glib-1-dev + $ cd /tmp + $ pyvenv py3 + $ source py3/bin/activate + $ pip install -U pip + $ pip install secretstorage dbus-python + $ pip install keyring + $ python + >>> import keyring + >>> keyring.get_keyring() + + >>> keyring.set_password("system", "username", "password") + >>> keyring.get_password("system", "username") + 'password' + + +Using Keyring on headless Linux systems +======================================= + +It is possible to use the SecretService backend on Linux systems without +X11 server available (only D-Bus is required). In this case: + +* Install the `GNOME Keyring`_ daemon. +* Start a D-Bus session, e.g. run ``dbus-run-session -- sh`` and run + the following commands inside that shell. +* Run ``gnome-keyring-daemon`` with ``--unlock`` option. The description of + that option says: + + Read a password from stdin, and use it to unlock the login keyring + or create it if the login keyring does not exist. + + When that command is started, enter a password into stdin and + press Ctrl+D (end of data). After that, the daemon will fork into + the background (use ``--foreground`` option to block). +* Now you can use the SecretService backend of Keyring. Remember to + run your application in the same D-Bus session as the daemon. + +.. _GNOME Keyring: https://wiki.gnome.org/Projects/GnomeKeyring + +Using Keyring on headless Linux systems in a Docker container +============================================================= + +It is possible to use keyring with the SecretService backend in Docker containers as well. +All you need to do is install the necessary dependencies and add the `--privileged` flag +to avoid any `Operation not permitted` errors when attempting to unlock the system's keyring. + +The following is a complete transcript for installing keyring on a Ubuntu 18:04 container:: + + docker run -it -d --privileged ubuntu:18.04 + + $ apt-get update + $ apt install -y gnome-keyring python3-venv python3-dev + $ python3 -m venv venv + $ source venv/bin/activate # source a virtual environment to avoid polluting your system + $ pip3 install --upgrade pip + $ pip3 install keyring + $ dbus-run-session -- sh # this will drop you into a new D-bus shell + $ echo 'somecredstorepass' | gnome-keyring-daemon --unlock # unlock the system's keyring + + $ python + >>> import keyring + >>> keyring.get_keyring() + + >>> keyring.set_password("system", "username", "password") + >>> keyring.get_password("system", "username") + 'password' + +Using Keyring with tox +====================== + +Some backends rely on environment variables to operate correctly, and ``tox`` filters most environment variables by default. + +For example, when using Keyring to store credentials for pip, one may encounter the following error when +running tests under ``tox`` when using a backend reliant on D-Bus: + + RuntimeError: No recommended backend was available. Install the keyrings.alt package if you want to use the non-recommended backends. See README.rst for details. + +This error is caused by Keyring KWallet backend not able to resolve the backing service. + +To work around the issue, add ``DBUS_SESSION_BUS_ADDRESS`` to ``pass_env`` in the +``tox`` configuration. Consider adding other necessary variables, such as ``DISPLAY`` and ``WAYLAND_DISPLAY`` (if using ``pinentry``). + +Integration +=========== + +API +--- + +The keyring lib has a few functions: + +* ``get_keyring()``: Return the currently-loaded keyring implementation. +* ``get_password(service, username)``: Returns the password stored in the + active keyring. If the password does not exist, it will return None. +* ``get_credential(service, username)``: Return a credential object stored + in the active keyring. This object contains at least ``username`` and + ``password`` attributes for the specified service, where the returned + ``username`` may be different from the argument. +* ``set_password(service, username, password)``: Store the password in the + keyring. +* ``delete_password(service, username)``: Delete the password stored in + keyring. If the password does not exist, it will raise an exception. + +In all cases, the parameters (``service``, ``username``, ``password``) +should be Unicode text. + + +Exceptions +---------- + +The keyring lib raises the following exceptions: + +* ``keyring.errors.KeyringError``: Base Error class for all exceptions in keyring lib. +* ``keyring.errors.InitError``: Raised when the keyring cannot be initialized. +* ``keyring.errors.PasswordSetError``: Raised when the password cannot be set in the keyring. +* ``keyring.errors.PasswordDeleteError``: Raised when the password cannot be deleted in the keyring. + +Get Involved +============ + +Python keyring lib is an open community project and eagerly +welcomes contributors. + +* Repository: https://github.com/jaraco/keyring/ +* Bug Tracker: https://github.com/jaraco/keyring/issues/ +* Mailing list: http://groups.google.com/group/python-keyring + +Security Considerations +======================= + +Each built-in backend may have security considerations to understand +before using this library. Authors of tools or libraries utilizing +``keyring`` are encouraged to consider these concerns. + +As with any list of known security concerns, this list is not exhaustive. +Additional issues can be added as needed. + +- macOS Keychain + - Any Python script or application can access secrets created by + ``keyring`` from that same Python executable without the operating + system prompting the user for a password. To cause any specific + secret to prompt for a password every time it is accessed, locate + the credential using the ``Keychain Access`` application, and in + the ``Access Control`` settings, remove ``Python`` from the list + of allowed applications. + +- Freedesktop Secret Service + - No analysis has been performed + +- KDE4 & KDE5 KWallet + - No analysis has been performed + +- Windows Credential Locker + - No analysis has been performed + +Making Releases +=============== + +This project makes use of automated releases and continuous +integration. The +simple workflow is to tag a commit and push it to Github. If it +passes tests in CI, it will be automatically deployed to PyPI. + +Other things to consider when making a release: + +- Check that the changelog is current for the intended release. + +Running Tests +============= + +Tests are continuously run in Github Actions. + +To run the tests locally, install and invoke +`tox `_. + +Background +========== + +The project was based on Tarek Ziade's idea in `this post`_. Kang Zhang +initially carried it out as a `Google Summer of Code`_ project, and Tarek +mentored Kang on this project. + +.. _this post: http://tarekziade.wordpress.com/2009/03/27/pycon-hallway-session-1-a-keyring-library-for-python/ +.. _Google Summer of Code: http://socghop.appspot.com/ + +For Enterprise +============== + +Available as part of the Tidelift Subscription. + +This project and the maintainers of thousands of other packages are working with Tidelift to deliver one enterprise subscription that covers all of the open source you use. + +`Learn more `_. diff --git a/lib/keyring-25.7.0.dist-info/RECORD b/lib/keyring-25.7.0.dist-info/RECORD new file mode 100644 index 0000000..f0863be --- /dev/null +++ b/lib/keyring-25.7.0.dist-info/RECORD @@ -0,0 +1,68 @@ +../../bin/keyring,sha256=k2jzr6sNOFrQzL2whxOi7SHxOw1QaRbmHwC8EU32iQ4,157 +keyring-25.7.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +keyring-25.7.0.dist-info/METADATA,sha256=vVTemP7ebcPh882JtON8ldiEKlI727nlKrQO3_GDcWM,21447 +keyring-25.7.0.dist-info/RECORD,, +keyring-25.7.0.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +keyring-25.7.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91 +keyring-25.7.0.dist-info/entry_points.txt,sha256=8ibyc9zH2ST1JDZHWlQZHEUPx9kVaXfVy8z5af_6OUk,334 +keyring-25.7.0.dist-info/licenses/LICENSE,sha256=WlfLTbheKi3YjCkGKJCK3VfjRRRJ4KmnH9-zh3b9dZ0,1076 +keyring-25.7.0.dist-info/top_level.txt,sha256=ohh1dke28_NdSNkZ6nkVSwIKkLJTOwIfEwnXKva3pkg,8 +keyring/__init__.py,sha256=4bk66hxOsw5JRhyy4I9U8c_VXK-pLusB-YB-aS86ot0,271 +keyring/__main__.py,sha256=vB_vOSk4pIZrkevBQeHXy6GYv7Nd0_vieKe44Xf1i9g,71 +keyring/__pycache__/__init__.cpython-314.pyc,, +keyring/__pycache__/__main__.cpython-314.pyc,, +keyring/__pycache__/backend.cpython-314.pyc,, +keyring/__pycache__/cli.cpython-314.pyc,, +keyring/__pycache__/completion.cpython-314.pyc,, +keyring/__pycache__/core.cpython-314.pyc,, +keyring/__pycache__/credentials.cpython-314.pyc,, +keyring/__pycache__/devpi_client.cpython-314.pyc,, +keyring/__pycache__/errors.cpython-314.pyc,, +keyring/__pycache__/http.cpython-314.pyc,, +keyring/backend.py,sha256=hg5qqlLy2K_KSh2sZ6BM_nFbgIKjFhjz5iJwwsdqIHs,9069 +keyring/backend_complete.bash,sha256=I3bRA3fGR_duzLrJyki94CaxxnelhiiXYyXLvUmlbec,397 +keyring/backend_complete.zsh,sha256=Je9QAn0CbF8_8ssGSkroa4HMcJDB3g20yL8XhhW50fI,451 +keyring/backends/SecretService.py,sha256=qt9lQpa8h6rGnjzTOE8GMIDH2e2J40RIhV3yc1TXSsc,4712 +keyring/backends/Windows.py,sha256=2pi3LSV2RCwXrLYeNplIUVJgPLH5uMnyYcSBgo-6kmw,5727 +keyring/backends/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +keyring/backends/__pycache__/SecretService.cpython-314.pyc,, +keyring/backends/__pycache__/Windows.cpython-314.pyc,, +keyring/backends/__pycache__/__init__.cpython-314.pyc,, +keyring/backends/__pycache__/chainer.cpython-314.pyc,, +keyring/backends/__pycache__/fail.cpython-314.pyc,, +keyring/backends/__pycache__/kwallet.cpython-314.pyc,, +keyring/backends/__pycache__/libsecret.cpython-314.pyc,, +keyring/backends/__pycache__/null.cpython-314.pyc,, +keyring/backends/chainer.py,sha256=-hhe-UWbCn0PAUK-00cWjHz_JJNQf_N4OyHUn89yCOw,2175 +keyring/backends/fail.py,sha256=ef5uP3Ddj2apq2pe08LXI2lLgpkmN0UrKZmOx58UHIU,914 +keyring/backends/kwallet.py,sha256=Le-bwfJVN7dNUiMLYLE66e0HzM5gmJZpXnmLQkDlCEo,5824 +keyring/backends/libsecret.py,sha256=gWeUveE44wZH0j7t2w2L-leYMpJOEHV0OqSUiC-sHQE,5942 +keyring/backends/macOS/__init__.py,sha256=-CIONvwrJFbeuj60opbCMZw4wWtiGyHuGCshocd4Ndg,2589 +keyring/backends/macOS/__pycache__/__init__.cpython-314.pyc,, +keyring/backends/macOS/__pycache__/api.cpython-314.pyc,, +keyring/backends/macOS/api.py,sha256=eikiBaGcYCQpqDsNdLy8wNoB_nFBYfY41j_38vsMKpo,4576 +keyring/backends/null.py,sha256=HW-Ovygh78UebL-ICPTilmCOk37h5WFPvVlMnNP8ElA,438 +keyring/cli.py,sha256=B9084Rmlt4atfQCw2qugMmovVQzeFjkeLRf6vTNcMTI,6605 +keyring/compat/__init__.py,sha256=WXWOxJd1wdBdrTNjKqjt8jOmfIahcIipDahbqdlQ6g8,169 +keyring/compat/__pycache__/__init__.cpython-314.pyc,, +keyring/compat/__pycache__/properties.cpython-314.pyc,, +keyring/compat/__pycache__/py312.cpython-314.pyc,, +keyring/compat/properties.py,sha256=JTlR3v7A5AgK93grI2nIW1sj0efYePgWQURDsWHwzj4,3886 +keyring/compat/py312.py,sha256=euMz5d91tbdrG2JkpoqDu3bBg3Pjzd3pEyWVxSK4IkA,159 +keyring/completion.py,sha256=MSj0qPtLAhhN9kSk34LRzGSYIhS19aG05wlYl_RHG_Q,1450 +keyring/core.py,sha256=2zEOVKitYardvqPDHzMFCRfIB812cuXLbIVh9udbxc0,5848 +keyring/credentials.py,sha256=PWFUzeAEX9FqjYonSIST4y6WHqQ2lKceLcvicKSaipY,2092 +keyring/devpi_client.py,sha256=IpkyYAso0BH9tXpsZ3K1UjJG_Obtj6kTflrpDatNzoQ,603 +keyring/errors.py,sha256=hiHZxG3e1WABMDw80iT0Yg6qrccaVuVUpTNFK7iVmnY,1625 +keyring/http.py,sha256=udH83q5BIrfKYm-4AOuefQ3Avb-J9UbpXBYu49Ik_iA,1214 +keyring/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +keyring/testing/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +keyring/testing/__pycache__/__init__.cpython-314.pyc,, +keyring/testing/__pycache__/backend.cpython-314.pyc,, +keyring/testing/__pycache__/util.cpython-314.pyc,, +keyring/testing/backend.py,sha256=HuCE8NL1rXMIZBrFELce2aO-N5pY3UEtQLDsNdCgvyA,7551 +keyring/testing/util.py,sha256=O15JsfcLIBcnsF1O8LfnbWkeEuiEfbovzQ1h8oN7XUA,1884 +keyring/util/__init__.py,sha256=ilEB7cz4cWl7acmrubGF9142ZeBer1mFqaL0U-7UXAc,302 +keyring/util/__pycache__/__init__.cpython-314.pyc,, +keyring/util/__pycache__/platform_.cpython-314.pyc,, +keyring/util/platform_.py,sha256=lhsGKWZobEvsztNOkotUoNqiHUhJ7G4ENCfdDwp2wVA,1092 diff --git a/lib/keyring-25.7.0.dist-info/REQUESTED b/lib/keyring-25.7.0.dist-info/REQUESTED new file mode 100644 index 0000000..e69de29 diff --git a/lib/keyring-25.7.0.dist-info/WHEEL b/lib/keyring-25.7.0.dist-info/WHEEL new file mode 100644 index 0000000..e7fa31b --- /dev/null +++ b/lib/keyring-25.7.0.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: setuptools (80.9.0) +Root-Is-Purelib: true +Tag: py3-none-any + diff --git a/lib/keyring-25.7.0.dist-info/entry_points.txt b/lib/keyring-25.7.0.dist-info/entry_points.txt new file mode 100644 index 0000000..802929d --- /dev/null +++ b/lib/keyring-25.7.0.dist-info/entry_points.txt @@ -0,0 +1,13 @@ +[console_scripts] +keyring = keyring.cli:main + +[devpi_client] +keyring = keyring.devpi_client + +[keyring.backends] +KWallet = keyring.backends.kwallet +SecretService = keyring.backends.SecretService +Windows = keyring.backends.Windows +chainer = keyring.backends.chainer +libsecret = keyring.backends.libsecret +macOS = keyring.backends.macOS diff --git a/lib/keyring-25.7.0.dist-info/licenses/LICENSE b/lib/keyring-25.7.0.dist-info/licenses/LICENSE new file mode 100644 index 0000000..f60bd57 --- /dev/null +++ b/lib/keyring-25.7.0.dist-info/licenses/LICENSE @@ -0,0 +1,18 @@ +MIT License + +Copyright (c) 2025 + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and +associated documentation files (the "Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the +following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial +portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT +LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO +EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE +USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/lib/keyring-25.7.0.dist-info/top_level.txt b/lib/keyring-25.7.0.dist-info/top_level.txt new file mode 100644 index 0000000..d6fa9c2 --- /dev/null +++ b/lib/keyring-25.7.0.dist-info/top_level.txt @@ -0,0 +1 @@ +keyring diff --git a/lib/keyring/__init__.py b/lib/keyring/__init__.py new file mode 100644 index 0000000..e1ee7a8 --- /dev/null +++ b/lib/keyring/__init__.py @@ -0,0 +1,17 @@ +from .core import ( + delete_password, + get_credential, + get_keyring, + get_password, + set_keyring, + set_password, +) + +__all__ = ( + 'set_keyring', + 'get_keyring', + 'set_password', + 'get_password', + 'delete_password', + 'get_credential', +) diff --git a/lib/keyring/__main__.py b/lib/keyring/__main__.py new file mode 100644 index 0000000..5dd75f4 --- /dev/null +++ b/lib/keyring/__main__.py @@ -0,0 +1,4 @@ +if __name__ == '__main__': + from keyring import cli + + cli.main() diff --git a/lib/keyring/__pycache__/__init__.cpython-314.pyc b/lib/keyring/__pycache__/__init__.cpython-314.pyc new file mode 100644 index 0000000..f1ba56c Binary files /dev/null and b/lib/keyring/__pycache__/__init__.cpython-314.pyc differ diff --git a/lib/keyring/__pycache__/__main__.cpython-314.pyc b/lib/keyring/__pycache__/__main__.cpython-314.pyc new file mode 100644 index 0000000..ef37aa7 Binary files /dev/null and b/lib/keyring/__pycache__/__main__.cpython-314.pyc differ diff --git a/lib/keyring/__pycache__/backend.cpython-314.pyc b/lib/keyring/__pycache__/backend.cpython-314.pyc new file mode 100644 index 0000000..265cc6f Binary files /dev/null and b/lib/keyring/__pycache__/backend.cpython-314.pyc differ diff --git a/lib/keyring/__pycache__/cli.cpython-314.pyc b/lib/keyring/__pycache__/cli.cpython-314.pyc new file mode 100644 index 0000000..50150ae Binary files /dev/null and b/lib/keyring/__pycache__/cli.cpython-314.pyc differ diff --git a/lib/keyring/__pycache__/completion.cpython-314.pyc b/lib/keyring/__pycache__/completion.cpython-314.pyc new file mode 100644 index 0000000..4226811 Binary files /dev/null and b/lib/keyring/__pycache__/completion.cpython-314.pyc differ diff --git a/lib/keyring/__pycache__/core.cpython-314.pyc b/lib/keyring/__pycache__/core.cpython-314.pyc new file mode 100644 index 0000000..d9f8c4c Binary files /dev/null and b/lib/keyring/__pycache__/core.cpython-314.pyc differ diff --git a/lib/keyring/__pycache__/credentials.cpython-314.pyc b/lib/keyring/__pycache__/credentials.cpython-314.pyc new file mode 100644 index 0000000..ad5cc11 Binary files /dev/null and b/lib/keyring/__pycache__/credentials.cpython-314.pyc differ diff --git a/lib/keyring/__pycache__/devpi_client.cpython-314.pyc b/lib/keyring/__pycache__/devpi_client.cpython-314.pyc new file mode 100644 index 0000000..6bb826c Binary files /dev/null and b/lib/keyring/__pycache__/devpi_client.cpython-314.pyc differ diff --git a/lib/keyring/__pycache__/errors.cpython-314.pyc b/lib/keyring/__pycache__/errors.cpython-314.pyc new file mode 100644 index 0000000..d84f4f8 Binary files /dev/null and b/lib/keyring/__pycache__/errors.cpython-314.pyc differ diff --git a/lib/keyring/__pycache__/http.cpython-314.pyc b/lib/keyring/__pycache__/http.cpython-314.pyc new file mode 100644 index 0000000..58ed5c2 Binary files /dev/null and b/lib/keyring/__pycache__/http.cpython-314.pyc differ diff --git a/lib/keyring/backend.py b/lib/keyring/backend.py new file mode 100644 index 0000000..4f91e16 --- /dev/null +++ b/lib/keyring/backend.py @@ -0,0 +1,300 @@ +""" +Keyring implementation support +""" + +from __future__ import annotations + +import abc +import copy +import functools +import logging +import operator +import os +import typing +import warnings + +from jaraco.context import ExceptionTrap +from jaraco.functools import once + +from . import credentials, errors, util +from .compat import properties +from .compat.py312 import metadata + +log = logging.getLogger(__name__) + + +by_priority = operator.attrgetter('priority') +_limit: typing.Callable[[KeyringBackend], bool] | None = None + + +class KeyringBackendMeta(abc.ABCMeta): + """ + Specialized subclass behavior. + + Keeps a registry of all (non-abstract) types. + + Wraps set_password to validate the username. + """ + + def __init__(cls, name, bases, dict): + super().__init__(name, bases, dict) + cls._register() + cls._validate_username_in_set_password() + + def _register(cls): + if not hasattr(cls, '_classes'): + cls._classes = set() + classes = cls._classes + if not cls.__abstractmethods__: + classes.add(cls) + + def _validate_username_in_set_password(cls): + """ + Wrap ``set_password`` such to validate the passed username. + """ + orig = cls.set_password + + @functools.wraps(orig) + def wrapper(self, system, username, *args, **kwargs): + self._validate_username(username) + return orig(self, system, username, *args, **kwargs) + + cls.set_password = wrapper + + +class KeyringBackend(metaclass=KeyringBackendMeta): + """The abstract base class of the keyring, every backend must implement + this interface. + """ + + def __init__(self): + self.set_properties_from_env() + + @properties.classproperty + def priority(self) -> float: + """ + Each backend class must supply a priority, a number (float or integer) + indicating the priority of the backend relative to all other backends. + The priority need not be static -- it may (and should) vary based + attributes of the environment in which is runs (platform, available + packages, etc.). + + A higher number indicates a higher priority. The priority should raise + a RuntimeError with a message indicating the underlying cause if the + backend is not suitable for the current environment. + + As a rule of thumb, a priority between zero but less than one is + suitable, but a priority of one or greater is recommended. + """ + raise NotImplementedError + + # Python 3.8 compatibility + passes = ExceptionTrap().passes + + @properties.classproperty + @passes + def viable(cls): + cls.priority # noqa: B018 + + @classmethod + def get_viable_backends( + cls: type[KeyringBackend], + ) -> filter[type[KeyringBackend]]: + """ + Return all subclasses deemed viable. + """ + return filter(operator.attrgetter('viable'), cls._classes) + + @properties.classproperty + def name(cls) -> str: + """ + The keyring name, suitable for display. + + The name is derived from module and class name. + """ + parent, sep, mod_name = cls.__module__.rpartition('.') + mod_name = mod_name.replace('_', ' ') + # mypy doesn't see `cls` is `type[Self]`, might be fixable in jaraco.classes + return ' '.join([mod_name, cls.__name__]) # type: ignore[attr-defined] + + def __str__(self) -> str: + keyring_class = type(self) + return f"{keyring_class.__module__}.{keyring_class.__name__} (priority: {keyring_class.priority:g})" + + @abc.abstractmethod + def get_password(self, service: str, username: str) -> str | None: + """Get password of the username for the service""" + return None + + def _validate_username(self, username: str) -> None: + """ + Ensure the username is not empty. + """ + if not username: + warnings.warn( + "Empty usernames are deprecated. See #668", + DeprecationWarning, + stacklevel=3, + ) + # raise ValueError("Username cannot be empty") + + @abc.abstractmethod + def set_password(self, service: str, username: str, password: str) -> None: + """Set password for the username of the service. + + If the backend cannot store passwords, raise + PasswordSetError. + """ + raise errors.PasswordSetError("reason") + + # for backward-compatibility, don't require a backend to implement + # delete_password + # @abc.abstractmethod + def delete_password(self, service: str, username: str) -> None: + """Delete the password for the username of the service. + + If the backend cannot delete passwords, raise + PasswordDeleteError. + """ + raise errors.PasswordDeleteError("reason") + + # for backward-compatibility, don't require a backend to implement + # get_credential + # @abc.abstractmethod + def get_credential( + self, + service: str, + username: str | None, + ) -> credentials.Credential | None: + """Gets the username and password for the service. + Returns a Credential instance. + + The *username* argument is optional and may be omitted by + the caller or ignored by the backend. Callers must use the + returned username. + """ + # The default implementation requires a username here. + if username is not None: + password = self.get_password(service, username) + if password is not None: + return credentials.SimpleCredential(username, password) + return None + + def set_properties_from_env(self) -> None: + """For all KEYRING_PROPERTY_* env var, set that property.""" + + def parse(item: tuple[str, str]): + key, value = item + pre, sep, name = key.partition('KEYRING_PROPERTY_') + return sep and (name.lower(), value) + + props: filter[tuple[str, str]] = filter(None, map(parse, os.environ.items())) + for name, value in props: + setattr(self, name, value) + + def with_properties(self, **kwargs: typing.Any) -> KeyringBackend: + alt = copy.copy(self) + vars(alt).update(kwargs) + return alt + + +class Crypter: + """Base class providing encryption and decryption""" + + @abc.abstractmethod + def encrypt(self, value): + """Encrypt the value.""" + pass + + @abc.abstractmethod + def decrypt(self, value): + """Decrypt the value.""" + pass + + +class NullCrypter(Crypter): + """A crypter that does nothing""" + + def encrypt(self, value): + return value + + def decrypt(self, value): + return value + + +def _load_plugins() -> None: + """ + Locate all setuptools entry points by the name 'keyring backends' + and initialize them. + Any third-party library may register an entry point by adding the + following to their setup.cfg:: + + [options.entry_points] + keyring.backends = + plugin_name = mylib.mymodule:initialize_func + + `plugin_name` can be anything, and is only used to display the name + of the plugin at initialization time. + + `initialize_func` is optional, but will be invoked if callable. + """ + for ep in metadata.entry_points(group='keyring.backends'): + try: + log.debug('Loading %s', ep.name) + init_func = ep.load() + if callable(init_func): + init_func() + except Exception: + log.exception(f"Error initializing plugin {ep}.") + + +@once +def get_all_keyring() -> list[KeyringBackend]: + """ + Return a list of all implemented keyrings that can be constructed without + parameters. + """ + _load_plugins() + viable_classes = KeyringBackend.get_viable_backends() + rings = util.suppress_exceptions(viable_classes, exceptions=TypeError) + return list(rings) + + +class SchemeSelectable: + """ + Allow a backend to select different "schemes" for the + username and service. + + >>> backend = SchemeSelectable() + >>> backend._query('contoso', 'alice') + {'username': 'alice', 'service': 'contoso'} + >>> backend._query('contoso') + {'service': 'contoso'} + >>> backend.scheme = 'KeePassXC' + >>> backend._query('contoso', 'alice') + {'UserName': 'alice', 'Title': 'contoso'} + >>> backend._query('contoso', 'alice', foo='bar') + {'UserName': 'alice', 'Title': 'contoso', 'foo': 'bar'} + """ + + scheme = 'default' + schemes = dict( + default=dict(username='username', service='service'), + KeePassXC=dict(username='UserName', service='Title'), + ) + + def _query( + self, service: str, username: str | None = None, **base: typing.Any + ) -> dict[str, str]: + scheme = self.schemes[self.scheme] + return dict( + { + scheme['username']: username, + scheme['service']: service, + } + if username is not None + else { + scheme['service']: service, + }, + **base, + ) diff --git a/lib/keyring/backend_complete.bash b/lib/keyring/backend_complete.bash new file mode 100644 index 0000000..1248d95 --- /dev/null +++ b/lib/keyring/backend_complete.bash @@ -0,0 +1,14 @@ +# Complete keyring backends for `keyring -b` from `keyring --list-backends` +# # keyring -b +# keyring.backends.chainer.ChainerBackend keyring.backends.fail.Keyring ... + +_keyring_backends() { + local choices + choices=$( + "${COMP_WORDS[0]}" --list-backends 2>/dev/null | + while IFS=$' \t' read -r backend rest; do + printf "%s\n" "$backend" + done + ) + compgen -W "${choices[*]}" -- "$1" +} diff --git a/lib/keyring/backend_complete.zsh b/lib/keyring/backend_complete.zsh new file mode 100644 index 0000000..eba76c6 --- /dev/null +++ b/lib/keyring/backend_complete.zsh @@ -0,0 +1,14 @@ +# Complete keyring backends for `keyring -b` from `keyring --list-backends` +# % keyring -b +# keyring priority +# keyring.backends.chainer.ChainerBackend 10 +# keyring.backends.fail.Keyring 0 +# ... ... + +backend_complete() { + local line + while read -r line; do + choices+=(${${line/ \(priority: /\\\\:}/)/}) + done <<< "$($words[1] --list-backends)" + _arguments "*:keyring priority:(($choices))" +} diff --git a/lib/keyring/backends/SecretService.py b/lib/keyring/backends/SecretService.py new file mode 100644 index 0000000..41aa788 --- /dev/null +++ b/lib/keyring/backends/SecretService.py @@ -0,0 +1,120 @@ +import logging +from contextlib import closing + +from jaraco.context import ExceptionTrap + +from .. import backend +from ..backend import KeyringBackend +from ..compat import properties +from ..credentials import SimpleCredential +from ..errors import ( + InitError, + KeyringLocked, + PasswordDeleteError, +) + +try: + import secretstorage + import secretstorage.exceptions as exceptions +except ImportError: + pass +except AttributeError: + # See https://github.com/jaraco/keyring/issues/296 + pass + +log = logging.getLogger(__name__) + + +class Keyring(backend.SchemeSelectable, KeyringBackend): + """Secret Service Keyring""" + + appid = 'Python keyring library' + + @properties.classproperty + def priority(cls) -> float: + with ExceptionTrap() as exc: + secretstorage.__name__ # noqa: B018 + if exc: + raise RuntimeError("SecretStorage required") + if secretstorage.__version_tuple__ < (3, 2): + raise RuntimeError("SecretStorage 3.2 or newer required") + try: + with closing(secretstorage.dbus_init()) as connection: + if not secretstorage.check_service_availability(connection): + raise RuntimeError( + "The Secret Service daemon is neither running nor " + "activatable through D-Bus" + ) + except exceptions.SecretStorageException as e: + raise RuntimeError(f"Unable to initialize SecretService: {e}") from e + return 5 + + def get_preferred_collection(self): + """If self.preferred_collection contains a D-Bus path, + the collection at that address is returned. Otherwise, + the default collection is returned. + """ + bus = secretstorage.dbus_init() + try: + if hasattr(self, 'preferred_collection'): + collection = secretstorage.Collection(bus, self.preferred_collection) + else: + collection = secretstorage.get_default_collection(bus) + except exceptions.SecretStorageException as e: + raise InitError(f"Failed to create the collection: {e}.") from e + if collection.is_locked(): + collection.unlock() + if collection.is_locked(): # User dismissed the prompt + raise KeyringLocked("Failed to unlock the collection!") + return collection + + def unlock(self, item): + if hasattr(item, 'unlock'): + item.unlock() + if item.is_locked(): # User dismissed the prompt + raise KeyringLocked('Failed to unlock the item!') + + def get_password(self, service, username): + """Get password of the username for the service""" + collection = self.get_preferred_collection() + with closing(collection.connection): + items = collection.search_items(self._query(service, username)) + for item in items: + self.unlock(item) + return item.get_secret().decode('utf-8') + + def set_password(self, service, username, password): + """Set password for the username of the service""" + collection = self.get_preferred_collection() + attributes = self._query(service, username, application=self.appid) + label = f"Password for '{username}' on '{service}'" + with closing(collection.connection): + collection.create_item(label, attributes, password, replace=True) + + def delete_password(self, service, username): + """Delete the stored password (only the first one)""" + collection = self.get_preferred_collection() + with closing(collection.connection): + items = collection.search_items(self._query(service, username)) + for item in items: + return item.delete() + raise PasswordDeleteError("No such password!") + + def get_credential(self, service, username): + """Gets the first username and password for a service. + Returns a Credential instance + + The username can be omitted, but if there is one, it will use get_password + and return a SimpleCredential containing the username and password + Otherwise, it will return the first username and password combo that it finds. + """ + scheme = self.schemes[self.scheme] + query = self._query(service, username) + collection = self.get_preferred_collection() + + with closing(collection.connection): + items = collection.search_items(query) + for item in items: + self.unlock(item) + username = item.get_attributes().get(scheme['username']) + return SimpleCredential(username, item.get_secret().decode('utf-8')) diff --git a/lib/keyring/backends/Windows.py b/lib/keyring/backends/Windows.py new file mode 100644 index 0000000..110075b --- /dev/null +++ b/lib/keyring/backends/Windows.py @@ -0,0 +1,168 @@ +from __future__ import annotations + +import logging + +from jaraco.context import ExceptionTrap + +from ..backend import KeyringBackend +from ..compat import properties +from ..credentials import SimpleCredential +from ..errors import PasswordDeleteError + +with ExceptionTrap() as missing_deps: + try: + # prefer pywin32-ctypes + from win32ctypes.pywin32 import pywintypes, win32cred + + # force demand import to raise ImportError + win32cred.__name__ # noqa: B018 + except ImportError: + # fallback to pywin32 + import pywintypes + import win32cred + + # force demand import to raise ImportError + win32cred.__name__ # noqa: B018 + +log = logging.getLogger(__name__) + + +class Persistence: + def __get__(self, keyring, type=None): + return getattr(keyring, '_persist', win32cred.CRED_PERSIST_ENTERPRISE) + + def __set__(self, keyring, value): + """ + Set the persistence value on the Keyring. Value may be + one of the win32cred.CRED_PERSIST_* constants or a + string representing one of those constants. For example, + 'local machine' or 'session'. + """ + if isinstance(value, str): + attr = 'CRED_PERSIST_' + value.replace(' ', '_').upper() + value = getattr(win32cred, attr) + keyring._persist = value + + +class DecodingCredential(dict): + @property + def value(self): + """ + Attempt to decode the credential blob as UTF-16 then UTF-8. + """ + cred = self['CredentialBlob'] + try: + return cred.decode('utf-16') + except UnicodeDecodeError: + decoded_cred_utf8 = cred.decode('utf-8') + log.warning( + "Retrieved a UTF-8 encoded credential. Please be aware that " + "this library only writes credentials in UTF-16." + ) + return decoded_cred_utf8 + + +class WinVaultKeyring(KeyringBackend): + """ + WinVaultKeyring stores encrypted passwords using the Windows Credential + Manager. + + Requires pywin32 + + This backend does some gymnastics to simulate multi-user support, + which WinVault doesn't support natively. See + https://github.com/jaraco/keyring/issues/47#issuecomment-75763152 + for details on the implementation, but here's the gist: + + Passwords are stored under the service name unless there is a collision + (another password with the same service name but different user name), + in which case the previous password is moved into a compound name: + {username}@{service} + """ + + persist = Persistence() + + @properties.classproperty + def priority(cls) -> float: + """ + If available, the preferred backend on Windows. + """ + if missing_deps: + raise RuntimeError("Requires Windows and pywin32") + return 5 + + @staticmethod + def _compound_name(username, service): + return f'{username}@{service}' + + def get_password(self, service, username): + res = self._resolve_credential(service, username) + return res and res.value + + def _resolve_credential( + self, service: str, username: str | None + ) -> DecodingCredential | None: + # first attempt to get the password under the service name + res = self._read_credential(service) + if not res or username and res['UserName'] != username: + # It wasn't found so attempt to get it with the compound name + res = self._read_credential(self._compound_name(username, service)) + return res + + def _read_credential(self, target): + try: + res = win32cred.CredRead( + Type=win32cred.CRED_TYPE_GENERIC, TargetName=target + ) + except pywintypes.error as e: + if e.winerror == 1168 and e.funcname == 'CredRead': # not found + return None + raise + return DecodingCredential(res) + + def set_password(self, service, username, password): + existing_pw = self._read_credential(service) + if existing_pw: + # resave the existing password using a compound target + existing_username = existing_pw['UserName'] + target = self._compound_name(existing_username, service) + self._set_password( + target, + existing_username, + existing_pw.value, + ) + self._set_password(service, username, str(password)) + + def _set_password(self, target, username, password): + credential = dict( + Type=win32cred.CRED_TYPE_GENERIC, + TargetName=target, + UserName=username, + CredentialBlob=password, + Comment="Stored using python-keyring", + Persist=self.persist, + ) + win32cred.CredWrite(credential, 0) + + def delete_password(self, service, username): + compound = self._compound_name(username, service) + deleted = False + for target in service, compound: + existing_pw = self._read_credential(target) + if existing_pw and existing_pw['UserName'] == username: + deleted = True + self._delete_password(target) + if not deleted: + raise PasswordDeleteError(service) + + def _delete_password(self, target): + try: + win32cred.CredDelete(Type=win32cred.CRED_TYPE_GENERIC, TargetName=target) + except pywintypes.error as e: + if e.winerror == 1168 and e.funcname == 'CredDelete': # not found + return + raise + + def get_credential(self, service, username): + res = self._resolve_credential(service, username) + return res and SimpleCredential(res['UserName'], res.value) diff --git a/lib/keyring/backends/__init__.py b/lib/keyring/backends/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lib/keyring/backends/__pycache__/SecretService.cpython-314.pyc b/lib/keyring/backends/__pycache__/SecretService.cpython-314.pyc new file mode 100644 index 0000000..1f4d395 Binary files /dev/null and b/lib/keyring/backends/__pycache__/SecretService.cpython-314.pyc differ diff --git a/lib/keyring/backends/__pycache__/Windows.cpython-314.pyc b/lib/keyring/backends/__pycache__/Windows.cpython-314.pyc new file mode 100644 index 0000000..e48f911 Binary files /dev/null and b/lib/keyring/backends/__pycache__/Windows.cpython-314.pyc differ diff --git a/lib/keyring/backends/__pycache__/__init__.cpython-314.pyc b/lib/keyring/backends/__pycache__/__init__.cpython-314.pyc new file mode 100644 index 0000000..4d33085 Binary files /dev/null and b/lib/keyring/backends/__pycache__/__init__.cpython-314.pyc differ diff --git a/lib/keyring/backends/__pycache__/chainer.cpython-314.pyc b/lib/keyring/backends/__pycache__/chainer.cpython-314.pyc new file mode 100644 index 0000000..aa3fb62 Binary files /dev/null and b/lib/keyring/backends/__pycache__/chainer.cpython-314.pyc differ diff --git a/lib/keyring/backends/__pycache__/fail.cpython-314.pyc b/lib/keyring/backends/__pycache__/fail.cpython-314.pyc new file mode 100644 index 0000000..7ecd33d Binary files /dev/null and b/lib/keyring/backends/__pycache__/fail.cpython-314.pyc differ diff --git a/lib/keyring/backends/__pycache__/kwallet.cpython-314.pyc b/lib/keyring/backends/__pycache__/kwallet.cpython-314.pyc new file mode 100644 index 0000000..ce6aa46 Binary files /dev/null and b/lib/keyring/backends/__pycache__/kwallet.cpython-314.pyc differ diff --git a/lib/keyring/backends/__pycache__/libsecret.cpython-314.pyc b/lib/keyring/backends/__pycache__/libsecret.cpython-314.pyc new file mode 100644 index 0000000..aa4ed7f Binary files /dev/null and b/lib/keyring/backends/__pycache__/libsecret.cpython-314.pyc differ diff --git a/lib/keyring/backends/__pycache__/null.cpython-314.pyc b/lib/keyring/backends/__pycache__/null.cpython-314.pyc new file mode 100644 index 0000000..362406f Binary files /dev/null and b/lib/keyring/backends/__pycache__/null.cpython-314.pyc differ diff --git a/lib/keyring/backends/chainer.py b/lib/keyring/backends/chainer.py new file mode 100644 index 0000000..6bc711f --- /dev/null +++ b/lib/keyring/backends/chainer.py @@ -0,0 +1,71 @@ +""" +Keyring Chainer - iterates over other viable backends to +discover passwords in each. +""" + +from .. import backend +from ..compat import properties +from . import fail + + +class ChainerBackend(backend.KeyringBackend): + """ + >>> ChainerBackend() + + """ + + # override viability as 'priority' cannot be determined + # until other backends have been constructed + viable = True + + @properties.classproperty + def priority(cls) -> float: + """ + If there are backends to chain, high priority + Otherwise very low priority since our operation when empty + is the same as null. + """ + return 10 if len(cls.backends) > 1 else (fail.Keyring.priority - 1) + + @properties.classproperty + def backends(cls): + """ + Discover all keyrings for chaining. + """ + + def allow(keyring): + limit = backend._limit or bool + return ( + not isinstance(keyring, ChainerBackend) + and limit(keyring) + and keyring.priority > 0 + ) + + allowed = filter(allow, backend.get_all_keyring()) + return sorted(allowed, key=backend.by_priority, reverse=True) + + def get_password(self, service, username): + for keyring in self.backends: + password = keyring.get_password(service, username) + if password is not None: + return password + + def set_password(self, service, username, password): + for keyring in self.backends: + try: + return keyring.set_password(service, username, password) + except NotImplementedError: + pass + + def delete_password(self, service, username): + for keyring in self.backends: + try: + return keyring.delete_password(service, username) + except NotImplementedError: + pass + + def get_credential(self, service, username): + for keyring in self.backends: + credential = keyring.get_credential(service, username) + if credential is not None: + return credential diff --git a/lib/keyring/backends/fail.py b/lib/keyring/backends/fail.py new file mode 100644 index 0000000..5007ab9 --- /dev/null +++ b/lib/keyring/backends/fail.py @@ -0,0 +1,30 @@ +from ..backend import KeyringBackend +from ..compat import properties +from ..errors import NoKeyringError + + +class Keyring(KeyringBackend): + """ + Keyring that raises error on every operation. + + >>> kr = Keyring() + >>> kr.get_password('svc', 'user') + Traceback (most recent call last): + ... + keyring.errors.NoKeyringError: ...No recommended backend... + """ + + @properties.classproperty + def priority(cls) -> float: + return 0 + + def get_password(self, service, username, password=None): + msg = ( + "No recommended backend was available. Install a recommended 3rd " + "party backend package; or, install the keyrings.alt package if " + "you want to use the non-recommended backends. See " + "https://pypi.org/project/keyring for details." + ) + raise NoKeyringError(msg) + + set_password = delete_password = get_password diff --git a/lib/keyring/backends/kwallet.py b/lib/keyring/backends/kwallet.py new file mode 100644 index 0000000..b1d9c8e --- /dev/null +++ b/lib/keyring/backends/kwallet.py @@ -0,0 +1,164 @@ +import contextlib +import os +import sys + +from ..backend import KeyringBackend +from ..compat import properties +from ..credentials import SimpleCredential +from ..errors import InitError, KeyringLocked, PasswordDeleteError, PasswordSetError + +try: + import dbus + from dbus.mainloop.glib import DBusGMainLoop +except ImportError: + pass +except AttributeError: + # See https://github.com/jaraco/keyring/issues/296 + pass + + +def _id_from_argv(): + """ + Safely infer an app id from sys.argv. + """ + allowed = AttributeError, IndexError, TypeError + with contextlib.suppress(allowed): + return sys.argv[0] + + +class DBusKeyring(KeyringBackend): + """ + KDE KWallet 5 via D-Bus + """ + + appid = _id_from_argv() or 'Python keyring library' + wallet = None + bus_name = 'org.kde.kwalletd5' + object_path = '/modules/kwalletd5' + + @properties.classproperty + def priority(cls) -> float: + if 'dbus' not in globals(): + raise RuntimeError('python-dbus not installed') + try: + bus = dbus.SessionBus(mainloop=DBusGMainLoop()) + except dbus.DBusException as exc: + raise RuntimeError(exc.get_dbus_message()) from exc + if not ( + bus.name_has_owner(cls.bus_name) + and cls.bus_name in bus.list_activatable_names() + ): + raise RuntimeError( + "The KWallet daemon is neither running nor activatable through D-Bus" + ) + if "KDE" in os.getenv("XDG_CURRENT_DESKTOP", "").split(":"): + return 5.1 + return 4.9 + + def __init__(self, *arg, **kw): + super().__init__(*arg, **kw) + self.handle = -1 + + def _migrate(self, service): + old_folder = 'Python' + entry_list = [] + if self.iface.hasFolder(self.handle, old_folder, self.appid): + entry_list = self.iface.readPasswordList( + self.handle, old_folder, '*@*', self.appid + ) + + for entry in entry_list.items(): + key = entry[0] + password = entry[1] + + username, service = key.rsplit('@', 1) + ret = self.iface.writePassword( + self.handle, service, username, password, self.appid + ) + if ret == 0: + self.iface.removeEntry(self.handle, old_folder, key, self.appid) + + entry_list = self.iface.readPasswordList( + self.handle, old_folder, '*', self.appid + ) + if not entry_list: + self.iface.removeFolder(self.handle, old_folder, self.appid) + + def connected(self, service): + if self.handle >= 0: + if self.iface.isOpen(self.handle): + return True + + bus = dbus.SessionBus(mainloop=DBusGMainLoop()) + wId = 0 + try: + remote_obj = bus.get_object(self.bus_name, self.object_path) + self.iface = dbus.Interface(remote_obj, 'org.kde.KWallet') + self.handle = self.iface.open(self.iface.networkWallet(), wId, self.appid) + except dbus.DBusException as e: + raise InitError(f'Failed to open keyring: {e}.') from e + + if self.handle < 0: + return False + self._migrate(service) + return True + + def get_password(self, service, username): + """Get password of the username for the service""" + if not self.connected(service): + # the user pressed "cancel" when prompted to unlock their keyring. + raise KeyringLocked("Failed to unlock the keyring!") + if not self.iface.hasEntry(self.handle, service, username, self.appid): + return None + password = self.iface.readPassword(self.handle, service, username, self.appid) + return str(password) + + def get_credential(self, service, username): + """Gets the first username and password for a service. + Returns a Credential instance + + The username can be omitted, but if there is one, it will forward to + get_password. + Otherwise, it will return the first username and password combo that it finds. + """ + if username is not None: + return super().get_credential(service, username) + + if not self.connected(service): + # the user pressed "cancel" when prompted to unlock their keyring. + raise KeyringLocked("Failed to unlock the keyring!") + + for username in self.iface.entryList(self.handle, service, self.appid): + password = self.iface.readPassword( + self.handle, service, username, self.appid + ) + return SimpleCredential(str(username), str(password)) + + def set_password(self, service, username, password): + """Set password for the username of the service""" + if not self.connected(service): + # the user pressed "cancel" when prompted to unlock their keyring. + raise PasswordSetError("Cancelled by user") + self.iface.writePassword(self.handle, service, username, password, self.appid) + + def delete_password(self, service, username): + """Delete the password for the username of the service.""" + if not self.connected(service): + # the user pressed "cancel" when prompted to unlock their keyring. + raise PasswordDeleteError("Cancelled by user") + if not self.iface.hasEntry(self.handle, service, username, self.appid): + raise PasswordDeleteError("Password not found") + self.iface.removeEntry(self.handle, service, username, self.appid) + + +class DBusKeyringKWallet4(DBusKeyring): + """ + KDE KWallet 4 via D-Bus + """ + + bus_name = 'org.kde.kwalletd' + object_path = '/modules/kwalletd' + + @properties.classproperty + def priority(cls): + return super().priority - 1 diff --git a/lib/keyring/backends/libsecret.py b/lib/keyring/backends/libsecret.py new file mode 100644 index 0000000..b92b3c2 --- /dev/null +++ b/lib/keyring/backends/libsecret.py @@ -0,0 +1,155 @@ +import logging + +from .. import backend +from ..backend import KeyringBackend +from ..compat import properties +from ..credentials import SimpleCredential +from ..errors import ( + KeyringLocked, + PasswordDeleteError, + PasswordSetError, +) + +available = False +try: + import gi + from gi.repository import Gio, GLib + + gi.require_version('Secret', '1') + from gi.repository import Secret + + available = True +except (AttributeError, ImportError, ValueError): + pass + +log = logging.getLogger(__name__) + + +class Keyring(backend.SchemeSelectable, KeyringBackend): + """libsecret Keyring""" + + appid = 'Python keyring library' + + @property + def schema(self): + return Secret.Schema.new( + "org.freedesktop.Secret.Generic", + Secret.SchemaFlags.NONE, + self._query( + Secret.SchemaAttributeType.STRING, + Secret.SchemaAttributeType.STRING, + application=Secret.SchemaAttributeType.STRING, + ), + ) + + @properties.NonDataProperty + def collection(self): + return Secret.COLLECTION_DEFAULT + + @properties.classproperty + def priority(cls) -> float: + if not available: + raise RuntimeError("libsecret required") + + # Make sure there is actually a secret service running + try: + Secret.Service.get_sync(Secret.ServiceFlags.OPEN_SESSION, None) + except GLib.Error as error: + raise RuntimeError("Can't open a session to the secret service") from error + + return 4.8 + + def get_password(self, service, username): + """Get password of the username for the service""" + attributes = self._query(service, username, application=self.appid) + try: + items = Secret.password_search_sync( + self.schema, attributes, Secret.SearchFlags.UNLOCK, None + ) + except GLib.Error as error: + quark = GLib.quark_try_string('g-io-error-quark') + if error.matches(quark, Gio.IOErrorEnum.FAILED): + raise KeyringLocked('Failed to unlock the item!') from error + raise + for item in items: + try: + return item.retrieve_secret_sync().get_text() + except GLib.Error as error: + quark = GLib.quark_try_string('secret-error') + if error.matches(quark, Secret.Error.IS_LOCKED): + raise KeyringLocked('Failed to unlock the item!') from error + raise + + def set_password(self, service, username, password): + """Set password for the username of the service""" + attributes = self._query(service, username, application=self.appid) + label = f"Password for '{username}' on '{service}'" + try: + stored = Secret.password_store_sync( + self.schema, attributes, self.collection, label, password, None + ) + except GLib.Error as error: + quark = GLib.quark_try_string('secret-error') + if error.matches(quark, Secret.Error.IS_LOCKED): + raise KeyringLocked("Failed to unlock the collection!") from error + quark = GLib.quark_try_string('g-io-error-quark') + if error.matches(quark, Gio.IOErrorEnum.FAILED): + raise KeyringLocked("Failed to unlock the collection!") from error + raise + if not stored: + raise PasswordSetError("Failed to store password!") + + def delete_password(self, service, username): + """Delete the stored password (only the first one)""" + attributes = self._query(service, username, application=self.appid) + try: + items = Secret.password_search_sync( + self.schema, attributes, Secret.SearchFlags.UNLOCK, None + ) + except GLib.Error as error: + quark = GLib.quark_try_string('g-io-error-quark') + if error.matches(quark, Gio.IOErrorEnum.FAILED): + raise KeyringLocked('Failed to unlock the item!') from error + raise + for item in items: + try: + removed = Secret.password_clear_sync( + self.schema, item.get_attributes(), None + ) + except GLib.Error as error: + quark = GLib.quark_try_string('secret-error') + if error.matches(quark, Secret.Error.IS_LOCKED): + raise KeyringLocked('Failed to unlock the item!') from error + raise + return removed + raise PasswordDeleteError("No such password!") + + def get_credential(self, service, username): + """Get the first username and password for a service. + Return a Credential instance + + The username can be omitted, but if there is one, it will use get_password + and return a SimpleCredential containing the username and password + Otherwise, it will return the first username and password combo that it finds. + """ + query = self._query(service, username) + try: + items = Secret.password_search_sync( + self.schema, query, Secret.SearchFlags.UNLOCK, None + ) + except GLib.Error as error: + quark = GLib.quark_try_string('g-io-error-quark') + if error.matches(quark, Gio.IOErrorEnum.FAILED): + raise KeyringLocked('Failed to unlock the item!') from error + raise + for item in items: + username = item.get_attributes().get("username") + try: + return SimpleCredential( + username, item.retrieve_secret_sync().get_text() + ) + except GLib.Error as error: + quark = GLib.quark_try_string('secret-error') + if error.matches(quark, Secret.Error.IS_LOCKED): + raise KeyringLocked('Failed to unlock the item!') from error + raise diff --git a/lib/keyring/backends/macOS/__init__.py b/lib/keyring/backends/macOS/__init__.py new file mode 100644 index 0000000..c3734e2 --- /dev/null +++ b/lib/keyring/backends/macOS/__init__.py @@ -0,0 +1,85 @@ +import functools +import os +import platform +import warnings + +from ...backend import KeyringBackend +from ...compat import properties +from ...errors import KeyringError, KeyringLocked, PasswordDeleteError, PasswordSetError + +try: + from . import api +except Exception: + pass + + +def warn_keychain(func): + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + if self.keychain: + warnings.warn("Specified keychain is ignored. See #623", stacklevel=2) + return func(self, *args, **kwargs) + + return wrapper + + +class Keyring(KeyringBackend): + """macOS Keychain""" + + keychain = os.environ.get('KEYCHAIN_PATH') + "Path to keychain file, overriding default" + + @properties.classproperty + def priority(cls): + """ + Preferred for all macOS environments. + """ + if platform.system() != 'Darwin': + raise RuntimeError("macOS required") + if 'api' not in globals(): + raise RuntimeError("Security API unavailable") + return 5 + + @warn_keychain + def set_password(self, service, username, password): + if username is None: + username = '' + + try: + api.set_generic_password(self.keychain, service, username, password) + except api.KeychainDenied as e: + raise KeyringLocked(f"Can't store password on keychain: {e}") from e + except api.Error as e: + raise PasswordSetError(f"Can't store password on keychain: {e}") from e + + @warn_keychain + def get_password(self, service, username): + if username is None: + username = '' + + try: + return api.find_generic_password(self.keychain, service, username) + except api.NotFound: + pass + except api.KeychainDenied as e: + raise KeyringLocked(f"Can't get password from keychain: {e}") from e + except api.Error as e: + raise KeyringError(f"Can't get password from keychain: {e}") from e + + @warn_keychain + def delete_password(self, service, username): + if username is None: + username = '' + + try: + return api.delete_generic_password(self.keychain, service, username) + except api.Error as e: + raise PasswordDeleteError(f"Can't delete password in keychain: {e}") from e + + def with_keychain(self, keychain): + warnings.warn( + "macOS.Keyring.with_keychain is deprecated. Use with_properties instead.", + DeprecationWarning, + stacklevel=2, + ) + return self.with_properties(keychain=keychain) diff --git a/lib/keyring/backends/macOS/__pycache__/__init__.cpython-314.pyc b/lib/keyring/backends/macOS/__pycache__/__init__.cpython-314.pyc new file mode 100644 index 0000000..f9eb476 Binary files /dev/null and b/lib/keyring/backends/macOS/__pycache__/__init__.cpython-314.pyc differ diff --git a/lib/keyring/backends/macOS/__pycache__/api.cpython-314.pyc b/lib/keyring/backends/macOS/__pycache__/api.cpython-314.pyc new file mode 100644 index 0000000..9e96163 Binary files /dev/null and b/lib/keyring/backends/macOS/__pycache__/api.cpython-314.pyc differ diff --git a/lib/keyring/backends/macOS/api.py b/lib/keyring/backends/macOS/api.py new file mode 100644 index 0000000..d837308 --- /dev/null +++ b/lib/keyring/backends/macOS/api.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +import contextlib +import ctypes +import functools +from ctypes import ( + byref, + c_int32, + c_uint32, + c_void_p, +) +from ctypes.util import find_library + +OS_status = c_int32 + + +class error: + item_not_found = -25300 + keychain_denied = -128 + sec_auth_failed = -25293 + plist_missing = -67030 + sec_interaction_not_allowed = -25308 + + +_sec = ctypes.CDLL(find_library('Security')) +_core = ctypes.CDLL(find_library('CoreServices')) +_found = ctypes.CDLL(find_library('Foundation')) + +CFDictionaryCreate = _found.CFDictionaryCreate +CFDictionaryCreate.restype = c_void_p +CFDictionaryCreate.argtypes = ( + c_void_p, + c_void_p, + c_void_p, + c_int32, + c_void_p, + c_void_p, +) + +CFStringCreateWithCString = _found.CFStringCreateWithCString +CFStringCreateWithCString.restype = c_void_p +CFStringCreateWithCString.argtypes = [c_void_p, c_void_p, c_uint32] + +CFNumberCreate = _found.CFNumberCreate +CFNumberCreate.restype = c_void_p +CFNumberCreate.argtypes = [c_void_p, c_uint32, ctypes.c_void_p] + +SecItemAdd = _sec.SecItemAdd +SecItemAdd.restype = OS_status +SecItemAdd.argtypes = (c_void_p, c_void_p) + +SecItemCopyMatching = _sec.SecItemCopyMatching +SecItemCopyMatching.restype = OS_status +SecItemCopyMatching.argtypes = (c_void_p, c_void_p) + +SecItemDelete = _sec.SecItemDelete +SecItemDelete.restype = OS_status +SecItemDelete.argtypes = (c_void_p,) + +CFDataGetBytePtr = _found.CFDataGetBytePtr +CFDataGetBytePtr.restype = c_void_p +CFDataGetBytePtr.argtypes = (c_void_p,) + +CFDataGetLength = _found.CFDataGetLength +CFDataGetLength.restype = c_int32 +CFDataGetLength.argtypes = (c_void_p,) + + +def k_(s): + return c_void_p.in_dll(_sec, s) + + +@functools.singledispatch +def create_cf(ob): + return ob + + +# explicit bool and int required for Python 3.10 compatibility +@create_cf.register(bool) +@create_cf.register(int) +def _(val: bool | int): + if val.bit_length() > 31: + raise OverflowError(val) + int32 = 0x9 + return CFNumberCreate(None, int32, ctypes.byref(c_int32(val))) + + +@create_cf.register +def _(s: str): + kCFStringEncodingUTF8 = 0x08000100 + return CFStringCreateWithCString(None, s.encode('utf8'), kCFStringEncodingUTF8) + + +def create_query(**kwargs): + return CFDictionaryCreate( + None, + (c_void_p * len(kwargs))(*map(k_, kwargs.keys())), + (c_void_p * len(kwargs))(*map(create_cf, kwargs.values())), + len(kwargs), + _found.kCFTypeDictionaryKeyCallBacks, + _found.kCFTypeDictionaryValueCallBacks, + ) + + +def cfstr_to_str(data): + return ctypes.string_at(CFDataGetBytePtr(data), CFDataGetLength(data)).decode( + 'utf-8' + ) + + +class Error(Exception): + @classmethod + def raise_for_status(cls, status): + if status == 0: + return + if status == error.item_not_found: + raise NotFound(status, "Item not found") + if status == error.keychain_denied: + raise KeychainDenied(status, "Keychain Access Denied") + if status == error.sec_auth_failed or status == error.plist_missing: + raise SecAuthFailure( + status, + "Security Auth Failure: make sure " + "executable is signed with codesign util", + ) + raise cls(status, "Unknown Error") + + +class NotFound(Error): + pass + + +class KeychainDenied(Error): + pass + + +class SecAuthFailure(Error): + pass + + +def find_generic_password(kc_name, service, username, not_found_ok=False): + q = create_query( + kSecClass=k_('kSecClassGenericPassword'), + kSecMatchLimit=k_('kSecMatchLimitOne'), + kSecAttrService=service, + kSecAttrAccount=username, + kSecReturnData=True, + ) + + data = c_void_p() + status = SecItemCopyMatching(q, byref(data)) + + if status == error.item_not_found and not_found_ok: + return + + Error.raise_for_status(status) + + return cfstr_to_str(data) + + +def set_generic_password(name, service, username, password): + with contextlib.suppress(NotFound): + delete_generic_password(name, service, username) + + q = create_query( + kSecClass=k_('kSecClassGenericPassword'), + kSecAttrService=service, + kSecAttrAccount=username, + kSecValueData=password, + ) + + status = SecItemAdd(q, None) + Error.raise_for_status(status) + + +def delete_generic_password(name, service, username): + q = create_query( + kSecClass=k_('kSecClassGenericPassword'), + kSecAttrService=service, + kSecAttrAccount=username, + ) + + status = SecItemDelete(q) + Error.raise_for_status(status) diff --git a/lib/keyring/backends/null.py b/lib/keyring/backends/null.py new file mode 100644 index 0000000..6b4c3b0 --- /dev/null +++ b/lib/keyring/backends/null.py @@ -0,0 +1,20 @@ +from ..backend import KeyringBackend +from ..compat import properties + + +class Keyring(KeyringBackend): + """ + Keyring that return None on every operation. + + >>> kr = Keyring() + >>> kr.get_password('svc', 'user') + """ + + @properties.classproperty + def priority(cls) -> float: + return -1 + + def get_password(self, service, username, password=None): + pass + + set_password = delete_password = get_password diff --git a/lib/keyring/cli.py b/lib/keyring/cli.py new file mode 100644 index 0000000..2c0ba4d --- /dev/null +++ b/lib/keyring/cli.py @@ -0,0 +1,220 @@ +"""Simple command line interface to get/set password from a keyring""" + +from __future__ import annotations + +import argparse +import getpass +import json +import sys + +from . import ( + backend, + completion, + core, + credentials, + delete_password, + get_credential, + get_password, + set_keyring, + set_password, +) +from .util import platform_ + + +class CommandLineTool: + # Attributes set dynamically by the ArgumentParser + keyring_path: str | None + keyring_backend: str | None + get_mode: str + output_format: str + operation: str + service: str + username: str + + def __init__(self): + self.parser = argparse.ArgumentParser() + self.parser.add_argument( + "-p", + "--keyring-path", + dest="keyring_path", + default=None, + help="Path to the keyring backend", + ) + self.parser.add_argument( + "-b", + "--keyring-backend", + dest="keyring_backend", + default=None, + help="Name of the keyring backend", + ) + self.parser.add_argument( + "--list-backends", + action="store_true", + help="List keyring backends and exit", + ) + self.parser.add_argument( + "--disable", action="store_true", help="Disable keyring and exit" + ) + self.parser._get_modes = ["password", "creds"] + self.parser.add_argument( + "--mode", + choices=self.parser._get_modes, + dest="get_mode", + default="password", + help=""" + Mode for 'get' operation. + 'password' requires a username and will return only the password. + 'creds' does not require a username and will return both the username and password separated by a newline. + + Default is 'password' + """, + ) + self.parser._output_formats = ["plain", "json"] + self.parser.add_argument( + "--output", + choices=self.parser._output_formats, + dest="output_format", + default="plain", + help=""" + Output format for 'get' operation. + + Default is 'plain' + """, + ) + self.parser._operations = ["get", "set", "del", "diagnose"] + self.parser.add_argument( + 'operation', + choices=self.parser._operations, + nargs="?", + ) + self.parser.add_argument( + 'service', + nargs="?", + ) + self.parser.add_argument( + 'username', + nargs="?", + ) + completion.install(self.parser) + + def run(self, argv): + args = self.parser.parse_args(argv) + vars(self).update(vars(args)) + + if args.list_backends: + for k in backend.get_all_keyring(): + print(k) + return + + if args.disable: + core.disable() + return + + if args.operation == 'diagnose': + self.diagnose() + return + + self._check_args() + self._load_spec_backend() + method = getattr(self, f'do_{self.operation}', self.invalid_op) + return method() + + def _check_args(self): + needs_username = self.operation != 'get' or self.get_mode != 'creds' + required = (['service'] + ['username'] * needs_username) * bool(self.operation) + if any(getattr(self, param) is None for param in required): + self.parser.error(f"{self.operation} requires {' and '.join(required)}") + + def do_get(self): + credential = getattr(self, f'_get_{self.get_mode}')() + if credential is None: + raise SystemExit(1) + getattr(self, f'_emit_{self.output_format}')(credential) + + def _emit_json(self, credential: credentials.Credential): + print(json.dumps(credential._vars())) + + def _emit_plain(self, credential: credentials.Credential): + for val in credential._vars().values(): + print(val) + + def _get_creds(self) -> credentials.Credential | None: + return get_credential(self.service, self.username) + + def _get_password(self) -> credentials.Credential | None: + password = get_password(self.service, self.username) + return ( + credentials.AnonymousCredential(password) if password is not None else None + ) + + def do_set(self): + password = self.input_password( + f"Password for '{self.username}' in '{self.service}': " + ) + set_password(self.service, self.username, password) + + def do_del(self): + delete_password(self.service, self.username) + + def diagnose(self): + config_root = core._config_path() + if config_root.exists(): + print("config path:", config_root) + else: + print("config path:", config_root, "(absent)") + print("data root:", platform_.data_root()) + + def invalid_op(self): + self.parser.error(f"Specify operation ({', '.join(self.parser._operations)}).") + + def _load_spec_backend(self): + if self.keyring_backend is None: + return + + try: + if self.keyring_path: + sys.path.insert(0, self.keyring_path) + set_keyring(core.load_keyring(self.keyring_backend)) + except Exception as exc: + # Tons of things can go wrong here: + # ImportError when using "fjkljfljkl" + # AttributeError when using "os.path.bar" + # TypeError when using "__builtins__.str" + # So, we play on the safe side, and catch everything. + self.parser.error(f"Unable to load specified keyring: {exc}") + + def input_password(self, prompt): + """Retrieve password from input.""" + return self.pass_from_pipe() or getpass.getpass(prompt) + + @classmethod + def pass_from_pipe(cls): + """Return password from pipe if not on TTY, else False.""" + is_pipe = not sys.stdin.isatty() + return is_pipe and cls.strip_last_newline(sys.stdin.read()) + + @staticmethod + def strip_last_newline(str): + r"""Strip one last newline, if present. + + >>> CommandLineTool.strip_last_newline('foo') + 'foo' + >>> CommandLineTool.strip_last_newline('foo\n') + 'foo' + """ + slc = slice(-1 if str.endswith('\n') else None) + return str[slc] + + +def main(argv=None): + """Main command line interface.""" + + if argv is None: + argv = sys.argv[1:] + + cli = CommandLineTool() + return cli.run(argv) + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/lib/keyring/compat/__init__.py b/lib/keyring/compat/__init__.py new file mode 100644 index 0000000..22f1e1c --- /dev/null +++ b/lib/keyring/compat/__init__.py @@ -0,0 +1,7 @@ +__all__ = ['properties'] + + +try: + from jaraco.classes import properties +except ImportError: # pragma: no cover + from . import properties # type: ignore[no-redef] diff --git a/lib/keyring/compat/__pycache__/__init__.cpython-314.pyc b/lib/keyring/compat/__pycache__/__init__.cpython-314.pyc new file mode 100644 index 0000000..4eb69a5 Binary files /dev/null and b/lib/keyring/compat/__pycache__/__init__.cpython-314.pyc differ diff --git a/lib/keyring/compat/__pycache__/properties.cpython-314.pyc b/lib/keyring/compat/__pycache__/properties.cpython-314.pyc new file mode 100644 index 0000000..376d0b9 Binary files /dev/null and b/lib/keyring/compat/__pycache__/properties.cpython-314.pyc differ diff --git a/lib/keyring/compat/__pycache__/py312.cpython-314.pyc b/lib/keyring/compat/__pycache__/py312.cpython-314.pyc new file mode 100644 index 0000000..a8dac9c Binary files /dev/null and b/lib/keyring/compat/__pycache__/py312.cpython-314.pyc differ diff --git a/lib/keyring/compat/properties.py b/lib/keyring/compat/properties.py new file mode 100644 index 0000000..ea993e1 --- /dev/null +++ b/lib/keyring/compat/properties.py @@ -0,0 +1,169 @@ +# from jaraco.classes 3.2.2 + + +class NonDataProperty: + """Much like the property builtin, but only implements __get__, + making it a non-data property, and can be subsequently reset. + + See http://users.rcn.com/python/download/Descriptor.htm for more + information. + + >>> class X(object): + ... @NonDataProperty + ... def foo(self): + ... return 3 + >>> x = X() + >>> x.foo + 3 + >>> x.foo = 4 + >>> x.foo + 4 + """ + + def __init__(self, fget): + assert fget is not None, "fget cannot be none" + assert callable(fget), "fget must be callable" + self.fget = fget + + def __get__(self, obj, objtype=None): + if obj is None: + return self + return self.fget(obj) + + +class classproperty: + """ + Like @property but applies at the class level. + + + >>> class X(metaclass=classproperty.Meta): + ... val = None + ... @classproperty + ... def foo(cls): + ... return cls.val + ... @foo.setter + ... def foo(cls, val): + ... cls.val = val + >>> X.foo + >>> X.foo = 3 + >>> X.foo + 3 + >>> x = X() + >>> x.foo + 3 + >>> X.foo = 4 + >>> x.foo + 4 + + Setting the property on an instance affects the class. + + >>> x.foo = 5 + >>> x.foo + 5 + >>> X.foo + 5 + >>> vars(x) + {} + >>> X().foo + 5 + + Attempting to set an attribute where no setter was defined + results in an AttributeError: + + >>> class GetOnly(metaclass=classproperty.Meta): + ... @classproperty + ... def foo(cls): + ... return 'bar' + >>> GetOnly.foo = 3 + Traceback (most recent call last): + ... + AttributeError: can't set attribute + + It is also possible to wrap a classmethod or staticmethod in + a classproperty. + + >>> class Static(metaclass=classproperty.Meta): + ... @classproperty + ... @classmethod + ... def foo(cls): + ... return 'foo' + ... @classproperty + ... @staticmethod + ... def bar(): + ... return 'bar' + >>> Static.foo + 'foo' + >>> Static.bar + 'bar' + + *Legacy* + + For compatibility, if the metaclass isn't specified, the + legacy behavior will be invoked. + + >>> class X: + ... val = None + ... @classproperty + ... def foo(cls): + ... return cls.val + ... @foo.setter + ... def foo(cls, val): + ... cls.val = val + >>> X.foo + >>> X.foo = 3 + >>> X.foo + 3 + >>> x = X() + >>> x.foo + 3 + >>> X.foo = 4 + >>> x.foo + 4 + + Note, because the metaclass was not specified, setting + a value on an instance does not have the intended effect. + + >>> x.foo = 5 + >>> x.foo + 5 + >>> X.foo # should be 5 + 4 + >>> vars(x) # should be empty + {'foo': 5} + >>> X().foo # should be 5 + 4 + """ + + class Meta(type): + def __setattr__(self, key, value): + obj = self.__dict__.get(key, None) + if type(obj) is classproperty: + return obj.__set__(self, value) + return super().__setattr__(key, value) + + def __init__(self, fget, fset=None): + self.fget = self._ensure_method(fget) + self.fset = fset + fset and self.setter(fset) + + def __get__(self, instance, owner=None): + return self.fget.__get__(None, owner)() + + def __set__(self, owner, value): + if not self.fset: + raise AttributeError("can't set attribute") + if type(owner) is not classproperty.Meta: + owner = type(owner) + return self.fset.__get__(None, owner)(value) + + def setter(self, fset): + self.fset = self._ensure_method(fset) + return self + + @classmethod + def _ensure_method(cls, fn): + """ + Ensure fn is a classmethod or staticmethod. + """ + needs_method = not isinstance(fn, (classmethod, staticmethod)) + return classmethod(fn) if needs_method else fn diff --git a/lib/keyring/compat/py312.py b/lib/keyring/compat/py312.py new file mode 100644 index 0000000..f14044a --- /dev/null +++ b/lib/keyring/compat/py312.py @@ -0,0 +1,9 @@ +import sys + +__all__ = ['metadata'] + + +if sys.version_info >= (3, 12): + import importlib.metadata as metadata +else: + import importlib_metadata as metadata diff --git a/lib/keyring/completion.py b/lib/keyring/completion.py new file mode 100644 index 0000000..7b0cd39 --- /dev/null +++ b/lib/keyring/completion.py @@ -0,0 +1,55 @@ +import argparse +import sys +from importlib.resources import files + +try: + import shtab +except ImportError: + pass + + +class _MissingCompletionAction(argparse.Action): + def __call__(self, parser, namespace, values, option_string): + print("Install keyring[completion] for completion support.", file=sys.stderr) + parser.exit(1) + + +def add_completion_notice(parser): + """Add completion argument to parser.""" + parser.add_argument( + "--print-completion", + choices=["bash", "zsh", "tcsh"], + action=_MissingCompletionAction, + help="print shell completion script", + ) + return parser + + +def get_action(parser, option): + (match,) = (action for action in parser._actions if option in action.option_strings) + return match + + +def install_completion(parser): + preamble = dict( + bash=files(__package__) + .joinpath('backend_complete.bash') + .read_text(encoding='utf-8'), + zsh=files(__package__) + .joinpath('backend_complete.zsh') + .read_text(encoding='utf-8'), + ) + shtab.add_argument_to(parser, preamble=preamble) + get_action(parser, '--keyring-path').complete = shtab.DIR + get_action(parser, '--keyring-backend').complete = dict( + bash='_keyring_backends', + zsh='backend_complete', + ) + return parser + + +def install(parser): + try: + install_completion(parser) + except NameError: + add_completion_notice(parser) diff --git a/lib/keyring/core.py b/lib/keyring/core.py new file mode 100644 index 0000000..b108845 --- /dev/null +++ b/lib/keyring/core.py @@ -0,0 +1,202 @@ +""" +Core API functions and initialization routines. +""" + +from __future__ import annotations + +import configparser +import logging +import os +import sys +import typing + +from . import backend, credentials +from .backends import fail +from .util import platform_ as platform + +LimitCallable = typing.Callable[[backend.KeyringBackend], bool] + +log = logging.getLogger(__name__) + +_keyring_backend = None + + +def set_keyring(keyring: backend.KeyringBackend) -> None: + """Set current keyring backend.""" + global _keyring_backend + if not isinstance(keyring, backend.KeyringBackend): + raise TypeError("The keyring must be an instance of KeyringBackend") + _keyring_backend = keyring + + +def get_keyring() -> backend.KeyringBackend: + """Get current keyring backend.""" + if _keyring_backend is None: + init_backend() + return typing.cast(backend.KeyringBackend, _keyring_backend) + + +def disable() -> None: + """ + Configure the null keyring as the default. + + >>> fs = getfixture('fs') + >>> disable() + >>> disable() + Traceback (most recent call last): + ... + RuntimeError: Refusing to overwrite... + """ + root = platform.config_root() + try: + os.makedirs(root) + except OSError: + pass + filename = os.path.join(root, 'keyringrc.cfg') + if os.path.exists(filename): + msg = f"Refusing to overwrite {filename}" + raise RuntimeError(msg) + with open(filename, 'w', encoding='utf-8') as file: + file.write('[backend]\ndefault-keyring=keyring.backends.null.Keyring') + + +def get_password(service_name: str, username: str) -> str | None: + """Get password from the specified service.""" + return get_keyring().get_password(service_name, username) + + +def set_password(service_name: str, username: str, password: str) -> None: + """Set password for the user in the specified service.""" + get_keyring().set_password(service_name, username, password) + + +def delete_password(service_name: str, username: str) -> None: + """Delete the password for the user in the specified service.""" + get_keyring().delete_password(service_name, username) + + +def get_credential( + service_name: str, username: str | None +) -> credentials.Credential | None: + """Get a Credential for the specified service.""" + return get_keyring().get_credential(service_name, username) + + +def recommended(backend) -> bool: + return backend.priority >= 1 + + +def init_backend(limit: LimitCallable | None = None): + """ + Load a detected backend. + """ + set_keyring(_detect_backend(limit)) + + +def _detect_backend(limit: LimitCallable | None = None): + """ + Return a keyring specified in the config file or infer the best available. + + Limit, if supplied, should be a callable taking a backend and returning + True if that backend should be included for consideration. + """ + + # save the limit for the chainer to honor + backend._limit = limit + return ( + load_env() + or load_config() + or max( + # all keyrings passing the limit filter + filter(limit, backend.get_all_keyring()), + default=fail.Keyring(), + key=backend.by_priority, + ) + ) + + +def _load_keyring_class(keyring_name: str) -> type[backend.KeyringBackend]: + """ + Load the keyring class indicated by name. + + These popular names are tested to ensure their presence. + + >>> popular_names = [ + ... 'keyring.backends.Windows.WinVaultKeyring', + ... 'keyring.backends.macOS.Keyring', + ... 'keyring.backends.kwallet.DBusKeyring', + ... 'keyring.backends.SecretService.Keyring', + ... ] + >>> list(map(_load_keyring_class, popular_names)) + [...] + """ + module_name, sep, class_name = keyring_name.rpartition('.') + __import__(module_name) + module = sys.modules[module_name] + return getattr(module, class_name) + + +def load_keyring(keyring_name: str) -> backend.KeyringBackend: + """ + Load the specified keyring by name (a fully-qualified name to the + keyring, such as 'keyring.backends.file.PlaintextKeyring') + """ + class_ = _load_keyring_class(keyring_name) + # invoke the priority to ensure it is viable, or raise a RuntimeError + class_.priority # noqa: B018 + return class_() + + +def load_env() -> backend.KeyringBackend | None: + """Load a keyring configured in the environment variable.""" + try: + return load_keyring(os.environ['PYTHON_KEYRING_BACKEND']) + except KeyError: + return None + + +def _config_path(): + return platform.config_root() / 'keyringrc.cfg' + + +def _ensure_path(path): + if not path.exists(): + raise FileNotFoundError(path) + return path + + +def load_config() -> backend.KeyringBackend | None: + """Load a keyring using the config file in the config root.""" + + config = configparser.RawConfigParser() + try: + config.read(_ensure_path(_config_path()), encoding='utf-8') + except FileNotFoundError: + return None + _load_keyring_path(config) + + # load the keyring class name, and then load this keyring + try: + if config.has_section("backend"): + keyring_name = config.get("backend", "default-keyring").strip() + else: + return None + + except (configparser.NoOptionError, ImportError): + logger = logging.getLogger('keyring') + logger.warning( + "Keyring config file contains incorrect values.\n" + + f"Config file: {_config_path()}" + ) + return None + + return load_keyring(keyring_name) + + +def _load_keyring_path(config: configparser.RawConfigParser) -> None: + "load the keyring-path option (if present)" + try: + path = config.get("backend", "keyring-path").strip() + sys.path.insert(0, os.path.expanduser(path)) + except (configparser.NoOptionError, configparser.NoSectionError): + pass diff --git a/lib/keyring/credentials.py b/lib/keyring/credentials.py new file mode 100644 index 0000000..6a2cecd --- /dev/null +++ b/lib/keyring/credentials.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +import abc +import os + + +class Credential(metaclass=abc.ABCMeta): + """Abstract class to manage credentials""" + + @abc.abstractproperty + def username(self) -> str: ... + + @abc.abstractproperty + def password(self) -> str: ... + + def _vars(self) -> dict[str, str]: + return dict(username=self.username, password=self.password) + + +class SimpleCredential(Credential): + """Simple credentials implementation""" + + def __init__(self, username: str, password: str): + self._username = username + self._password = password + + @property + def username(self) -> str: + return self._username + + @property + def password(self) -> str: + return self._password + + +class AnonymousCredential(SimpleCredential): + def __init__(self, password: str): + self._password = password + + @property + def username(self) -> str: + raise ValueError("Anonymous credential has no username") + + def _vars(self) -> dict[str, str]: + return dict(password=self.password) + + +class EnvironCredential(Credential): + """ + Source credentials from environment variables. + + Actual sourcing is deferred until requested. + + Supports comparison by equality. + + >>> e1 = EnvironCredential('a', 'b') + >>> e2 = EnvironCredential('a', 'b') + >>> e3 = EnvironCredential('a', 'c') + >>> e1 == e2 + True + >>> e2 == e3 + False + """ + + def __init__(self, user_env_var: str, pwd_env_var: str): + self.user_env_var = user_env_var + self.pwd_env_var = pwd_env_var + + def __eq__(self, other: object) -> bool: + return vars(self) == vars(other) + + def _get_env(self, env_var: str) -> str: + """Helper to read an environment variable""" + value = os.environ.get(env_var) + if not value: + raise ValueError(f'Missing environment variable:{env_var}') + return value + + @property + def username(self) -> str: + return self._get_env(self.user_env_var) + + @property + def password(self) -> str: + return self._get_env(self.pwd_env_var) diff --git a/lib/keyring/devpi_client.py b/lib/keyring/devpi_client.py new file mode 100644 index 0000000..dd4b09d --- /dev/null +++ b/lib/keyring/devpi_client.py @@ -0,0 +1,29 @@ +import functools + +import pluggy +from jaraco.context import suppress + +import keyring.errors + +hookimpl = pluggy.HookimplMarker("devpiclient") + + +def restore_signature(func): + # workaround for pytest-dev/pluggy#358 + @functools.wraps(func) + def wrapper(url, username): + return func(url, username) + + return wrapper + + +@hookimpl() +@restore_signature +@suppress(keyring.errors.KeyringError) +def devpiclient_get_password(url, username): + """ + >>> pluggy._hooks.varnames(devpiclient_get_password) + (('url', 'username'), ()) + >>> + """ + return keyring.get_password(url, username) diff --git a/lib/keyring/errors.py b/lib/keyring/errors.py new file mode 100644 index 0000000..ed97cf9 --- /dev/null +++ b/lib/keyring/errors.py @@ -0,0 +1,67 @@ +import sys +import warnings + + +class KeyringError(Exception): + """Base class for exceptions in keyring""" + + +class PasswordSetError(KeyringError): + """Raised when the password can't be set.""" + + +class PasswordDeleteError(KeyringError): + """Raised when the password can't be deleted.""" + + +class InitError(KeyringError): + """Raised when the keyring could not be initialised""" + + +class KeyringLocked(KeyringError): + """Raised when the keyring failed unlocking""" + + +class NoKeyringError(KeyringError, RuntimeError): + """Raised when there is no keyring backend""" + + +class ExceptionRaisedContext: + """ + An exception-trapping context that indicates whether an exception was + raised. + """ + + def __init__(self, ExpectedException=Exception): + warnings.warn( + "ExceptionRaisedContext is deprecated; use `jaraco.context.ExceptionTrap`", + DeprecationWarning, + stacklevel=2, + ) + self.ExpectedException = ExpectedException + self.exc_info = None + + def __enter__(self): + self.exc_info = object.__new__(ExceptionInfo) + return self.exc_info + + def __exit__(self, *exc_info): + self.exc_info.__init__(*exc_info) + return self.exc_info.type and issubclass( + self.exc_info.type, self.ExpectedException + ) + + +class ExceptionInfo: + def __init__(self, *info): + if not info: + info = sys.exc_info() + self.type, self.value, _ = info + + def __bool__(self): + """ + Return True if an exception occurred + """ + return bool(self.type) + + __nonzero__ = __bool__ diff --git a/lib/keyring/http.py b/lib/keyring/http.py new file mode 100644 index 0000000..2561535 --- /dev/null +++ b/lib/keyring/http.py @@ -0,0 +1,39 @@ +""" +urllib2.HTTPPasswordMgr object using the keyring, for use with the +urllib2.HTTPBasicAuthHandler. + +usage: + import urllib2 + handlers = [urllib2.HTTPBasicAuthHandler(PasswordMgr())] + urllib2.install_opener(handlers) + urllib2.urlopen(...) + +This will prompt for a password if one is required and isn't already +in the keyring. Then, it adds it to the keyring for subsequent use. +""" + +import getpass + +from . import delete_password, get_password, set_password + + +class PasswordMgr: + def get_username(self, realm, authuri): + return getpass.getuser() + + def add_password(self, realm, authuri, password): + user = self.get_username(realm, authuri) + set_password(realm, user, password) + + def find_user_password(self, realm, authuri): + user = self.get_username(realm, authuri) + password = get_password(realm, user) + if password is None: + prompt = f'password for {user}@{realm} for {authuri}: ' + password = getpass.getpass(prompt) + set_password(realm, user, password) + return user, password + + def clear_password(self, realm, authuri): + user = self.get_username(realm, authuri) + delete_password(realm, user) diff --git a/lib/keyring/py.typed b/lib/keyring/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/lib/keyring/testing/__init__.py b/lib/keyring/testing/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lib/keyring/testing/__pycache__/__init__.cpython-314.pyc b/lib/keyring/testing/__pycache__/__init__.cpython-314.pyc new file mode 100644 index 0000000..add4385 Binary files /dev/null and b/lib/keyring/testing/__pycache__/__init__.cpython-314.pyc differ diff --git a/lib/keyring/testing/__pycache__/backend.cpython-314.pyc b/lib/keyring/testing/__pycache__/backend.cpython-314.pyc new file mode 100644 index 0000000..d50f0d5 Binary files /dev/null and b/lib/keyring/testing/__pycache__/backend.cpython-314.pyc differ diff --git a/lib/keyring/testing/__pycache__/util.cpython-314.pyc b/lib/keyring/testing/__pycache__/util.cpython-314.pyc new file mode 100644 index 0000000..ad4e374 Binary files /dev/null and b/lib/keyring/testing/__pycache__/util.cpython-314.pyc differ diff --git a/lib/keyring/testing/backend.py b/lib/keyring/testing/backend.py new file mode 100644 index 0000000..89a414b --- /dev/null +++ b/lib/keyring/testing/backend.py @@ -0,0 +1,200 @@ +""" +Common test functionality for backends. +""" + +import os +import string + +import pytest + +from keyring import errors + +from .util import random_string + +# unicode only characters +# Sourced from The Quick Brown Fox... Pangrams +# http://www.columbia.edu/~fdc/utf8/ +UNICODE_CHARS = ( + "זהכיףסתםלשמועאיךתנצחקרפדעץטובבגן" + "ξεσκεπάζωτηνψυχοφθόραβδελυγμία" + "Съешьжеещёэтихмягкихфранцузскихбулокдавыпейчаю" + "Жълтатадюлябешещастливачепухъткойтоцъфназамръзнакатогьон" +) + +# ensure no-ascii chars slip by - watch your editor! +assert min(ord(char) for char in UNICODE_CHARS) > 127 + + +def is_ascii_printable(s): + return all(32 <= ord(c) < 127 for c in s) + + +class BackendBasicTests: + """Test for the keyring's basic functions. password_set and password_get""" + + DIFFICULT_CHARS = string.whitespace + string.punctuation + + @pytest.fixture(autouse=True) + def _init_properties(self, request): + self.keyring = self.init_keyring() + self.credentials_created = set() + request.addfinalizer(self.cleanup) + + def cleanup(self): + for item in self.credentials_created: + self.keyring.delete_password(*item) + + def set_password(self, service, username, password): + # set the password and save the result so the test runner can clean + # up after if necessary. + self.keyring.set_password(service, username, password) + self.credentials_created.add((service, username)) + + def check_set_get(self, service, username, password): + keyring = self.keyring + + # for the non-existent password + assert keyring.get_password(service, username) is None + + # common usage + self.set_password(service, username, password) + assert keyring.get_password(service, username) == password + + # for the empty password + self.set_password(service, username, "") + assert keyring.get_password(service, username) == "" + + def test_password_set_get(self): + password = random_string(20) + username = random_string(20) + service = random_string(20) + self.check_set_get(service, username, password) + + def test_set_after_set_blank(self): + service = random_string(20) + username = random_string(20) + self.keyring.set_password(service, username, "") + self.keyring.set_password(service, username, "non-blank") + + def test_difficult_chars(self): + password = random_string(20, self.DIFFICULT_CHARS) + username = random_string(20, self.DIFFICULT_CHARS) + service = random_string(20, self.DIFFICULT_CHARS) + self.check_set_get(service, username, password) + + def test_delete_present(self): + password = random_string(20, self.DIFFICULT_CHARS) + username = random_string(20, self.DIFFICULT_CHARS) + service = random_string(20, self.DIFFICULT_CHARS) + self.keyring.set_password(service, username, password) + self.keyring.delete_password(service, username) + assert self.keyring.get_password(service, username) is None + + def test_delete_not_present(self): + username = random_string(20, self.DIFFICULT_CHARS) + service = random_string(20, self.DIFFICULT_CHARS) + with pytest.raises(errors.PasswordDeleteError): + self.keyring.delete_password(service, username) + + def test_delete_one_in_group(self): + username1 = random_string(20, self.DIFFICULT_CHARS) + username2 = random_string(20, self.DIFFICULT_CHARS) + password = random_string(20, self.DIFFICULT_CHARS) + service = random_string(20, self.DIFFICULT_CHARS) + self.keyring.set_password(service, username1, password) + self.set_password(service, username2, password) + self.keyring.delete_password(service, username1) + assert self.keyring.get_password(service, username2) == password + + def test_name_property(self): + assert is_ascii_printable(self.keyring.name) + + def test_unicode_chars(self): + password = random_string(20, UNICODE_CHARS) + username = random_string(20, UNICODE_CHARS) + service = random_string(20, UNICODE_CHARS) + self.check_set_get(service, username, password) + + def test_unicode_and_ascii_chars(self): + source = ( + random_string(10, UNICODE_CHARS) + + random_string(10) + + random_string(10, self.DIFFICULT_CHARS) + ) + password = random_string(20, source) + username = random_string(20, source) + service = random_string(20, source) + self.check_set_get(service, username, password) + + def test_different_user(self): + """ + Issue #47 reports that WinVault isn't storing passwords for + multiple users. This test exercises that test for each of the + backends. + """ + + keyring = self.keyring + self.set_password('service1', 'user1', 'password1') + self.set_password('service1', 'user2', 'password2') + assert keyring.get_password('service1', 'user1') == 'password1' + assert keyring.get_password('service1', 'user2') == 'password2' + self.set_password('service2', 'user3', 'password3') + assert keyring.get_password('service1', 'user1') == 'password1' + + def test_credential(self): + keyring = self.keyring + + cred = keyring.get_credential('service', None) + assert cred is None + + self.set_password('service1', 'user1', 'password1') + self.set_password('service1', 'user2', 'password2') + + cred = keyring.get_credential('service1', None) + assert cred is None or (cred.username, cred.password) in ( + ('user1', 'password1'), + ('user2', 'password2'), + ) + + cred = keyring.get_credential('service1', 'user2') + assert cred is not None + assert (cred.username, cred.password) in ( + ('user1', 'password1'), + ('user2', 'password2'), + ) + + @pytest.mark.xfail("platform.system() == 'Windows'", reason="#668") + def test_empty_username(self): + with pytest.deprecated_call(): + self.set_password('service1', '', 'password1') + assert self.keyring.get_password('service1', '') == 'password1' + + def test_set_properties(self, monkeypatch): + env = dict(KEYRING_PROPERTY_FOO_BAR='fizz buzz', OTHER_SETTING='ignore me') + monkeypatch.setattr(os, 'environ', env) + self.keyring.set_properties_from_env() + assert self.keyring.foo_bar == 'fizz buzz' + + def test_new_with_properties(self): + alt = self.keyring.with_properties(foo='bar') + assert alt is not self.keyring + assert alt.foo == 'bar' + with pytest.raises(AttributeError): + self.keyring.foo # noqa: B018 + + def test_wrong_username_returns_none(self): + keyring = self.keyring + service = 'test_wrong_username_returns_none' + cred = keyring.get_credential(service, None) + assert cred is None + + password_1 = 'password1' + password_2 = 'password2' + self.set_password(service, 'user1', password_1) + self.set_password(service, 'user2', password_2) + + assert keyring.get_credential(service, "user1").password == password_1 + assert keyring.get_credential(service, "user2").password == password_2 + + # Missing/wrong username should not return a cred + assert keyring.get_credential(service, "nobody!") is None diff --git a/lib/keyring/testing/util.py b/lib/keyring/testing/util.py new file mode 100644 index 0000000..b8ef4c6 --- /dev/null +++ b/lib/keyring/testing/util.py @@ -0,0 +1,68 @@ +import contextlib +import os +import random +import string +import sys + + +class ImportKiller: + "Context manager to make an import of a given name or names fail." + + def __init__(self, *names): + self.names = names + + def find_module(self, fullname, path=None): + if fullname in self.names: + return self + + def load_module(self, fullname): + assert fullname in self.names + raise ImportError(fullname) + + def __enter__(self): + self.original = {} + for name in self.names: + self.original[name] = sys.modules.pop(name, None) + sys.meta_path.insert(0, self) + + def __exit__(self, *args): + sys.meta_path.remove(self) + for key, value in self.original.items(): + if value is not None: + sys.modules[key] = value + + +@contextlib.contextmanager +def NoNoneDictMutator(destination, **changes): + """Helper context manager to make and unmake changes to a dict. + + A None is not a valid value for the destination, and so means that the + associated name should be removed.""" + original = {} + for key, value in changes.items(): + original[key] = destination.get(key) + if value is None: + if key in destination: + del destination[key] + else: + destination[key] = value + yield + for key, value in original.items(): + if value is None: + if key in destination: + del destination[key] + else: + destination[key] = value + + +def Environ(**changes): + """A context manager to temporarily change the os.environ""" + return NoNoneDictMutator(os.environ, **changes) + + +ALPHABET = string.ascii_letters + string.digits + + +def random_string(k, source=ALPHABET): + """Generate a random string with length k""" + return ''.join(random.choice(source) for _unused in range(k)) diff --git a/lib/keyring/util/__init__.py b/lib/keyring/util/__init__.py new file mode 100644 index 0000000..097a943 --- /dev/null +++ b/lib/keyring/util/__init__.py @@ -0,0 +1,11 @@ +import contextlib + + +def suppress_exceptions(callables, exceptions=Exception): + """ + yield the results of calling each element of callables, suppressing + any indicated exceptions. + """ + for callable in callables: + with contextlib.suppress(exceptions): + yield callable() diff --git a/lib/keyring/util/__pycache__/__init__.cpython-314.pyc b/lib/keyring/util/__pycache__/__init__.cpython-314.pyc new file mode 100644 index 0000000..2f8fd4a Binary files /dev/null and b/lib/keyring/util/__pycache__/__init__.cpython-314.pyc differ diff --git a/lib/keyring/util/__pycache__/platform_.cpython-314.pyc b/lib/keyring/util/__pycache__/platform_.cpython-314.pyc new file mode 100644 index 0000000..693f7b3 Binary files /dev/null and b/lib/keyring/util/__pycache__/platform_.cpython-314.pyc differ diff --git a/lib/keyring/util/platform_.py b/lib/keyring/util/platform_.py new file mode 100644 index 0000000..cb5f77c --- /dev/null +++ b/lib/keyring/util/platform_.py @@ -0,0 +1,40 @@ +import os +import pathlib +import platform + + +def _data_root_Windows(): + release, version, csd, ptype = platform.win32_ver() + root = pathlib.Path( + os.environ.get('LOCALAPPDATA', os.environ.get('ProgramData', '.')) + ) + return root / 'Python Keyring' + + +def _data_root_Linux(): + """ + Use freedesktop.org Base Dir Specification to determine storage + location. + """ + fallback = pathlib.Path.home() / '.local/share' + root = os.environ.get('XDG_DATA_HOME', None) or fallback + return pathlib.Path(root, 'python_keyring') + + +_config_root_Windows = _data_root_Windows + + +def _config_root_Linux(): + """ + Use freedesktop.org Base Dir Specification to determine config + location. + """ + fallback = pathlib.Path.home() / '.config' + key = 'XDG_CONFIG_HOME' + root = os.environ.get(key, None) or fallback + return pathlib.Path(root, 'python_keyring') + + +# by default, use Unix convention +data_root = globals().get('_data_root_' + platform.system(), _data_root_Linux) +config_root = globals().get('_config_root_' + platform.system(), _config_root_Linux) diff --git a/lib/more_itertools-10.8.0.dist-info/INSTALLER b/lib/more_itertools-10.8.0.dist-info/INSTALLER new file mode 100644 index 0000000..a1b589e --- /dev/null +++ b/lib/more_itertools-10.8.0.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/lib/more_itertools-10.8.0.dist-info/METADATA b/lib/more_itertools-10.8.0.dist-info/METADATA new file mode 100644 index 0000000..bb7a3db --- /dev/null +++ b/lib/more_itertools-10.8.0.dist-info/METADATA @@ -0,0 +1,283 @@ +Metadata-Version: 2.4 +Name: more-itertools +Version: 10.8.0 +Summary: More routines for operating on iterables, beyond itertools +Keywords: itertools,iterator,iteration,filter,peek,peekable,chunk,chunked +Author-email: Erik Rose +Requires-Python: >=3.9 +Description-Content-Type: text/x-rst +License-Expression: MIT +Classifier: Development Status :: 5 - Production/Stable +Classifier: Intended Audience :: Developers +Classifier: Natural Language :: English +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: 3.12 +Classifier: Programming Language :: Python :: 3.13 +Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Programming Language :: Python :: Implementation :: CPython +Classifier: Programming Language :: Python :: Implementation :: PyPy +Classifier: Topic :: Software Development :: Libraries +License-File: LICENSE +Project-URL: Documentation, https://more-itertools.readthedocs.io/en/stable/ +Project-URL: Homepage, https://github.com/more-itertools/more-itertools + +============== +More Itertools +============== + +.. image:: https://readthedocs.org/projects/more-itertools/badge/?version=latest + :target: https://more-itertools.readthedocs.io/en/stable/ + +Python's ``itertools`` library is a gem - you can compose elegant solutions +for a variety of problems with the functions it provides. In ``more-itertools`` +we collect additional building blocks, recipes, and routines for working with +Python iterables. + ++------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| Grouping | `chunked `_, | +| | `ichunked `_, | +| | `chunked_even `_, | +| | `sliced `_, | +| | `constrained_batches `_, | +| | `distribute `_, | +| | `divide `_, | +| | `split_at `_, | +| | `split_before `_, | +| | `split_after `_, | +| | `split_into `_, | +| | `split_when `_, | +| | `bucket `_, | +| | `unzip `_, | +| | `batched `_, | +| | `grouper `_, | +| | `partition `_, | +| | `transpose `_ | ++------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| Lookahead and lookback | `spy `_, | +| | `peekable `_, | +| | `seekable `_ | ++------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| Windowing | `windowed `_, | +| | `substrings `_, | +| | `substrings_indexes `_, | +| | `stagger `_, | +| | `windowed_complete `_, | +| | `pairwise `_, | +| | `triplewise `_, | +| | `sliding_window `_, | +| | `subslices `_ | ++------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| Augmenting | `count_cycle `_, | +| | `intersperse `_, | +| | `padded `_, | +| | `repeat_each `_, | +| | `mark_ends `_, | +| | `repeat_last `_, | +| | `adjacent `_, | +| | `groupby_transform `_, | +| | `pad_none `_, | +| | `ncycles `_ | ++------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| Combining | `collapse `_, | +| | `sort_together `_, | +| | `interleave `_, | +| | `interleave_longest `_, | +| | `interleave_evenly `_, | +| | `interleave_randomly `_, | +| | `zip_offset `_, | +| | `zip_equal `_, | +| | `zip_broadcast `_, | +| | `flatten `_, | +| | `roundrobin `_, | +| | `prepend `_, | +| | `value_chain `_, | +| | `partial_product `_ | ++------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| Summarizing | `ilen `_, | +| | `unique_to_each `_, | +| | `sample `_, | +| | `consecutive_groups `_, | +| | `run_length `_, | +| | `map_reduce `_, | +| | `join_mappings `_, | +| | `exactly_n `_, | +| | `is_sorted `_, | +| | `all_equal `_, | +| | `all_unique `_, | +| | `argmin `_, | +| | `argmax `_, | +| | `minmax `_, | +| | `first_true `_, | +| | `quantify `_, | +| | `iequals `_ | ++------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| Selecting | `islice_extended `_, | +| | `first `_, | +| | `last `_, | +| | `one `_, | +| | `only `_, | +| | `strictly_n `_, | +| | `strip `_, | +| | `lstrip `_, | +| | `rstrip `_, | +| | `filter_except `_, | +| | `map_except `_, | +| | `filter_map `_, | +| | `iter_suppress `_, | +| | `nth_or_last `_, | +| | `extract `_, | +| | `unique_in_window `_, | +| | `before_and_after `_, | +| | `nth `_, | +| | `take `_, | +| | `tail `_, | +| | `unique_everseen `_, | +| | `unique_justseen `_, | +| | `unique `_, | +| | `duplicates_everseen `_, | +| | `duplicates_justseen `_, | +| | `classify_unique `_, | +| | `longest_common_prefix `_, | +| | `takewhile_inclusive `_ | ++------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| Math | `dft `_, | +| | `idft `_, | +| | `convolve `_, | +| | `dotproduct `_, | +| | `matmul `_, | +| | `polynomial_from_roots `_, | +| | `polynomial_derivative `_, | +| | `polynomial_eval `_, | +| | `sum_of_squares `_, | +| | `running_median `_, | +| | `totient `_ | ++------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| Integer math | `factor `_, | +| | `is_prime `_, | +| | `multinomial `_, | +| | `nth_prime `_, | +| | `sieve `_ | ++------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| Combinatorics | `circular_shifts `_, | +| | `derangements `_, | +| | `gray_product `_, | +| | `outer_product `_, | +| | `partitions `_, | +| | `set_partitions `_, | +| | `powerset `_, | +| | `powerset_of_sets `_ | +| +-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| | `distinct_combinations `_, | +| | `distinct_permutations `_ | +| +-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| | `combination_index `_, | +| | `combination_with_replacement_index `_, | +| | `permutation_index `_, | +| | `product_index `_ | +| +-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| | `nth_combination `_, | +| | `nth_combination_with_replacement `_, | +| | `nth_permutation `_, | +| | `nth_product `_ | +| +-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| | `random_combination `_, | +| | `random_combination_with_replacement `_, | +| | `random_permutation `_, | +| | `random_product `_ | ++------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| Wrapping | `always_iterable `_, | +| | `always_reversible `_, | +| | `countable `_, | +| | `consumer `_, | +| | `with_iter `_, | +| | `iter_except `_ | ++------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| Others | `locate `_, | +| | `rlocate `_, | +| | `replace `_, | +| | `numeric_range `_, | +| | `side_effect `_, | +| | `iterate `_, | +| | `loops `_, | +| | `difference `_, | +| | `make_decorator `_, | +| | `SequenceView `_, | +| | `time_limited `_, | +| | `map_if `_, | +| | `iter_index `_, | +| | `consume `_, | +| | `tabulate `_, | +| | `repeatfunc `_, | +| | `reshape `_, | +| | `doublestarmap `_ | ++------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + + +Getting started +=============== + +To get started, install the library with `pip `_: + +.. code-block:: shell + + pip install more-itertools + +The recipes from the `itertools docs `_ +are included in the top-level package: + +.. code-block:: python + + >>> from more_itertools import flatten + >>> iterable = [(0, 1), (2, 3)] + >>> list(flatten(iterable)) + [0, 1, 2, 3] + +Several new recipes are available as well: + +.. code-block:: python + + >>> from more_itertools import chunked + >>> iterable = [0, 1, 2, 3, 4, 5, 6, 7, 8] + >>> list(chunked(iterable, 3)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8]] + + >>> from more_itertools import spy + >>> iterable = (x * x for x in range(1, 6)) + >>> head, iterable = spy(iterable, n=3) + >>> list(head) + [1, 4, 9] + >>> list(iterable) + [1, 4, 9, 16, 25] + + + +For the full listing of functions, see the `API documentation `_. + + +Links elsewhere +=============== + +Blog posts about ``more-itertools``: + +* `Yo, I heard you like decorators `__ +* `Tour of Python Itertools `__ (`Alternate `__) +* `Real-World Python More Itertools `_ + + +Development +=========== + +``more-itertools`` is maintained by `@erikrose `_ +and `@bbayles `_, with help from `many others `_. +If you have a problem or suggestion, please file a bug or pull request in this +repository. Thanks for contributing! + + +Version History +=============== + +The version history can be found in `documentation `_. + diff --git a/lib/more_itertools-10.8.0.dist-info/RECORD b/lib/more_itertools-10.8.0.dist-info/RECORD new file mode 100644 index 0000000..03ae95c --- /dev/null +++ b/lib/more_itertools-10.8.0.dist-info/RECORD @@ -0,0 +1,15 @@ +more_itertools-10.8.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +more_itertools-10.8.0.dist-info/METADATA,sha256=arNRUUWr5YsGfwh8hnYxz0z11lP-2BuWQu4SCGw5BLg,39413 +more_itertools-10.8.0.dist-info/RECORD,, +more_itertools-10.8.0.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82 +more_itertools-10.8.0.dist-info/licenses/LICENSE,sha256=CfHIyelBrz5YTVlkHqm4fYPAyw_QB-te85Gn4mQ8GkY,1053 +more_itertools/__init__.py,sha256=5F7E_zpoGcEBW_T_3WE0WYYt8j-gJodIuiBcOJxrOv8,149 +more_itertools/__init__.pyi,sha256=5B3eTzON1BBuOLob1vCflyEb2lSd6usXQQ-Cv-hXkeA,43 +more_itertools/__pycache__/__init__.cpython-314.pyc,, +more_itertools/__pycache__/more.cpython-314.pyc,, +more_itertools/__pycache__/recipes.cpython-314.pyc,, +more_itertools/more.py,sha256=mNPKKu5UI7lRL460vgm0QTCWFiGMVCMosSPxVSdibos,163690 +more_itertools/more.pyi,sha256=fpEgNX3O66wY5cnT-s5VYDKNUpAcaCyU3iP84It3OOM,27119 +more_itertools/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +more_itertools/recipes.py,sha256=Ma-kuBNZDFhaQDbIJgRmnrG86WzaupbOyUV3v8je3xw,41811 +more_itertools/recipes.pyi,sha256=LNRwN-OL3nkMfQAqx-PPc1fBaetUObb_Z6mdePyzh1c,6226 diff --git a/lib/more_itertools-10.8.0.dist-info/WHEEL b/lib/more_itertools-10.8.0.dist-info/WHEEL new file mode 100644 index 0000000..d8b9936 --- /dev/null +++ b/lib/more_itertools-10.8.0.dist-info/WHEEL @@ -0,0 +1,4 @@ +Wheel-Version: 1.0 +Generator: flit 3.12.0 +Root-Is-Purelib: true +Tag: py3-none-any diff --git a/lib/more_itertools-10.8.0.dist-info/licenses/LICENSE b/lib/more_itertools-10.8.0.dist-info/licenses/LICENSE new file mode 100644 index 0000000..0a523be --- /dev/null +++ b/lib/more_itertools-10.8.0.dist-info/licenses/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2012 Erik Rose + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/lib/more_itertools/__init__.py b/lib/more_itertools/__init__.py new file mode 100644 index 0000000..24216c5 --- /dev/null +++ b/lib/more_itertools/__init__.py @@ -0,0 +1,6 @@ +"""More routines for operating on iterables, beyond itertools""" + +from .more import * # noqa +from .recipes import * # noqa + +__version__ = '10.8.0' diff --git a/lib/more_itertools/__init__.pyi b/lib/more_itertools/__init__.pyi new file mode 100644 index 0000000..96f6e36 --- /dev/null +++ b/lib/more_itertools/__init__.pyi @@ -0,0 +1,2 @@ +from .more import * +from .recipes import * diff --git a/lib/more_itertools/__pycache__/__init__.cpython-314.pyc b/lib/more_itertools/__pycache__/__init__.cpython-314.pyc new file mode 100644 index 0000000..a601ec0 Binary files /dev/null and b/lib/more_itertools/__pycache__/__init__.cpython-314.pyc differ diff --git a/lib/more_itertools/__pycache__/more.cpython-314.pyc b/lib/more_itertools/__pycache__/more.cpython-314.pyc new file mode 100644 index 0000000..c9bb61c Binary files /dev/null and b/lib/more_itertools/__pycache__/more.cpython-314.pyc differ diff --git a/lib/more_itertools/__pycache__/recipes.cpython-314.pyc b/lib/more_itertools/__pycache__/recipes.cpython-314.pyc new file mode 100644 index 0000000..4029046 Binary files /dev/null and b/lib/more_itertools/__pycache__/recipes.cpython-314.pyc differ diff --git a/lib/more_itertools/more.py b/lib/more_itertools/more.py new file mode 100755 index 0000000..bf50195 --- /dev/null +++ b/lib/more_itertools/more.py @@ -0,0 +1,5303 @@ +import math +import warnings + +from collections import Counter, defaultdict, deque, abc +from collections.abc import Sequence +from contextlib import suppress +from functools import cached_property, partial, reduce, wraps +from heapq import heapify, heapreplace +from itertools import ( + chain, + combinations, + compress, + count, + cycle, + dropwhile, + groupby, + islice, + permutations, + repeat, + starmap, + takewhile, + tee, + zip_longest, + product, +) +from math import comb, e, exp, factorial, floor, fsum, log, log1p, perm, tau +from math import ceil +from queue import Empty, Queue +from random import random, randrange, shuffle, uniform +from operator import ( + attrgetter, + is_not, + itemgetter, + lt, + mul, + neg, + sub, + gt, +) +from sys import hexversion, maxsize +from time import monotonic + +from .recipes import ( + _marker, + _zip_equal, + UnequalIterablesError, + consume, + first_true, + flatten, + is_prime, + nth, + powerset, + sieve, + take, + unique_everseen, + all_equal, + batched, +) + +__all__ = [ + 'AbortThread', + 'SequenceView', + 'UnequalIterablesError', + 'adjacent', + 'all_unique', + 'always_iterable', + 'always_reversible', + 'argmax', + 'argmin', + 'bucket', + 'callback_iter', + 'chunked', + 'chunked_even', + 'circular_shifts', + 'collapse', + 'combination_index', + 'combination_with_replacement_index', + 'consecutive_groups', + 'constrained_batches', + 'consumer', + 'count_cycle', + 'countable', + 'derangements', + 'dft', + 'difference', + 'distinct_combinations', + 'distinct_permutations', + 'distribute', + 'divide', + 'doublestarmap', + 'duplicates_everseen', + 'duplicates_justseen', + 'classify_unique', + 'exactly_n', + 'extract', + 'filter_except', + 'filter_map', + 'first', + 'gray_product', + 'groupby_transform', + 'ichunked', + 'iequals', + 'idft', + 'ilen', + 'interleave', + 'interleave_evenly', + 'interleave_longest', + 'interleave_randomly', + 'intersperse', + 'is_sorted', + 'islice_extended', + 'iterate', + 'iter_suppress', + 'join_mappings', + 'last', + 'locate', + 'longest_common_prefix', + 'lstrip', + 'make_decorator', + 'map_except', + 'map_if', + 'map_reduce', + 'mark_ends', + 'minmax', + 'nth_or_last', + 'nth_permutation', + 'nth_prime', + 'nth_product', + 'nth_combination_with_replacement', + 'numeric_range', + 'one', + 'only', + 'outer_product', + 'padded', + 'partial_product', + 'partitions', + 'peekable', + 'permutation_index', + 'powerset_of_sets', + 'product_index', + 'raise_', + 'repeat_each', + 'repeat_last', + 'replace', + 'rlocate', + 'rstrip', + 'run_length', + 'sample', + 'seekable', + 'set_partitions', + 'side_effect', + 'sliced', + 'sort_together', + 'split_after', + 'split_at', + 'split_before', + 'split_into', + 'split_when', + 'spy', + 'stagger', + 'strip', + 'strictly_n', + 'substrings', + 'substrings_indexes', + 'takewhile_inclusive', + 'time_limited', + 'unique_in_window', + 'unique_to_each', + 'unzip', + 'value_chain', + 'windowed', + 'windowed_complete', + 'with_iter', + 'zip_broadcast', + 'zip_equal', + 'zip_offset', +] + +# math.sumprod is available for Python 3.12+ +try: + from math import sumprod as _fsumprod + +except ImportError: # pragma: no cover + # Extended precision algorithms from T. J. Dekker, + # "A Floating-Point Technique for Extending the Available Precision" + # https://csclub.uwaterloo.ca/~pbarfuss/dekker1971.pdf + # Formulas: (5.5) (5.6) and (5.8). Code: mul12() + + def dl_split(x: float): + "Split a float into two half-precision components." + t = x * 134217729.0 # Veltkamp constant = 2.0 ** 27 + 1 + hi = t - (t - x) + lo = x - hi + return hi, lo + + def dl_mul(x, y): + "Lossless multiplication." + xx_hi, xx_lo = dl_split(x) + yy_hi, yy_lo = dl_split(y) + p = xx_hi * yy_hi + q = xx_hi * yy_lo + xx_lo * yy_hi + z = p + q + zz = p - z + q + xx_lo * yy_lo + return z, zz + + def _fsumprod(p, q): + return fsum(chain.from_iterable(map(dl_mul, p, q))) + + +def chunked(iterable, n, strict=False): + """Break *iterable* into lists of length *n*: + + >>> list(chunked([1, 2, 3, 4, 5, 6], 3)) + [[1, 2, 3], [4, 5, 6]] + + By the default, the last yielded list will have fewer than *n* elements + if the length of *iterable* is not divisible by *n*: + + >>> list(chunked([1, 2, 3, 4, 5, 6, 7, 8], 3)) + [[1, 2, 3], [4, 5, 6], [7, 8]] + + To use a fill-in value instead, see the :func:`grouper` recipe. + + If the length of *iterable* is not divisible by *n* and *strict* is + ``True``, then ``ValueError`` will be raised before the last + list is yielded. + + """ + iterator = iter(partial(take, n, iter(iterable)), []) + if strict: + if n is None: + raise ValueError('n must not be None when using strict mode.') + + def ret(): + for chunk in iterator: + if len(chunk) != n: + raise ValueError('iterable is not divisible by n.') + yield chunk + + return ret() + else: + return iterator + + +def first(iterable, default=_marker): + """Return the first item of *iterable*, or *default* if *iterable* is + empty. + + >>> first([0, 1, 2, 3]) + 0 + >>> first([], 'some default') + 'some default' + + If *default* is not provided and there are no items in the iterable, + raise ``ValueError``. + + :func:`first` is useful when you have a generator of expensive-to-retrieve + values and want any arbitrary one. It is marginally shorter than + ``next(iter(iterable), default)``. + + """ + for item in iterable: + return item + if default is _marker: + raise ValueError( + 'first() was called on an empty iterable, ' + 'and no default value was provided.' + ) + return default + + +def last(iterable, default=_marker): + """Return the last item of *iterable*, or *default* if *iterable* is + empty. + + >>> last([0, 1, 2, 3]) + 3 + >>> last([], 'some default') + 'some default' + + If *default* is not provided and there are no items in the iterable, + raise ``ValueError``. + """ + try: + if isinstance(iterable, Sequence): + return iterable[-1] + # Work around https://bugs.python.org/issue38525 + if getattr(iterable, '__reversed__', None): + return next(reversed(iterable)) + return deque(iterable, maxlen=1)[-1] + except (IndexError, TypeError, StopIteration): + if default is _marker: + raise ValueError( + 'last() was called on an empty iterable, ' + 'and no default value was provided.' + ) + return default + + +def nth_or_last(iterable, n, default=_marker): + """Return the nth or the last item of *iterable*, + or *default* if *iterable* is empty. + + >>> nth_or_last([0, 1, 2, 3], 2) + 2 + >>> nth_or_last([0, 1], 2) + 1 + >>> nth_or_last([], 0, 'some default') + 'some default' + + If *default* is not provided and there are no items in the iterable, + raise ``ValueError``. + """ + return last(islice(iterable, n + 1), default=default) + + +class peekable: + """Wrap an iterator to allow lookahead and prepending elements. + + Call :meth:`peek` on the result to get the value that will be returned + by :func:`next`. This won't advance the iterator: + + >>> p = peekable(['a', 'b']) + >>> p.peek() + 'a' + >>> next(p) + 'a' + + Pass :meth:`peek` a default value to return that instead of raising + ``StopIteration`` when the iterator is exhausted. + + >>> p = peekable([]) + >>> p.peek('hi') + 'hi' + + peekables also offer a :meth:`prepend` method, which "inserts" items + at the head of the iterable: + + >>> p = peekable([1, 2, 3]) + >>> p.prepend(10, 11, 12) + >>> next(p) + 10 + >>> p.peek() + 11 + >>> list(p) + [11, 12, 1, 2, 3] + + peekables can be indexed. Index 0 is the item that will be returned by + :func:`next`, index 1 is the item after that, and so on: + The values up to the given index will be cached. + + >>> p = peekable(['a', 'b', 'c', 'd']) + >>> p[0] + 'a' + >>> p[1] + 'b' + >>> next(p) + 'a' + + Negative indexes are supported, but be aware that they will cache the + remaining items in the source iterator, which may require significant + storage. + + To check whether a peekable is exhausted, check its truth value: + + >>> p = peekable(['a', 'b']) + >>> if p: # peekable has items + ... list(p) + ['a', 'b'] + >>> if not p: # peekable is exhausted + ... list(p) + [] + + """ + + def __init__(self, iterable): + self._it = iter(iterable) + self._cache = deque() + + def __iter__(self): + return self + + def __bool__(self): + try: + self.peek() + except StopIteration: + return False + return True + + def peek(self, default=_marker): + """Return the item that will be next returned from ``next()``. + + Return ``default`` if there are no items left. If ``default`` is not + provided, raise ``StopIteration``. + + """ + if not self._cache: + try: + self._cache.append(next(self._it)) + except StopIteration: + if default is _marker: + raise + return default + return self._cache[0] + + def prepend(self, *items): + """Stack up items to be the next ones returned from ``next()`` or + ``self.peek()``. The items will be returned in + first in, first out order:: + + >>> p = peekable([1, 2, 3]) + >>> p.prepend(10, 11, 12) + >>> next(p) + 10 + >>> list(p) + [11, 12, 1, 2, 3] + + It is possible, by prepending items, to "resurrect" a peekable that + previously raised ``StopIteration``. + + >>> p = peekable([]) + >>> next(p) + Traceback (most recent call last): + ... + StopIteration + >>> p.prepend(1) + >>> next(p) + 1 + >>> next(p) + Traceback (most recent call last): + ... + StopIteration + + """ + self._cache.extendleft(reversed(items)) + + def __next__(self): + if self._cache: + return self._cache.popleft() + + return next(self._it) + + def _get_slice(self, index): + # Normalize the slice's arguments + step = 1 if (index.step is None) else index.step + if step > 0: + start = 0 if (index.start is None) else index.start + stop = maxsize if (index.stop is None) else index.stop + elif step < 0: + start = -1 if (index.start is None) else index.start + stop = (-maxsize - 1) if (index.stop is None) else index.stop + else: + raise ValueError('slice step cannot be zero') + + # If either the start or stop index is negative, we'll need to cache + # the rest of the iterable in order to slice from the right side. + if (start < 0) or (stop < 0): + self._cache.extend(self._it) + # Otherwise we'll need to find the rightmost index and cache to that + # point. + else: + n = min(max(start, stop) + 1, maxsize) + cache_len = len(self._cache) + if n >= cache_len: + self._cache.extend(islice(self._it, n - cache_len)) + + return list(self._cache)[index] + + def __getitem__(self, index): + if isinstance(index, slice): + return self._get_slice(index) + + cache_len = len(self._cache) + if index < 0: + self._cache.extend(self._it) + elif index >= cache_len: + self._cache.extend(islice(self._it, index + 1 - cache_len)) + + return self._cache[index] + + +def consumer(func): + """Decorator that automatically advances a PEP-342-style "reverse iterator" + to its first yield point so you don't have to call ``next()`` on it + manually. + + >>> @consumer + ... def tally(): + ... i = 0 + ... while True: + ... print('Thing number %s is %s.' % (i, (yield))) + ... i += 1 + ... + >>> t = tally() + >>> t.send('red') + Thing number 0 is red. + >>> t.send('fish') + Thing number 1 is fish. + + Without the decorator, you would have to call ``next(t)`` before + ``t.send()`` could be used. + + """ + + @wraps(func) + def wrapper(*args, **kwargs): + gen = func(*args, **kwargs) + next(gen) + return gen + + return wrapper + + +def ilen(iterable): + """Return the number of items in *iterable*. + + For example, there are 168 prime numbers below 1,000: + + >>> ilen(sieve(1000)) + 168 + + Equivalent to, but faster than:: + + def ilen(iterable): + count = 0 + for _ in iterable: + count += 1 + return count + + This fully consumes the iterable, so handle with care. + + """ + # This is the "most beautiful of the fast variants" of this function. + # If you think you can improve on it, please ensure that your version + # is both 10x faster and 10x more beautiful. + return sum(compress(repeat(1), zip(iterable))) + + +def iterate(func, start): + """Return ``start``, ``func(start)``, ``func(func(start))``, ... + + Produces an infinite iterator. To add a stopping condition, + use :func:`take`, ``takewhile``, or :func:`takewhile_inclusive`:. + + >>> take(10, iterate(lambda x: 2*x, 1)) + [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + + >>> collatz = lambda x: 3*x + 1 if x%2==1 else x // 2 + >>> list(takewhile_inclusive(lambda x: x!=1, iterate(collatz, 10))) + [10, 5, 16, 8, 4, 2, 1] + + """ + with suppress(StopIteration): + while True: + yield start + start = func(start) + + +def with_iter(context_manager): + """Wrap an iterable in a ``with`` statement, so it closes once exhausted. + + For example, this will close the file when the iterator is exhausted:: + + upper_lines = (line.upper() for line in with_iter(open('foo'))) + + Any context manager which returns an iterable is a candidate for + ``with_iter``. + + """ + with context_manager as iterable: + yield from iterable + + +def one(iterable, too_short=None, too_long=None): + """Return the first item from *iterable*, which is expected to contain only + that item. Raise an exception if *iterable* is empty or has more than one + item. + + :func:`one` is useful for ensuring that an iterable contains only one item. + For example, it can be used to retrieve the result of a database query + that is expected to return a single row. + + If *iterable* is empty, ``ValueError`` will be raised. You may specify a + different exception with the *too_short* keyword: + + >>> it = [] + >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError: too few items in iterable (expected 1)' + >>> too_short = IndexError('too few items') + >>> one(it, too_short=too_short) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + IndexError: too few items + + Similarly, if *iterable* contains more than one item, ``ValueError`` will + be raised. You may specify a different exception with the *too_long* + keyword: + + >>> it = ['too', 'many'] + >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError: Expected exactly one item in iterable, but got 'too', + 'many', and perhaps more. + >>> too_long = RuntimeError + >>> one(it, too_long=too_long) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + RuntimeError + + Note that :func:`one` attempts to advance *iterable* twice to ensure there + is only one item. See :func:`spy` or :func:`peekable` to check iterable + contents less destructively. + + """ + iterator = iter(iterable) + for first in iterator: + for second in iterator: + msg = ( + f'Expected exactly one item in iterable, but got {first!r}, ' + f'{second!r}, and perhaps more.' + ) + raise too_long or ValueError(msg) + return first + raise too_short or ValueError('too few items in iterable (expected 1)') + + +def raise_(exception, *args): + raise exception(*args) + + +def strictly_n(iterable, n, too_short=None, too_long=None): + """Validate that *iterable* has exactly *n* items and return them if + it does. If it has fewer than *n* items, call function *too_short* + with the actual number of items. If it has more than *n* items, call function + *too_long* with the number ``n + 1``. + + >>> iterable = ['a', 'b', 'c', 'd'] + >>> n = 4 + >>> list(strictly_n(iterable, n)) + ['a', 'b', 'c', 'd'] + + Note that the returned iterable must be consumed in order for the check to + be made. + + By default, *too_short* and *too_long* are functions that raise + ``ValueError``. + + >>> list(strictly_n('ab', 3)) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError: too few items in iterable (got 2) + + >>> list(strictly_n('abc', 2)) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError: too many items in iterable (got at least 3) + + You can instead supply functions that do something else. + *too_short* will be called with the number of items in *iterable*. + *too_long* will be called with `n + 1`. + + >>> def too_short(item_count): + ... raise RuntimeError + >>> it = strictly_n('abcd', 6, too_short=too_short) + >>> list(it) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + RuntimeError + + >>> def too_long(item_count): + ... print('The boss is going to hear about this') + >>> it = strictly_n('abcdef', 4, too_long=too_long) + >>> list(it) + The boss is going to hear about this + ['a', 'b', 'c', 'd'] + + """ + if too_short is None: + too_short = lambda item_count: raise_( + ValueError, + f'Too few items in iterable (got {item_count})', + ) + + if too_long is None: + too_long = lambda item_count: raise_( + ValueError, + f'Too many items in iterable (got at least {item_count})', + ) + + it = iter(iterable) + + sent = 0 + for item in islice(it, n): + yield item + sent += 1 + + if sent < n: + too_short(sent) + return + + for item in it: + too_long(n + 1) + return + + +def distinct_permutations(iterable, r=None): + """Yield successive distinct permutations of the elements in *iterable*. + + >>> sorted(distinct_permutations([1, 0, 1])) + [(0, 1, 1), (1, 0, 1), (1, 1, 0)] + + Equivalent to yielding from ``set(permutations(iterable))``, except + duplicates are not generated and thrown away. For larger input sequences + this is much more efficient. + + Duplicate permutations arise when there are duplicated elements in the + input iterable. The number of items returned is + `n! / (x_1! * x_2! * ... * x_n!)`, where `n` is the total number of + items input, and each `x_i` is the count of a distinct item in the input + sequence. The function :func:`multinomial` computes this directly. + + If *r* is given, only the *r*-length permutations are yielded. + + >>> sorted(distinct_permutations([1, 0, 1], r=2)) + [(0, 1), (1, 0), (1, 1)] + >>> sorted(distinct_permutations(range(3), r=2)) + [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)] + + *iterable* need not be sortable, but note that using equal (``x == y``) + but non-identical (``id(x) != id(y)``) elements may produce surprising + behavior. For example, ``1`` and ``True`` are equal but non-identical: + + >>> list(distinct_permutations([1, True, '3'])) # doctest: +SKIP + [ + (1, True, '3'), + (1, '3', True), + ('3', 1, True) + ] + >>> list(distinct_permutations([1, 2, '3'])) # doctest: +SKIP + [ + (1, 2, '3'), + (1, '3', 2), + (2, 1, '3'), + (2, '3', 1), + ('3', 1, 2), + ('3', 2, 1) + ] + """ + + # Algorithm: https://w.wiki/Qai + def _full(A): + while True: + # Yield the permutation we have + yield tuple(A) + + # Find the largest index i such that A[i] < A[i + 1] + for i in range(size - 2, -1, -1): + if A[i] < A[i + 1]: + break + # If no such index exists, this permutation is the last one + else: + return + + # Find the largest index j greater than j such that A[i] < A[j] + for j in range(size - 1, i, -1): + if A[i] < A[j]: + break + + # Swap the value of A[i] with that of A[j], then reverse the + # sequence from A[i + 1] to form the new permutation + A[i], A[j] = A[j], A[i] + A[i + 1 :] = A[: i - size : -1] # A[i + 1:][::-1] + + # Algorithm: modified from the above + def _partial(A, r): + # Split A into the first r items and the last r items + head, tail = A[:r], A[r:] + right_head_indexes = range(r - 1, -1, -1) + left_tail_indexes = range(len(tail)) + + while True: + # Yield the permutation we have + yield tuple(head) + + # Starting from the right, find the first index of the head with + # value smaller than the maximum value of the tail - call it i. + pivot = tail[-1] + for i in right_head_indexes: + if head[i] < pivot: + break + pivot = head[i] + else: + return + + # Starting from the left, find the first value of the tail + # with a value greater than head[i] and swap. + for j in left_tail_indexes: + if tail[j] > head[i]: + head[i], tail[j] = tail[j], head[i] + break + # If we didn't find one, start from the right and find the first + # index of the head with a value greater than head[i] and swap. + else: + for j in right_head_indexes: + if head[j] > head[i]: + head[i], head[j] = head[j], head[i] + break + + # Reverse head[i + 1:] and swap it with tail[:r - (i + 1)] + tail += head[: i - r : -1] # head[i + 1:][::-1] + i += 1 + head[i:], tail[:] = tail[: r - i], tail[r - i :] + + items = list(iterable) + + try: + items.sort() + sortable = True + except TypeError: + sortable = False + + indices_dict = defaultdict(list) + + for item in items: + indices_dict[items.index(item)].append(item) + + indices = [items.index(item) for item in items] + indices.sort() + + equivalent_items = {k: cycle(v) for k, v in indices_dict.items()} + + def permuted_items(permuted_indices): + return tuple( + next(equivalent_items[index]) for index in permuted_indices + ) + + size = len(items) + if r is None: + r = size + + # functools.partial(_partial, ... ) + algorithm = _full if (r == size) else partial(_partial, r=r) + + if 0 < r <= size: + if sortable: + return algorithm(items) + else: + return ( + permuted_items(permuted_indices) + for permuted_indices in algorithm(indices) + ) + + return iter(() if r else ((),)) + + +def derangements(iterable, r=None): + """Yield successive derangements of the elements in *iterable*. + + A derangement is a permutation in which no element appears at its original + index. In other words, a derangement is a permutation that has no fixed points. + + Suppose Alice, Bob, Carol, and Dave are playing Secret Santa. + The code below outputs all of the different ways to assign gift recipients + such that nobody is assigned to himself or herself: + + >>> for d in derangements(['Alice', 'Bob', 'Carol', 'Dave']): + ... print(', '.join(d)) + Bob, Alice, Dave, Carol + Bob, Carol, Dave, Alice + Bob, Dave, Alice, Carol + Carol, Alice, Dave, Bob + Carol, Dave, Alice, Bob + Carol, Dave, Bob, Alice + Dave, Alice, Bob, Carol + Dave, Carol, Alice, Bob + Dave, Carol, Bob, Alice + + If *r* is given, only the *r*-length derangements are yielded. + + >>> sorted(derangements(range(3), 2)) + [(1, 0), (1, 2), (2, 0)] + >>> sorted(derangements([0, 2, 3], 2)) + [(2, 0), (2, 3), (3, 0)] + + Elements are treated as unique based on their position, not on their value. + + Consider the Secret Santa example with two *different* people who have + the *same* name. Then there are two valid gift assignments even though + it might appear that a person is assigned to themselves: + + >>> names = ['Alice', 'Bob', 'Bob'] + >>> list(derangements(names)) + [('Bob', 'Bob', 'Alice'), ('Bob', 'Alice', 'Bob')] + + To avoid confusion, make the inputs distinct: + + >>> deduped = [f'{name}{index}' for index, name in enumerate(names)] + >>> list(derangements(deduped)) + [('Bob1', 'Bob2', 'Alice0'), ('Bob2', 'Alice0', 'Bob1')] + + The number of derangements of a set of size *n* is known as the + "subfactorial of n". For n > 0, the subfactorial is: + ``round(math.factorial(n) / math.e)``. + + References: + + * Article: https://www.numberanalytics.com/blog/ultimate-guide-to-derangements-in-combinatorics + * Sizes: https://oeis.org/A000166 + """ + xs = tuple(iterable) + ys = tuple(range(len(xs))) + return compress( + permutations(xs, r=r), + map(all, map(map, repeat(is_not), repeat(ys), permutations(ys, r=r))), + ) + + +def intersperse(e, iterable, n=1): + """Intersperse filler element *e* among the items in *iterable*, leaving + *n* items between each filler element. + + >>> list(intersperse('!', [1, 2, 3, 4, 5])) + [1, '!', 2, '!', 3, '!', 4, '!', 5] + + >>> list(intersperse(None, [1, 2, 3, 4, 5], n=2)) + [1, 2, None, 3, 4, None, 5] + + """ + if n == 0: + raise ValueError('n must be > 0') + elif n == 1: + # interleave(repeat(e), iterable) -> e, x_0, e, x_1, e, x_2... + # islice(..., 1, None) -> x_0, e, x_1, e, x_2... + return islice(interleave(repeat(e), iterable), 1, None) + else: + # interleave(filler, chunks) -> [e], [x_0, x_1], [e], [x_2, x_3]... + # islice(..., 1, None) -> [x_0, x_1], [e], [x_2, x_3]... + # flatten(...) -> x_0, x_1, e, x_2, x_3... + filler = repeat([e]) + chunks = chunked(iterable, n) + return flatten(islice(interleave(filler, chunks), 1, None)) + + +def unique_to_each(*iterables): + """Return the elements from each of the input iterables that aren't in the + other input iterables. + + For example, suppose you have a set of packages, each with a set of + dependencies:: + + {'pkg_1': {'A', 'B'}, 'pkg_2': {'B', 'C'}, 'pkg_3': {'B', 'D'}} + + If you remove one package, which dependencies can also be removed? + + If ``pkg_1`` is removed, then ``A`` is no longer necessary - it is not + associated with ``pkg_2`` or ``pkg_3``. Similarly, ``C`` is only needed for + ``pkg_2``, and ``D`` is only needed for ``pkg_3``:: + + >>> unique_to_each({'A', 'B'}, {'B', 'C'}, {'B', 'D'}) + [['A'], ['C'], ['D']] + + If there are duplicates in one input iterable that aren't in the others + they will be duplicated in the output. Input order is preserved:: + + >>> unique_to_each("mississippi", "missouri") + [['p', 'p'], ['o', 'u', 'r']] + + It is assumed that the elements of each iterable are hashable. + + """ + pool = [list(it) for it in iterables] + counts = Counter(chain.from_iterable(map(set, pool))) + uniques = {element for element in counts if counts[element] == 1} + return [list(filter(uniques.__contains__, it)) for it in pool] + + +def windowed(seq, n, fillvalue=None, step=1): + """Return a sliding window of width *n* over the given iterable. + + >>> all_windows = windowed([1, 2, 3, 4, 5], 3) + >>> list(all_windows) + [(1, 2, 3), (2, 3, 4), (3, 4, 5)] + + When the window is larger than the iterable, *fillvalue* is used in place + of missing values: + + >>> list(windowed([1, 2, 3], 4)) + [(1, 2, 3, None)] + + Each window will advance in increments of *step*: + + >>> list(windowed([1, 2, 3, 4, 5, 6], 3, fillvalue='!', step=2)) + [(1, 2, 3), (3, 4, 5), (5, 6, '!')] + + To slide into the iterable's items, use :func:`chain` to add filler items + to the left: + + >>> iterable = [1, 2, 3, 4] + >>> n = 3 + >>> padding = [None] * (n - 1) + >>> list(windowed(chain(padding, iterable), 3)) + [(None, None, 1), (None, 1, 2), (1, 2, 3), (2, 3, 4)] + """ + if n < 0: + raise ValueError('n must be >= 0') + if n == 0: + yield () + return + if step < 1: + raise ValueError('step must be >= 1') + + iterator = iter(seq) + + # Generate first window + window = deque(islice(iterator, n), maxlen=n) + + # Deal with the first window not being full + if not window: + return + if len(window) < n: + yield tuple(window) + ((fillvalue,) * (n - len(window))) + return + yield tuple(window) + + # Create the filler for the next windows. The padding ensures + # we have just enough elements to fill the last window. + padding = (fillvalue,) * (n - 1 if step >= n else step - 1) + filler = map(window.append, chain(iterator, padding)) + + # Generate the rest of the windows + for _ in islice(filler, step - 1, None, step): + yield tuple(window) + + +def substrings(iterable): + """Yield all of the substrings of *iterable*. + + >>> [''.join(s) for s in substrings('more')] + ['m', 'o', 'r', 'e', 'mo', 'or', 're', 'mor', 'ore', 'more'] + + Note that non-string iterables can also be subdivided. + + >>> list(substrings([0, 1, 2])) + [(0,), (1,), (2,), (0, 1), (1, 2), (0, 1, 2)] + + """ + # The length-1 substrings + seq = [] + for item in iterable: + seq.append(item) + yield (item,) + seq = tuple(seq) + item_count = len(seq) + + # And the rest + for n in range(2, item_count + 1): + for i in range(item_count - n + 1): + yield seq[i : i + n] + + +def substrings_indexes(seq, reverse=False): + """Yield all substrings and their positions in *seq* + + The items yielded will be a tuple of the form ``(substr, i, j)``, where + ``substr == seq[i:j]``. + + This function only works for iterables that support slicing, such as + ``str`` objects. + + >>> for item in substrings_indexes('more'): + ... print(item) + ('m', 0, 1) + ('o', 1, 2) + ('r', 2, 3) + ('e', 3, 4) + ('mo', 0, 2) + ('or', 1, 3) + ('re', 2, 4) + ('mor', 0, 3) + ('ore', 1, 4) + ('more', 0, 4) + + Set *reverse* to ``True`` to yield the same items in the opposite order. + + + """ + r = range(1, len(seq) + 1) + if reverse: + r = reversed(r) + return ( + (seq[i : i + L], i, i + L) for L in r for i in range(len(seq) - L + 1) + ) + + +class bucket: + """Wrap *iterable* and return an object that buckets the iterable into + child iterables based on a *key* function. + + >>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3'] + >>> s = bucket(iterable, key=lambda x: x[0]) # Bucket by 1st character + >>> sorted(list(s)) # Get the keys + ['a', 'b', 'c'] + >>> a_iterable = s['a'] + >>> next(a_iterable) + 'a1' + >>> next(a_iterable) + 'a2' + >>> list(s['b']) + ['b1', 'b2', 'b3'] + + The original iterable will be advanced and its items will be cached until + they are used by the child iterables. This may require significant storage. + + By default, attempting to select a bucket to which no items belong will + exhaust the iterable and cache all values. + If you specify a *validator* function, selected buckets will instead be + checked against it. + + >>> from itertools import count + >>> it = count(1, 2) # Infinite sequence of odd numbers + >>> key = lambda x: x % 10 # Bucket by last digit + >>> validator = lambda x: x in {1, 3, 5, 7, 9} # Odd digits only + >>> s = bucket(it, key=key, validator=validator) + >>> 2 in s + False + >>> list(s[2]) + [] + + """ + + def __init__(self, iterable, key, validator=None): + self._it = iter(iterable) + self._key = key + self._cache = defaultdict(deque) + self._validator = validator or (lambda x: True) + + def __contains__(self, value): + if not self._validator(value): + return False + + try: + item = next(self[value]) + except StopIteration: + return False + else: + self._cache[value].appendleft(item) + + return True + + def _get_values(self, value): + """ + Helper to yield items from the parent iterator that match *value*. + Items that don't match are stored in the local cache as they + are encountered. + """ + while True: + # If we've cached some items that match the target value, emit + # the first one and evict it from the cache. + if self._cache[value]: + yield self._cache[value].popleft() + # Otherwise we need to advance the parent iterator to search for + # a matching item, caching the rest. + else: + while True: + try: + item = next(self._it) + except StopIteration: + return + item_value = self._key(item) + if item_value == value: + yield item + break + elif self._validator(item_value): + self._cache[item_value].append(item) + + def __iter__(self): + for item in self._it: + item_value = self._key(item) + if self._validator(item_value): + self._cache[item_value].append(item) + + return iter(self._cache) + + def __getitem__(self, value): + if not self._validator(value): + return iter(()) + + return self._get_values(value) + + +def spy(iterable, n=1): + """Return a 2-tuple with a list containing the first *n* elements of + *iterable*, and an iterator with the same items as *iterable*. + This allows you to "look ahead" at the items in the iterable without + advancing it. + + There is one item in the list by default: + + >>> iterable = 'abcdefg' + >>> head, iterable = spy(iterable) + >>> head + ['a'] + >>> list(iterable) + ['a', 'b', 'c', 'd', 'e', 'f', 'g'] + + You may use unpacking to retrieve items instead of lists: + + >>> (head,), iterable = spy('abcdefg') + >>> head + 'a' + >>> (first, second), iterable = spy('abcdefg', 2) + >>> first + 'a' + >>> second + 'b' + + The number of items requested can be larger than the number of items in + the iterable: + + >>> iterable = [1, 2, 3, 4, 5] + >>> head, iterable = spy(iterable, 10) + >>> head + [1, 2, 3, 4, 5] + >>> list(iterable) + [1, 2, 3, 4, 5] + + """ + p, q = tee(iterable) + return take(n, q), p + + +def interleave(*iterables): + """Return a new iterable yielding from each iterable in turn, + until the shortest is exhausted. + + >>> list(interleave([1, 2, 3], [4, 5], [6, 7, 8])) + [1, 4, 6, 2, 5, 7] + + For a version that doesn't terminate after the shortest iterable is + exhausted, see :func:`interleave_longest`. + + """ + return chain.from_iterable(zip(*iterables)) + + +def interleave_longest(*iterables): + """Return a new iterable yielding from each iterable in turn, + skipping any that are exhausted. + + >>> list(interleave_longest([1, 2, 3], [4, 5], [6, 7, 8])) + [1, 4, 6, 2, 5, 7, 3, 8] + + This function produces the same output as :func:`roundrobin`, but may + perform better for some inputs (in particular when the number of iterables + is large). + + """ + for xs in zip_longest(*iterables, fillvalue=_marker): + for x in xs: + if x is not _marker: + yield x + + +def interleave_evenly(iterables, lengths=None): + """ + Interleave multiple iterables so that their elements are evenly distributed + throughout the output sequence. + + >>> iterables = [1, 2, 3, 4, 5], ['a', 'b'] + >>> list(interleave_evenly(iterables)) + [1, 2, 'a', 3, 4, 'b', 5] + + >>> iterables = [[1, 2, 3], [4, 5], [6, 7, 8]] + >>> list(interleave_evenly(iterables)) + [1, 6, 4, 2, 7, 3, 8, 5] + + This function requires iterables of known length. Iterables without + ``__len__()`` can be used by manually specifying lengths with *lengths*: + + >>> from itertools import combinations, repeat + >>> iterables = [combinations(range(4), 2), ['a', 'b', 'c']] + >>> lengths = [4 * (4 - 1) // 2, 3] + >>> list(interleave_evenly(iterables, lengths=lengths)) + [(0, 1), (0, 2), 'a', (0, 3), (1, 2), 'b', (1, 3), (2, 3), 'c'] + + Based on Bresenham's algorithm. + """ + if lengths is None: + try: + lengths = [len(it) for it in iterables] + except TypeError: + raise ValueError( + 'Iterable lengths could not be determined automatically. ' + 'Specify them with the lengths keyword.' + ) + elif len(iterables) != len(lengths): + raise ValueError('Mismatching number of iterables and lengths.') + + dims = len(lengths) + + # sort iterables by length, descending + lengths_permute = sorted( + range(dims), key=lambda i: lengths[i], reverse=True + ) + lengths_desc = [lengths[i] for i in lengths_permute] + iters_desc = [iter(iterables[i]) for i in lengths_permute] + + # the longest iterable is the primary one (Bresenham: the longest + # distance along an axis) + delta_primary, deltas_secondary = lengths_desc[0], lengths_desc[1:] + iter_primary, iters_secondary = iters_desc[0], iters_desc[1:] + errors = [delta_primary // dims] * len(deltas_secondary) + + to_yield = sum(lengths) + while to_yield: + yield next(iter_primary) + to_yield -= 1 + # update errors for each secondary iterable + errors = [e - delta for e, delta in zip(errors, deltas_secondary)] + + # those iterables for which the error is negative are yielded + # ("diagonal step" in Bresenham) + for i, e_ in enumerate(errors): + if e_ < 0: + yield next(iters_secondary[i]) + to_yield -= 1 + errors[i] += delta_primary + + +def interleave_randomly(*iterables): + """Repeatedly select one of the input *iterables* at random and yield the next + item from it. + + >>> iterables = [1, 2, 3], 'abc', (True, False, None) + >>> list(interleave_randomly(*iterables)) # doctest: +SKIP + ['a', 'b', 1, 'c', True, False, None, 2, 3] + + The relative order of the items in each input iterable will preserved. Note the + sequences of items with this property are not equally likely to be generated. + + """ + iterators = [iter(e) for e in iterables] + while iterators: + idx = randrange(len(iterators)) + try: + yield next(iterators[idx]) + except StopIteration: + # equivalent to `list.pop` but slightly faster + iterators[idx] = iterators[-1] + del iterators[-1] + + +def collapse(iterable, base_type=None, levels=None): + """Flatten an iterable with multiple levels of nesting (e.g., a list of + lists of tuples) into non-iterable types. + + >>> iterable = [(1, 2), ([3, 4], [[5], [6]])] + >>> list(collapse(iterable)) + [1, 2, 3, 4, 5, 6] + + Binary and text strings are not considered iterable and + will not be collapsed. + + To avoid collapsing other types, specify *base_type*: + + >>> iterable = ['ab', ('cd', 'ef'), ['gh', 'ij']] + >>> list(collapse(iterable, base_type=tuple)) + ['ab', ('cd', 'ef'), 'gh', 'ij'] + + Specify *levels* to stop flattening after a certain level: + + >>> iterable = [('a', ['b']), ('c', ['d'])] + >>> list(collapse(iterable)) # Fully flattened + ['a', 'b', 'c', 'd'] + >>> list(collapse(iterable, levels=1)) # Only one level flattened + ['a', ['b'], 'c', ['d']] + + """ + stack = deque() + # Add our first node group, treat the iterable as a single node + stack.appendleft((0, repeat(iterable, 1))) + + while stack: + node_group = stack.popleft() + level, nodes = node_group + + # Check if beyond max level + if levels is not None and level > levels: + yield from nodes + continue + + for node in nodes: + # Check if done iterating + if isinstance(node, (str, bytes)) or ( + (base_type is not None) and isinstance(node, base_type) + ): + yield node + # Otherwise try to create child nodes + else: + try: + tree = iter(node) + except TypeError: + yield node + else: + # Save our current location + stack.appendleft(node_group) + # Append the new child node + stack.appendleft((level + 1, tree)) + # Break to process child node + break + + +def side_effect(func, iterable, chunk_size=None, before=None, after=None): + """Invoke *func* on each item in *iterable* (or on each *chunk_size* group + of items) before yielding the item. + + `func` must be a function that takes a single argument. Its return value + will be discarded. + + *before* and *after* are optional functions that take no arguments. They + will be executed before iteration starts and after it ends, respectively. + + `side_effect` can be used for logging, updating progress bars, or anything + that is not functionally "pure." + + Emitting a status message: + + >>> from more_itertools import consume + >>> func = lambda item: print('Received {}'.format(item)) + >>> consume(side_effect(func, range(2))) + Received 0 + Received 1 + + Operating on chunks of items: + + >>> pair_sums = [] + >>> func = lambda chunk: pair_sums.append(sum(chunk)) + >>> list(side_effect(func, [0, 1, 2, 3, 4, 5], 2)) + [0, 1, 2, 3, 4, 5] + >>> list(pair_sums) + [1, 5, 9] + + Writing to a file-like object: + + >>> from io import StringIO + >>> from more_itertools import consume + >>> f = StringIO() + >>> func = lambda x: print(x, file=f) + >>> before = lambda: print(u'HEADER', file=f) + >>> after = f.close + >>> it = [u'a', u'b', u'c'] + >>> consume(side_effect(func, it, before=before, after=after)) + >>> f.closed + True + + """ + try: + if before is not None: + before() + + if chunk_size is None: + for item in iterable: + func(item) + yield item + else: + for chunk in chunked(iterable, chunk_size): + func(chunk) + yield from chunk + finally: + if after is not None: + after() + + +def sliced(seq, n, strict=False): + """Yield slices of length *n* from the sequence *seq*. + + >>> list(sliced((1, 2, 3, 4, 5, 6), 3)) + [(1, 2, 3), (4, 5, 6)] + + By the default, the last yielded slice will have fewer than *n* elements + if the length of *seq* is not divisible by *n*: + + >>> list(sliced((1, 2, 3, 4, 5, 6, 7, 8), 3)) + [(1, 2, 3), (4, 5, 6), (7, 8)] + + If the length of *seq* is not divisible by *n* and *strict* is + ``True``, then ``ValueError`` will be raised before the last + slice is yielded. + + This function will only work for iterables that support slicing. + For non-sliceable iterables, see :func:`chunked`. + + """ + iterator = takewhile(len, (seq[i : i + n] for i in count(0, n))) + if strict: + + def ret(): + for _slice in iterator: + if len(_slice) != n: + raise ValueError("seq is not divisible by n.") + yield _slice + + return ret() + else: + return iterator + + +def split_at(iterable, pred, maxsplit=-1, keep_separator=False): + """Yield lists of items from *iterable*, where each list is delimited by + an item where callable *pred* returns ``True``. + + >>> list(split_at('abcdcba', lambda x: x == 'b')) + [['a'], ['c', 'd', 'c'], ['a']] + + >>> list(split_at(range(10), lambda n: n % 2 == 1)) + [[0], [2], [4], [6], [8], []] + + At most *maxsplit* splits are done. If *maxsplit* is not specified or -1, + then there is no limit on the number of splits: + + >>> list(split_at(range(10), lambda n: n % 2 == 1, maxsplit=2)) + [[0], [2], [4, 5, 6, 7, 8, 9]] + + By default, the delimiting items are not included in the output. + To include them, set *keep_separator* to ``True``. + + >>> list(split_at('abcdcba', lambda x: x == 'b', keep_separator=True)) + [['a'], ['b'], ['c', 'd', 'c'], ['b'], ['a']] + + """ + if maxsplit == 0: + yield list(iterable) + return + + buf = [] + it = iter(iterable) + for item in it: + if pred(item): + yield buf + if keep_separator: + yield [item] + if maxsplit == 1: + yield list(it) + return + buf = [] + maxsplit -= 1 + else: + buf.append(item) + yield buf + + +def split_before(iterable, pred, maxsplit=-1): + """Yield lists of items from *iterable*, where each list ends just before + an item for which callable *pred* returns ``True``: + + >>> list(split_before('OneTwo', lambda s: s.isupper())) + [['O', 'n', 'e'], ['T', 'w', 'o']] + + >>> list(split_before(range(10), lambda n: n % 3 == 0)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] + + At most *maxsplit* splits are done. If *maxsplit* is not specified or -1, + then there is no limit on the number of splits: + + >>> list(split_before(range(10), lambda n: n % 3 == 0, maxsplit=2)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8, 9]] + """ + if maxsplit == 0: + yield list(iterable) + return + + buf = [] + it = iter(iterable) + for item in it: + if pred(item) and buf: + yield buf + if maxsplit == 1: + yield [item, *it] + return + buf = [] + maxsplit -= 1 + buf.append(item) + if buf: + yield buf + + +def split_after(iterable, pred, maxsplit=-1): + """Yield lists of items from *iterable*, where each list ends with an + item where callable *pred* returns ``True``: + + >>> list(split_after('one1two2', lambda s: s.isdigit())) + [['o', 'n', 'e', '1'], ['t', 'w', 'o', '2']] + + >>> list(split_after(range(10), lambda n: n % 3 == 0)) + [[0], [1, 2, 3], [4, 5, 6], [7, 8, 9]] + + At most *maxsplit* splits are done. If *maxsplit* is not specified or -1, + then there is no limit on the number of splits: + + >>> list(split_after(range(10), lambda n: n % 3 == 0, maxsplit=2)) + [[0], [1, 2, 3], [4, 5, 6, 7, 8, 9]] + + """ + if maxsplit == 0: + yield list(iterable) + return + + buf = [] + it = iter(iterable) + for item in it: + buf.append(item) + if pred(item) and buf: + yield buf + if maxsplit == 1: + buf = list(it) + if buf: + yield buf + return + buf = [] + maxsplit -= 1 + if buf: + yield buf + + +def split_when(iterable, pred, maxsplit=-1): + """Split *iterable* into pieces based on the output of *pred*. + *pred* should be a function that takes successive pairs of items and + returns ``True`` if the iterable should be split in between them. + + For example, to find runs of increasing numbers, split the iterable when + element ``i`` is larger than element ``i + 1``: + + >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2], lambda x, y: x > y)) + [[1, 2, 3, 3], [2, 5], [2, 4], [2]] + + At most *maxsplit* splits are done. If *maxsplit* is not specified or -1, + then there is no limit on the number of splits: + + >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2], + ... lambda x, y: x > y, maxsplit=2)) + [[1, 2, 3, 3], [2, 5], [2, 4, 2]] + + """ + if maxsplit == 0: + yield list(iterable) + return + + it = iter(iterable) + try: + cur_item = next(it) + except StopIteration: + return + + buf = [cur_item] + for next_item in it: + if pred(cur_item, next_item): + yield buf + if maxsplit == 1: + yield [next_item, *it] + return + buf = [] + maxsplit -= 1 + + buf.append(next_item) + cur_item = next_item + + yield buf + + +def split_into(iterable, sizes): + """Yield a list of sequential items from *iterable* of length 'n' for each + integer 'n' in *sizes*. + + >>> list(split_into([1,2,3,4,5,6], [1,2,3])) + [[1], [2, 3], [4, 5, 6]] + + If the sum of *sizes* is smaller than the length of *iterable*, then the + remaining items of *iterable* will not be returned. + + >>> list(split_into([1,2,3,4,5,6], [2,3])) + [[1, 2], [3, 4, 5]] + + If the sum of *sizes* is larger than the length of *iterable*, fewer items + will be returned in the iteration that overruns the *iterable* and further + lists will be empty: + + >>> list(split_into([1,2,3,4], [1,2,3,4])) + [[1], [2, 3], [4], []] + + When a ``None`` object is encountered in *sizes*, the returned list will + contain items up to the end of *iterable* the same way that + :func:`itertools.slice` does: + + >>> list(split_into([1,2,3,4,5,6,7,8,9,0], [2,3,None])) + [[1, 2], [3, 4, 5], [6, 7, 8, 9, 0]] + + :func:`split_into` can be useful for grouping a series of items where the + sizes of the groups are not uniform. An example would be where in a row + from a table, multiple columns represent elements of the same feature + (e.g. a point represented by x,y,z) but, the format is not the same for + all columns. + """ + # convert the iterable argument into an iterator so its contents can + # be consumed by islice in case it is a generator + it = iter(iterable) + + for size in sizes: + if size is None: + yield list(it) + return + else: + yield list(islice(it, size)) + + +def padded(iterable, fillvalue=None, n=None, next_multiple=False): + """Yield the elements from *iterable*, followed by *fillvalue*, such that + at least *n* items are emitted. + + >>> list(padded([1, 2, 3], '?', 5)) + [1, 2, 3, '?', '?'] + + If *next_multiple* is ``True``, *fillvalue* will be emitted until the + number of items emitted is a multiple of *n*: + + >>> list(padded([1, 2, 3, 4], n=3, next_multiple=True)) + [1, 2, 3, 4, None, None] + + If *n* is ``None``, *fillvalue* will be emitted indefinitely. + + To create an *iterable* of exactly size *n*, you can truncate with + :func:`islice`. + + >>> list(islice(padded([1, 2, 3], '?'), 5)) + [1, 2, 3, '?', '?'] + >>> list(islice(padded([1, 2, 3, 4, 5, 6, 7, 8], '?'), 5)) + [1, 2, 3, 4, 5] + + """ + iterator = iter(iterable) + iterator_with_repeat = chain(iterator, repeat(fillvalue)) + + if n is None: + return iterator_with_repeat + elif n < 1: + raise ValueError('n must be at least 1') + elif next_multiple: + + def slice_generator(): + for first in iterator: + yield (first,) + yield islice(iterator_with_repeat, n - 1) + + # While elements exist produce slices of size n + return chain.from_iterable(slice_generator()) + else: + # Ensure the first batch is at least size n then iterate + return chain(islice(iterator_with_repeat, n), iterator) + + +def repeat_each(iterable, n=2): + """Repeat each element in *iterable* *n* times. + + >>> list(repeat_each('ABC', 3)) + ['A', 'A', 'A', 'B', 'B', 'B', 'C', 'C', 'C'] + """ + return chain.from_iterable(map(repeat, iterable, repeat(n))) + + +def repeat_last(iterable, default=None): + """After the *iterable* is exhausted, keep yielding its last element. + + >>> list(islice(repeat_last(range(3)), 5)) + [0, 1, 2, 2, 2] + + If the iterable is empty, yield *default* forever:: + + >>> list(islice(repeat_last(range(0), 42), 5)) + [42, 42, 42, 42, 42] + + """ + item = _marker + for item in iterable: + yield item + final = default if item is _marker else item + yield from repeat(final) + + +def distribute(n, iterable): + """Distribute the items from *iterable* among *n* smaller iterables. + + >>> group_1, group_2 = distribute(2, [1, 2, 3, 4, 5, 6]) + >>> list(group_1) + [1, 3, 5] + >>> list(group_2) + [2, 4, 6] + + If the length of *iterable* is not evenly divisible by *n*, then the + length of the returned iterables will not be identical: + + >>> children = distribute(3, [1, 2, 3, 4, 5, 6, 7]) + >>> [list(c) for c in children] + [[1, 4, 7], [2, 5], [3, 6]] + + If the length of *iterable* is smaller than *n*, then the last returned + iterables will be empty: + + >>> children = distribute(5, [1, 2, 3]) + >>> [list(c) for c in children] + [[1], [2], [3], [], []] + + This function uses :func:`itertools.tee` and may require significant + storage. + + If you need the order items in the smaller iterables to match the + original iterable, see :func:`divide`. + + """ + if n < 1: + raise ValueError('n must be at least 1') + + children = tee(iterable, n) + return [islice(it, index, None, n) for index, it in enumerate(children)] + + +def stagger(iterable, offsets=(-1, 0, 1), longest=False, fillvalue=None): + """Yield tuples whose elements are offset from *iterable*. + The amount by which the `i`-th item in each tuple is offset is given by + the `i`-th item in *offsets*. + + >>> list(stagger([0, 1, 2, 3])) + [(None, 0, 1), (0, 1, 2), (1, 2, 3)] + >>> list(stagger(range(8), offsets=(0, 2, 4))) + [(0, 2, 4), (1, 3, 5), (2, 4, 6), (3, 5, 7)] + + By default, the sequence will end when the final element of a tuple is the + last item in the iterable. To continue until the first element of a tuple + is the last item in the iterable, set *longest* to ``True``:: + + >>> list(stagger([0, 1, 2, 3], longest=True)) + [(None, 0, 1), (0, 1, 2), (1, 2, 3), (2, 3, None), (3, None, None)] + + By default, ``None`` will be used to replace offsets beyond the end of the + sequence. Specify *fillvalue* to use some other value. + + """ + children = tee(iterable, len(offsets)) + + return zip_offset( + *children, offsets=offsets, longest=longest, fillvalue=fillvalue + ) + + +def zip_equal(*iterables): + """``zip`` the input *iterables* together but raise + ``UnequalIterablesError`` if they aren't all the same length. + + >>> it_1 = range(3) + >>> it_2 = iter('abc') + >>> list(zip_equal(it_1, it_2)) + [(0, 'a'), (1, 'b'), (2, 'c')] + + >>> it_1 = range(3) + >>> it_2 = iter('abcd') + >>> list(zip_equal(it_1, it_2)) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + more_itertools.more.UnequalIterablesError: Iterables have different + lengths + + """ + if hexversion >= 0x30A00A6: + warnings.warn( + ( + 'zip_equal will be removed in a future version of ' + 'more-itertools. Use the builtin zip function with ' + 'strict=True instead.' + ), + DeprecationWarning, + ) + + return _zip_equal(*iterables) + + +def zip_offset(*iterables, offsets, longest=False, fillvalue=None): + """``zip`` the input *iterables* together, but offset the `i`-th iterable + by the `i`-th item in *offsets*. + + >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1))) + [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e')] + + This can be used as a lightweight alternative to SciPy or pandas to analyze + data sets in which some series have a lead or lag relationship. + + By default, the sequence will end when the shortest iterable is exhausted. + To continue until the longest iterable is exhausted, set *longest* to + ``True``. + + >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1), longest=True)) + [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e'), (None, 'f')] + + By default, ``None`` will be used to replace offsets beyond the end of the + sequence. Specify *fillvalue* to use some other value. + + """ + if len(iterables) != len(offsets): + raise ValueError("Number of iterables and offsets didn't match") + + staggered = [] + for it, n in zip(iterables, offsets): + if n < 0: + staggered.append(chain(repeat(fillvalue, -n), it)) + elif n > 0: + staggered.append(islice(it, n, None)) + else: + staggered.append(it) + + if longest: + return zip_longest(*staggered, fillvalue=fillvalue) + + return zip(*staggered) + + +def sort_together( + iterables, key_list=(0,), key=None, reverse=False, strict=False +): + """Return the input iterables sorted together, with *key_list* as the + priority for sorting. All iterables are trimmed to the length of the + shortest one. + + This can be used like the sorting function in a spreadsheet. If each + iterable represents a column of data, the key list determines which + columns are used for sorting. + + By default, all iterables are sorted using the ``0``-th iterable:: + + >>> iterables = [(4, 3, 2, 1), ('a', 'b', 'c', 'd')] + >>> sort_together(iterables) + [(1, 2, 3, 4), ('d', 'c', 'b', 'a')] + + Set a different key list to sort according to another iterable. + Specifying multiple keys dictates how ties are broken:: + + >>> iterables = [(3, 1, 2), (0, 1, 0), ('c', 'b', 'a')] + >>> sort_together(iterables, key_list=(1, 2)) + [(2, 3, 1), (0, 0, 1), ('a', 'c', 'b')] + + To sort by a function of the elements of the iterable, pass a *key* + function. Its arguments are the elements of the iterables corresponding to + the key list:: + + >>> names = ('a', 'b', 'c') + >>> lengths = (1, 2, 3) + >>> widths = (5, 2, 1) + >>> def area(length, width): + ... return length * width + >>> sort_together([names, lengths, widths], key_list=(1, 2), key=area) + [('c', 'b', 'a'), (3, 2, 1), (1, 2, 5)] + + Set *reverse* to ``True`` to sort in descending order. + + >>> sort_together([(1, 2, 3), ('c', 'b', 'a')], reverse=True) + [(3, 2, 1), ('a', 'b', 'c')] + + If the *strict* keyword argument is ``True``, then + ``UnequalIterablesError`` will be raised if any of the iterables have + different lengths. + + """ + if key is None: + # if there is no key function, the key argument to sorted is an + # itemgetter + key_argument = itemgetter(*key_list) + else: + # if there is a key function, call it with the items at the offsets + # specified by the key function as arguments + key_list = list(key_list) + if len(key_list) == 1: + # if key_list contains a single item, pass the item at that offset + # as the only argument to the key function + key_offset = key_list[0] + key_argument = lambda zipped_items: key(zipped_items[key_offset]) + else: + # if key_list contains multiple items, use itemgetter to return a + # tuple of items, which we pass as *args to the key function + get_key_items = itemgetter(*key_list) + key_argument = lambda zipped_items: key( + *get_key_items(zipped_items) + ) + + zipper = zip_equal if strict else zip + return list( + zipper(*sorted(zipper(*iterables), key=key_argument, reverse=reverse)) + ) + + +def unzip(iterable): + """The inverse of :func:`zip`, this function disaggregates the elements + of the zipped *iterable*. + + The ``i``-th iterable contains the ``i``-th element from each element + of the zipped iterable. The first element is used to determine the + length of the remaining elements. + + >>> iterable = [('a', 1), ('b', 2), ('c', 3), ('d', 4)] + >>> letters, numbers = unzip(iterable) + >>> list(letters) + ['a', 'b', 'c', 'd'] + >>> list(numbers) + [1, 2, 3, 4] + + This is similar to using ``zip(*iterable)``, but it avoids reading + *iterable* into memory. Note, however, that this function uses + :func:`itertools.tee` and thus may require significant storage. + + """ + head, iterable = spy(iterable) + if not head: + # empty iterable, e.g. zip([], [], []) + return () + # spy returns a one-length iterable as head + head = head[0] + iterables = tee(iterable, len(head)) + + # If we have an iterable like iter([(1, 2, 3), (4, 5), (6,)]), + # the second unzipped iterable fails at the third tuple since + # it tries to access (6,)[1]. + # Same with the third unzipped iterable and the second tuple. + # To support these "improperly zipped" iterables, we suppress + # the IndexError, which just stops the unzipped iterables at + # first length mismatch. + return tuple( + iter_suppress(map(itemgetter(i), it), IndexError) + for i, it in enumerate(iterables) + ) + + +def divide(n, iterable): + """Divide the elements from *iterable* into *n* parts, maintaining + order. + + >>> group_1, group_2 = divide(2, [1, 2, 3, 4, 5, 6]) + >>> list(group_1) + [1, 2, 3] + >>> list(group_2) + [4, 5, 6] + + If the length of *iterable* is not evenly divisible by *n*, then the + length of the returned iterables will not be identical: + + >>> children = divide(3, [1, 2, 3, 4, 5, 6, 7]) + >>> [list(c) for c in children] + [[1, 2, 3], [4, 5], [6, 7]] + + If the length of the iterable is smaller than n, then the last returned + iterables will be empty: + + >>> children = divide(5, [1, 2, 3]) + >>> [list(c) for c in children] + [[1], [2], [3], [], []] + + This function will exhaust the iterable before returning. + If order is not important, see :func:`distribute`, which does not first + pull the iterable into memory. + + """ + if n < 1: + raise ValueError('n must be at least 1') + + try: + iterable[:0] + except TypeError: + seq = tuple(iterable) + else: + seq = iterable + + q, r = divmod(len(seq), n) + + ret = [] + stop = 0 + for i in range(1, n + 1): + start = stop + stop += q + 1 if i <= r else q + ret.append(iter(seq[start:stop])) + + return ret + + +def always_iterable(obj, base_type=(str, bytes)): + """If *obj* is iterable, return an iterator over its items:: + + >>> obj = (1, 2, 3) + >>> list(always_iterable(obj)) + [1, 2, 3] + + If *obj* is not iterable, return a one-item iterable containing *obj*:: + + >>> obj = 1 + >>> list(always_iterable(obj)) + [1] + + If *obj* is ``None``, return an empty iterable: + + >>> obj = None + >>> list(always_iterable(None)) + [] + + By default, binary and text strings are not considered iterable:: + + >>> obj = 'foo' + >>> list(always_iterable(obj)) + ['foo'] + + If *base_type* is set, objects for which ``isinstance(obj, base_type)`` + returns ``True`` won't be considered iterable. + + >>> obj = {'a': 1} + >>> list(always_iterable(obj)) # Iterate over the dict's keys + ['a'] + >>> list(always_iterable(obj, base_type=dict)) # Treat dicts as a unit + [{'a': 1}] + + Set *base_type* to ``None`` to avoid any special handling and treat objects + Python considers iterable as iterable: + + >>> obj = 'foo' + >>> list(always_iterable(obj, base_type=None)) + ['f', 'o', 'o'] + """ + if obj is None: + return iter(()) + + if (base_type is not None) and isinstance(obj, base_type): + return iter((obj,)) + + try: + return iter(obj) + except TypeError: + return iter((obj,)) + + +def adjacent(predicate, iterable, distance=1): + """Return an iterable over `(bool, item)` tuples where the `item` is + drawn from *iterable* and the `bool` indicates whether + that item satisfies the *predicate* or is adjacent to an item that does. + + For example, to find whether items are adjacent to a ``3``:: + + >>> list(adjacent(lambda x: x == 3, range(6))) + [(False, 0), (False, 1), (True, 2), (True, 3), (True, 4), (False, 5)] + + Set *distance* to change what counts as adjacent. For example, to find + whether items are two places away from a ``3``: + + >>> list(adjacent(lambda x: x == 3, range(6), distance=2)) + [(False, 0), (True, 1), (True, 2), (True, 3), (True, 4), (True, 5)] + + This is useful for contextualizing the results of a search function. + For example, a code comparison tool might want to identify lines that + have changed, but also surrounding lines to give the viewer of the diff + context. + + The predicate function will only be called once for each item in the + iterable. + + See also :func:`groupby_transform`, which can be used with this function + to group ranges of items with the same `bool` value. + + """ + # Allow distance=0 mainly for testing that it reproduces results with map() + if distance < 0: + raise ValueError('distance must be at least 0') + + i1, i2 = tee(iterable) + padding = [False] * distance + selected = chain(padding, map(predicate, i1), padding) + adjacent_to_selected = map(any, windowed(selected, 2 * distance + 1)) + return zip(adjacent_to_selected, i2) + + +def groupby_transform(iterable, keyfunc=None, valuefunc=None, reducefunc=None): + """An extension of :func:`itertools.groupby` that can apply transformations + to the grouped data. + + * *keyfunc* is a function computing a key value for each item in *iterable* + * *valuefunc* is a function that transforms the individual items from + *iterable* after grouping + * *reducefunc* is a function that transforms each group of items + + >>> iterable = 'aAAbBBcCC' + >>> keyfunc = lambda k: k.upper() + >>> valuefunc = lambda v: v.lower() + >>> reducefunc = lambda g: ''.join(g) + >>> list(groupby_transform(iterable, keyfunc, valuefunc, reducefunc)) + [('A', 'aaa'), ('B', 'bbb'), ('C', 'ccc')] + + Each optional argument defaults to an identity function if not specified. + + :func:`groupby_transform` is useful when grouping elements of an iterable + using a separate iterable as the key. To do this, :func:`zip` the iterables + and pass a *keyfunc* that extracts the first element and a *valuefunc* + that extracts the second element:: + + >>> from operator import itemgetter + >>> keys = [0, 0, 1, 1, 1, 2, 2, 2, 3] + >>> values = 'abcdefghi' + >>> iterable = zip(keys, values) + >>> grouper = groupby_transform(iterable, itemgetter(0), itemgetter(1)) + >>> [(k, ''.join(g)) for k, g in grouper] + [(0, 'ab'), (1, 'cde'), (2, 'fgh'), (3, 'i')] + + Note that the order of items in the iterable is significant. + Only adjacent items are grouped together, so if you don't want any + duplicate groups, you should sort the iterable by the key function. + + """ + ret = groupby(iterable, keyfunc) + if valuefunc: + ret = ((k, map(valuefunc, g)) for k, g in ret) + if reducefunc: + ret = ((k, reducefunc(g)) for k, g in ret) + + return ret + + +class numeric_range(abc.Sequence, abc.Hashable): + """An extension of the built-in ``range()`` function whose arguments can + be any orderable numeric type. + + With only *stop* specified, *start* defaults to ``0`` and *step* + defaults to ``1``. The output items will match the type of *stop*: + + >>> list(numeric_range(3.5)) + [0.0, 1.0, 2.0, 3.0] + + With only *start* and *stop* specified, *step* defaults to ``1``. The + output items will match the type of *start*: + + >>> from decimal import Decimal + >>> start = Decimal('2.1') + >>> stop = Decimal('5.1') + >>> list(numeric_range(start, stop)) + [Decimal('2.1'), Decimal('3.1'), Decimal('4.1')] + + With *start*, *stop*, and *step* specified the output items will match + the type of ``start + step``: + + >>> from fractions import Fraction + >>> start = Fraction(1, 2) # Start at 1/2 + >>> stop = Fraction(5, 2) # End at 5/2 + >>> step = Fraction(1, 2) # Count by 1/2 + >>> list(numeric_range(start, stop, step)) + [Fraction(1, 2), Fraction(1, 1), Fraction(3, 2), Fraction(2, 1)] + + If *step* is zero, ``ValueError`` is raised. Negative steps are supported: + + >>> list(numeric_range(3, -1, -1.0)) + [3.0, 2.0, 1.0, 0.0] + + Be aware of the limitations of floating-point numbers; the representation + of the yielded numbers may be surprising. + + ``datetime.datetime`` objects can be used for *start* and *stop*, if *step* + is a ``datetime.timedelta`` object: + + >>> import datetime + >>> start = datetime.datetime(2019, 1, 1) + >>> stop = datetime.datetime(2019, 1, 3) + >>> step = datetime.timedelta(days=1) + >>> items = iter(numeric_range(start, stop, step)) + >>> next(items) + datetime.datetime(2019, 1, 1, 0, 0) + >>> next(items) + datetime.datetime(2019, 1, 2, 0, 0) + + """ + + _EMPTY_HASH = hash(range(0, 0)) + + def __init__(self, *args): + argc = len(args) + if argc == 1: + (self._stop,) = args + self._start = type(self._stop)(0) + self._step = type(self._stop - self._start)(1) + elif argc == 2: + self._start, self._stop = args + self._step = type(self._stop - self._start)(1) + elif argc == 3: + self._start, self._stop, self._step = args + elif argc == 0: + raise TypeError( + f'numeric_range expected at least 1 argument, got {argc}' + ) + else: + raise TypeError( + f'numeric_range expected at most 3 arguments, got {argc}' + ) + + self._zero = type(self._step)(0) + if self._step == self._zero: + raise ValueError('numeric_range() arg 3 must not be zero') + self._growing = self._step > self._zero + + def __bool__(self): + if self._growing: + return self._start < self._stop + else: + return self._start > self._stop + + def __contains__(self, elem): + if self._growing: + if self._start <= elem < self._stop: + return (elem - self._start) % self._step == self._zero + else: + if self._start >= elem > self._stop: + return (self._start - elem) % (-self._step) == self._zero + + return False + + def __eq__(self, other): + if isinstance(other, numeric_range): + empty_self = not bool(self) + empty_other = not bool(other) + if empty_self or empty_other: + return empty_self and empty_other # True if both empty + else: + return ( + self._start == other._start + and self._step == other._step + and self._get_by_index(-1) == other._get_by_index(-1) + ) + else: + return False + + def __getitem__(self, key): + if isinstance(key, int): + return self._get_by_index(key) + elif isinstance(key, slice): + step = self._step if key.step is None else key.step * self._step + + if key.start is None or key.start <= -self._len: + start = self._start + elif key.start >= self._len: + start = self._stop + else: # -self._len < key.start < self._len + start = self._get_by_index(key.start) + + if key.stop is None or key.stop >= self._len: + stop = self._stop + elif key.stop <= -self._len: + stop = self._start + else: # -self._len < key.stop < self._len + stop = self._get_by_index(key.stop) + + return numeric_range(start, stop, step) + else: + raise TypeError( + 'numeric range indices must be ' + f'integers or slices, not {type(key).__name__}' + ) + + def __hash__(self): + if self: + return hash((self._start, self._get_by_index(-1), self._step)) + else: + return self._EMPTY_HASH + + def __iter__(self): + values = (self._start + (n * self._step) for n in count()) + if self._growing: + return takewhile(partial(gt, self._stop), values) + else: + return takewhile(partial(lt, self._stop), values) + + def __len__(self): + return self._len + + @cached_property + def _len(self): + if self._growing: + start = self._start + stop = self._stop + step = self._step + else: + start = self._stop + stop = self._start + step = -self._step + distance = stop - start + if distance <= self._zero: + return 0 + else: # distance > 0 and step > 0: regular euclidean division + q, r = divmod(distance, step) + return int(q) + int(r != self._zero) + + def __reduce__(self): + return numeric_range, (self._start, self._stop, self._step) + + def __repr__(self): + if self._step == 1: + return f"numeric_range({self._start!r}, {self._stop!r})" + return ( + f"numeric_range({self._start!r}, {self._stop!r}, {self._step!r})" + ) + + def __reversed__(self): + return iter( + numeric_range( + self._get_by_index(-1), self._start - self._step, -self._step + ) + ) + + def count(self, value): + return int(value in self) + + def index(self, value): + if self._growing: + if self._start <= value < self._stop: + q, r = divmod(value - self._start, self._step) + if r == self._zero: + return int(q) + else: + if self._start >= value > self._stop: + q, r = divmod(self._start - value, -self._step) + if r == self._zero: + return int(q) + + raise ValueError(f"{value} is not in numeric range") + + def _get_by_index(self, i): + if i < 0: + i += self._len + if i < 0 or i >= self._len: + raise IndexError("numeric range object index out of range") + return self._start + i * self._step + + +def count_cycle(iterable, n=None): + """Cycle through the items from *iterable* up to *n* times, yielding + the number of completed cycles along with each item. If *n* is omitted the + process repeats indefinitely. + + >>> list(count_cycle('AB', 3)) + [(0, 'A'), (0, 'B'), (1, 'A'), (1, 'B'), (2, 'A'), (2, 'B')] + + """ + seq = tuple(iterable) + if not seq: + return iter(()) + counter = count() if n is None else range(n) + return zip(repeat_each(counter, len(seq)), cycle(seq)) + + +def mark_ends(iterable): + """Yield 3-tuples of the form ``(is_first, is_last, item)``. + + >>> list(mark_ends('ABC')) + [(True, False, 'A'), (False, False, 'B'), (False, True, 'C')] + + Use this when looping over an iterable to take special action on its first + and/or last items: + + >>> iterable = ['Header', 100, 200, 'Footer'] + >>> total = 0 + >>> for is_first, is_last, item in mark_ends(iterable): + ... if is_first: + ... continue # Skip the header + ... if is_last: + ... continue # Skip the footer + ... total += item + >>> print(total) + 300 + """ + it = iter(iterable) + for a in it: + first = True + for b in it: + yield first, False, a + a = b + first = False + yield first, True, a + + +def locate(iterable, pred=bool, window_size=None): + """Yield the index of each item in *iterable* for which *pred* returns + ``True``. + + *pred* defaults to :func:`bool`, which will select truthy items: + + >>> list(locate([0, 1, 1, 0, 1, 0, 0])) + [1, 2, 4] + + Set *pred* to a custom function to, e.g., find the indexes for a particular + item. + + >>> list(locate(['a', 'b', 'c', 'b'], lambda x: x == 'b')) + [1, 3] + + If *window_size* is given, then the *pred* function will be called with + that many items. This enables searching for sub-sequences: + + >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3] + >>> pred = lambda *args: args == (1, 2, 3) + >>> list(locate(iterable, pred=pred, window_size=3)) + [1, 5, 9] + + Use with :func:`seekable` to find indexes and then retrieve the associated + items: + + >>> from itertools import count + >>> from more_itertools import seekable + >>> source = (3 * n + 1 if (n % 2) else n // 2 for n in count()) + >>> it = seekable(source) + >>> pred = lambda x: x > 100 + >>> indexes = locate(it, pred=pred) + >>> i = next(indexes) + >>> it.seek(i) + >>> next(it) + 106 + + """ + if window_size is None: + return compress(count(), map(pred, iterable)) + + if window_size < 1: + raise ValueError('window size must be at least 1') + + it = windowed(iterable, window_size, fillvalue=_marker) + return compress(count(), starmap(pred, it)) + + +def longest_common_prefix(iterables): + """Yield elements of the longest common prefix among given *iterables*. + + >>> ''.join(longest_common_prefix(['abcd', 'abc', 'abf'])) + 'ab' + + """ + return (c[0] for c in takewhile(all_equal, zip(*iterables))) + + +def lstrip(iterable, pred): + """Yield the items from *iterable*, but strip any from the beginning + for which *pred* returns ``True``. + + For example, to remove a set of items from the start of an iterable: + + >>> iterable = (None, False, None, 1, 2, None, 3, False, None) + >>> pred = lambda x: x in {None, False, ''} + >>> list(lstrip(iterable, pred)) + [1, 2, None, 3, False, None] + + This function is analogous to to :func:`str.lstrip`, and is essentially + an wrapper for :func:`itertools.dropwhile`. + + """ + return dropwhile(pred, iterable) + + +def rstrip(iterable, pred): + """Yield the items from *iterable*, but strip any from the end + for which *pred* returns ``True``. + + For example, to remove a set of items from the end of an iterable: + + >>> iterable = (None, False, None, 1, 2, None, 3, False, None) + >>> pred = lambda x: x in {None, False, ''} + >>> list(rstrip(iterable, pred)) + [None, False, None, 1, 2, None, 3] + + This function is analogous to :func:`str.rstrip`. + + """ + cache = [] + cache_append = cache.append + cache_clear = cache.clear + for x in iterable: + if pred(x): + cache_append(x) + else: + yield from cache + cache_clear() + yield x + + +def strip(iterable, pred): + """Yield the items from *iterable*, but strip any from the + beginning and end for which *pred* returns ``True``. + + For example, to remove a set of items from both ends of an iterable: + + >>> iterable = (None, False, None, 1, 2, None, 3, False, None) + >>> pred = lambda x: x in {None, False, ''} + >>> list(strip(iterable, pred)) + [1, 2, None, 3] + + This function is analogous to :func:`str.strip`. + + """ + return rstrip(lstrip(iterable, pred), pred) + + +class islice_extended: + """An extension of :func:`itertools.islice` that supports negative values + for *stop*, *start*, and *step*. + + >>> iterator = iter('abcdefgh') + >>> list(islice_extended(iterator, -4, -1)) + ['e', 'f', 'g'] + + Slices with negative values require some caching of *iterable*, but this + function takes care to minimize the amount of memory required. + + For example, you can use a negative step with an infinite iterator: + + >>> from itertools import count + >>> list(islice_extended(count(), 110, 99, -2)) + [110, 108, 106, 104, 102, 100] + + You can also use slice notation directly: + + >>> iterator = map(str, count()) + >>> it = islice_extended(iterator)[10:20:2] + >>> list(it) + ['10', '12', '14', '16', '18'] + + """ + + def __init__(self, iterable, *args): + it = iter(iterable) + if args: + self._iterator = _islice_helper(it, slice(*args)) + else: + self._iterator = it + + def __iter__(self): + return self + + def __next__(self): + return next(self._iterator) + + def __getitem__(self, key): + if isinstance(key, slice): + return islice_extended(_islice_helper(self._iterator, key)) + + raise TypeError('islice_extended.__getitem__ argument must be a slice') + + +def _islice_helper(it, s): + start = s.start + stop = s.stop + if s.step == 0: + raise ValueError('step argument must be a non-zero integer or None.') + step = s.step or 1 + + if step > 0: + start = 0 if (start is None) else start + + if start < 0: + # Consume all but the last -start items + cache = deque(enumerate(it, 1), maxlen=-start) + len_iter = cache[-1][0] if cache else 0 + + # Adjust start to be positive + i = max(len_iter + start, 0) + + # Adjust stop to be positive + if stop is None: + j = len_iter + elif stop >= 0: + j = min(stop, len_iter) + else: + j = max(len_iter + stop, 0) + + # Slice the cache + n = j - i + if n <= 0: + return + + for index in range(n): + if index % step == 0: + # pop and yield the item. + # We don't want to use an intermediate variable + # it would extend the lifetime of the current item + yield cache.popleft()[1] + else: + # just pop and discard the item + cache.popleft() + elif (stop is not None) and (stop < 0): + # Advance to the start position + next(islice(it, start, start), None) + + # When stop is negative, we have to carry -stop items while + # iterating + cache = deque(islice(it, -stop), maxlen=-stop) + + for index, item in enumerate(it): + if index % step == 0: + # pop and yield the item. + # We don't want to use an intermediate variable + # it would extend the lifetime of the current item + yield cache.popleft() + else: + # just pop and discard the item + cache.popleft() + cache.append(item) + else: + # When both start and stop are positive we have the normal case + yield from islice(it, start, stop, step) + else: + start = -1 if (start is None) else start + + if (stop is not None) and (stop < 0): + # Consume all but the last items + n = -stop - 1 + cache = deque(enumerate(it, 1), maxlen=n) + len_iter = cache[-1][0] if cache else 0 + + # If start and stop are both negative they are comparable and + # we can just slice. Otherwise we can adjust start to be negative + # and then slice. + if start < 0: + i, j = start, stop + else: + i, j = min(start - len_iter, -1), None + + for index, item in list(cache)[i:j:step]: + yield item + else: + # Advance to the stop position + if stop is not None: + m = stop + 1 + next(islice(it, m, m), None) + + # stop is positive, so if start is negative they are not comparable + # and we need the rest of the items. + if start < 0: + i = start + n = None + # stop is None and start is positive, so we just need items up to + # the start index. + elif stop is None: + i = None + n = start + 1 + # Both stop and start are positive, so they are comparable. + else: + i = None + n = start - stop + if n <= 0: + return + + cache = list(islice(it, n)) + + yield from cache[i::step] + + +def always_reversible(iterable): + """An extension of :func:`reversed` that supports all iterables, not + just those which implement the ``Reversible`` or ``Sequence`` protocols. + + >>> print(*always_reversible(x for x in range(3))) + 2 1 0 + + If the iterable is already reversible, this function returns the + result of :func:`reversed()`. If the iterable is not reversible, + this function will cache the remaining items in the iterable and + yield them in reverse order, which may require significant storage. + """ + try: + return reversed(iterable) + except TypeError: + return reversed(list(iterable)) + + +def consecutive_groups(iterable, ordering=None): + """Yield groups of consecutive items using :func:`itertools.groupby`. + The *ordering* function determines whether two items are adjacent by + returning their position. + + By default, the ordering function is the identity function. This is + suitable for finding runs of numbers: + + >>> iterable = [1, 10, 11, 12, 20, 30, 31, 32, 33, 40] + >>> for group in consecutive_groups(iterable): + ... print(list(group)) + [1] + [10, 11, 12] + [20] + [30, 31, 32, 33] + [40] + + To find runs of adjacent letters, apply :func:`ord` function + to convert letters to ordinals. + + >>> iterable = 'abcdfgilmnop' + >>> ordering = ord + >>> for group in consecutive_groups(iterable, ordering): + ... print(list(group)) + ['a', 'b', 'c', 'd'] + ['f', 'g'] + ['i'] + ['l', 'm', 'n', 'o', 'p'] + + Each group of consecutive items is an iterator that shares it source with + *iterable*. When an an output group is advanced, the previous group is + no longer available unless its elements are copied (e.g., into a ``list``). + + >>> iterable = [1, 2, 11, 12, 21, 22] + >>> saved_groups = [] + >>> for group in consecutive_groups(iterable): + ... saved_groups.append(list(group)) # Copy group elements + >>> saved_groups + [[1, 2], [11, 12], [21, 22]] + + """ + if ordering is None: + key = lambda x: x[0] - x[1] + else: + key = lambda x: x[0] - ordering(x[1]) + + for k, g in groupby(enumerate(iterable), key=key): + yield map(itemgetter(1), g) + + +def difference(iterable, func=sub, *, initial=None): + """This function is the inverse of :func:`itertools.accumulate`. By default + it will compute the first difference of *iterable* using + :func:`operator.sub`: + + >>> from itertools import accumulate + >>> iterable = accumulate([0, 1, 2, 3, 4]) # produces 0, 1, 3, 6, 10 + >>> list(difference(iterable)) + [0, 1, 2, 3, 4] + + *func* defaults to :func:`operator.sub`, but other functions can be + specified. They will be applied as follows:: + + A, B, C, D, ... --> A, func(B, A), func(C, B), func(D, C), ... + + For example, to do progressive division: + + >>> iterable = [1, 2, 6, 24, 120] + >>> func = lambda x, y: x // y + >>> list(difference(iterable, func)) + [1, 2, 3, 4, 5] + + If the *initial* keyword is set, the first element will be skipped when + computing successive differences. + + >>> it = [10, 11, 13, 16] # from accumulate([1, 2, 3], initial=10) + >>> list(difference(it, initial=10)) + [1, 2, 3] + + """ + a, b = tee(iterable) + try: + first = [next(b)] + except StopIteration: + return iter([]) + + if initial is not None: + first = [] + + return chain(first, map(func, b, a)) + + +class SequenceView(Sequence): + """Return a read-only view of the sequence object *target*. + + :class:`SequenceView` objects are analogous to Python's built-in + "dictionary view" types. They provide a dynamic view of a sequence's items, + meaning that when the sequence updates, so does the view. + + >>> seq = ['0', '1', '2'] + >>> view = SequenceView(seq) + >>> view + SequenceView(['0', '1', '2']) + >>> seq.append('3') + >>> view + SequenceView(['0', '1', '2', '3']) + + Sequence views support indexing, slicing, and length queries. They act + like the underlying sequence, except they don't allow assignment: + + >>> view[1] + '1' + >>> view[1:-1] + ['1', '2'] + >>> len(view) + 4 + + Sequence views are useful as an alternative to copying, as they don't + require (much) extra storage. + + """ + + def __init__(self, target): + if not isinstance(target, Sequence): + raise TypeError + self._target = target + + def __getitem__(self, index): + return self._target[index] + + def __len__(self): + return len(self._target) + + def __repr__(self): + return f'{self.__class__.__name__}({self._target!r})' + + +class seekable: + """Wrap an iterator to allow for seeking backward and forward. This + progressively caches the items in the source iterable so they can be + re-visited. + + Call :meth:`seek` with an index to seek to that position in the source + iterable. + + To "reset" an iterator, seek to ``0``: + + >>> from itertools import count + >>> it = seekable((str(n) for n in count())) + >>> next(it), next(it), next(it) + ('0', '1', '2') + >>> it.seek(0) + >>> next(it), next(it), next(it) + ('0', '1', '2') + + You can also seek forward: + + >>> it = seekable((str(n) for n in range(20))) + >>> it.seek(10) + >>> next(it) + '10' + >>> it.seek(20) # Seeking past the end of the source isn't a problem + >>> list(it) + [] + >>> it.seek(0) # Resetting works even after hitting the end + >>> next(it) + '0' + + Call :meth:`relative_seek` to seek relative to the source iterator's + current position. + + >>> it = seekable((str(n) for n in range(20))) + >>> next(it), next(it), next(it) + ('0', '1', '2') + >>> it.relative_seek(2) + >>> next(it) + '5' + >>> it.relative_seek(-3) # Source is at '6', we move back to '3' + >>> next(it) + '3' + >>> it.relative_seek(-3) # Source is at '4', we move back to '1' + >>> next(it) + '1' + + + Call :meth:`peek` to look ahead one item without advancing the iterator: + + >>> it = seekable('1234') + >>> it.peek() + '1' + >>> list(it) + ['1', '2', '3', '4'] + >>> it.peek(default='empty') + 'empty' + + Before the iterator is at its end, calling :func:`bool` on it will return + ``True``. After it will return ``False``: + + >>> it = seekable('5678') + >>> bool(it) + True + >>> list(it) + ['5', '6', '7', '8'] + >>> bool(it) + False + + You may view the contents of the cache with the :meth:`elements` method. + That returns a :class:`SequenceView`, a view that updates automatically: + + >>> it = seekable((str(n) for n in range(10))) + >>> next(it), next(it), next(it) + ('0', '1', '2') + >>> elements = it.elements() + >>> elements + SequenceView(['0', '1', '2']) + >>> next(it) + '3' + >>> elements + SequenceView(['0', '1', '2', '3']) + + By default, the cache grows as the source iterable progresses, so beware of + wrapping very large or infinite iterables. Supply *maxlen* to limit the + size of the cache (this of course limits how far back you can seek). + + >>> from itertools import count + >>> it = seekable((str(n) for n in count()), maxlen=2) + >>> next(it), next(it), next(it), next(it) + ('0', '1', '2', '3') + >>> list(it.elements()) + ['2', '3'] + >>> it.seek(0) + >>> next(it), next(it), next(it), next(it) + ('2', '3', '4', '5') + >>> next(it) + '6' + + """ + + def __init__(self, iterable, maxlen=None): + self._source = iter(iterable) + if maxlen is None: + self._cache = [] + else: + self._cache = deque([], maxlen) + self._index = None + + def __iter__(self): + return self + + def __next__(self): + if self._index is not None: + try: + item = self._cache[self._index] + except IndexError: + self._index = None + else: + self._index += 1 + return item + + item = next(self._source) + self._cache.append(item) + return item + + def __bool__(self): + try: + self.peek() + except StopIteration: + return False + return True + + def peek(self, default=_marker): + try: + peeked = next(self) + except StopIteration: + if default is _marker: + raise + return default + if self._index is None: + self._index = len(self._cache) + self._index -= 1 + return peeked + + def elements(self): + return SequenceView(self._cache) + + def seek(self, index): + self._index = index + remainder = index - len(self._cache) + if remainder > 0: + consume(self, remainder) + + def relative_seek(self, count): + if self._index is None: + self._index = len(self._cache) + + self.seek(max(self._index + count, 0)) + + +class run_length: + """ + :func:`run_length.encode` compresses an iterable with run-length encoding. + It yields groups of repeated items with the count of how many times they + were repeated: + + >>> uncompressed = 'abbcccdddd' + >>> list(run_length.encode(uncompressed)) + [('a', 1), ('b', 2), ('c', 3), ('d', 4)] + + :func:`run_length.decode` decompresses an iterable that was previously + compressed with run-length encoding. It yields the items of the + decompressed iterable: + + >>> compressed = [('a', 1), ('b', 2), ('c', 3), ('d', 4)] + >>> list(run_length.decode(compressed)) + ['a', 'b', 'b', 'c', 'c', 'c', 'd', 'd', 'd', 'd'] + + """ + + @staticmethod + def encode(iterable): + return ((k, ilen(g)) for k, g in groupby(iterable)) + + @staticmethod + def decode(iterable): + return chain.from_iterable(starmap(repeat, iterable)) + + +def exactly_n(iterable, n, predicate=bool): + """Return ``True`` if exactly ``n`` items in the iterable are ``True`` + according to the *predicate* function. + + >>> exactly_n([True, True, False], 2) + True + >>> exactly_n([True, True, False], 1) + False + >>> exactly_n([0, 1, 2, 3, 4, 5], 3, lambda x: x < 3) + True + + The iterable will be advanced until ``n + 1`` truthy items are encountered, + so avoid calling it on infinite iterables. + + """ + return ilen(islice(filter(predicate, iterable), n + 1)) == n + + +def circular_shifts(iterable, steps=1): + """Yield the circular shifts of *iterable*. + + >>> list(circular_shifts(range(4))) + [(0, 1, 2, 3), (1, 2, 3, 0), (2, 3, 0, 1), (3, 0, 1, 2)] + + Set *steps* to the number of places to rotate to the left + (or to the right if negative). Defaults to 1. + + >>> list(circular_shifts(range(4), 2)) + [(0, 1, 2, 3), (2, 3, 0, 1)] + + >>> list(circular_shifts(range(4), -1)) + [(0, 1, 2, 3), (3, 0, 1, 2), (2, 3, 0, 1), (1, 2, 3, 0)] + + """ + buffer = deque(iterable) + if steps == 0: + raise ValueError('Steps should be a non-zero integer') + + buffer.rotate(steps) + steps = -steps + n = len(buffer) + n //= math.gcd(n, steps) + + for _ in repeat(None, n): + buffer.rotate(steps) + yield tuple(buffer) + + +def make_decorator(wrapping_func, result_index=0): + """Return a decorator version of *wrapping_func*, which is a function that + modifies an iterable. *result_index* is the position in that function's + signature where the iterable goes. + + This lets you use itertools on the "production end," i.e. at function + definition. This can augment what the function returns without changing the + function's code. + + For example, to produce a decorator version of :func:`chunked`: + + >>> from more_itertools import chunked + >>> chunker = make_decorator(chunked, result_index=0) + >>> @chunker(3) + ... def iter_range(n): + ... return iter(range(n)) + ... + >>> list(iter_range(9)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8]] + + To only allow truthy items to be returned: + + >>> truth_serum = make_decorator(filter, result_index=1) + >>> @truth_serum(bool) + ... def boolean_test(): + ... return [0, 1, '', ' ', False, True] + ... + >>> list(boolean_test()) + [1, ' ', True] + + The :func:`peekable` and :func:`seekable` wrappers make for practical + decorators: + + >>> from more_itertools import peekable + >>> peekable_function = make_decorator(peekable) + >>> @peekable_function() + ... def str_range(*args): + ... return (str(x) for x in range(*args)) + ... + >>> it = str_range(1, 20, 2) + >>> next(it), next(it), next(it) + ('1', '3', '5') + >>> it.peek() + '7' + >>> next(it) + '7' + + """ + + # See https://sites.google.com/site/bbayles/index/decorator_factory for + # notes on how this works. + def decorator(*wrapping_args, **wrapping_kwargs): + def outer_wrapper(f): + def inner_wrapper(*args, **kwargs): + result = f(*args, **kwargs) + wrapping_args_ = list(wrapping_args) + wrapping_args_.insert(result_index, result) + return wrapping_func(*wrapping_args_, **wrapping_kwargs) + + return inner_wrapper + + return outer_wrapper + + return decorator + + +def map_reduce(iterable, keyfunc, valuefunc=None, reducefunc=None): + """Return a dictionary that maps the items in *iterable* to categories + defined by *keyfunc*, transforms them with *valuefunc*, and + then summarizes them by category with *reducefunc*. + + *valuefunc* defaults to the identity function if it is unspecified. + If *reducefunc* is unspecified, no summarization takes place: + + >>> keyfunc = lambda x: x.upper() + >>> result = map_reduce('abbccc', keyfunc) + >>> sorted(result.items()) + [('A', ['a']), ('B', ['b', 'b']), ('C', ['c', 'c', 'c'])] + + Specifying *valuefunc* transforms the categorized items: + + >>> keyfunc = lambda x: x.upper() + >>> valuefunc = lambda x: 1 + >>> result = map_reduce('abbccc', keyfunc, valuefunc) + >>> sorted(result.items()) + [('A', [1]), ('B', [1, 1]), ('C', [1, 1, 1])] + + Specifying *reducefunc* summarizes the categorized items: + + >>> keyfunc = lambda x: x.upper() + >>> valuefunc = lambda x: 1 + >>> reducefunc = sum + >>> result = map_reduce('abbccc', keyfunc, valuefunc, reducefunc) + >>> sorted(result.items()) + [('A', 1), ('B', 2), ('C', 3)] + + You may want to filter the input iterable before applying the map/reduce + procedure: + + >>> all_items = range(30) + >>> items = [x for x in all_items if 10 <= x <= 20] # Filter + >>> keyfunc = lambda x: x % 2 # Evens map to 0; odds to 1 + >>> categories = map_reduce(items, keyfunc=keyfunc) + >>> sorted(categories.items()) + [(0, [10, 12, 14, 16, 18, 20]), (1, [11, 13, 15, 17, 19])] + >>> summaries = map_reduce(items, keyfunc=keyfunc, reducefunc=sum) + >>> sorted(summaries.items()) + [(0, 90), (1, 75)] + + Note that all items in the iterable are gathered into a list before the + summarization step, which may require significant storage. + + The returned object is a :obj:`collections.defaultdict` with the + ``default_factory`` set to ``None``, such that it behaves like a normal + dictionary. + + """ + + ret = defaultdict(list) + + if valuefunc is None: + for item in iterable: + key = keyfunc(item) + ret[key].append(item) + + else: + for item in iterable: + key = keyfunc(item) + value = valuefunc(item) + ret[key].append(value) + + if reducefunc is not None: + for key, value_list in ret.items(): + ret[key] = reducefunc(value_list) + + ret.default_factory = None + return ret + + +def rlocate(iterable, pred=bool, window_size=None): + """Yield the index of each item in *iterable* for which *pred* returns + ``True``, starting from the right and moving left. + + *pred* defaults to :func:`bool`, which will select truthy items: + + >>> list(rlocate([0, 1, 1, 0, 1, 0, 0])) # Truthy at 1, 2, and 4 + [4, 2, 1] + + Set *pred* to a custom function to, e.g., find the indexes for a particular + item: + + >>> iterator = iter('abcb') + >>> pred = lambda x: x == 'b' + >>> list(rlocate(iterator, pred)) + [3, 1] + + If *window_size* is given, then the *pred* function will be called with + that many items. This enables searching for sub-sequences: + + >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3] + >>> pred = lambda *args: args == (1, 2, 3) + >>> list(rlocate(iterable, pred=pred, window_size=3)) + [9, 5, 1] + + Beware, this function won't return anything for infinite iterables. + If *iterable* is reversible, ``rlocate`` will reverse it and search from + the right. Otherwise, it will search from the left and return the results + in reverse order. + + See :func:`locate` to for other example applications. + + """ + if window_size is None: + try: + len_iter = len(iterable) + return (len_iter - i - 1 for i in locate(reversed(iterable), pred)) + except TypeError: + pass + + return reversed(list(locate(iterable, pred, window_size))) + + +def replace(iterable, pred, substitutes, count=None, window_size=1): + """Yield the items from *iterable*, replacing the items for which *pred* + returns ``True`` with the items from the iterable *substitutes*. + + >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1] + >>> pred = lambda x: x == 0 + >>> substitutes = (2, 3) + >>> list(replace(iterable, pred, substitutes)) + [1, 1, 2, 3, 1, 1, 2, 3, 1, 1] + + If *count* is given, the number of replacements will be limited: + + >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1, 0] + >>> pred = lambda x: x == 0 + >>> substitutes = [None] + >>> list(replace(iterable, pred, substitutes, count=2)) + [1, 1, None, 1, 1, None, 1, 1, 0] + + Use *window_size* to control the number of items passed as arguments to + *pred*. This allows for locating and replacing subsequences. + + >>> iterable = [0, 1, 2, 5, 0, 1, 2, 5] + >>> window_size = 3 + >>> pred = lambda *args: args == (0, 1, 2) # 3 items passed to pred + >>> substitutes = [3, 4] # Splice in these items + >>> list(replace(iterable, pred, substitutes, window_size=window_size)) + [3, 4, 5, 3, 4, 5] + + """ + if window_size < 1: + raise ValueError('window_size must be at least 1') + + # Save the substitutes iterable, since it's used more than once + substitutes = tuple(substitutes) + + # Add padding such that the number of windows matches the length of the + # iterable + it = chain(iterable, repeat(_marker, window_size - 1)) + windows = windowed(it, window_size) + + n = 0 + for w in windows: + # If the current window matches our predicate (and we haven't hit + # our maximum number of replacements), splice in the substitutes + # and then consume the following windows that overlap with this one. + # For example, if the iterable is (0, 1, 2, 3, 4...) + # and the window size is 2, we have (0, 1), (1, 2), (2, 3)... + # If the predicate matches on (0, 1), we need to zap (0, 1) and (1, 2) + if pred(*w): + if (count is None) or (n < count): + n += 1 + yield from substitutes + consume(windows, window_size - 1) + continue + + # If there was no match (or we've reached the replacement limit), + # yield the first item from the window. + if w and (w[0] is not _marker): + yield w[0] + + +def partitions(iterable): + """Yield all possible order-preserving partitions of *iterable*. + + >>> iterable = 'abc' + >>> for part in partitions(iterable): + ... print([''.join(p) for p in part]) + ['abc'] + ['a', 'bc'] + ['ab', 'c'] + ['a', 'b', 'c'] + + This is unrelated to :func:`partition`. + + """ + sequence = list(iterable) + n = len(sequence) + for i in powerset(range(1, n)): + yield [sequence[i:j] for i, j in zip((0,) + i, i + (n,))] + + +def set_partitions(iterable, k=None, min_size=None, max_size=None): + """ + Yield the set partitions of *iterable* into *k* parts. Set partitions are + not order-preserving. + + >>> iterable = 'abc' + >>> for part in set_partitions(iterable, 2): + ... print([''.join(p) for p in part]) + ['a', 'bc'] + ['ab', 'c'] + ['b', 'ac'] + + + If *k* is not given, every set partition is generated. + + >>> iterable = 'abc' + >>> for part in set_partitions(iterable): + ... print([''.join(p) for p in part]) + ['abc'] + ['a', 'bc'] + ['ab', 'c'] + ['b', 'ac'] + ['a', 'b', 'c'] + + if *min_size* and/or *max_size* are given, the minimum and/or maximum size + per block in partition is set. + + >>> iterable = 'abc' + >>> for part in set_partitions(iterable, min_size=2): + ... print([''.join(p) for p in part]) + ['abc'] + >>> for part in set_partitions(iterable, max_size=2): + ... print([''.join(p) for p in part]) + ['a', 'bc'] + ['ab', 'c'] + ['b', 'ac'] + ['a', 'b', 'c'] + + """ + L = list(iterable) + n = len(L) + if k is not None: + if k < 1: + raise ValueError( + "Can't partition in a negative or zero number of groups" + ) + elif k > n: + return + + min_size = min_size if min_size is not None else 0 + max_size = max_size if max_size is not None else n + if min_size > max_size: + return + + def set_partitions_helper(L, k): + n = len(L) + if k == 1: + yield [L] + elif n == k: + yield [[s] for s in L] + else: + e, *M = L + for p in set_partitions_helper(M, k - 1): + yield [[e], *p] + for p in set_partitions_helper(M, k): + for i in range(len(p)): + yield p[:i] + [[e] + p[i]] + p[i + 1 :] + + if k is None: + for k in range(1, n + 1): + yield from filter( + lambda z: all(min_size <= len(bk) <= max_size for bk in z), + set_partitions_helper(L, k), + ) + else: + yield from filter( + lambda z: all(min_size <= len(bk) <= max_size for bk in z), + set_partitions_helper(L, k), + ) + + +class time_limited: + """ + Yield items from *iterable* until *limit_seconds* have passed. + If the time limit expires before all items have been yielded, the + ``timed_out`` parameter will be set to ``True``. + + >>> from time import sleep + >>> def generator(): + ... yield 1 + ... yield 2 + ... sleep(0.2) + ... yield 3 + >>> iterable = time_limited(0.1, generator()) + >>> list(iterable) + [1, 2] + >>> iterable.timed_out + True + + Note that the time is checked before each item is yielded, and iteration + stops if the time elapsed is greater than *limit_seconds*. If your time + limit is 1 second, but it takes 2 seconds to generate the first item from + the iterable, the function will run for 2 seconds and not yield anything. + As a special case, when *limit_seconds* is zero, the iterator never + returns anything. + + """ + + def __init__(self, limit_seconds, iterable): + if limit_seconds < 0: + raise ValueError('limit_seconds must be positive') + self.limit_seconds = limit_seconds + self._iterator = iter(iterable) + self._start_time = monotonic() + self.timed_out = False + + def __iter__(self): + return self + + def __next__(self): + if self.limit_seconds == 0: + self.timed_out = True + raise StopIteration + item = next(self._iterator) + if monotonic() - self._start_time > self.limit_seconds: + self.timed_out = True + raise StopIteration + + return item + + +def only(iterable, default=None, too_long=None): + """If *iterable* has only one item, return it. + If it has zero items, return *default*. + If it has more than one item, raise the exception given by *too_long*, + which is ``ValueError`` by default. + + >>> only([], default='missing') + 'missing' + >>> only([1]) + 1 + >>> only([1, 2]) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError: Expected exactly one item in iterable, but got 1, 2, + and perhaps more.' + >>> only([1, 2], too_long=TypeError) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + TypeError + + Note that :func:`only` attempts to advance *iterable* twice to ensure there + is only one item. See :func:`spy` or :func:`peekable` to check + iterable contents less destructively. + + """ + iterator = iter(iterable) + for first in iterator: + for second in iterator: + msg = ( + f'Expected exactly one item in iterable, but got {first!r}, ' + f'{second!r}, and perhaps more.' + ) + raise too_long or ValueError(msg) + return first + return default + + +def _ichunk(iterator, n): + cache = deque() + chunk = islice(iterator, n) + + def generator(): + with suppress(StopIteration): + while True: + if cache: + yield cache.popleft() + else: + yield next(chunk) + + def materialize_next(n=1): + # if n not specified materialize everything + if n is None: + cache.extend(chunk) + return len(cache) + + to_cache = n - len(cache) + + # materialize up to n + if to_cache > 0: + cache.extend(islice(chunk, to_cache)) + + # return number materialized up to n + return min(n, len(cache)) + + return (generator(), materialize_next) + + +def ichunked(iterable, n): + """Break *iterable* into sub-iterables with *n* elements each. + :func:`ichunked` is like :func:`chunked`, but it yields iterables + instead of lists. + + If the sub-iterables are read in order, the elements of *iterable* + won't be stored in memory. + If they are read out of order, :func:`itertools.tee` is used to cache + elements as necessary. + + >>> from itertools import count + >>> all_chunks = ichunked(count(), 4) + >>> c_1, c_2, c_3 = next(all_chunks), next(all_chunks), next(all_chunks) + >>> list(c_2) # c_1's elements have been cached; c_3's haven't been + [4, 5, 6, 7] + >>> list(c_1) + [0, 1, 2, 3] + >>> list(c_3) + [8, 9, 10, 11] + + """ + iterator = iter(iterable) + while True: + # Create new chunk + chunk, materialize_next = _ichunk(iterator, n) + + # Check to see whether we're at the end of the source iterable + if not materialize_next(): + return + + yield chunk + + # Fill previous chunk's cache + materialize_next(None) + + +def iequals(*iterables): + """Return ``True`` if all given *iterables* are equal to each other, + which means that they contain the same elements in the same order. + + The function is useful for comparing iterables of different data types + or iterables that do not support equality checks. + + >>> iequals("abc", ['a', 'b', 'c'], ('a', 'b', 'c'), iter("abc")) + True + + >>> iequals("abc", "acb") + False + + Not to be confused with :func:`all_equal`, which checks whether all + elements of iterable are equal to each other. + + """ + return all(map(all_equal, zip_longest(*iterables, fillvalue=object()))) + + +def distinct_combinations(iterable, r): + """Yield the distinct combinations of *r* items taken from *iterable*. + + >>> list(distinct_combinations([0, 0, 1], 2)) + [(0, 0), (0, 1)] + + Equivalent to ``set(combinations(iterable))``, except duplicates are not + generated and thrown away. For larger input sequences this is much more + efficient. + + """ + if r < 0: + raise ValueError('r must be non-negative') + elif r == 0: + yield () + return + pool = tuple(iterable) + generators = [unique_everseen(enumerate(pool), key=itemgetter(1))] + current_combo = [None] * r + level = 0 + while generators: + try: + cur_idx, p = next(generators[-1]) + except StopIteration: + generators.pop() + level -= 1 + continue + current_combo[level] = p + if level + 1 == r: + yield tuple(current_combo) + else: + generators.append( + unique_everseen( + enumerate(pool[cur_idx + 1 :], cur_idx + 1), + key=itemgetter(1), + ) + ) + level += 1 + + +def filter_except(validator, iterable, *exceptions): + """Yield the items from *iterable* for which the *validator* function does + not raise one of the specified *exceptions*. + + *validator* is called for each item in *iterable*. + It should be a function that accepts one argument and raises an exception + if that item is not valid. + + >>> iterable = ['1', '2', 'three', '4', None] + >>> list(filter_except(int, iterable, ValueError, TypeError)) + ['1', '2', '4'] + + If an exception other than one given by *exceptions* is raised by + *validator*, it is raised like normal. + """ + for item in iterable: + try: + validator(item) + except exceptions: + pass + else: + yield item + + +def map_except(function, iterable, *exceptions): + """Transform each item from *iterable* with *function* and yield the + result, unless *function* raises one of the specified *exceptions*. + + *function* is called to transform each item in *iterable*. + It should accept one argument. + + >>> iterable = ['1', '2', 'three', '4', None] + >>> list(map_except(int, iterable, ValueError, TypeError)) + [1, 2, 4] + + If an exception other than one given by *exceptions* is raised by + *function*, it is raised like normal. + """ + for item in iterable: + try: + yield function(item) + except exceptions: + pass + + +def map_if(iterable, pred, func, func_else=None): + """Evaluate each item from *iterable* using *pred*. If the result is + equivalent to ``True``, transform the item with *func* and yield it. + Otherwise, transform the item with *func_else* and yield it. + + *pred*, *func*, and *func_else* should each be functions that accept + one argument. By default, *func_else* is the identity function. + + >>> from math import sqrt + >>> iterable = list(range(-5, 5)) + >>> iterable + [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4] + >>> list(map_if(iterable, lambda x: x > 3, lambda x: 'toobig')) + [-5, -4, -3, -2, -1, 0, 1, 2, 3, 'toobig'] + >>> list(map_if(iterable, lambda x: x >= 0, + ... lambda x: f'{sqrt(x):.2f}', lambda x: None)) + [None, None, None, None, None, '0.00', '1.00', '1.41', '1.73', '2.00'] + """ + + if func_else is None: + for item in iterable: + yield func(item) if pred(item) else item + + else: + for item in iterable: + yield func(item) if pred(item) else func_else(item) + + +def _sample_unweighted(iterator, k, strict): + # Algorithm L in the 1994 paper by Kim-Hung Li: + # "Reservoir-Sampling Algorithms of Time Complexity O(n(1+log(N/n)))". + + reservoir = list(islice(iterator, k)) + if strict and len(reservoir) < k: + raise ValueError('Sample larger than population') + W = 1.0 + + with suppress(StopIteration): + while True: + W *= random() ** (1 / k) + skip = floor(log(random()) / log1p(-W)) + element = next(islice(iterator, skip, None)) + reservoir[randrange(k)] = element + + shuffle(reservoir) + return reservoir + + +def _sample_weighted(iterator, k, weights, strict): + # Implementation of "A-ExpJ" from the 2006 paper by Efraimidis et al. : + # "Weighted random sampling with a reservoir". + + # Log-transform for numerical stability for weights that are small/large + weight_keys = (log(random()) / weight for weight in weights) + + # Fill up the reservoir (collection of samples) with the first `k` + # weight-keys and elements, then heapify the list. + reservoir = take(k, zip(weight_keys, iterator)) + if strict and len(reservoir) < k: + raise ValueError('Sample larger than population') + + heapify(reservoir) + + # The number of jumps before changing the reservoir is a random variable + # with an exponential distribution. Sample it using random() and logs. + smallest_weight_key, _ = reservoir[0] + weights_to_skip = log(random()) / smallest_weight_key + + for weight, element in zip(weights, iterator): + if weight >= weights_to_skip: + # The notation here is consistent with the paper, but we store + # the weight-keys in log-space for better numerical stability. + smallest_weight_key, _ = reservoir[0] + t_w = exp(weight * smallest_weight_key) + r_2 = uniform(t_w, 1) # generate U(t_w, 1) + weight_key = log(r_2) / weight + heapreplace(reservoir, (weight_key, element)) + smallest_weight_key, _ = reservoir[0] + weights_to_skip = log(random()) / smallest_weight_key + else: + weights_to_skip -= weight + + ret = [element for weight_key, element in reservoir] + shuffle(ret) + return ret + + +def _sample_counted(population, k, counts, strict): + element = None + remaining = 0 + + def feed(i): + # Advance *i* steps ahead and consume an element + nonlocal element, remaining + + while i + 1 > remaining: + i = i - remaining + element = next(population) + remaining = next(counts) + remaining -= i + 1 + return element + + with suppress(StopIteration): + reservoir = [] + for _ in range(k): + reservoir.append(feed(0)) + + if strict and len(reservoir) < k: + raise ValueError('Sample larger than population') + + with suppress(StopIteration): + W = 1.0 + while True: + W *= random() ** (1 / k) + skip = floor(log(random()) / log1p(-W)) + element = feed(skip) + reservoir[randrange(k)] = element + + shuffle(reservoir) + return reservoir + + +def sample(iterable, k, weights=None, *, counts=None, strict=False): + """Return a *k*-length list of elements chosen (without replacement) + from the *iterable*. + + Similar to :func:`random.sample`, but works on inputs that aren't + indexable (such as sets and dictionaries) and on inputs where the + size isn't known in advance (such as generators). + + >>> iterable = range(100) + >>> sample(iterable, 5) # doctest: +SKIP + [81, 60, 96, 16, 4] + + For iterables with repeated elements, you may supply *counts* to + indicate the repeats. + + >>> iterable = ['a', 'b'] + >>> counts = [3, 4] # Equivalent to 'a', 'a', 'a', 'b', 'b', 'b', 'b' + >>> sample(iterable, k=3, counts=counts) # doctest: +SKIP + ['a', 'a', 'b'] + + An iterable with *weights* may be given: + + >>> iterable = range(100) + >>> weights = (i * i + 1 for i in range(100)) + >>> sampled = sample(iterable, 5, weights=weights) # doctest: +SKIP + [79, 67, 74, 66, 78] + + Weighted selections are made without replacement. + After an element is selected, it is removed from the pool and the + relative weights of the other elements increase (this + does not match the behavior of :func:`random.sample`'s *counts* + parameter). Note that *weights* may not be used with *counts*. + + If the length of *iterable* is less than *k*, + ``ValueError`` is raised if *strict* is ``True`` and + all elements are returned (in shuffled order) if *strict* is ``False``. + + By default, the `Algorithm L `__ reservoir sampling + technique is used. When *weights* are provided, + `Algorithm A-ExpJ `__ is used instead. + + Notes on reproducibility: + + * The algorithms rely on inexact floating-point functions provided + by the underlying math library (e.g. ``log``, ``log1p``, and ``pow``). + Those functions can `produce slightly different results + `_ on + different builds. Accordingly, selections can vary across builds + even for the same seed. + + * The algorithms loop over the input and make selections based on + ordinal position, so selections from unordered collections (such as + sets) won't reproduce across sessions on the same platform using the + same seed. For example, this won't reproduce:: + + >> seed(8675309) + >> sample(set('abcdefghijklmnopqrstuvwxyz'), 10) + ['c', 'p', 'e', 'w', 's', 'a', 'j', 'd', 'n', 't'] + + """ + iterator = iter(iterable) + + if k < 0: + raise ValueError('k must be non-negative') + + if k == 0: + return [] + + if weights is not None and counts is not None: + raise TypeError('weights and counts are mutually exclusive') + + elif weights is not None: + weights = iter(weights) + return _sample_weighted(iterator, k, weights, strict) + + elif counts is not None: + counts = iter(counts) + return _sample_counted(iterator, k, counts, strict) + + else: + return _sample_unweighted(iterator, k, strict) + + +def is_sorted(iterable, key=None, reverse=False, strict=False): + """Returns ``True`` if the items of iterable are in sorted order, and + ``False`` otherwise. *key* and *reverse* have the same meaning that they do + in the built-in :func:`sorted` function. + + >>> is_sorted(['1', '2', '3', '4', '5'], key=int) + True + >>> is_sorted([5, 4, 3, 1, 2], reverse=True) + False + + If *strict*, tests for strict sorting, that is, returns ``False`` if equal + elements are found: + + >>> is_sorted([1, 2, 2]) + True + >>> is_sorted([1, 2, 2], strict=True) + False + + The function returns ``False`` after encountering the first out-of-order + item, which means it may produce results that differ from the built-in + :func:`sorted` function for objects with unusual comparison dynamics + (like ``math.nan``). If there are no out-of-order items, the iterable is + exhausted. + """ + it = iterable if (key is None) else map(key, iterable) + a, b = tee(it) + next(b, None) + if reverse: + b, a = a, b + return all(map(lt, a, b)) if strict else not any(map(lt, b, a)) + + +class AbortThread(BaseException): + pass + + +class callback_iter: + """Convert a function that uses callbacks to an iterator. + + Let *func* be a function that takes a `callback` keyword argument. + For example: + + >>> def func(callback=None): + ... for i, c in [(1, 'a'), (2, 'b'), (3, 'c')]: + ... if callback: + ... callback(i, c) + ... return 4 + + + Use ``with callback_iter(func)`` to get an iterator over the parameters + that are delivered to the callback. + + >>> with callback_iter(func) as it: + ... for args, kwargs in it: + ... print(args) + (1, 'a') + (2, 'b') + (3, 'c') + + The function will be called in a background thread. The ``done`` property + indicates whether it has completed execution. + + >>> it.done + True + + If it completes successfully, its return value will be available + in the ``result`` property. + + >>> it.result + 4 + + Notes: + + * If the function uses some keyword argument besides ``callback``, supply + *callback_kwd*. + * If it finished executing, but raised an exception, accessing the + ``result`` property will raise the same exception. + * If it hasn't finished executing, accessing the ``result`` + property from within the ``with`` block will raise ``RuntimeError``. + * If it hasn't finished executing, accessing the ``result`` property from + outside the ``with`` block will raise a + ``more_itertools.AbortThread`` exception. + * Provide *wait_seconds* to adjust how frequently the it is polled for + output. + + """ + + def __init__(self, func, callback_kwd='callback', wait_seconds=0.1): + self._func = func + self._callback_kwd = callback_kwd + self._aborted = False + self._future = None + self._wait_seconds = wait_seconds + # Lazily import concurrent.future + self._executor = __import__( + 'concurrent.futures' + ).futures.ThreadPoolExecutor(max_workers=1) + self._iterator = self._reader() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self._aborted = True + self._executor.shutdown() + + def __iter__(self): + return self + + def __next__(self): + return next(self._iterator) + + @property + def done(self): + if self._future is None: + return False + return self._future.done() + + @property + def result(self): + if not self.done: + raise RuntimeError('Function has not yet completed') + + return self._future.result() + + def _reader(self): + q = Queue() + + def callback(*args, **kwargs): + if self._aborted: + raise AbortThread('canceled by user') + + q.put((args, kwargs)) + + self._future = self._executor.submit( + self._func, **{self._callback_kwd: callback} + ) + + while True: + try: + item = q.get(timeout=self._wait_seconds) + except Empty: + pass + else: + q.task_done() + yield item + + if self._future.done(): + break + + remaining = [] + while True: + try: + item = q.get_nowait() + except Empty: + break + else: + q.task_done() + remaining.append(item) + q.join() + yield from remaining + + +def windowed_complete(iterable, n): + """ + Yield ``(beginning, middle, end)`` tuples, where: + + * Each ``middle`` has *n* items from *iterable* + * Each ``beginning`` has the items before the ones in ``middle`` + * Each ``end`` has the items after the ones in ``middle`` + + >>> iterable = range(7) + >>> n = 3 + >>> for beginning, middle, end in windowed_complete(iterable, n): + ... print(beginning, middle, end) + () (0, 1, 2) (3, 4, 5, 6) + (0,) (1, 2, 3) (4, 5, 6) + (0, 1) (2, 3, 4) (5, 6) + (0, 1, 2) (3, 4, 5) (6,) + (0, 1, 2, 3) (4, 5, 6) () + + Note that *n* must be at least 0 and most equal to the length of + *iterable*. + + This function will exhaust the iterable and may require significant + storage. + """ + if n < 0: + raise ValueError('n must be >= 0') + + seq = tuple(iterable) + size = len(seq) + + if n > size: + raise ValueError('n must be <= len(seq)') + + for i in range(size - n + 1): + beginning = seq[:i] + middle = seq[i : i + n] + end = seq[i + n :] + yield beginning, middle, end + + +def all_unique(iterable, key=None): + """ + Returns ``True`` if all the elements of *iterable* are unique (no two + elements are equal). + + >>> all_unique('ABCB') + False + + If a *key* function is specified, it will be used to make comparisons. + + >>> all_unique('ABCb') + True + >>> all_unique('ABCb', str.lower) + False + + The function returns as soon as the first non-unique element is + encountered. Iterables with a mix of hashable and unhashable items can + be used, but the function will be slower for unhashable items. + """ + seenset = set() + seenset_add = seenset.add + seenlist = [] + seenlist_add = seenlist.append + for element in map(key, iterable) if key else iterable: + try: + if element in seenset: + return False + seenset_add(element) + except TypeError: + if element in seenlist: + return False + seenlist_add(element) + return True + + +def nth_product(index, *args): + """Equivalent to ``list(product(*args))[index]``. + + The products of *args* can be ordered lexicographically. + :func:`nth_product` computes the product at sort position *index* without + computing the previous products. + + >>> nth_product(8, range(2), range(2), range(2), range(2)) + (1, 0, 0, 0) + + ``IndexError`` will be raised if the given *index* is invalid. + """ + pools = list(map(tuple, reversed(args))) + ns = list(map(len, pools)) + + c = reduce(mul, ns) + + if index < 0: + index += c + + if not 0 <= index < c: + raise IndexError + + result = [] + for pool, n in zip(pools, ns): + result.append(pool[index % n]) + index //= n + + return tuple(reversed(result)) + + +def nth_permutation(iterable, r, index): + """Equivalent to ``list(permutations(iterable, r))[index]``` + + The subsequences of *iterable* that are of length *r* where order is + important can be ordered lexicographically. :func:`nth_permutation` + computes the subsequence at sort position *index* directly, without + computing the previous subsequences. + + >>> nth_permutation('ghijk', 2, 5) + ('h', 'i') + + ``ValueError`` will be raised If *r* is negative or greater than the length + of *iterable*. + ``IndexError`` will be raised if the given *index* is invalid. + """ + pool = list(iterable) + n = len(pool) + + if r is None or r == n: + r, c = n, factorial(n) + elif not 0 <= r < n: + raise ValueError + else: + c = perm(n, r) + assert c > 0 # factorial(n)>0, and r>> nth_combination_with_replacement(range(5), 3, 5) + (0, 1, 1) + + ``ValueError`` will be raised If *r* is negative or greater than the length + of *iterable*. + ``IndexError`` will be raised if the given *index* is invalid. + """ + pool = tuple(iterable) + n = len(pool) + if (r < 0) or (r > n): + raise ValueError + + c = comb(n + r - 1, r) + + if index < 0: + index += c + + if (index < 0) or (index >= c): + raise IndexError + + result = [] + i = 0 + while r: + r -= 1 + while n >= 0: + num_combs = comb(n + r - 1, r) + if index < num_combs: + break + n -= 1 + i += 1 + index -= num_combs + result.append(pool[i]) + + return tuple(result) + + +def value_chain(*args): + """Yield all arguments passed to the function in the same order in which + they were passed. If an argument itself is iterable then iterate over its + values. + + >>> list(value_chain(1, 2, 3, [4, 5, 6])) + [1, 2, 3, 4, 5, 6] + + Binary and text strings are not considered iterable and are emitted + as-is: + + >>> list(value_chain('12', '34', ['56', '78'])) + ['12', '34', '56', '78'] + + Pre- or postpend a single element to an iterable: + + >>> list(value_chain(1, [2, 3, 4, 5, 6])) + [1, 2, 3, 4, 5, 6] + >>> list(value_chain([1, 2, 3, 4, 5], 6)) + [1, 2, 3, 4, 5, 6] + + Multiple levels of nesting are not flattened. + + """ + for value in args: + if isinstance(value, (str, bytes)): + yield value + continue + try: + yield from value + except TypeError: + yield value + + +def product_index(element, *args): + """Equivalent to ``list(product(*args)).index(element)`` + + The products of *args* can be ordered lexicographically. + :func:`product_index` computes the first index of *element* without + computing the previous products. + + >>> product_index([8, 2], range(10), range(5)) + 42 + + ``ValueError`` will be raised if the given *element* isn't in the product + of *args*. + """ + index = 0 + + for x, pool in zip_longest(element, args, fillvalue=_marker): + if x is _marker or pool is _marker: + raise ValueError('element is not a product of args') + + pool = tuple(pool) + index = index * len(pool) + pool.index(x) + + return index + + +def combination_index(element, iterable): + """Equivalent to ``list(combinations(iterable, r)).index(element)`` + + The subsequences of *iterable* that are of length *r* can be ordered + lexicographically. :func:`combination_index` computes the index of the + first *element*, without computing the previous combinations. + + >>> combination_index('adf', 'abcdefg') + 10 + + ``ValueError`` will be raised if the given *element* isn't one of the + combinations of *iterable*. + """ + element = enumerate(element) + k, y = next(element, (None, None)) + if k is None: + return 0 + + indexes = [] + pool = enumerate(iterable) + for n, x in pool: + if x == y: + indexes.append(n) + tmp, y = next(element, (None, None)) + if tmp is None: + break + else: + k = tmp + else: + raise ValueError('element is not a combination of iterable') + + n, _ = last(pool, default=(n, None)) + + # Python versions below 3.8 don't have math.comb + index = 1 + for i, j in enumerate(reversed(indexes), start=1): + j = n - j + if i <= j: + index += comb(j, i) + + return comb(n + 1, k + 1) - index + + +def combination_with_replacement_index(element, iterable): + """Equivalent to + ``list(combinations_with_replacement(iterable, r)).index(element)`` + + The subsequences with repetition of *iterable* that are of length *r* can + be ordered lexicographically. :func:`combination_with_replacement_index` + computes the index of the first *element*, without computing the previous + combinations with replacement. + + >>> combination_with_replacement_index('adf', 'abcdefg') + 20 + + ``ValueError`` will be raised if the given *element* isn't one of the + combinations with replacement of *iterable*. + """ + element = tuple(element) + l = len(element) + element = enumerate(element) + + k, y = next(element, (None, None)) + if k is None: + return 0 + + indexes = [] + pool = tuple(iterable) + for n, x in enumerate(pool): + while x == y: + indexes.append(n) + tmp, y = next(element, (None, None)) + if tmp is None: + break + else: + k = tmp + if y is None: + break + else: + raise ValueError( + 'element is not a combination with replacement of iterable' + ) + + n = len(pool) + occupations = [0] * n + for p in indexes: + occupations[p] += 1 + + index = 0 + cumulative_sum = 0 + for k in range(1, n): + cumulative_sum += occupations[k - 1] + j = l + n - 1 - k - cumulative_sum + i = n - k + if i <= j: + index += comb(j, i) + + return index + + +def permutation_index(element, iterable): + """Equivalent to ``list(permutations(iterable, r)).index(element)``` + + The subsequences of *iterable* that are of length *r* where order is + important can be ordered lexicographically. :func:`permutation_index` + computes the index of the first *element* directly, without computing + the previous permutations. + + >>> permutation_index([1, 3, 2], range(5)) + 19 + + ``ValueError`` will be raised if the given *element* isn't one of the + permutations of *iterable*. + """ + index = 0 + pool = list(iterable) + for i, x in zip(range(len(pool), -1, -1), element): + r = pool.index(x) + index = index * i + r + del pool[r] + + return index + + +class countable: + """Wrap *iterable* and keep a count of how many items have been consumed. + + The ``items_seen`` attribute starts at ``0`` and increments as the iterable + is consumed: + + >>> iterable = map(str, range(10)) + >>> it = countable(iterable) + >>> it.items_seen + 0 + >>> next(it), next(it) + ('0', '1') + >>> list(it) + ['2', '3', '4', '5', '6', '7', '8', '9'] + >>> it.items_seen + 10 + """ + + def __init__(self, iterable): + self._iterator = iter(iterable) + self.items_seen = 0 + + def __iter__(self): + return self + + def __next__(self): + item = next(self._iterator) + self.items_seen += 1 + + return item + + +def chunked_even(iterable, n): + """Break *iterable* into lists of approximately length *n*. + Items are distributed such the lengths of the lists differ by at most + 1 item. + + >>> iterable = [1, 2, 3, 4, 5, 6, 7] + >>> n = 3 + >>> list(chunked_even(iterable, n)) # List lengths: 3, 2, 2 + [[1, 2, 3], [4, 5], [6, 7]] + >>> list(chunked(iterable, n)) # List lengths: 3, 3, 1 + [[1, 2, 3], [4, 5, 6], [7]] + + """ + iterator = iter(iterable) + + # Initialize a buffer to process the chunks while keeping + # some back to fill any underfilled chunks + min_buffer = (n - 1) * (n - 2) + buffer = list(islice(iterator, min_buffer)) + + # Append items until we have a completed chunk + for _ in islice(map(buffer.append, iterator), n, None, n): + yield buffer[:n] + del buffer[:n] + + # Check if any chunks need addition processing + if not buffer: + return + length = len(buffer) + + # Chunks are either size `full_size <= n` or `partial_size = full_size - 1` + q, r = divmod(length, n) + num_lists = q + (1 if r > 0 else 0) + q, r = divmod(length, num_lists) + full_size = q + (1 if r > 0 else 0) + partial_size = full_size - 1 + num_full = length - partial_size * num_lists + + # Yield chunks of full size + partial_start_idx = num_full * full_size + if full_size > 0: + for i in range(0, partial_start_idx, full_size): + yield buffer[i : i + full_size] + + # Yield chunks of partial size + if partial_size > 0: + for i in range(partial_start_idx, length, partial_size): + yield buffer[i : i + partial_size] + + +def zip_broadcast(*objects, scalar_types=(str, bytes), strict=False): + """A version of :func:`zip` that "broadcasts" any scalar + (i.e., non-iterable) items into output tuples. + + >>> iterable_1 = [1, 2, 3] + >>> iterable_2 = ['a', 'b', 'c'] + >>> scalar = '_' + >>> list(zip_broadcast(iterable_1, iterable_2, scalar)) + [(1, 'a', '_'), (2, 'b', '_'), (3, 'c', '_')] + + The *scalar_types* keyword argument determines what types are considered + scalar. It is set to ``(str, bytes)`` by default. Set it to ``None`` to + treat strings and byte strings as iterable: + + >>> list(zip_broadcast('abc', 0, 'xyz', scalar_types=None)) + [('a', 0, 'x'), ('b', 0, 'y'), ('c', 0, 'z')] + + If the *strict* keyword argument is ``True``, then + ``UnequalIterablesError`` will be raised if any of the iterables have + different lengths. + """ + + def is_scalar(obj): + if scalar_types and isinstance(obj, scalar_types): + return True + try: + iter(obj) + except TypeError: + return True + else: + return False + + size = len(objects) + if not size: + return + + new_item = [None] * size + iterables, iterable_positions = [], [] + for i, obj in enumerate(objects): + if is_scalar(obj): + new_item[i] = obj + else: + iterables.append(iter(obj)) + iterable_positions.append(i) + + if not iterables: + yield tuple(objects) + return + + zipper = _zip_equal if strict else zip + for item in zipper(*iterables): + for i, new_item[i] in zip(iterable_positions, item): + pass + yield tuple(new_item) + + +def unique_in_window(iterable, n, key=None): + """Yield the items from *iterable* that haven't been seen recently. + *n* is the size of the lookback window. + + >>> iterable = [0, 1, 0, 2, 3, 0] + >>> n = 3 + >>> list(unique_in_window(iterable, n)) + [0, 1, 2, 3, 0] + + The *key* function, if provided, will be used to determine uniqueness: + + >>> list(unique_in_window('abAcda', 3, key=lambda x: x.lower())) + ['a', 'b', 'c', 'd', 'a'] + + The items in *iterable* must be hashable. + + """ + if n <= 0: + raise ValueError('n must be greater than 0') + + window = deque(maxlen=n) + counts = defaultdict(int) + use_key = key is not None + + for item in iterable: + if len(window) == n: + to_discard = window[0] + if counts[to_discard] == 1: + del counts[to_discard] + else: + counts[to_discard] -= 1 + + k = key(item) if use_key else item + if k not in counts: + yield item + counts[k] += 1 + window.append(k) + + +def duplicates_everseen(iterable, key=None): + """Yield duplicate elements after their first appearance. + + >>> list(duplicates_everseen('mississippi')) + ['s', 'i', 's', 's', 'i', 'p', 'i'] + >>> list(duplicates_everseen('AaaBbbCccAaa', str.lower)) + ['a', 'a', 'b', 'b', 'c', 'c', 'A', 'a', 'a'] + + This function is analogous to :func:`unique_everseen` and is subject to + the same performance considerations. + + """ + seen_set = set() + seen_list = [] + use_key = key is not None + + for element in iterable: + k = key(element) if use_key else element + try: + if k not in seen_set: + seen_set.add(k) + else: + yield element + except TypeError: + if k not in seen_list: + seen_list.append(k) + else: + yield element + + +def duplicates_justseen(iterable, key=None): + """Yields serially-duplicate elements after their first appearance. + + >>> list(duplicates_justseen('mississippi')) + ['s', 's', 'p'] + >>> list(duplicates_justseen('AaaBbbCccAaa', str.lower)) + ['a', 'a', 'b', 'b', 'c', 'c', 'a', 'a'] + + This function is analogous to :func:`unique_justseen`. + + """ + return flatten(g for _, g in groupby(iterable, key) for _ in g) + + +def classify_unique(iterable, key=None): + """Classify each element in terms of its uniqueness. + + For each element in the input iterable, return a 3-tuple consisting of: + + 1. The element itself + 2. ``False`` if the element is equal to the one preceding it in the input, + ``True`` otherwise (i.e. the equivalent of :func:`unique_justseen`) + 3. ``False`` if this element has been seen anywhere in the input before, + ``True`` otherwise (i.e. the equivalent of :func:`unique_everseen`) + + >>> list(classify_unique('otto')) # doctest: +NORMALIZE_WHITESPACE + [('o', True, True), + ('t', True, True), + ('t', False, False), + ('o', True, False)] + + This function is analogous to :func:`unique_everseen` and is subject to + the same performance considerations. + + """ + seen_set = set() + seen_list = [] + use_key = key is not None + previous = None + + for i, element in enumerate(iterable): + k = key(element) if use_key else element + is_unique_justseen = not i or previous != k + previous = k + is_unique_everseen = False + try: + if k not in seen_set: + seen_set.add(k) + is_unique_everseen = True + except TypeError: + if k not in seen_list: + seen_list.append(k) + is_unique_everseen = True + yield element, is_unique_justseen, is_unique_everseen + + +def minmax(iterable_or_value, *others, key=None, default=_marker): + """Returns both the smallest and largest items from an iterable + or from two or more arguments. + + >>> minmax([3, 1, 5]) + (1, 5) + + >>> minmax(4, 2, 6) + (2, 6) + + If a *key* function is provided, it will be used to transform the input + items for comparison. + + >>> minmax([5, 30], key=str) # '30' sorts before '5' + (30, 5) + + If a *default* value is provided, it will be returned if there are no + input items. + + >>> minmax([], default=(0, 0)) + (0, 0) + + Otherwise ``ValueError`` is raised. + + This function makes a single pass over the input elements and takes care to + minimize the number of comparisons made during processing. + + Note that unlike the builtin ``max`` function, which always returns the first + item with the maximum value, this function may return another item when there are + ties. + + This function is based on the + `recipe `__ by + Raymond Hettinger. + """ + iterable = (iterable_or_value, *others) if others else iterable_or_value + + it = iter(iterable) + + try: + lo = hi = next(it) + except StopIteration as exc: + if default is _marker: + raise ValueError( + '`minmax()` argument is an empty iterable. ' + 'Provide a `default` value to suppress this error.' + ) from exc + return default + + # Different branches depending on the presence of key. This saves a lot + # of unimportant copies which would slow the "key=None" branch + # significantly down. + if key is None: + for x, y in zip_longest(it, it, fillvalue=lo): + if y < x: + x, y = y, x + if x < lo: + lo = x + if hi < y: + hi = y + + else: + lo_key = hi_key = key(lo) + + for x, y in zip_longest(it, it, fillvalue=lo): + x_key, y_key = key(x), key(y) + + if y_key < x_key: + x, y, x_key, y_key = y, x, y_key, x_key + if x_key < lo_key: + lo, lo_key = x, x_key + if hi_key < y_key: + hi, hi_key = y, y_key + + return lo, hi + + +def constrained_batches( + iterable, max_size, max_count=None, get_len=len, strict=True +): + """Yield batches of items from *iterable* with a combined size limited by + *max_size*. + + >>> iterable = [b'12345', b'123', b'12345678', b'1', b'1', b'12', b'1'] + >>> list(constrained_batches(iterable, 10)) + [(b'12345', b'123'), (b'12345678', b'1', b'1'), (b'12', b'1')] + + If a *max_count* is supplied, the number of items per batch is also + limited: + + >>> iterable = [b'12345', b'123', b'12345678', b'1', b'1', b'12', b'1'] + >>> list(constrained_batches(iterable, 10, max_count = 2)) + [(b'12345', b'123'), (b'12345678', b'1'), (b'1', b'12'), (b'1',)] + + If a *get_len* function is supplied, use that instead of :func:`len` to + determine item size. + + If *strict* is ``True``, raise ``ValueError`` if any single item is bigger + than *max_size*. Otherwise, allow single items to exceed *max_size*. + """ + if max_size <= 0: + raise ValueError('maximum size must be greater than zero') + + batch = [] + batch_size = 0 + batch_count = 0 + for item in iterable: + item_len = get_len(item) + if strict and item_len > max_size: + raise ValueError('item size exceeds maximum size') + + reached_count = batch_count == max_count + reached_size = item_len + batch_size > max_size + if batch_count and (reached_size or reached_count): + yield tuple(batch) + batch.clear() + batch_size = 0 + batch_count = 0 + + batch.append(item) + batch_size += item_len + batch_count += 1 + + if batch: + yield tuple(batch) + + +def gray_product(*iterables): + """Like :func:`itertools.product`, but return tuples in an order such + that only one element in the generated tuple changes from one iteration + to the next. + + >>> list(gray_product('AB','CD')) + [('A', 'C'), ('B', 'C'), ('B', 'D'), ('A', 'D')] + + This function consumes all of the input iterables before producing output. + If any of the input iterables have fewer than two items, ``ValueError`` + is raised. + + For information on the algorithm, see + `this section `__ + of Donald Knuth's *The Art of Computer Programming*. + """ + all_iterables = tuple(tuple(x) for x in iterables) + iterable_count = len(all_iterables) + for iterable in all_iterables: + if len(iterable) < 2: + raise ValueError("each iterable must have two or more items") + + # This is based on "Algorithm H" from section 7.2.1.1, page 20. + # a holds the indexes of the source iterables for the n-tuple to be yielded + # f is the array of "focus pointers" + # o is the array of "directions" + a = [0] * iterable_count + f = list(range(iterable_count + 1)) + o = [1] * iterable_count + while True: + yield tuple(all_iterables[i][a[i]] for i in range(iterable_count)) + j = f[0] + f[0] = 0 + if j == iterable_count: + break + a[j] = a[j] + o[j] + if a[j] == 0 or a[j] == len(all_iterables[j]) - 1: + o[j] = -o[j] + f[j] = f[j + 1] + f[j + 1] = j + 1 + + +def partial_product(*iterables): + """Yields tuples containing one item from each iterator, with subsequent + tuples changing a single item at a time by advancing each iterator until it + is exhausted. This sequence guarantees every value in each iterable is + output at least once without generating all possible combinations. + + This may be useful, for example, when testing an expensive function. + + >>> list(partial_product('AB', 'C', 'DEF')) + [('A', 'C', 'D'), ('B', 'C', 'D'), ('B', 'C', 'E'), ('B', 'C', 'F')] + """ + + iterators = list(map(iter, iterables)) + + try: + prod = [next(it) for it in iterators] + except StopIteration: + return + yield tuple(prod) + + for i, it in enumerate(iterators): + for prod[i] in it: + yield tuple(prod) + + +def takewhile_inclusive(predicate, iterable): + """A variant of :func:`takewhile` that yields one additional element. + + >>> list(takewhile_inclusive(lambda x: x < 5, [1, 4, 6, 4, 1])) + [1, 4, 6] + + :func:`takewhile` would return ``[1, 4]``. + """ + for x in iterable: + yield x + if not predicate(x): + break + + +def outer_product(func, xs, ys, *args, **kwargs): + """A generalized outer product that applies a binary function to all + pairs of items. Returns a 2D matrix with ``len(xs)`` rows and ``len(ys)`` + columns. + Also accepts ``*args`` and ``**kwargs`` that are passed to ``func``. + + Multiplication table: + + >>> list(outer_product(mul, range(1, 4), range(1, 6))) + [(1, 2, 3, 4, 5), (2, 4, 6, 8, 10), (3, 6, 9, 12, 15)] + + Cross tabulation: + + >>> xs = ['A', 'B', 'A', 'A', 'B', 'B', 'A', 'A', 'B', 'B'] + >>> ys = ['X', 'X', 'X', 'Y', 'Z', 'Z', 'Y', 'Y', 'Z', 'Z'] + >>> pair_counts = Counter(zip(xs, ys)) + >>> count_rows = lambda x, y: pair_counts[x, y] + >>> list(outer_product(count_rows, sorted(set(xs)), sorted(set(ys)))) + [(2, 3, 0), (1, 0, 4)] + + Usage with ``*args`` and ``**kwargs``: + + >>> animals = ['cat', 'wolf', 'mouse'] + >>> list(outer_product(min, animals, animals, key=len)) + [('cat', 'cat', 'cat'), ('cat', 'wolf', 'wolf'), ('cat', 'wolf', 'mouse')] + """ + ys = tuple(ys) + return batched( + starmap(lambda x, y: func(x, y, *args, **kwargs), product(xs, ys)), + n=len(ys), + ) + + +def iter_suppress(iterable, *exceptions): + """Yield each of the items from *iterable*. If the iteration raises one of + the specified *exceptions*, that exception will be suppressed and iteration + will stop. + + >>> from itertools import chain + >>> def breaks_at_five(x): + ... while True: + ... if x >= 5: + ... raise RuntimeError + ... yield x + ... x += 1 + >>> it_1 = iter_suppress(breaks_at_five(1), RuntimeError) + >>> it_2 = iter_suppress(breaks_at_five(2), RuntimeError) + >>> list(chain(it_1, it_2)) + [1, 2, 3, 4, 2, 3, 4] + """ + try: + yield from iterable + except exceptions: + return + + +def filter_map(func, iterable): + """Apply *func* to every element of *iterable*, yielding only those which + are not ``None``. + + >>> elems = ['1', 'a', '2', 'b', '3'] + >>> list(filter_map(lambda s: int(s) if s.isnumeric() else None, elems)) + [1, 2, 3] + """ + for x in iterable: + y = func(x) + if y is not None: + yield y + + +def powerset_of_sets(iterable): + """Yields all possible subsets of the iterable. + + >>> list(powerset_of_sets([1, 2, 3])) # doctest: +SKIP + [set(), {1}, {2}, {3}, {1, 2}, {1, 3}, {2, 3}, {1, 2, 3}] + >>> list(powerset_of_sets([1, 1, 0])) # doctest: +SKIP + [set(), {1}, {0}, {0, 1}] + + :func:`powerset_of_sets` takes care to minimize the number + of hash operations performed. + """ + sets = tuple(dict.fromkeys(map(frozenset, zip(iterable)))) + return chain.from_iterable( + starmap(set().union, combinations(sets, r)) + for r in range(len(sets) + 1) + ) + + +def join_mappings(**field_to_map): + """ + Joins multiple mappings together using their common keys. + + >>> user_scores = {'elliot': 50, 'claris': 60} + >>> user_times = {'elliot': 30, 'claris': 40} + >>> join_mappings(score=user_scores, time=user_times) + {'elliot': {'score': 50, 'time': 30}, 'claris': {'score': 60, 'time': 40}} + """ + ret = defaultdict(dict) + + for field_name, mapping in field_to_map.items(): + for key, value in mapping.items(): + ret[key][field_name] = value + + return dict(ret) + + +def _complex_sumprod(v1, v2): + """High precision sumprod() for complex numbers. + Used by :func:`dft` and :func:`idft`. + """ + + real = attrgetter('real') + imag = attrgetter('imag') + r1 = chain(map(real, v1), map(neg, map(imag, v1))) + r2 = chain(map(real, v2), map(imag, v2)) + i1 = chain(map(real, v1), map(imag, v1)) + i2 = chain(map(imag, v2), map(real, v2)) + return complex(_fsumprod(r1, r2), _fsumprod(i1, i2)) + + +def dft(xarr): + """Discrete Fourier Transform. *xarr* is a sequence of complex numbers. + Yields the components of the corresponding transformed output vector. + + >>> import cmath + >>> xarr = [1, 2-1j, -1j, -1+2j] # time domain + >>> Xarr = [2, -2-2j, -2j, 4+4j] # frequency domain + >>> magnitudes, phases = zip(*map(cmath.polar, Xarr)) + >>> all(map(cmath.isclose, dft(xarr), Xarr)) + True + + Inputs are restricted to numeric types that can add and multiply + with a complex number. This includes int, float, complex, and + Fraction, but excludes Decimal. + + See :func:`idft` for the inverse Discrete Fourier Transform. + """ + N = len(xarr) + roots_of_unity = [e ** (n / N * tau * -1j) for n in range(N)] + for k in range(N): + coeffs = [roots_of_unity[k * n % N] for n in range(N)] + yield _complex_sumprod(xarr, coeffs) + + +def idft(Xarr): + """Inverse Discrete Fourier Transform. *Xarr* is a sequence of + complex numbers. Yields the components of the corresponding + inverse-transformed output vector. + + >>> import cmath + >>> xarr = [1, 2-1j, -1j, -1+2j] # time domain + >>> Xarr = [2, -2-2j, -2j, 4+4j] # frequency domain + >>> all(map(cmath.isclose, idft(Xarr), xarr)) + True + + Inputs are restricted to numeric types that can add and multiply + with a complex number. This includes int, float, complex, and + Fraction, but excludes Decimal. + + See :func:`dft` for the Discrete Fourier Transform. + """ + N = len(Xarr) + roots_of_unity = [e ** (n / N * tau * 1j) for n in range(N)] + for k in range(N): + coeffs = [roots_of_unity[k * n % N] for n in range(N)] + yield _complex_sumprod(Xarr, coeffs) / N + + +def doublestarmap(func, iterable): + """Apply *func* to every item of *iterable* by dictionary unpacking + the item into *func*. + + The difference between :func:`itertools.starmap` and :func:`doublestarmap` + parallels the distinction between ``func(*a)`` and ``func(**a)``. + + >>> iterable = [{'a': 1, 'b': 2}, {'a': 40, 'b': 60}] + >>> list(doublestarmap(lambda a, b: a + b, iterable)) + [3, 100] + + ``TypeError`` will be raised if *func*'s signature doesn't match the + mapping contained in *iterable* or if *iterable* does not contain mappings. + """ + for item in iterable: + yield func(**item) + + +def _nth_prime_bounds(n): + """Bounds for the nth prime (counting from 1): lb < p_n < ub.""" + # At and above 688,383, the lb/ub spread is under 0.003 * p_n. + + if n < 1: + raise ValueError + + if n < 6: + return (n, 2.25 * n) + + # https://en.wikipedia.org/wiki/Prime-counting_function#Inequalities + upper_bound = n * log(n * log(n)) + lower_bound = upper_bound - n + if n >= 688_383: + upper_bound -= n * (1.0 - (log(log(n)) - 2.0) / log(n)) + + return lower_bound, upper_bound + + +def nth_prime(n, *, approximate=False): + """Return the nth prime (counting from 0). + + >>> nth_prime(0) + 2 + >>> nth_prime(100) + 547 + + If *approximate* is set to True, will return a prime close + to the nth prime. The estimation is much faster than computing + an exact result. + + >>> nth_prime(200_000_000, approximate=True) # Exact result is 4222234763 + 4217820427 + + """ + lb, ub = _nth_prime_bounds(n + 1) + + if not approximate or n <= 1_000_000: + return nth(sieve(ceil(ub)), n) + + # Search from the midpoint and return the first odd prime + odd = floor((lb + ub) / 2) | 1 + return first_true(count(odd, step=2), pred=is_prime) + + +def argmin(iterable, *, key=None): + """ + Index of the first occurrence of a minimum value in an iterable. + + >>> argmin('efghabcdijkl') + 4 + >>> argmin([3, 2, 1, 0, 4, 2, 1, 0]) + 3 + + For example, look up a label corresponding to the position + of a value that minimizes a cost function:: + + >>> def cost(x): + ... "Days for a wound to heal given a subject's age." + ... return x**2 - 20*x + 150 + ... + >>> labels = ['homer', 'marge', 'bart', 'lisa', 'maggie'] + >>> ages = [ 35, 30, 10, 9, 1 ] + + # Fastest healing family member + >>> labels[argmin(ages, key=cost)] + 'bart' + + # Age with fastest healing + >>> min(ages, key=cost) + 10 + + """ + if key is not None: + iterable = map(key, iterable) + return min(enumerate(iterable), key=itemgetter(1))[0] + + +def argmax(iterable, *, key=None): + """ + Index of the first occurrence of a maximum value in an iterable. + + >>> argmax('abcdefghabcd') + 7 + >>> argmax([0, 1, 2, 3, 3, 2, 1, 0]) + 3 + + For example, identify the best machine learning model:: + + >>> models = ['svm', 'random forest', 'knn', 'naïve bayes'] + >>> accuracy = [ 68, 61, 84, 72 ] + + # Most accurate model + >>> models[argmax(accuracy)] + 'knn' + + # Best accuracy + >>> max(accuracy) + 84 + + """ + if key is not None: + iterable = map(key, iterable) + return max(enumerate(iterable), key=itemgetter(1))[0] + + +def extract(iterable, indices): + """Yield values at the specified indices. + + Example: + + >>> data = 'abcdefghijklmnopqrstuvwxyz' + >>> list(extract(data, [7, 4, 11, 11, 14])) + ['h', 'e', 'l', 'l', 'o'] + + The *iterable* is consumed lazily and can be infinite. + The *indices* are consumed immediately and must be finite. + + Raises ``IndexError`` if an index lies beyond the iterable. + Raises ``ValueError`` for negative indices. + """ + + iterator = iter(iterable) + index_and_position = sorted(zip(indices, count())) + + if index_and_position and index_and_position[0][0] < 0: + raise ValueError('Indices must be non-negative') + + buffer = {} + iterator_position = -1 + next_to_emit = 0 + + for index, order in index_and_position: + advance = index - iterator_position + if advance: + try: + value = next(islice(iterator, advance - 1, None)) + except StopIteration: + raise IndexError(index) + iterator_position = index + + buffer[order] = value + + while next_to_emit in buffer: + yield buffer.pop(next_to_emit) + next_to_emit += 1 diff --git a/lib/more_itertools/more.pyi b/lib/more_itertools/more.pyi new file mode 100644 index 0000000..b5e33f8 --- /dev/null +++ b/lib/more_itertools/more.pyi @@ -0,0 +1,949 @@ +"""Stubs for more_itertools.more""" + +from __future__ import annotations + +import sys +import types + +from collections.abc import ( + Container, + Hashable, + Iterable, + Iterator, + Mapping, + Reversible, + Sequence, + Sized, +) +from contextlib import AbstractContextManager +from typing import ( + Any, + Callable, + Generic, + TypeVar, + overload, + type_check_only, +) +from typing_extensions import Protocol + +__all__ = [ + 'AbortThread', + 'SequenceView', + 'UnequalIterablesError', + 'adjacent', + 'all_unique', + 'always_iterable', + 'always_reversible', + 'argmax', + 'argmin', + 'bucket', + 'callback_iter', + 'chunked', + 'chunked_even', + 'circular_shifts', + 'collapse', + 'combination_index', + 'combination_with_replacement_index', + 'consecutive_groups', + 'constrained_batches', + 'consumer', + 'count_cycle', + 'countable', + 'derangements', + 'dft', + 'difference', + 'distinct_combinations', + 'distinct_permutations', + 'distribute', + 'divide', + 'doublestarmap', + 'duplicates_everseen', + 'duplicates_justseen', + 'classify_unique', + 'exactly_n', + 'extract', + 'filter_except', + 'filter_map', + 'first', + 'gray_product', + 'groupby_transform', + 'ichunked', + 'iequals', + 'idft', + 'ilen', + 'interleave', + 'interleave_evenly', + 'interleave_longest', + 'interleave_randomly', + 'intersperse', + 'is_sorted', + 'islice_extended', + 'iterate', + 'iter_suppress', + 'join_mappings', + 'last', + 'locate', + 'longest_common_prefix', + 'lstrip', + 'make_decorator', + 'map_except', + 'map_if', + 'map_reduce', + 'mark_ends', + 'minmax', + 'nth_or_last', + 'nth_permutation', + 'nth_prime', + 'nth_product', + 'nth_combination_with_replacement', + 'numeric_range', + 'one', + 'only', + 'outer_product', + 'padded', + 'partial_product', + 'partitions', + 'peekable', + 'permutation_index', + 'powerset_of_sets', + 'product_index', + 'raise_', + 'repeat_each', + 'repeat_last', + 'replace', + 'rlocate', + 'rstrip', + 'run_length', + 'sample', + 'seekable', + 'set_partitions', + 'side_effect', + 'sliced', + 'sort_together', + 'split_after', + 'split_at', + 'split_before', + 'split_into', + 'split_when', + 'spy', + 'stagger', + 'strip', + 'strictly_n', + 'substrings', + 'substrings_indexes', + 'takewhile_inclusive', + 'time_limited', + 'unique_in_window', + 'unique_to_each', + 'unzip', + 'value_chain', + 'windowed', + 'windowed_complete', + 'with_iter', + 'zip_broadcast', + 'zip_equal', + 'zip_offset', +] + +# Type and type variable definitions +_T = TypeVar('_T') +_T1 = TypeVar('_T1') +_T2 = TypeVar('_T2') +_T3 = TypeVar('_T3') +_T4 = TypeVar('_T4') +_T5 = TypeVar('_T5') +_U = TypeVar('_U') +_V = TypeVar('_V') +_W = TypeVar('_W') +_T_co = TypeVar('_T_co', covariant=True) +_GenFn = TypeVar('_GenFn', bound=Callable[..., Iterator[Any]]) +_Raisable = BaseException | type[BaseException] + +# The type of isinstance's second argument (from typeshed builtins) +if sys.version_info >= (3, 10): + _ClassInfo = type | types.UnionType | tuple[_ClassInfo, ...] +else: + _ClassInfo = type | tuple[_ClassInfo, ...] + +@type_check_only +class _SizedIterable(Protocol[_T_co], Sized, Iterable[_T_co]): ... + +@type_check_only +class _SizedReversible(Protocol[_T_co], Sized, Reversible[_T_co]): ... + +@type_check_only +class _SupportsSlicing(Protocol[_T_co]): + def __getitem__(self, __k: slice) -> _T_co: ... + +def chunked( + iterable: Iterable[_T], n: int | None, strict: bool = ... +) -> Iterator[list[_T]]: ... +@overload +def first(iterable: Iterable[_T]) -> _T: ... +@overload +def first(iterable: Iterable[_T], default: _U) -> _T | _U: ... +@overload +def last(iterable: Iterable[_T]) -> _T: ... +@overload +def last(iterable: Iterable[_T], default: _U) -> _T | _U: ... +@overload +def nth_or_last(iterable: Iterable[_T], n: int) -> _T: ... +@overload +def nth_or_last(iterable: Iterable[_T], n: int, default: _U) -> _T | _U: ... + +class peekable(Generic[_T], Iterator[_T]): + def __init__(self, iterable: Iterable[_T]) -> None: ... + def __iter__(self) -> peekable[_T]: ... + def __bool__(self) -> bool: ... + @overload + def peek(self) -> _T: ... + @overload + def peek(self, default: _U) -> _T | _U: ... + def prepend(self, *items: _T) -> None: ... + def __next__(self) -> _T: ... + @overload + def __getitem__(self, index: int) -> _T: ... + @overload + def __getitem__(self, index: slice) -> list[_T]: ... + +def consumer(func: _GenFn) -> _GenFn: ... +def ilen(iterable: Iterable[_T]) -> int: ... +def iterate(func: Callable[[_T], _T], start: _T) -> Iterator[_T]: ... +def with_iter( + context_manager: AbstractContextManager[Iterable[_T]], +) -> Iterator[_T]: ... +def one( + iterable: Iterable[_T], + too_short: _Raisable | None = ..., + too_long: _Raisable | None = ..., +) -> _T: ... +def raise_(exception: _Raisable, *args: Any) -> None: ... +def strictly_n( + iterable: Iterable[_T], + n: int, + too_short: _GenFn | None = ..., + too_long: _GenFn | None = ..., +) -> list[_T]: ... +def distinct_permutations( + iterable: Iterable[_T], r: int | None = ... +) -> Iterator[tuple[_T, ...]]: ... +def derangements( + iterable: Iterable[_T], r: int | None = None +) -> Iterator[tuple[_T, ...]]: ... +def intersperse( + e: _U, iterable: Iterable[_T], n: int = ... +) -> Iterator[_T | _U]: ... +def unique_to_each(*iterables: Iterable[_T]) -> list[list[_T]]: ... +@overload +def windowed( + seq: Iterable[_T], n: int, *, step: int = ... +) -> Iterator[tuple[_T | None, ...]]: ... +@overload +def windowed( + seq: Iterable[_T], n: int, fillvalue: _U, step: int = ... +) -> Iterator[tuple[_T | _U, ...]]: ... +def substrings(iterable: Iterable[_T]) -> Iterator[tuple[_T, ...]]: ... +def substrings_indexes( + seq: Sequence[_T], reverse: bool = ... +) -> Iterator[tuple[Sequence[_T], int, int]]: ... + +class bucket(Generic[_T, _U], Container[_U]): + def __init__( + self, + iterable: Iterable[_T], + key: Callable[[_T], _U], + validator: Callable[[_U], object] | None = ..., + ) -> None: ... + def __contains__(self, value: object) -> bool: ... + def __iter__(self) -> Iterator[_U]: ... + def __getitem__(self, value: object) -> Iterator[_T]: ... + +def spy( + iterable: Iterable[_T], n: int = ... +) -> tuple[list[_T], Iterator[_T]]: ... +def interleave(*iterables: Iterable[_T]) -> Iterator[_T]: ... +def interleave_longest(*iterables: Iterable[_T]) -> Iterator[_T]: ... +def interleave_evenly( + iterables: list[Iterable[_T]], lengths: list[int] | None = ... +) -> Iterator[_T]: ... +def interleave_randomly(*iterables: Iterable[_T]) -> Iterable[_T]: ... +def collapse( + iterable: Iterable[Any], + base_type: _ClassInfo | None = ..., + levels: int | None = ..., +) -> Iterator[Any]: ... +@overload +def side_effect( + func: Callable[[_T], object], + iterable: Iterable[_T], + chunk_size: None = ..., + before: Callable[[], object] | None = ..., + after: Callable[[], object] | None = ..., +) -> Iterator[_T]: ... +@overload +def side_effect( + func: Callable[[list[_T]], object], + iterable: Iterable[_T], + chunk_size: int, + before: Callable[[], object] | None = ..., + after: Callable[[], object] | None = ..., +) -> Iterator[_T]: ... +def sliced( + seq: _SupportsSlicing[_T], n: int, strict: bool = ... +) -> Iterator[_T]: ... +def split_at( + iterable: Iterable[_T], + pred: Callable[[_T], object], + maxsplit: int = ..., + keep_separator: bool = ..., +) -> Iterator[list[_T]]: ... +def split_before( + iterable: Iterable[_T], pred: Callable[[_T], object], maxsplit: int = ... +) -> Iterator[list[_T]]: ... +def split_after( + iterable: Iterable[_T], pred: Callable[[_T], object], maxsplit: int = ... +) -> Iterator[list[_T]]: ... +def split_when( + iterable: Iterable[_T], + pred: Callable[[_T, _T], object], + maxsplit: int = ..., +) -> Iterator[list[_T]]: ... +def split_into( + iterable: Iterable[_T], sizes: Iterable[int | None] +) -> Iterator[list[_T]]: ... +@overload +def padded( + iterable: Iterable[_T], + *, + n: int | None = ..., + next_multiple: bool = ..., +) -> Iterator[_T | None]: ... +@overload +def padded( + iterable: Iterable[_T], + fillvalue: _U, + n: int | None = ..., + next_multiple: bool = ..., +) -> Iterator[_T | _U]: ... +@overload +def repeat_last(iterable: Iterable[_T]) -> Iterator[_T]: ... +@overload +def repeat_last(iterable: Iterable[_T], default: _U) -> Iterator[_T | _U]: ... +def distribute(n: int, iterable: Iterable[_T]) -> list[Iterator[_T]]: ... +@overload +def stagger( + iterable: Iterable[_T], + offsets: _SizedIterable[int] = ..., + longest: bool = ..., +) -> Iterator[tuple[_T | None, ...]]: ... +@overload +def stagger( + iterable: Iterable[_T], + offsets: _SizedIterable[int] = ..., + longest: bool = ..., + fillvalue: _U = ..., +) -> Iterator[tuple[_T | _U, ...]]: ... + +class UnequalIterablesError(ValueError): + def __init__(self, details: tuple[int, int, int] | None = ...) -> None: ... + +# zip_equal +@overload +def zip_equal(__iter1: Iterable[_T1]) -> Iterator[tuple[_T1]]: ... +@overload +def zip_equal( + __iter1: Iterable[_T1], __iter2: Iterable[_T2] +) -> Iterator[tuple[_T1, _T2]]: ... +@overload +def zip_equal( + __iter1: Iterable[_T1], __iter2: Iterable[_T2], __iter3: Iterable[_T3] +) -> Iterator[tuple[_T1, _T2, _T3]]: ... +@overload +def zip_equal( + __iter1: Iterable[_T1], + __iter2: Iterable[_T2], + __iter3: Iterable[_T3], + __iter4: Iterable[_T4], +) -> Iterator[tuple[_T1, _T2, _T3, _T4]]: ... +@overload +def zip_equal( + __iter1: Iterable[_T1], + __iter2: Iterable[_T2], + __iter3: Iterable[_T3], + __iter4: Iterable[_T4], + __iter5: Iterable[_T5], +) -> Iterator[tuple[_T1, _T2, _T3, _T4, _T5]]: ... +@overload +def zip_equal( + __iter1: Iterable[Any], + __iter2: Iterable[Any], + __iter3: Iterable[Any], + __iter4: Iterable[Any], + __iter5: Iterable[Any], + __iter6: Iterable[Any], + *iterables: Iterable[Any], +) -> Iterator[tuple[Any, ...]]: ... + +# zip_offset +@overload +def zip_offset( + __iter1: Iterable[_T1], + *, + offsets: _SizedIterable[int], + longest: bool = ..., + fillvalue: None = None, +) -> Iterator[tuple[_T1 | None]]: ... +@overload +def zip_offset( + __iter1: Iterable[_T1], + __iter2: Iterable[_T2], + *, + offsets: _SizedIterable[int], + longest: bool = ..., + fillvalue: None = None, +) -> Iterator[tuple[_T1 | None, _T2 | None]]: ... +@overload +def zip_offset( + __iter1: Iterable[_T], + __iter2: Iterable[_T], + __iter3: Iterable[_T], + *iterables: Iterable[_T], + offsets: _SizedIterable[int], + longest: bool = ..., + fillvalue: None = None, +) -> Iterator[tuple[_T | None, ...]]: ... +@overload +def zip_offset( + __iter1: Iterable[_T1], + *, + offsets: _SizedIterable[int], + longest: bool = ..., + fillvalue: _U, +) -> Iterator[tuple[_T1 | _U]]: ... +@overload +def zip_offset( + __iter1: Iterable[_T1], + __iter2: Iterable[_T2], + *, + offsets: _SizedIterable[int], + longest: bool = ..., + fillvalue: _U, +) -> Iterator[tuple[_T1 | _U, _T2 | _U]]: ... +@overload +def zip_offset( + __iter1: Iterable[_T], + __iter2: Iterable[_T], + __iter3: Iterable[_T], + *iterables: Iterable[_T], + offsets: _SizedIterable[int], + longest: bool = ..., + fillvalue: _U, +) -> Iterator[tuple[_T | _U, ...]]: ... +def sort_together( + iterables: Iterable[Iterable[_T]], + key_list: Iterable[int] = ..., + key: Callable[..., Any] | None = ..., + reverse: bool = ..., + strict: bool = ..., +) -> list[tuple[_T, ...]]: ... +def unzip(iterable: Iterable[Sequence[_T]]) -> tuple[Iterator[_T], ...]: ... +def divide(n: int, iterable: Iterable[_T]) -> list[Iterator[_T]]: ... +def always_iterable( + obj: object, + base_type: _ClassInfo | None = ..., +) -> Iterator[Any]: ... +def adjacent( + predicate: Callable[[_T], bool], + iterable: Iterable[_T], + distance: int = ..., +) -> Iterator[tuple[bool, _T]]: ... +@overload +def groupby_transform( + iterable: Iterable[_T], + keyfunc: None = None, + valuefunc: None = None, + reducefunc: None = None, +) -> Iterator[tuple[_T, Iterator[_T]]]: ... +@overload +def groupby_transform( + iterable: Iterable[_T], + keyfunc: Callable[[_T], _U], + valuefunc: None, + reducefunc: None, +) -> Iterator[tuple[_U, Iterator[_T]]]: ... +@overload +def groupby_transform( + iterable: Iterable[_T], + keyfunc: None, + valuefunc: Callable[[_T], _V], + reducefunc: None, +) -> Iterator[tuple[_T, Iterator[_V]]]: ... +@overload +def groupby_transform( + iterable: Iterable[_T], + keyfunc: Callable[[_T], _U], + valuefunc: Callable[[_T], _V], + reducefunc: None, +) -> Iterator[tuple[_U, Iterator[_V]]]: ... +@overload +def groupby_transform( + iterable: Iterable[_T], + keyfunc: None, + valuefunc: None, + reducefunc: Callable[[Iterator[_T]], _W], +) -> Iterator[tuple[_T, _W]]: ... +@overload +def groupby_transform( + iterable: Iterable[_T], + keyfunc: Callable[[_T], _U], + valuefunc: None, + reducefunc: Callable[[Iterator[_T]], _W], +) -> Iterator[tuple[_U, _W]]: ... +@overload +def groupby_transform( + iterable: Iterable[_T], + keyfunc: None, + valuefunc: Callable[[_T], _V], + reducefunc: Callable[[Iterator[_V]], _W], +) -> Iterator[tuple[_T, _W]]: ... +@overload +def groupby_transform( + iterable: Iterable[_T], + keyfunc: Callable[[_T], _U], + valuefunc: Callable[[_T], _V], + reducefunc: Callable[[Iterator[_V]], _W], +) -> Iterator[tuple[_U, _W]]: ... + +class numeric_range(Generic[_T, _U], Sequence[_T], Hashable, Reversible[_T]): + @overload + def __init__(self, __stop: _T) -> None: ... + @overload + def __init__(self, __start: _T, __stop: _T) -> None: ... + @overload + def __init__(self, __start: _T, __stop: _T, __step: _U) -> None: ... + def __bool__(self) -> bool: ... + def __contains__(self, elem: object) -> bool: ... + def __eq__(self, other: object) -> bool: ... + @overload + def __getitem__(self, key: int) -> _T: ... + @overload + def __getitem__(self, key: slice) -> numeric_range[_T, _U]: ... + def __hash__(self) -> int: ... + def __iter__(self) -> Iterator[_T]: ... + def __len__(self) -> int: ... + def __reduce__( + self, + ) -> tuple[type[numeric_range[_T, _U]], tuple[_T, _T, _U]]: ... + def __repr__(self) -> str: ... + def __reversed__(self) -> Iterator[_T]: ... + def count(self, value: _T) -> int: ... + def index(self, value: _T) -> int: ... # type: ignore + +def count_cycle( + iterable: Iterable[_T], n: int | None = ... +) -> Iterable[tuple[int, _T]]: ... +def mark_ends( + iterable: Iterable[_T], +) -> Iterable[tuple[bool, bool, _T]]: ... +def locate( + iterable: Iterable[_T], + pred: Callable[..., Any] = ..., + window_size: int | None = ..., +) -> Iterator[int]: ... +def lstrip( + iterable: Iterable[_T], pred: Callable[[_T], object] +) -> Iterator[_T]: ... +def rstrip( + iterable: Iterable[_T], pred: Callable[[_T], object] +) -> Iterator[_T]: ... +def strip( + iterable: Iterable[_T], pred: Callable[[_T], object] +) -> Iterator[_T]: ... + +class islice_extended(Generic[_T], Iterator[_T]): + def __init__(self, iterable: Iterable[_T], *args: int | None) -> None: ... + def __iter__(self) -> islice_extended[_T]: ... + def __next__(self) -> _T: ... + def __getitem__(self, index: slice) -> islice_extended[_T]: ... + +def always_reversible(iterable: Iterable[_T]) -> Iterator[_T]: ... +def consecutive_groups( + iterable: Iterable[_T], ordering: None | Callable[[_T], int] = ... +) -> Iterator[Iterator[_T]]: ... +@overload +def difference( + iterable: Iterable[_T], + func: Callable[[_T, _T], _U] = ..., + *, + initial: None = ..., +) -> Iterator[_T | _U]: ... +@overload +def difference( + iterable: Iterable[_T], func: Callable[[_T, _T], _U] = ..., *, initial: _U +) -> Iterator[_U]: ... + +class SequenceView(Generic[_T], Sequence[_T]): + def __init__(self, target: Sequence[_T]) -> None: ... + @overload + def __getitem__(self, index: int) -> _T: ... + @overload + def __getitem__(self, index: slice) -> Sequence[_T]: ... + def __len__(self) -> int: ... + +class seekable(Generic[_T], Iterator[_T]): + def __init__( + self, iterable: Iterable[_T], maxlen: int | None = ... + ) -> None: ... + def __iter__(self) -> seekable[_T]: ... + def __next__(self) -> _T: ... + def __bool__(self) -> bool: ... + @overload + def peek(self) -> _T: ... + @overload + def peek(self, default: _U) -> _T | _U: ... + def elements(self) -> SequenceView[_T]: ... + def seek(self, index: int) -> None: ... + def relative_seek(self, count: int) -> None: ... + +class run_length: + @staticmethod + def encode(iterable: Iterable[_T]) -> Iterator[tuple[_T, int]]: ... + @staticmethod + def decode(iterable: Iterable[tuple[_T, int]]) -> Iterator[_T]: ... + +def exactly_n( + iterable: Iterable[_T], n: int, predicate: Callable[[_T], object] = ... +) -> bool: ... +def circular_shifts( + iterable: Iterable[_T], steps: int = 1 +) -> list[tuple[_T, ...]]: ... +def make_decorator( + wrapping_func: Callable[..., _U], result_index: int = ... +) -> Callable[..., Callable[[Callable[..., Any]], Callable[..., _U]]]: ... +@overload +def map_reduce( + iterable: Iterable[_T], + keyfunc: Callable[[_T], _U], + valuefunc: None = ..., + reducefunc: None = ..., +) -> dict[_U, list[_T]]: ... +@overload +def map_reduce( + iterable: Iterable[_T], + keyfunc: Callable[[_T], _U], + valuefunc: Callable[[_T], _V], + reducefunc: None = ..., +) -> dict[_U, list[_V]]: ... +@overload +def map_reduce( + iterable: Iterable[_T], + keyfunc: Callable[[_T], _U], + valuefunc: None = ..., + reducefunc: Callable[[list[_T]], _W] = ..., +) -> dict[_U, _W]: ... +@overload +def map_reduce( + iterable: Iterable[_T], + keyfunc: Callable[[_T], _U], + valuefunc: Callable[[_T], _V], + reducefunc: Callable[[list[_V]], _W], +) -> dict[_U, _W]: ... +def rlocate( + iterable: Iterable[_T], + pred: Callable[..., object] = ..., + window_size: int | None = ..., +) -> Iterator[int]: ... +def replace( + iterable: Iterable[_T], + pred: Callable[..., object], + substitutes: Iterable[_U], + count: int | None = ..., + window_size: int = ..., +) -> Iterator[_T | _U]: ... +def partitions(iterable: Iterable[_T]) -> Iterator[list[list[_T]]]: ... +def set_partitions( + iterable: Iterable[_T], + k: int | None = ..., + min_size: int | None = ..., + max_size: int | None = ..., +) -> Iterator[list[list[_T]]]: ... + +class time_limited(Generic[_T], Iterator[_T]): + def __init__( + self, limit_seconds: float, iterable: Iterable[_T] + ) -> None: ... + def __iter__(self) -> islice_extended[_T]: ... + def __next__(self) -> _T: ... + +@overload +def only( + iterable: Iterable[_T], *, too_long: _Raisable | None = ... +) -> _T | None: ... +@overload +def only( + iterable: Iterable[_T], default: _U, too_long: _Raisable | None = ... +) -> _T | _U: ... +def ichunked(iterable: Iterable[_T], n: int) -> Iterator[Iterator[_T]]: ... +def distinct_combinations( + iterable: Iterable[_T], r: int +) -> Iterator[tuple[_T, ...]]: ... +def filter_except( + validator: Callable[[Any], object], + iterable: Iterable[_T], + *exceptions: type[BaseException], +) -> Iterator[_T]: ... +def map_except( + function: Callable[[Any], _U], + iterable: Iterable[_T], + *exceptions: type[BaseException], +) -> Iterator[_U]: ... +def map_if( + iterable: Iterable[Any], + pred: Callable[[Any], bool], + func: Callable[[Any], Any], + func_else: Callable[[Any], Any] | None = ..., +) -> Iterator[Any]: ... +def _sample_unweighted( + iterator: Iterator[_T], k: int, strict: bool +) -> list[_T]: ... +def _sample_counted( + population: Iterator[_T], k: int, counts: Iterable[int], strict: bool +) -> list[_T]: ... +def _sample_weighted( + iterator: Iterator[_T], k: int, weights: Iterator[float], strict: bool +) -> list[_T]: ... +def sample( + iterable: Iterable[_T], + k: int, + weights: Iterable[float] | None = ..., + *, + counts: Iterable[int] | None = ..., + strict: bool = False, +) -> list[_T]: ... +def is_sorted( + iterable: Iterable[_T], + key: Callable[[_T], _U] | None = ..., + reverse: bool = False, + strict: bool = False, +) -> bool: ... + +class AbortThread(BaseException): + pass + +class callback_iter(Generic[_T], Iterator[_T]): + def __init__( + self, + func: Callable[..., Any], + callback_kwd: str = ..., + wait_seconds: float = ..., + ) -> None: ... + def __enter__(self) -> callback_iter[_T]: ... + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: types.TracebackType | None, + ) -> bool | None: ... + def __iter__(self) -> callback_iter[_T]: ... + def __next__(self) -> _T: ... + def _reader(self) -> Iterator[_T]: ... + @property + def done(self) -> bool: ... + @property + def result(self) -> Any: ... + +def windowed_complete( + iterable: Iterable[_T], n: int +) -> Iterator[tuple[tuple[_T, ...], tuple[_T, ...], tuple[_T, ...]]]: ... +def all_unique( + iterable: Iterable[_T], key: Callable[[_T], _U] | None = ... +) -> bool: ... +def nth_product(index: int, *args: Iterable[_T]) -> tuple[_T, ...]: ... +def nth_combination_with_replacement( + iterable: Iterable[_T], r: int, index: int +) -> tuple[_T, ...]: ... +def nth_permutation( + iterable: Iterable[_T], r: int, index: int +) -> tuple[_T, ...]: ... +def value_chain(*args: _T | Iterable[_T]) -> Iterable[_T]: ... +def product_index(element: Iterable[_T], *args: Iterable[_T]) -> int: ... +def combination_index( + element: Iterable[_T], iterable: Iterable[_T] +) -> int: ... +def combination_with_replacement_index( + element: Iterable[_T], iterable: Iterable[_T] +) -> int: ... +def permutation_index( + element: Iterable[_T], iterable: Iterable[_T] +) -> int: ... +def repeat_each(iterable: Iterable[_T], n: int = ...) -> Iterator[_T]: ... + +class countable(Generic[_T], Iterator[_T]): + def __init__(self, iterable: Iterable[_T]) -> None: ... + def __iter__(self) -> countable[_T]: ... + def __next__(self) -> _T: ... + items_seen: int + +def chunked_even(iterable: Iterable[_T], n: int) -> Iterator[list[_T]]: ... +@overload +def zip_broadcast( + __obj1: _T | Iterable[_T], + *, + scalar_types: _ClassInfo | None = ..., + strict: bool = ..., +) -> Iterable[tuple[_T, ...]]: ... +@overload +def zip_broadcast( + __obj1: _T | Iterable[_T], + __obj2: _T | Iterable[_T], + *, + scalar_types: _ClassInfo | None = ..., + strict: bool = ..., +) -> Iterable[tuple[_T, ...]]: ... +@overload +def zip_broadcast( + __obj1: _T | Iterable[_T], + __obj2: _T | Iterable[_T], + __obj3: _T | Iterable[_T], + *, + scalar_types: _ClassInfo | None = ..., + strict: bool = ..., +) -> Iterable[tuple[_T, ...]]: ... +@overload +def zip_broadcast( + __obj1: _T | Iterable[_T], + __obj2: _T | Iterable[_T], + __obj3: _T | Iterable[_T], + __obj4: _T | Iterable[_T], + *, + scalar_types: _ClassInfo | None = ..., + strict: bool = ..., +) -> Iterable[tuple[_T, ...]]: ... +@overload +def zip_broadcast( + __obj1: _T | Iterable[_T], + __obj2: _T | Iterable[_T], + __obj3: _T | Iterable[_T], + __obj4: _T | Iterable[_T], + __obj5: _T | Iterable[_T], + *, + scalar_types: _ClassInfo | None = ..., + strict: bool = ..., +) -> Iterable[tuple[_T, ...]]: ... +@overload +def zip_broadcast( + __obj1: _T | Iterable[_T], + __obj2: _T | Iterable[_T], + __obj3: _T | Iterable[_T], + __obj4: _T | Iterable[_T], + __obj5: _T | Iterable[_T], + __obj6: _T | Iterable[_T], + *objects: _T | Iterable[_T], + scalar_types: _ClassInfo | None = ..., + strict: bool = ..., +) -> Iterable[tuple[_T, ...]]: ... +def unique_in_window( + iterable: Iterable[_T], n: int, key: Callable[[_T], _U] | None = ... +) -> Iterator[_T]: ... +def duplicates_everseen( + iterable: Iterable[_T], key: Callable[[_T], _U] | None = ... +) -> Iterator[_T]: ... +def duplicates_justseen( + iterable: Iterable[_T], key: Callable[[_T], _U] | None = ... +) -> Iterator[_T]: ... +def classify_unique( + iterable: Iterable[_T], key: Callable[[_T], _U] | None = ... +) -> Iterator[tuple[_T, bool, bool]]: ... + +class _SupportsLessThan(Protocol): + def __lt__(self, __other: Any) -> bool: ... + +_SupportsLessThanT = TypeVar("_SupportsLessThanT", bound=_SupportsLessThan) + +@overload +def minmax( + iterable_or_value: Iterable[_SupportsLessThanT], *, key: None = None +) -> tuple[_SupportsLessThanT, _SupportsLessThanT]: ... +@overload +def minmax( + iterable_or_value: Iterable[_T], *, key: Callable[[_T], _SupportsLessThan] +) -> tuple[_T, _T]: ... +@overload +def minmax( + iterable_or_value: Iterable[_SupportsLessThanT], + *, + key: None = None, + default: _U, +) -> _U | tuple[_SupportsLessThanT, _SupportsLessThanT]: ... +@overload +def minmax( + iterable_or_value: Iterable[_T], + *, + key: Callable[[_T], _SupportsLessThan], + default: _U, +) -> _U | tuple[_T, _T]: ... +@overload +def minmax( + iterable_or_value: _SupportsLessThanT, + __other: _SupportsLessThanT, + *others: _SupportsLessThanT, +) -> tuple[_SupportsLessThanT, _SupportsLessThanT]: ... +@overload +def minmax( + iterable_or_value: _T, + __other: _T, + *others: _T, + key: Callable[[_T], _SupportsLessThan], +) -> tuple[_T, _T]: ... +def longest_common_prefix( + iterables: Iterable[Iterable[_T]], +) -> Iterator[_T]: ... +def iequals(*iterables: Iterable[Any]) -> bool: ... +def constrained_batches( + iterable: Iterable[_T], + max_size: int, + max_count: int | None = ..., + get_len: Callable[[_T], object] = ..., + strict: bool = ..., +) -> Iterator[tuple[_T]]: ... +def gray_product(*iterables: Iterable[_T]) -> Iterator[tuple[_T, ...]]: ... +def partial_product(*iterables: Iterable[_T]) -> Iterator[tuple[_T, ...]]: ... +def takewhile_inclusive( + predicate: Callable[[_T], bool], iterable: Iterable[_T] +) -> Iterator[_T]: ... +def outer_product( + func: Callable[[_T, _U], _V], + xs: Iterable[_T], + ys: Iterable[_U], + *args: Any, + **kwargs: Any, +) -> Iterator[tuple[_V, ...]]: ... +def iter_suppress( + iterable: Iterable[_T], + *exceptions: type[BaseException], +) -> Iterator[_T]: ... +def filter_map( + func: Callable[[_T], _V | None], + iterable: Iterable[_T], +) -> Iterator[_V]: ... +def powerset_of_sets(iterable: Iterable[_T]) -> Iterator[set[_T]]: ... +def join_mappings( + **field_to_map: Mapping[_T, _V], +) -> dict[_T, dict[str, _V]]: ... +def doublestarmap( + func: Callable[..., _T], + iterable: Iterable[Mapping[str, Any]], +) -> Iterator[_T]: ... +def dft(xarr: Sequence[complex]) -> Iterator[complex]: ... +def idft(Xarr: Sequence[complex]) -> Iterator[complex]: ... +def _nth_prime_ub(n: int) -> float: ... +def nth_prime(n: int, *, approximate: bool = ...) -> int: ... +def argmin( + iterable: Iterable[_T], *, key: Callable[[_T], _U] | None = ... +) -> int: ... +def argmax( + iterable: Iterable[_T], *, key: Callable[[_T], _U] | None = ... +) -> int: ... +def extract( + iterable: Iterable[_T], indices: Iterable[int] +) -> Iterator[_T]: ... diff --git a/lib/more_itertools/py.typed b/lib/more_itertools/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/lib/more_itertools/recipes.py b/lib/more_itertools/recipes.py new file mode 100644 index 0000000..dacf614 --- /dev/null +++ b/lib/more_itertools/recipes.py @@ -0,0 +1,1471 @@ +"""Imported from the recipes section of the itertools documentation. + +All functions taken from the recipes section of the itertools library docs +[1]_. +Some backward-compatible usability improvements have been made. + +.. [1] http://docs.python.org/library/itertools.html#recipes + +""" + +import random + +from bisect import bisect_left, insort +from collections import deque +from contextlib import suppress +from functools import lru_cache, partial, reduce +from heapq import heappush, heappushpop +from itertools import ( + accumulate, + chain, + combinations, + compress, + count, + cycle, + groupby, + islice, + product, + repeat, + starmap, + takewhile, + tee, + zip_longest, +) +from math import prod, comb, isqrt, gcd +from operator import mul, not_, itemgetter, getitem, index +from random import randrange, sample, choice +from sys import hexversion + +__all__ = [ + 'all_equal', + 'batched', + 'before_and_after', + 'consume', + 'convolve', + 'dotproduct', + 'first_true', + 'factor', + 'flatten', + 'grouper', + 'is_prime', + 'iter_except', + 'iter_index', + 'loops', + 'matmul', + 'multinomial', + 'ncycles', + 'nth', + 'nth_combination', + 'padnone', + 'pad_none', + 'pairwise', + 'partition', + 'polynomial_eval', + 'polynomial_from_roots', + 'polynomial_derivative', + 'powerset', + 'prepend', + 'quantify', + 'reshape', + 'random_combination_with_replacement', + 'random_combination', + 'random_permutation', + 'random_product', + 'repeatfunc', + 'roundrobin', + 'running_median', + 'sieve', + 'sliding_window', + 'subslices', + 'sum_of_squares', + 'tabulate', + 'tail', + 'take', + 'totient', + 'transpose', + 'triplewise', + 'unique', + 'unique_everseen', + 'unique_justseen', +] + +_marker = object() + + +# zip with strict is available for Python 3.10+ +try: + zip(strict=True) +except TypeError: # pragma: no cover + _zip_strict = zip +else: # pragma: no cover + _zip_strict = partial(zip, strict=True) + + +# math.sumprod is available for Python 3.12+ +try: + from math import sumprod as _sumprod +except ImportError: # pragma: no cover + _sumprod = lambda x, y: dotproduct(x, y) + + +# heapq max-heap functions are available for Python 3.14+ +try: + from heapq import heappush_max, heappushpop_max +except ImportError: # pragma: no cover + _max_heap_available = False +else: # pragma: no cover + _max_heap_available = True + + +def take(n, iterable): + """Return first *n* items of the *iterable* as a list. + + >>> take(3, range(10)) + [0, 1, 2] + + If there are fewer than *n* items in the iterable, all of them are + returned. + + >>> take(10, range(3)) + [0, 1, 2] + + """ + return list(islice(iterable, n)) + + +def tabulate(function, start=0): + """Return an iterator over the results of ``func(start)``, + ``func(start + 1)``, ``func(start + 2)``... + + *func* should be a function that accepts one integer argument. + + If *start* is not specified it defaults to 0. It will be incremented each + time the iterator is advanced. + + >>> square = lambda x: x ** 2 + >>> iterator = tabulate(square, -3) + >>> take(4, iterator) + [9, 4, 1, 0] + + """ + return map(function, count(start)) + + +def tail(n, iterable): + """Return an iterator over the last *n* items of *iterable*. + + >>> t = tail(3, 'ABCDEFG') + >>> list(t) + ['E', 'F', 'G'] + + """ + try: + size = len(iterable) + except TypeError: + return iter(deque(iterable, maxlen=n)) + else: + return islice(iterable, max(0, size - n), None) + + +def consume(iterator, n=None): + """Advance *iterable* by *n* steps. If *n* is ``None``, consume it + entirely. + + Efficiently exhausts an iterator without returning values. Defaults to + consuming the whole iterator, but an optional second argument may be + provided to limit consumption. + + >>> i = (x for x in range(10)) + >>> next(i) + 0 + >>> consume(i, 3) + >>> next(i) + 4 + >>> consume(i) + >>> next(i) + Traceback (most recent call last): + File "", line 1, in + StopIteration + + If the iterator has fewer items remaining than the provided limit, the + whole iterator will be consumed. + + >>> i = (x for x in range(3)) + >>> consume(i, 5) + >>> next(i) + Traceback (most recent call last): + File "", line 1, in + StopIteration + + """ + # Use functions that consume iterators at C speed. + if n is None: + # feed the entire iterator into a zero-length deque + deque(iterator, maxlen=0) + else: + # advance to the empty slice starting at position n + next(islice(iterator, n, n), None) + + +def nth(iterable, n, default=None): + """Returns the nth item or a default value. + + >>> l = range(10) + >>> nth(l, 3) + 3 + >>> nth(l, 20, "zebra") + 'zebra' + + """ + return next(islice(iterable, n, None), default) + + +def all_equal(iterable, key=None): + """ + Returns ``True`` if all the elements are equal to each other. + + >>> all_equal('aaaa') + True + >>> all_equal('aaab') + False + + A function that accepts a single argument and returns a transformed version + of each input item can be specified with *key*: + + >>> all_equal('AaaA', key=str.casefold) + True + >>> all_equal([1, 2, 3], key=lambda x: x < 10) + True + + """ + iterator = groupby(iterable, key) + for first in iterator: + for second in iterator: + return False + return True + return True + + +def quantify(iterable, pred=bool): + """Return the how many times the predicate is true. + + >>> quantify([True, False, True]) + 2 + + """ + return sum(map(pred, iterable)) + + +def pad_none(iterable): + """Returns the sequence of elements and then returns ``None`` indefinitely. + + >>> take(5, pad_none(range(3))) + [0, 1, 2, None, None] + + Useful for emulating the behavior of the built-in :func:`map` function. + + See also :func:`padded`. + + """ + return chain(iterable, repeat(None)) + + +padnone = pad_none + + +def ncycles(iterable, n): + """Returns the sequence elements *n* times + + >>> list(ncycles(["a", "b"], 3)) + ['a', 'b', 'a', 'b', 'a', 'b'] + + """ + return chain.from_iterable(repeat(tuple(iterable), n)) + + +def dotproduct(vec1, vec2): + """Returns the dot product of the two iterables. + + >>> dotproduct([10, 15, 12], [0.65, 0.80, 1.25]) + 33.5 + >>> 10 * 0.65 + 15 * 0.80 + 12 * 1.25 + 33.5 + + In Python 3.12 and later, use ``math.sumprod()`` instead. + """ + return sum(map(mul, vec1, vec2)) + + +def flatten(listOfLists): + """Return an iterator flattening one level of nesting in a list of lists. + + >>> list(flatten([[0, 1], [2, 3]])) + [0, 1, 2, 3] + + See also :func:`collapse`, which can flatten multiple levels of nesting. + + """ + return chain.from_iterable(listOfLists) + + +def repeatfunc(func, times=None, *args): + """Call *func* with *args* repeatedly, returning an iterable over the + results. + + If *times* is specified, the iterable will terminate after that many + repetitions: + + >>> from operator import add + >>> times = 4 + >>> args = 3, 5 + >>> list(repeatfunc(add, times, *args)) + [8, 8, 8, 8] + + If *times* is ``None`` the iterable will not terminate: + + >>> from random import randrange + >>> times = None + >>> args = 1, 11 + >>> take(6, repeatfunc(randrange, times, *args)) # doctest:+SKIP + [2, 4, 8, 1, 8, 4] + + """ + if times is None: + return starmap(func, repeat(args)) + return starmap(func, repeat(args, times)) + + +def _pairwise(iterable): + """Returns an iterator of paired items, overlapping, from the original + + >>> take(4, pairwise(count())) + [(0, 1), (1, 2), (2, 3), (3, 4)] + + On Python 3.10 and above, this is an alias for :func:`itertools.pairwise`. + + """ + a, b = tee(iterable) + next(b, None) + return zip(a, b) + + +try: + from itertools import pairwise as itertools_pairwise +except ImportError: # pragma: no cover + pairwise = _pairwise +else: # pragma: no cover + + def pairwise(iterable): + return itertools_pairwise(iterable) + + pairwise.__doc__ = _pairwise.__doc__ + + +class UnequalIterablesError(ValueError): + def __init__(self, details=None): + msg = 'Iterables have different lengths' + if details is not None: + msg += (': index 0 has length {}; index {} has length {}').format( + *details + ) + + super().__init__(msg) + + +def _zip_equal_generator(iterables): + for combo in zip_longest(*iterables, fillvalue=_marker): + for val in combo: + if val is _marker: + raise UnequalIterablesError() + yield combo + + +def _zip_equal(*iterables): + # Check whether the iterables are all the same size. + try: + first_size = len(iterables[0]) + for i, it in enumerate(iterables[1:], 1): + size = len(it) + if size != first_size: + raise UnequalIterablesError(details=(first_size, i, size)) + # All sizes are equal, we can use the built-in zip. + return zip(*iterables) + # If any one of the iterables didn't have a length, start reading + # them until one runs out. + except TypeError: + return _zip_equal_generator(iterables) + + +def grouper(iterable, n, incomplete='fill', fillvalue=None): + """Group elements from *iterable* into fixed-length groups of length *n*. + + >>> list(grouper('ABCDEF', 3)) + [('A', 'B', 'C'), ('D', 'E', 'F')] + + The keyword arguments *incomplete* and *fillvalue* control what happens for + iterables whose length is not a multiple of *n*. + + When *incomplete* is `'fill'`, the last group will contain instances of + *fillvalue*. + + >>> list(grouper('ABCDEFG', 3, incomplete='fill', fillvalue='x')) + [('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')] + + When *incomplete* is `'ignore'`, the last group will not be emitted. + + >>> list(grouper('ABCDEFG', 3, incomplete='ignore', fillvalue='x')) + [('A', 'B', 'C'), ('D', 'E', 'F')] + + When *incomplete* is `'strict'`, a subclass of `ValueError` will be raised. + + >>> iterator = grouper('ABCDEFG', 3, incomplete='strict') + >>> list(iterator) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + UnequalIterablesError + + """ + iterators = [iter(iterable)] * n + if incomplete == 'fill': + return zip_longest(*iterators, fillvalue=fillvalue) + if incomplete == 'strict': + return _zip_equal(*iterators) + if incomplete == 'ignore': + return zip(*iterators) + else: + raise ValueError('Expected fill, strict, or ignore') + + +def roundrobin(*iterables): + """Visit input iterables in a cycle until each is exhausted. + + >>> list(roundrobin('ABC', 'D', 'EF')) + ['A', 'D', 'E', 'B', 'F', 'C'] + + This function produces the same output as :func:`interleave_longest`, but + may perform better for some inputs (in particular when the number of + iterables is small). + + """ + # Algorithm credited to George Sakkis + iterators = map(iter, iterables) + for num_active in range(len(iterables), 0, -1): + iterators = cycle(islice(iterators, num_active)) + yield from map(next, iterators) + + +def partition(pred, iterable): + """ + Returns a 2-tuple of iterables derived from the input iterable. + The first yields the items that have ``pred(item) == False``. + The second yields the items that have ``pred(item) == True``. + + >>> is_odd = lambda x: x % 2 != 0 + >>> iterable = range(10) + >>> even_items, odd_items = partition(is_odd, iterable) + >>> list(even_items), list(odd_items) + ([0, 2, 4, 6, 8], [1, 3, 5, 7, 9]) + + If *pred* is None, :func:`bool` is used. + + >>> iterable = [0, 1, False, True, '', ' '] + >>> false_items, true_items = partition(None, iterable) + >>> list(false_items), list(true_items) + ([0, False, ''], [1, True, ' ']) + + """ + if pred is None: + pred = bool + + t1, t2, p = tee(iterable, 3) + p1, p2 = tee(map(pred, p)) + return (compress(t1, map(not_, p1)), compress(t2, p2)) + + +def powerset(iterable): + """Yields all possible subsets of the iterable. + + >>> list(powerset([1, 2, 3])) + [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)] + + :func:`powerset` will operate on iterables that aren't :class:`set` + instances, so repeated elements in the input will produce repeated elements + in the output. + + >>> seq = [1, 1, 0] + >>> list(powerset(seq)) + [(), (1,), (1,), (0,), (1, 1), (1, 0), (1, 0), (1, 1, 0)] + + For a variant that efficiently yields actual :class:`set` instances, see + :func:`powerset_of_sets`. + """ + s = list(iterable) + return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)) + + +def unique_everseen(iterable, key=None): + """ + Yield unique elements, preserving order. + + >>> list(unique_everseen('AAAABBBCCDAABBB')) + ['A', 'B', 'C', 'D'] + >>> list(unique_everseen('ABBCcAD', str.lower)) + ['A', 'B', 'C', 'D'] + + Sequences with a mix of hashable and unhashable items can be used. + The function will be slower (i.e., `O(n^2)`) for unhashable items. + + Remember that ``list`` objects are unhashable - you can use the *key* + parameter to transform the list to a tuple (which is hashable) to + avoid a slowdown. + + >>> iterable = ([1, 2], [2, 3], [1, 2]) + >>> list(unique_everseen(iterable)) # Slow + [[1, 2], [2, 3]] + >>> list(unique_everseen(iterable, key=tuple)) # Faster + [[1, 2], [2, 3]] + + Similarly, you may want to convert unhashable ``set`` objects with + ``key=frozenset``. For ``dict`` objects, + ``key=lambda x: frozenset(x.items())`` can be used. + + """ + seenset = set() + seenset_add = seenset.add + seenlist = [] + seenlist_add = seenlist.append + use_key = key is not None + + for element in iterable: + k = key(element) if use_key else element + try: + if k not in seenset: + seenset_add(k) + yield element + except TypeError: + if k not in seenlist: + seenlist_add(k) + yield element + + +def unique_justseen(iterable, key=None): + """Yields elements in order, ignoring serial duplicates + + >>> list(unique_justseen('AAAABBBCCDAABBB')) + ['A', 'B', 'C', 'D', 'A', 'B'] + >>> list(unique_justseen('ABBCcAD', str.lower)) + ['A', 'B', 'C', 'A', 'D'] + + """ + if key is None: + return map(itemgetter(0), groupby(iterable)) + + return map(next, map(itemgetter(1), groupby(iterable, key))) + + +def unique(iterable, key=None, reverse=False): + """Yields unique elements in sorted order. + + >>> list(unique([[1, 2], [3, 4], [1, 2]])) + [[1, 2], [3, 4]] + + *key* and *reverse* are passed to :func:`sorted`. + + >>> list(unique('ABBcCAD', str.casefold)) + ['A', 'B', 'c', 'D'] + >>> list(unique('ABBcCAD', str.casefold, reverse=True)) + ['D', 'c', 'B', 'A'] + + The elements in *iterable* need not be hashable, but they must be + comparable for sorting to work. + """ + sequenced = sorted(iterable, key=key, reverse=reverse) + return unique_justseen(sequenced, key=key) + + +def iter_except(func, exception, first=None): + """Yields results from a function repeatedly until an exception is raised. + + Converts a call-until-exception interface to an iterator interface. + Like ``iter(func, sentinel)``, but uses an exception instead of a sentinel + to end the loop. + + >>> l = [0, 1, 2] + >>> list(iter_except(l.pop, IndexError)) + [2, 1, 0] + + Multiple exceptions can be specified as a stopping condition: + + >>> l = [1, 2, 3, '...', 4, 5, 6] + >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError))) + [7, 6, 5] + >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError))) + [4, 3, 2] + >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError))) + [] + + """ + with suppress(exception): + if first is not None: + yield first() + while True: + yield func() + + +def first_true(iterable, default=None, pred=None): + """ + Returns the first true value in the iterable. + + If no true value is found, returns *default* + + If *pred* is not None, returns the first item for which + ``pred(item) == True`` . + + >>> first_true(range(10)) + 1 + >>> first_true(range(10), pred=lambda x: x > 5) + 6 + >>> first_true(range(10), default='missing', pred=lambda x: x > 9) + 'missing' + + """ + return next(filter(pred, iterable), default) + + +def random_product(*args, repeat=1): + """Draw an item at random from each of the input iterables. + + >>> random_product('abc', range(4), 'XYZ') # doctest:+SKIP + ('c', 3, 'Z') + + If *repeat* is provided as a keyword argument, that many items will be + drawn from each iterable. + + >>> random_product('abcd', range(4), repeat=2) # doctest:+SKIP + ('a', 2, 'd', 3) + + This equivalent to taking a random selection from + ``itertools.product(*args, repeat=repeat)``. + + """ + pools = [tuple(pool) for pool in args] * repeat + return tuple(choice(pool) for pool in pools) + + +def random_permutation(iterable, r=None): + """Return a random *r* length permutation of the elements in *iterable*. + + If *r* is not specified or is ``None``, then *r* defaults to the length of + *iterable*. + + >>> random_permutation(range(5)) # doctest:+SKIP + (3, 4, 0, 1, 2) + + This equivalent to taking a random selection from + ``itertools.permutations(iterable, r)``. + + """ + pool = tuple(iterable) + r = len(pool) if r is None else r + return tuple(sample(pool, r)) + + +def random_combination(iterable, r): + """Return a random *r* length subsequence of the elements in *iterable*. + + >>> random_combination(range(5), 3) # doctest:+SKIP + (2, 3, 4) + + This equivalent to taking a random selection from + ``itertools.combinations(iterable, r)``. + + """ + pool = tuple(iterable) + n = len(pool) + indices = sorted(sample(range(n), r)) + return tuple(pool[i] for i in indices) + + +def random_combination_with_replacement(iterable, r): + """Return a random *r* length subsequence of elements in *iterable*, + allowing individual elements to be repeated. + + >>> random_combination_with_replacement(range(3), 5) # doctest:+SKIP + (0, 0, 1, 2, 2) + + This equivalent to taking a random selection from + ``itertools.combinations_with_replacement(iterable, r)``. + + """ + pool = tuple(iterable) + n = len(pool) + indices = sorted(randrange(n) for i in range(r)) + return tuple(pool[i] for i in indices) + + +def nth_combination(iterable, r, index): + """Equivalent to ``list(combinations(iterable, r))[index]``. + + The subsequences of *iterable* that are of length *r* can be ordered + lexicographically. :func:`nth_combination` computes the subsequence at + sort position *index* directly, without computing the previous + subsequences. + + >>> nth_combination(range(5), 3, 5) + (0, 3, 4) + + ``ValueError`` will be raised If *r* is negative or greater than the length + of *iterable*. + ``IndexError`` will be raised if the given *index* is invalid. + """ + pool = tuple(iterable) + n = len(pool) + if (r < 0) or (r > n): + raise ValueError + + c = 1 + k = min(r, n - r) + for i in range(1, k + 1): + c = c * (n - k + i) // i + + if index < 0: + index += c + + if (index < 0) or (index >= c): + raise IndexError + + result = [] + while r: + c, n, r = c * r // n, n - 1, r - 1 + while index >= c: + index -= c + c, n = c * (n - r) // n, n - 1 + result.append(pool[-1 - n]) + + return tuple(result) + + +def prepend(value, iterator): + """Yield *value*, followed by the elements in *iterator*. + + >>> value = '0' + >>> iterator = ['1', '2', '3'] + >>> list(prepend(value, iterator)) + ['0', '1', '2', '3'] + + To prepend multiple values, see :func:`itertools.chain` + or :func:`value_chain`. + + """ + return chain([value], iterator) + + +def convolve(signal, kernel): + """Discrete linear convolution of two iterables. + Equivalent to polynomial multiplication. + + For example, multiplying ``(x² -x - 20)`` by ``(x - 3)`` + gives ``(x³ -4x² -17x + 60)``. + + >>> list(convolve([1, -1, -20], [1, -3])) + [1, -4, -17, 60] + + Examples of popular kinds of kernels: + + * The kernel ``[0.25, 0.25, 0.25, 0.25]`` computes a moving average. + For image data, this blurs the image and reduces noise. + * The kernel ``[1/2, 0, -1/2]`` estimates the first derivative of + a function evaluated at evenly spaced inputs. + * The kernel ``[1, -2, 1]`` estimates the second derivative of a + function evaluated at evenly spaced inputs. + + Convolutions are mathematically commutative; however, the inputs are + evaluated differently. The signal is consumed lazily and can be + infinite. The kernel is fully consumed before the calculations begin. + + Supports all numeric types: int, float, complex, Decimal, Fraction. + + References: + + * Article: https://betterexplained.com/articles/intuitive-convolution/ + * Video by 3Blue1Brown: https://www.youtube.com/watch?v=KuXjwB4LzSA + + """ + # This implementation comes from an older version of the itertools + # documentation. While the newer implementation is a bit clearer, + # this one was kept because the inlined window logic is faster + # and it avoids an unnecessary deque-to-tuple conversion. + kernel = tuple(kernel)[::-1] + n = len(kernel) + window = deque([0], maxlen=n) * n + for x in chain(signal, repeat(0, n - 1)): + window.append(x) + yield _sumprod(kernel, window) + + +def before_and_after(predicate, it): + """A variant of :func:`takewhile` that allows complete access to the + remainder of the iterator. + + >>> it = iter('ABCdEfGhI') + >>> all_upper, remainder = before_and_after(str.isupper, it) + >>> ''.join(all_upper) + 'ABC' + >>> ''.join(remainder) # takewhile() would lose the 'd' + 'dEfGhI' + + Note that the first iterator must be fully consumed before the second + iterator can generate valid results. + """ + trues, after = tee(it) + trues = compress(takewhile(predicate, trues), zip(after)) + return trues, after + + +def triplewise(iterable): + """Return overlapping triplets from *iterable*. + + >>> list(triplewise('ABCDE')) + [('A', 'B', 'C'), ('B', 'C', 'D'), ('C', 'D', 'E')] + + """ + # This deviates from the itertools documentation recipe - see + # https://github.com/more-itertools/more-itertools/issues/889 + t1, t2, t3 = tee(iterable, 3) + next(t3, None) + next(t3, None) + next(t2, None) + return zip(t1, t2, t3) + + +def _sliding_window_islice(iterable, n): + # Fast path for small, non-zero values of n. + iterators = tee(iterable, n) + for i, iterator in enumerate(iterators): + next(islice(iterator, i, i), None) + return zip(*iterators) + + +def _sliding_window_deque(iterable, n): + # Normal path for other values of n. + iterator = iter(iterable) + window = deque(islice(iterator, n - 1), maxlen=n) + for x in iterator: + window.append(x) + yield tuple(window) + + +def sliding_window(iterable, n): + """Return a sliding window of width *n* over *iterable*. + + >>> list(sliding_window(range(6), 4)) + [(0, 1, 2, 3), (1, 2, 3, 4), (2, 3, 4, 5)] + + If *iterable* has fewer than *n* items, then nothing is yielded: + + >>> list(sliding_window(range(3), 4)) + [] + + For a variant with more features, see :func:`windowed`. + """ + if n > 20: + return _sliding_window_deque(iterable, n) + elif n > 2: + return _sliding_window_islice(iterable, n) + elif n == 2: + return pairwise(iterable) + elif n == 1: + return zip(iterable) + else: + raise ValueError(f'n should be at least one, not {n}') + + +def subslices(iterable): + """Return all contiguous non-empty subslices of *iterable*. + + >>> list(subslices('ABC')) + [['A'], ['A', 'B'], ['A', 'B', 'C'], ['B'], ['B', 'C'], ['C']] + + This is similar to :func:`substrings`, but emits items in a different + order. + """ + seq = list(iterable) + slices = starmap(slice, combinations(range(len(seq) + 1), 2)) + return map(getitem, repeat(seq), slices) + + +def polynomial_from_roots(roots): + """Compute a polynomial's coefficients from its roots. + + >>> roots = [5, -4, 3] # (x - 5) * (x + 4) * (x - 3) + >>> polynomial_from_roots(roots) # x³ - 4 x² - 17 x + 60 + [1, -4, -17, 60] + + Note that polynomial coefficients are specified in descending power order. + + Supports all numeric types: int, float, complex, Decimal, Fraction. + """ + + # This recipe differs from the one in itertools docs in that it + # applies list() after each call to convolve(). This avoids + # hitting stack limits with nested generators. + + poly = [1] + for root in roots: + poly = list(convolve(poly, (1, -root))) + return poly + + +def iter_index(iterable, value, start=0, stop=None): + """Yield the index of each place in *iterable* that *value* occurs, + beginning with index *start* and ending before index *stop*. + + + >>> list(iter_index('AABCADEAF', 'A')) + [0, 1, 4, 7] + >>> list(iter_index('AABCADEAF', 'A', 1)) # start index is inclusive + [1, 4, 7] + >>> list(iter_index('AABCADEAF', 'A', 1, 7)) # stop index is not inclusive + [1, 4] + + The behavior for non-scalar *values* matches the built-in Python types. + + >>> list(iter_index('ABCDABCD', 'AB')) + [0, 4] + >>> list(iter_index([0, 1, 2, 3, 0, 1, 2, 3], [0, 1])) + [] + >>> list(iter_index([[0, 1], [2, 3], [0, 1], [2, 3]], [0, 1])) + [0, 2] + + See :func:`locate` for a more general means of finding the indexes + associated with particular values. + + """ + seq_index = getattr(iterable, 'index', None) + if seq_index is None: + # Slow path for general iterables + iterator = islice(iterable, start, stop) + for i, element in enumerate(iterator, start): + if element is value or element == value: + yield i + else: + # Fast path for sequences + stop = len(iterable) if stop is None else stop + i = start - 1 + with suppress(ValueError): + while True: + yield (i := seq_index(value, i + 1, stop)) + + +def sieve(n): + """Yield the primes less than n. + + >>> list(sieve(30)) + [2, 3, 5, 7, 11, 13, 17, 19, 23, 29] + + """ + # This implementation comes from an older version of the itertools + # documentation. The newer implementation is easier to read but is + # less lazy. + if n > 2: + yield 2 + start = 3 + data = bytearray((0, 1)) * (n // 2) + for p in iter_index(data, 1, start, stop=isqrt(n) + 1): + yield from iter_index(data, 1, start, p * p) + data[p * p : n : p + p] = bytes(len(range(p * p, n, p + p))) + start = p * p + yield from iter_index(data, 1, start) + + +def _batched(iterable, n, *, strict=False): # pragma: no cover + """Batch data into tuples of length *n*. If the number of items in + *iterable* is not divisible by *n*: + * The last batch will be shorter if *strict* is ``False``. + * :exc:`ValueError` will be raised if *strict* is ``True``. + + >>> list(batched('ABCDEFG', 3)) + [('A', 'B', 'C'), ('D', 'E', 'F'), ('G',)] + + On Python 3.13 and above, this is an alias for :func:`itertools.batched`. + """ + if n < 1: + raise ValueError('n must be at least one') + iterator = iter(iterable) + while batch := tuple(islice(iterator, n)): + if strict and len(batch) != n: + raise ValueError('batched(): incomplete batch') + yield batch + + +if hexversion >= 0x30D00A2: # pragma: no cover + from itertools import batched as itertools_batched + + def batched(iterable, n, *, strict=False): + return itertools_batched(iterable, n, strict=strict) + + batched.__doc__ = _batched.__doc__ +else: # pragma: no cover + batched = _batched + + +def transpose(it): + """Swap the rows and columns of the input matrix. + + >>> list(transpose([(1, 2, 3), (11, 22, 33)])) + [(1, 11), (2, 22), (3, 33)] + + The caller should ensure that the dimensions of the input are compatible. + If the input is empty, no output will be produced. + """ + return _zip_strict(*it) + + +def _is_scalar(value, stringlike=(str, bytes)): + "Scalars are bytes, strings, and non-iterables." + try: + iter(value) + except TypeError: + return True + return isinstance(value, stringlike) + + +def _flatten_tensor(tensor): + "Depth-first iterator over scalars in a tensor." + iterator = iter(tensor) + while True: + try: + value = next(iterator) + except StopIteration: + return iterator + iterator = chain((value,), iterator) + if _is_scalar(value): + return iterator + iterator = chain.from_iterable(iterator) + + +def reshape(matrix, shape): + """Change the shape of a *matrix*. + + If *shape* is an integer, the matrix must be two dimensional + and the shape is interpreted as the desired number of columns: + + >>> matrix = [(0, 1), (2, 3), (4, 5)] + >>> cols = 3 + >>> list(reshape(matrix, cols)) + [(0, 1, 2), (3, 4, 5)] + + If *shape* is a tuple (or other iterable), the input matrix can have + any number of dimensions. It will first be flattened and then rebuilt + to the desired shape which can also be multidimensional: + + >>> matrix = [(0, 1), (2, 3), (4, 5)] # Start with a 3 x 2 matrix + + >>> list(reshape(matrix, (2, 3))) # Make a 2 x 3 matrix + [(0, 1, 2), (3, 4, 5)] + + >>> list(reshape(matrix, (6,))) # Make a vector of length six + [0, 1, 2, 3, 4, 5] + + >>> list(reshape(matrix, (2, 1, 3, 1))) # Make 2 x 1 x 3 x 1 tensor + [(((0,), (1,), (2,)),), (((3,), (4,), (5,)),)] + + Each dimension is assumed to be uniform, either all arrays or all scalars. + Flattening stops when the first value in a dimension is a scalar. + Scalars are bytes, strings, and non-iterables. + The reshape iterator stops when the requested shape is complete + or when the input is exhausted, whichever comes first. + + """ + if isinstance(shape, int): + return batched(chain.from_iterable(matrix), shape) + first_dim, *dims = shape + scalar_stream = _flatten_tensor(matrix) + reshaped = reduce(batched, reversed(dims), scalar_stream) + return islice(reshaped, first_dim) + + +def matmul(m1, m2): + """Multiply two matrices. + + >>> list(matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)])) + [(49, 80), (41, 60)] + + The caller should ensure that the dimensions of the input matrices are + compatible with each other. + + Supports all numeric types: int, float, complex, Decimal, Fraction. + """ + n = len(m2[0]) + return batched(starmap(_sumprod, product(m1, transpose(m2))), n) + + +def _factor_pollard(n): + # Return a factor of n using Pollard's rho algorithm. + # Efficient when n is odd and composite. + for b in range(1, n): + x = y = 2 + d = 1 + while d == 1: + x = (x * x + b) % n + y = (y * y + b) % n + y = (y * y + b) % n + d = gcd(x - y, n) + if d != n: + return d + raise ValueError('prime or under 5') # pragma: no cover + + +_primes_below_211 = tuple(sieve(211)) + + +def factor(n): + """Yield the prime factors of n. + + >>> list(factor(360)) + [2, 2, 2, 3, 3, 5] + + Finds small factors with trial division. Larger factors are + either verified as prime with ``is_prime`` or split into + smaller factors with Pollard's rho algorithm. + """ + + # Corner case reduction + if n < 2: + return + + # Trial division reduction + for prime in _primes_below_211: + while not n % prime: + yield prime + n //= prime + + # Pollard's rho reduction + primes = [] + todo = [n] if n > 1 else [] + for n in todo: + if n < 211**2 or is_prime(n): + primes.append(n) + else: + fact = _factor_pollard(n) + todo += (fact, n // fact) + yield from sorted(primes) + + +def polynomial_eval(coefficients, x): + """Evaluate a polynomial at a specific value. + + Computes with better numeric stability than Horner's method. + + Evaluate ``x^3 - 4 * x^2 - 17 * x + 60`` at ``x = 2.5``: + + >>> coefficients = [1, -4, -17, 60] + >>> x = 2.5 + >>> polynomial_eval(coefficients, x) + 8.125 + + Note that polynomial coefficients are specified in descending power order. + + Supports all numeric types: int, float, complex, Decimal, Fraction. + """ + n = len(coefficients) + if n == 0: + return type(x)(0) + powers = map(pow, repeat(x), reversed(range(n))) + return _sumprod(coefficients, powers) + + +def sum_of_squares(it): + """Return the sum of the squares of the input values. + + >>> sum_of_squares([10, 20, 30]) + 1400 + + Supports all numeric types: int, float, complex, Decimal, Fraction. + """ + return _sumprod(*tee(it)) + + +def polynomial_derivative(coefficients): + """Compute the first derivative of a polynomial. + + Evaluate the derivative of ``x³ - 4 x² - 17 x + 60``: + + >>> coefficients = [1, -4, -17, 60] + >>> derivative_coefficients = polynomial_derivative(coefficients) + >>> derivative_coefficients + [3, -8, -17] + + Note that polynomial coefficients are specified in descending power order. + + Supports all numeric types: int, float, complex, Decimal, Fraction. + """ + n = len(coefficients) + powers = reversed(range(1, n)) + return list(map(mul, coefficients, powers)) + + +def totient(n): + """Return the count of natural numbers up to *n* that are coprime with *n*. + + Euler's totient function φ(n) gives the number of totatives. + Totative are integers k in the range 1 ≤ k ≤ n such that gcd(n, k) = 1. + + >>> n = 9 + >>> totient(n) + 6 + + >>> totatives = [x for x in range(1, n) if gcd(n, x) == 1] + >>> totatives + [1, 2, 4, 5, 7, 8] + >>> len(totatives) + 6 + + Reference: https://en.wikipedia.org/wiki/Euler%27s_totient_function + + """ + for prime in set(factor(n)): + n -= n // prime + return n + + +# Miller–Rabin primality test: https://oeis.org/A014233 +_perfect_tests = [ + (2047, (2,)), + (9080191, (31, 73)), + (4759123141, (2, 7, 61)), + (1122004669633, (2, 13, 23, 1662803)), + (2152302898747, (2, 3, 5, 7, 11)), + (3474749660383, (2, 3, 5, 7, 11, 13)), + (18446744073709551616, (2, 325, 9375, 28178, 450775, 9780504, 1795265022)), + ( + 3317044064679887385961981, + (2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41), + ), +] + + +@lru_cache +def _shift_to_odd(n): + 'Return s, d such that 2**s * d == n' + s = ((n - 1) ^ n).bit_length() - 1 + d = n >> s + assert (1 << s) * d == n and d & 1 and s >= 0 + return s, d + + +def _strong_probable_prime(n, base): + assert (n > 2) and (n & 1) and (2 <= base < n) + + s, d = _shift_to_odd(n - 1) + + x = pow(base, d, n) + if x == 1 or x == n - 1: + return True + + for _ in range(s - 1): + x = x * x % n + if x == n - 1: + return True + + return False + + +# Separate instance of Random() that doesn't share state +# with the default user instance of Random(). +_private_randrange = random.Random().randrange + + +def is_prime(n): + """Return ``True`` if *n* is prime and ``False`` otherwise. + + Basic examples: + + >>> is_prime(37) + True + >>> is_prime(3 * 13) + False + >>> is_prime(18_446_744_073_709_551_557) + True + + Find the next prime over one billion: + + >>> next(filter(is_prime, count(10**9))) + 1000000007 + + Generate random primes up to 200 bits and up to 60 decimal digits: + + >>> from random import seed, randrange, getrandbits + >>> seed(18675309) + + >>> next(filter(is_prime, map(getrandbits, repeat(200)))) + 893303929355758292373272075469392561129886005037663238028407 + + >>> next(filter(is_prime, map(randrange, repeat(10**60)))) + 269638077304026462407872868003560484232362454342414618963649 + + This function is exact for values of *n* below 10**24. For larger inputs, + the probabilistic Miller-Rabin primality test has a less than 1 in 2**128 + chance of a false positive. + """ + + if n < 17: + return n in {2, 3, 5, 7, 11, 13} + + if not (n & 1 and n % 3 and n % 5 and n % 7 and n % 11 and n % 13): + return False + + for limit, bases in _perfect_tests: + if n < limit: + break + else: + bases = (_private_randrange(2, n - 1) for i in range(64)) + + return all(_strong_probable_prime(n, base) for base in bases) + + +def loops(n): + """Returns an iterable with *n* elements for efficient looping. + Like ``range(n)`` but doesn't create integers. + + >>> i = 0 + >>> for _ in loops(5): + ... i += 1 + >>> i + 5 + + """ + return repeat(None, n) + + +def multinomial(*counts): + """Number of distinct arrangements of a multiset. + + The expression ``multinomial(3, 4, 2)`` has several equivalent + interpretations: + + * In the expansion of ``(a + b + c)⁹``, the coefficient of the + ``a³b⁴c²`` term is 1260. + + * There are 1260 distinct ways to arrange 9 balls consisting of 3 reds, 4 + greens, and 2 blues. + + * There are 1260 unique ways to place 9 distinct objects into three bins + with sizes 3, 4, and 2. + + The :func:`multinomial` function computes the length of + :func:`distinct_permutations`. For example, there are 83,160 distinct + anagrams of the word "abracadabra": + + >>> from more_itertools import distinct_permutations, ilen + >>> ilen(distinct_permutations('abracadabra')) + 83160 + + This can be computed directly from the letter counts, 5a 2b 2r 1c 1d: + + >>> from collections import Counter + >>> list(Counter('abracadabra').values()) + [5, 2, 2, 1, 1] + >>> multinomial(5, 2, 2, 1, 1) + 83160 + + A binomial coefficient is a special case of multinomial where there are + only two categories. For example, the number of ways to arrange 12 balls + with 5 reds and 7 blues is ``multinomial(5, 7)`` or ``math.comb(12, 5)``. + + Likewise, factorial is a special case of multinomial where + the multiplicities are all just 1 so that + ``multinomial(1, 1, 1, 1, 1, 1, 1) == math.factorial(7)``. + + Reference: https://en.wikipedia.org/wiki/Multinomial_theorem + + """ + return prod(map(comb, accumulate(counts), counts)) + + +def _running_median_minheap_and_maxheap(iterator): # pragma: no cover + "Non-windowed running_median() for Python 3.14+" + + read = iterator.__next__ + lo = [] # max-heap + hi = [] # min-heap (same size as or one smaller than lo) + + with suppress(StopIteration): + while True: + heappush_max(lo, heappushpop(hi, read())) + yield lo[0] + + heappush(hi, heappushpop_max(lo, read())) + yield (lo[0] + hi[0]) / 2 + + +def _running_median_minheap_only(iterator): # pragma: no cover + "Backport of non-windowed running_median() for Python 3.13 and prior." + + read = iterator.__next__ + lo = [] # max-heap (actually a minheap with negated values) + hi = [] # min-heap (same size as or one smaller than lo) + + with suppress(StopIteration): + while True: + heappush(lo, -heappushpop(hi, read())) + yield -lo[0] + + heappush(hi, -heappushpop(lo, -read())) + yield (hi[0] - lo[0]) / 2 + + +def _running_median_windowed(iterator, maxlen): + "Yield median of values in a sliding window." + + window = deque() + ordered = [] + + for x in iterator: + window.append(x) + insort(ordered, x) + + if len(ordered) > maxlen: + i = bisect_left(ordered, window.popleft()) + del ordered[i] + + n = len(ordered) + m = n // 2 + yield ordered[m] if n & 1 else (ordered[m - 1] + ordered[m]) / 2 + + +def running_median(iterable, *, maxlen=None): + """Cumulative median of values seen so far or values in a sliding window. + + Set *maxlen* to a positive integer to specify the maximum size + of the sliding window. The default of *None* is equivalent to + an unbounded window. + + For example: + + >>> list(running_median([5.0, 9.0, 4.0, 12.0, 8.0, 9.0])) + [5.0, 7.0, 5.0, 7.0, 8.0, 8.5] + >>> list(running_median([5.0, 9.0, 4.0, 12.0, 8.0, 9.0], maxlen=3)) + [5.0, 7.0, 5.0, 9.0, 8.0, 9.0] + + Supports numeric types such as int, float, Decimal, and Fraction, + but not complex numbers which are unorderable. + + On version Python 3.13 and prior, max-heaps are simulated with + negative values. The negation causes Decimal inputs to apply context + rounding, making the results slightly different than that obtained + by statistics.median(). + """ + + iterator = iter(iterable) + + if maxlen is not None: + maxlen = index(maxlen) + if maxlen <= 0: + raise ValueError('Window size should be positive') + return _running_median_windowed(iterator, maxlen) + + if not _max_heap_available: + return _running_median_minheap_only(iterator) # pragma: no cover + + return _running_median_minheap_and_maxheap(iterator) # pragma: no cover diff --git a/lib/more_itertools/recipes.pyi b/lib/more_itertools/recipes.pyi new file mode 100644 index 0000000..de3d0a1 --- /dev/null +++ b/lib/more_itertools/recipes.pyi @@ -0,0 +1,205 @@ +"""Stubs for more_itertools.recipes""" + +from __future__ import annotations + +from collections.abc import Iterable, Iterator, Sequence +from decimal import Decimal +from fractions import Fraction +from typing import ( + Any, + Callable, + TypeVar, + overload, +) + +__all__ = [ + 'all_equal', + 'batched', + 'before_and_after', + 'consume', + 'convolve', + 'dotproduct', + 'first_true', + 'factor', + 'flatten', + 'grouper', + 'is_prime', + 'iter_except', + 'iter_index', + 'loops', + 'matmul', + 'multinomial', + 'ncycles', + 'nth', + 'nth_combination', + 'padnone', + 'pad_none', + 'pairwise', + 'partition', + 'polynomial_eval', + 'polynomial_from_roots', + 'polynomial_derivative', + 'powerset', + 'prepend', + 'quantify', + 'reshape', + 'random_combination_with_replacement', + 'random_combination', + 'random_permutation', + 'random_product', + 'repeatfunc', + 'roundrobin', + 'running_median', + 'sieve', + 'sliding_window', + 'subslices', + 'sum_of_squares', + 'tabulate', + 'tail', + 'take', + 'totient', + 'transpose', + 'triplewise', + 'unique', + 'unique_everseen', + 'unique_justseen', +] + +# Type and type variable definitions +_T = TypeVar('_T') +_T1 = TypeVar('_T1') +_T2 = TypeVar('_T2') +_U = TypeVar('_U') +_NumberT = TypeVar("_NumberT", float, Decimal, Fraction) + +def take(n: int, iterable: Iterable[_T]) -> list[_T]: ... +def tabulate( + function: Callable[[int], _T], start: int = ... +) -> Iterator[_T]: ... +def tail(n: int, iterable: Iterable[_T]) -> Iterator[_T]: ... +def consume(iterator: Iterable[_T], n: int | None = ...) -> None: ... +@overload +def nth(iterable: Iterable[_T], n: int) -> _T | None: ... +@overload +def nth(iterable: Iterable[_T], n: int, default: _U) -> _T | _U: ... +def all_equal( + iterable: Iterable[_T], key: Callable[[_T], _U] | None = ... +) -> bool: ... +def quantify( + iterable: Iterable[_T], pred: Callable[[_T], bool] = ... +) -> int: ... +def pad_none(iterable: Iterable[_T]) -> Iterator[_T | None]: ... +def padnone(iterable: Iterable[_T]) -> Iterator[_T | None]: ... +def ncycles(iterable: Iterable[_T], n: int) -> Iterator[_T]: ... +def dotproduct(vec1: Iterable[_T1], vec2: Iterable[_T2]) -> Any: ... +def flatten(listOfLists: Iterable[Iterable[_T]]) -> Iterator[_T]: ... +def repeatfunc( + func: Callable[..., _U], times: int | None = ..., *args: Any +) -> Iterator[_U]: ... +def pairwise(iterable: Iterable[_T]) -> Iterator[tuple[_T, _T]]: ... +def grouper( + iterable: Iterable[_T], + n: int, + incomplete: str = ..., + fillvalue: _U = ..., +) -> Iterator[tuple[_T | _U, ...]]: ... +def roundrobin(*iterables: Iterable[_T]) -> Iterator[_T]: ... +def partition( + pred: Callable[[_T], object] | None, iterable: Iterable[_T] +) -> tuple[Iterator[_T], Iterator[_T]]: ... +def powerset(iterable: Iterable[_T]) -> Iterator[tuple[_T, ...]]: ... +def unique_everseen( + iterable: Iterable[_T], key: Callable[[_T], _U] | None = ... +) -> Iterator[_T]: ... +def unique_justseen( + iterable: Iterable[_T], key: Callable[[_T], object] | None = ... +) -> Iterator[_T]: ... +def unique( + iterable: Iterable[_T], + key: Callable[[_T], object] | None = ..., + reverse: bool = False, +) -> Iterator[_T]: ... +@overload +def iter_except( + func: Callable[[], _T], + exception: type[BaseException] | tuple[type[BaseException], ...], + first: None = ..., +) -> Iterator[_T]: ... +@overload +def iter_except( + func: Callable[[], _T], + exception: type[BaseException] | tuple[type[BaseException], ...], + first: Callable[[], _U], +) -> Iterator[_T | _U]: ... +@overload +def first_true( + iterable: Iterable[_T], *, pred: Callable[[_T], object] | None = ... +) -> _T | None: ... +@overload +def first_true( + iterable: Iterable[_T], + default: _U, + pred: Callable[[_T], object] | None = ..., +) -> _T | _U: ... +def random_product( + *args: Iterable[_T], repeat: int = ... +) -> tuple[_T, ...]: ... +def random_permutation( + iterable: Iterable[_T], r: int | None = ... +) -> tuple[_T, ...]: ... +def random_combination(iterable: Iterable[_T], r: int) -> tuple[_T, ...]: ... +def random_combination_with_replacement( + iterable: Iterable[_T], r: int +) -> tuple[_T, ...]: ... +def nth_combination( + iterable: Iterable[_T], r: int, index: int +) -> tuple[_T, ...]: ... +def prepend(value: _T, iterator: Iterable[_U]) -> Iterator[_T | _U]: ... +def convolve(signal: Iterable[_T], kernel: Iterable[_T]) -> Iterator[_T]: ... +def before_and_after( + predicate: Callable[[_T], bool], it: Iterable[_T] +) -> tuple[Iterator[_T], Iterator[_T]]: ... +def triplewise(iterable: Iterable[_T]) -> Iterator[tuple[_T, _T, _T]]: ... +def sliding_window( + iterable: Iterable[_T], n: int +) -> Iterator[tuple[_T, ...]]: ... +def subslices(iterable: Iterable[_T]) -> Iterator[list[_T]]: ... +def polynomial_from_roots(roots: Sequence[_T]) -> list[_T]: ... +def iter_index( + iterable: Iterable[_T], + value: Any, + start: int | None = ..., + stop: int | None = ..., +) -> Iterator[int]: ... +def sieve(n: int) -> Iterator[int]: ... +def _batched( + iterable: Iterable[_T], n: int, *, strict: bool = False +) -> Iterator[tuple[_T, ...]]: ... + +batched = _batched + +def transpose( + it: Iterable[Iterable[_T]], +) -> Iterator[tuple[_T, ...]]: ... +@overload +def reshape( + matrix: Iterable[Iterable[_T]], shape: int +) -> Iterator[tuple[_T, ...]]: ... +@overload +def reshape(matrix: Iterable[Any], shape: Iterable[int]) -> Iterator[Any]: ... +def matmul(m1: Sequence[_T], m2: Sequence[_T]) -> Iterator[tuple[_T]]: ... +def _factor_trial(n: int) -> Iterator[int]: ... +def _factor_pollard(n: int) -> int: ... +def factor(n: int) -> Iterator[int]: ... +def polynomial_eval(coefficients: Sequence[_T], x: _U) -> _U: ... +def sum_of_squares(it: Iterable[_T]) -> _T: ... +def polynomial_derivative(coefficients: Sequence[_T]) -> list[_T]: ... +def totient(n: int) -> int: ... +def _shift_to_odd(n: int) -> tuple[int, int]: ... +def _strong_probable_prime(n: int, base: int) -> bool: ... +def is_prime(n: int) -> bool: ... +def loops(n: int) -> Iterator[None]: ... +def multinomial(*counts: int) -> int: ... +def running_median( + iterable: Iterable[_NumberT], *, maxlen: int | None = ... +) -> Iterator[_NumberT]: ... diff --git a/lib/secretstorage-3.5.0.dist-info/INSTALLER b/lib/secretstorage-3.5.0.dist-info/INSTALLER new file mode 100644 index 0000000..a1b589e --- /dev/null +++ b/lib/secretstorage-3.5.0.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/lib/secretstorage-3.5.0.dist-info/METADATA b/lib/secretstorage-3.5.0.dist-info/METADATA new file mode 100644 index 0000000..2a2d24f --- /dev/null +++ b/lib/secretstorage-3.5.0.dist-info/METADATA @@ -0,0 +1,109 @@ +Metadata-Version: 2.4 +Name: SecretStorage +Version: 3.5.0 +Summary: Python bindings to FreeDesktop.org Secret Service API +Author-email: Dmitry Shachnev +License-Expression: BSD-3-Clause +Project-URL: Homepage, https://github.com/mitya57/secretstorage +Project-URL: Documentation, https://secretstorage.readthedocs.io/en/latest/ +Project-URL: Issue Tracker, https://github.com/mitya57/secretstorage/issues/ +Platform: Linux +Classifier: Development Status :: 5 - Production/Stable +Classifier: Operating System :: POSIX +Classifier: Programming Language :: Python +Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: 3.12 +Classifier: Programming Language :: Python :: 3.13 +Classifier: Programming Language :: Python :: 3.14 +Classifier: Topic :: Security +Classifier: Topic :: Software Development :: Libraries :: Python Modules +Requires-Python: >=3.10 +Description-Content-Type: text/x-rst +License-File: LICENSE +Requires-Dist: cryptography>=2.0 +Requires-Dist: jeepney>=0.6 +Dynamic: license-file + +.. image:: https://github.com/mitya57/secretstorage/workflows/tests/badge.svg + :target: https://github.com/mitya57/secretstorage/actions + :alt: GitHub Actions status +.. image:: https://codecov.io/gh/mitya57/secretstorage/branch/master/graph/badge.svg + :target: https://codecov.io/gh/mitya57/secretstorage + :alt: Coverage status +.. image:: https://readthedocs.org/projects/secretstorage/badge/?version=latest + :target: https://secretstorage.readthedocs.io/en/latest/ + :alt: ReadTheDocs status + +Module description +================== + +This module provides a way for securely storing passwords and other secrets. + +It uses D-Bus `Secret Service`_ API that is supported by GNOME Keyring, +KWallet (since version 5.97) and KeePassXC. + +The main classes provided are ``secretstorage.Item``, representing a secret +item (that has a *label*, a *secret* and some *attributes*) and +``secretstorage.Collection``, a place items are stored in. + +SecretStorage supports most of the functions provided by Secret Service, +including creating and deleting items and collections, editing items, +locking and unlocking collections. + +The documentation can be found on `secretstorage.readthedocs.io`_. + +.. _`Secret Service`: https://specifications.freedesktop.org/secret-service/ +.. _`secretstorage.readthedocs.io`: https://secretstorage.readthedocs.io/en/latest/ + +Building the module +=================== + +SecretStorage requires Python ≥ 3.10 and these packages to work: + +* Jeepney_ +* `python-cryptography`_ + +To build SecretStorage, use this command:: + + python3 -m build + +If you have Sphinx_ installed, you can also build the documentation:: + + python3 -m sphinx docs build/sphinx/html + +.. _Jeepney: https://pypi.org/project/jeepney/ +.. _`python-cryptography`: https://pypi.org/project/cryptography/ +.. _Sphinx: https://www.sphinx-doc.org/en/master/ + +Testing the module +================== + +First, make sure that you have the Secret Service daemon installed. +The `GNOME Keyring`_ is the reference server-side implementation for the +Secret Service specification. + +.. _`GNOME Keyring`: https://download.gnome.org/sources/gnome-keyring/ + +Then, start the daemon and unlock the ``default`` collection, if needed. +The testsuite will fail to run if the ``default`` collection exists and is +locked. If it does not exist, the testsuite can also use the temporary +``session`` collection, as provided by the GNOME Keyring. + +Then, run the Python unittest module:: + + python3 -m unittest discover -s tests + +If you want to run the tests in an isolated or headless environment, run +this command in a D-Bus session:: + + dbus-run-session -- python3 -m unittest discover -s tests + +Get the code +============ + +SecretStorage is available under BSD license. The source code can be found +on GitHub_. + +.. _GitHub: https://github.com/mitya57/secretstorage diff --git a/lib/secretstorage-3.5.0.dist-info/RECORD b/lib/secretstorage-3.5.0.dist-info/RECORD new file mode 100644 index 0000000..f1c9798 --- /dev/null +++ b/lib/secretstorage-3.5.0.dist-info/RECORD @@ -0,0 +1,21 @@ +secretstorage-3.5.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +secretstorage-3.5.0.dist-info/METADATA,sha256=uY49xC_36bG6zOYNrCZizZll-tbkRp_JSDyeuZmYw2c,3974 +secretstorage-3.5.0.dist-info/RECORD,, +secretstorage-3.5.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91 +secretstorage-3.5.0.dist-info/licenses/LICENSE,sha256=5wA1dZK7k3BMsEDOSmZjqFKScWpAQbaiZ_WY8hZYRm8,1504 +secretstorage-3.5.0.dist-info/top_level.txt,sha256=hveSi1OWGaEt3kEVbjmZ0M-ASPxi6y-nTPVa-d3c0B4,14 +secretstorage/__init__.py,sha256=_Y34co2i9G-0ljfDUZE2CdLMYa1HV8jQRc6Xtrf-dEQ,3402 +secretstorage/__pycache__/__init__.cpython-314.pyc,, +secretstorage/__pycache__/collection.cpython-314.pyc,, +secretstorage/__pycache__/defines.cpython-314.pyc,, +secretstorage/__pycache__/dhcrypto.cpython-314.pyc,, +secretstorage/__pycache__/exceptions.cpython-314.pyc,, +secretstorage/__pycache__/item.cpython-314.pyc,, +secretstorage/__pycache__/util.cpython-314.pyc,, +secretstorage/collection.py,sha256=5g-aqHnPbHoaX41J839oaRFaUd8Mr4Bf7KSatfCykoo,9829 +secretstorage/defines.py,sha256=OacsZ_i7F7E-BBqKOdSTBJJUmPB7CHRNLgNjIQnjVM0,872 +secretstorage/dhcrypto.py,sha256=VdTb-rxwJcOlsWFT4a0vrFudDMdhgasTSD4qTr3bHn8,2274 +secretstorage/exceptions.py,sha256=1uUZXTua4jRZf4PKDIT2SVWcSKP2lP97s8r3eJZudio,1655 +secretstorage/item.py,sha256=b7FlGcWEmDC1zfECiM55rJ1we2h684NsVZkWUmV_zDc,6150 +secretstorage/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +secretstorage/util.py,sha256=J_tZYXJTVv75ESFuvKYjltidNrt7fbRBhwM0C08qbLg,7487 diff --git a/lib/secretstorage-3.5.0.dist-info/WHEEL b/lib/secretstorage-3.5.0.dist-info/WHEEL new file mode 100644 index 0000000..e7fa31b --- /dev/null +++ b/lib/secretstorage-3.5.0.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: setuptools (80.9.0) +Root-Is-Purelib: true +Tag: py3-none-any + diff --git a/lib/secretstorage-3.5.0.dist-info/licenses/LICENSE b/lib/secretstorage-3.5.0.dist-info/licenses/LICENSE new file mode 100644 index 0000000..076e642 --- /dev/null +++ b/lib/secretstorage-3.5.0.dist-info/licenses/LICENSE @@ -0,0 +1,25 @@ +Copyright 2012-2025 Dmitry Shachnev +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. +3. Neither the name of the University nor the names of its contributors may be + used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR ANY +DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/lib/secretstorage-3.5.0.dist-info/top_level.txt b/lib/secretstorage-3.5.0.dist-info/top_level.txt new file mode 100644 index 0000000..0ec6ae8 --- /dev/null +++ b/lib/secretstorage-3.5.0.dist-info/top_level.txt @@ -0,0 +1 @@ +secretstorage diff --git a/lib/secretstorage/__init__.py b/lib/secretstorage/__init__.py new file mode 100644 index 0000000..b0e51ae --- /dev/null +++ b/lib/secretstorage/__init__.py @@ -0,0 +1,103 @@ +# SecretStorage module for Python +# Access passwords using the SecretService DBus API +# Author: Dmitry Shachnev, 2013-2020 +# License: 3-clause BSD, see LICENSE file + +"""This file provides quick access to all SecretStorage API. Please +refer to documentation of individual modules for API details. +""" + +from jeepney.bus_messages import message_bus +from jeepney.io.blocking import DBusConnection, Proxy, open_dbus_connection + +from secretstorage.collection import ( + Collection, + create_collection, + get_all_collections, + get_any_collection, + get_collection_by_alias, + get_default_collection, + search_items, +) +from secretstorage.exceptions import ( + ItemNotFoundException, + LockedException, + PromptDismissedException, + SecretServiceNotAvailableException, + SecretStorageException, +) +from secretstorage.item import Item +from secretstorage.util import add_match_rules + +__version_tuple__ = (3, 5, 0) +__version__ = '.'.join(map(str, __version_tuple__)) + +__all__ = [ + 'Collection', + 'Item', + 'ItemNotFoundException', + 'LockedException', + 'PromptDismissedException', + 'SecretServiceNotAvailableException', + 'SecretStorageException', + 'check_service_availability', + 'create_collection', + 'dbus_init', + 'get_all_collections', + 'get_any_collection', + 'get_collection_by_alias', + 'get_default_collection', + 'search_items', +] + + +def dbus_init() -> DBusConnection: + """Returns a new connection to the session bus, instance of + jeepney's :class:`DBusConnection` class. This connection can + then be passed to various SecretStorage functions, such as + :func:`~secretstorage.collection.get_default_collection`. + + .. warning:: + The D-Bus socket will not be closed automatically. You can + close it manually using the :meth:`DBusConnection.close` method, + or you can use the :class:`contextlib.closing` context manager: + + .. code-block:: python + + from contextlib import closing + with closing(dbus_init()) as conn: + collection = secretstorage.get_default_collection(conn) + items = collection.search_items({'application': 'myapp'}) + + However, you will not be able to call any methods on the objects + created within the context after you leave it. + + .. versionchanged:: 3.0 + Before the port to Jeepney, this function returned an + instance of :class:`dbus.SessionBus` class. + + .. versionchanged:: 3.1 + This function no longer accepts any arguments. + """ + try: + connection = open_dbus_connection() + add_match_rules(connection) + return connection + except KeyError as ex: + # os.environ['DBUS_SESSION_BUS_ADDRESS'] may raise it + reason = f"Environment variable {ex.args[0]} is unset" + raise SecretServiceNotAvailableException(reason) from ex + except (ConnectionError, ValueError) as ex: + raise SecretServiceNotAvailableException(str(ex)) from ex + + +def check_service_availability(connection: DBusConnection) -> bool: + """Returns True if the Secret Service daemon is either running or + available for activation via D-Bus, False otherwise. + + .. versionadded:: 3.2 + """ + from secretstorage.util import BUS_NAME + proxy = Proxy(message_bus, connection) + return (proxy.NameHasOwner(BUS_NAME)[0] == 1 + or BUS_NAME in proxy.ListActivatableNames()[0]) diff --git a/lib/secretstorage/__pycache__/__init__.cpython-314.pyc b/lib/secretstorage/__pycache__/__init__.cpython-314.pyc new file mode 100644 index 0000000..4eacfa0 Binary files /dev/null and b/lib/secretstorage/__pycache__/__init__.cpython-314.pyc differ diff --git a/lib/secretstorage/__pycache__/collection.cpython-314.pyc b/lib/secretstorage/__pycache__/collection.cpython-314.pyc new file mode 100644 index 0000000..0140160 Binary files /dev/null and b/lib/secretstorage/__pycache__/collection.cpython-314.pyc differ diff --git a/lib/secretstorage/__pycache__/defines.cpython-314.pyc b/lib/secretstorage/__pycache__/defines.cpython-314.pyc new file mode 100644 index 0000000..8b194ef Binary files /dev/null and b/lib/secretstorage/__pycache__/defines.cpython-314.pyc differ diff --git a/lib/secretstorage/__pycache__/dhcrypto.cpython-314.pyc b/lib/secretstorage/__pycache__/dhcrypto.cpython-314.pyc new file mode 100644 index 0000000..e40054c Binary files /dev/null and b/lib/secretstorage/__pycache__/dhcrypto.cpython-314.pyc differ diff --git a/lib/secretstorage/__pycache__/exceptions.cpython-314.pyc b/lib/secretstorage/__pycache__/exceptions.cpython-314.pyc new file mode 100644 index 0000000..b05a840 Binary files /dev/null and b/lib/secretstorage/__pycache__/exceptions.cpython-314.pyc differ diff --git a/lib/secretstorage/__pycache__/item.cpython-314.pyc b/lib/secretstorage/__pycache__/item.cpython-314.pyc new file mode 100644 index 0000000..04a8aaa Binary files /dev/null and b/lib/secretstorage/__pycache__/item.cpython-314.pyc differ diff --git a/lib/secretstorage/__pycache__/util.cpython-314.pyc b/lib/secretstorage/__pycache__/util.cpython-314.pyc new file mode 100644 index 0000000..56a14b2 Binary files /dev/null and b/lib/secretstorage/__pycache__/util.cpython-314.pyc differ diff --git a/lib/secretstorage/collection.py b/lib/secretstorage/collection.py new file mode 100644 index 0000000..4db2b41 --- /dev/null +++ b/lib/secretstorage/collection.py @@ -0,0 +1,244 @@ +# SecretStorage module for Python +# Access passwords using the SecretService DBus API +# Author: Dmitry Shachnev, 2013-2025 +# License: 3-clause BSD, see LICENSE file + +"""Collection is a place where secret items are stored. Normally, only +the default collection should be used, but this module allows to use any +registered collection. Use :func:`get_default_collection` to get the +default collection (and create it, if necessary). + +Collections are usually automatically unlocked when user logs in, but +collections can also be locked and unlocked using :meth:`Collection.lock` +and :meth:`Collection.unlock` methods (unlocking requires showing the +unlocking prompt to user and blocks until user accepts or declines it). +Creating new items and editing existing ones is possible only in unlocked +collections. +""" + +from collections.abc import Iterator + +from jeepney.io.blocking import DBusConnection + +from secretstorage.defines import SS_PATH, SS_PREFIX +from secretstorage.dhcrypto import Session +from secretstorage.exceptions import ( + ItemNotFoundException, + LockedException, + PromptDismissedException, +) +from secretstorage.item import Item +from secretstorage.util import ( + DBusAddressWrapper, + exec_prompt, + format_secret, + open_session, + unlock_objects, +) + +COLLECTION_IFACE = SS_PREFIX + 'Collection' +SERVICE_IFACE = SS_PREFIX + 'Service' +DEFAULT_COLLECTION = '/org/freedesktop/secrets/aliases/default' +SESSION_COLLECTION = '/org/freedesktop/secrets/collection/session' + + +class Collection: + """Represents a collection.""" + + def __init__(self, connection: DBusConnection, + collection_path: str = DEFAULT_COLLECTION, + session: Session | None = None) -> None: + self.connection = connection + self.session = session + self.collection_path = collection_path + self._collection = DBusAddressWrapper( + collection_path, COLLECTION_IFACE, connection) + self._collection.get_property('Label') + + def is_locked(self) -> bool: + """Returns :const:`True` if item is locked, otherwise + :const:`False`.""" + return bool(self._collection.get_property('Locked')) + + def ensure_not_locked(self) -> None: + """If collection is locked, raises + :exc:`~secretstorage.exceptions.LockedException`.""" + if self.is_locked(): + raise LockedException('Collection is locked!') + + def unlock(self, timeout: float | None = None) -> bool: + """Requests unlocking the collection. + + Returns a boolean representing whether the prompt has been + dismissed; that means :const:`False` on successful unlocking + and :const:`True` if it has been dismissed. + + :raises: ``TimeoutError`` if `timeout` (in seconds) passed + and the prompt was neither accepted nor dismissed. + + .. versionchanged:: 3.0 + No longer accepts the ``callback`` argument. + + .. versionchanged:: 3.5 + Added ``timeout`` argument. + """ + return unlock_objects(self.connection, [self.collection_path], timeout=timeout) + + def lock(self) -> None: + """Locks the collection.""" + service = DBusAddressWrapper(SS_PATH, SERVICE_IFACE, self.connection) + service.call('Lock', 'ao', [self.collection_path]) + + def delete(self) -> None: + """Deletes the collection and all items inside it.""" + self.ensure_not_locked() + prompt, = self._collection.call('Delete', '') + if prompt != "/": + dismissed, _result = exec_prompt(self.connection, prompt) + if dismissed: + raise PromptDismissedException('Prompt dismissed.') + + def get_all_items(self) -> Iterator[Item]: + """Returns a generator of all items in the collection.""" + for item_path in self._collection.get_property('Items'): + yield Item(self.connection, item_path, self.session) + + def search_items(self, attributes: dict[str, str]) -> Iterator[Item]: + """Returns a generator of items with the given attributes. + `attributes` should be a dictionary.""" + result, = self._collection.call('SearchItems', 'a{ss}', attributes) + for item_path in result: + yield Item(self.connection, item_path, self.session) + + def get_label(self) -> str: + """Returns the collection label.""" + label = self._collection.get_property('Label') + assert isinstance(label, str) + return label + + def set_label(self, label: str) -> None: + """Sets collection label to `label`.""" + self.ensure_not_locked() + self._collection.set_property('Label', 's', label) + + def create_item(self, label: str, attributes: dict[str, str], + secret: bytes, replace: bool = False, + content_type: str = 'text/plain') -> Item: + """Creates a new :class:`~secretstorage.item.Item` with given + `label` (unicode string), `attributes` (dictionary) and `secret` + (bytestring). If `replace` is :const:`True`, replaces the existing + item with the same attributes. If `content_type` is given, also + sets the content type of the secret (``text/plain`` by default). + Returns the created item.""" + self.ensure_not_locked() + if not self.session: + self.session = open_session(self.connection) + _secret = format_secret(self.session, secret, content_type) + properties = { + SS_PREFIX + 'Item.Label': ('s', label), + SS_PREFIX + 'Item.Attributes': ('a{ss}', attributes), + } + item_path, prompt = self._collection.call( + 'CreateItem', + 'a{sv}(oayays)b', + properties, + _secret, + replace + ) + if len(item_path) > 1: + return Item(self.connection, item_path, self.session) + dismissed, result = exec_prompt(self.connection, prompt) + if dismissed: + raise PromptDismissedException('Prompt dismissed.') + signature, item_path = result + assert signature == 'o' + return Item(self.connection, item_path, self.session) + + def __repr__(self) -> str: + return f"" + + +def create_collection(connection: DBusConnection, label: str, alias: str = '', + session: Session | None = None) -> Collection: + """Creates a new :class:`Collection` with the given `label` and `alias` + and returns it. This action requires prompting. + + :raises: :exc:`~secretstorage.exceptions.PromptDismissedException` + if the prompt is dismissed. + """ + if not session: + session = open_session(connection) + properties = {SS_PREFIX + 'Collection.Label': ('s', label)} + service = DBusAddressWrapper(SS_PATH, SERVICE_IFACE, connection) + collection_path, prompt = service.call('CreateCollection', 'a{sv}s', + properties, alias) + if len(collection_path) > 1: + return Collection(connection, collection_path, session=session) + dismissed, result = exec_prompt(connection, prompt) + if dismissed: + raise PromptDismissedException('Prompt dismissed.') + signature, collection_path = result + assert signature == 'o' + return Collection(connection, collection_path, session=session) + + +def get_all_collections(connection: DBusConnection) -> Iterator[Collection]: + """Returns a generator of all available collections.""" + service = DBusAddressWrapper(SS_PATH, SERVICE_IFACE, connection) + for collection_path in service.get_property('Collections'): + yield Collection(connection, collection_path) + + +def get_default_collection(connection: DBusConnection, + session: Session | None = None) -> Collection: + """Returns the default collection. If it doesn't exist, + creates it.""" + try: + return Collection(connection) + except ItemNotFoundException: + return create_collection(connection, 'Default', 'default', session) + + +def get_any_collection(connection: DBusConnection) -> Collection: + """Returns any collection, in the following order of preference: + + - The default collection; + - The "session" collection (usually temporary); + - The first collection in the collections list.""" + try: + return Collection(connection) + except ItemNotFoundException: + pass + try: + # GNOME Keyring provides session collection where items + # are stored in process memory. + return Collection(connection, SESSION_COLLECTION) + except ItemNotFoundException: + pass + collections = list(get_all_collections(connection)) + if collections: + return collections[0] + else: + raise ItemNotFoundException('No collections found.') + + +def get_collection_by_alias(connection: DBusConnection, + alias: str) -> Collection: + """Returns the collection with the given `alias`. If there is no + such collection, raises + :exc:`~secretstorage.exceptions.ItemNotFoundException`.""" + service = DBusAddressWrapper(SS_PATH, SERVICE_IFACE, connection) + collection_path, = service.call('ReadAlias', 's', alias) + if len(collection_path) <= 1: + raise ItemNotFoundException('No collection with such alias.') + return Collection(connection, collection_path) + + +def search_items(connection: DBusConnection, + attributes: dict[str, str]) -> Iterator[Item]: + """Returns a generator of items in all collections with the given + attributes. `attributes` should be a dictionary.""" + service = DBusAddressWrapper(SS_PATH, SERVICE_IFACE, connection) + locked, unlocked = service.call('SearchItems', 'a{ss}', attributes) + for item_path in locked + unlocked: + yield Item(connection, item_path) diff --git a/lib/secretstorage/defines.py b/lib/secretstorage/defines.py new file mode 100644 index 0000000..59c7286 --- /dev/null +++ b/lib/secretstorage/defines.py @@ -0,0 +1,21 @@ +# SecretStorage module for Python +# Access passwords using the SecretService DBus API +# Author: Dmitry Shachnev, 2013-2016 +# License: 3-clause BSD, see LICENSE file + +# This file contains some common defines. + +SS_PREFIX = 'org.freedesktop.Secret.' +SS_PATH = '/org/freedesktop/secrets' + +DBUS_UNKNOWN_METHOD = 'org.freedesktop.DBus.Error.UnknownMethod' +DBUS_ACCESS_DENIED = 'org.freedesktop.DBus.Error.AccessDenied' +DBUS_SERVICE_UNKNOWN = 'org.freedesktop.DBus.Error.ServiceUnknown' +DBUS_EXEC_FAILED = 'org.freedesktop.DBus.Error.Spawn.ExecFailed' +DBUS_NO_REPLY = 'org.freedesktop.DBus.Error.NoReply' +DBUS_NOT_SUPPORTED = 'org.freedesktop.DBus.Error.NotSupported' +DBUS_NO_SUCH_OBJECT = 'org.freedesktop.Secret.Error.NoSuchObject' +DBUS_UNKNOWN_OBJECT = 'org.freedesktop.DBus.Error.UnknownObject' + +ALGORITHM_PLAIN = 'plain' +ALGORITHM_DH = 'dh-ietf1024-sha256-aes128-cbc-pkcs7' diff --git a/lib/secretstorage/dhcrypto.py b/lib/secretstorage/dhcrypto.py new file mode 100644 index 0000000..31516cb --- /dev/null +++ b/lib/secretstorage/dhcrypto.py @@ -0,0 +1,50 @@ +# SecretStorage module for Python +# Access passwords using the SecretService DBus API +# Author: Dmitry Shachnev, 2014-2025 +# License: 3-clause BSD, see LICENSE file + +'''This module contains needed classes, functions and constants +to implement dh-ietf1024-sha256-aes128-cbc-pkcs7 secret encryption +algorithm.''' + +import hmac +import os +from hashlib import sha256 + +# A standard 1024 bits (128 bytes) prime number for use in Diffie-Hellman exchange +DH_PRIME_1024_BYTES = ( + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xC9, 0x0F, 0xDA, 0xA2, 0x21, 0x68, + 0xC2, 0x34, 0xC4, 0xC6, 0x62, 0x8B, 0x80, 0xDC, 0x1C, 0xD1, 0x29, 0x02, 0x4E, 0x08, + 0x8A, 0x67, 0xCC, 0x74, 0x02, 0x0B, 0xBE, 0xA6, 0x3B, 0x13, 0x9B, 0x22, 0x51, 0x4A, + 0x08, 0x79, 0x8E, 0x34, 0x04, 0xDD, 0xEF, 0x95, 0x19, 0xB3, 0xCD, 0x3A, 0x43, 0x1B, + 0x30, 0x2B, 0x0A, 0x6D, 0xF2, 0x5F, 0x14, 0x37, 0x4F, 0xE1, 0x35, 0x6D, 0x6D, 0x51, + 0xC2, 0x45, 0xE4, 0x85, 0xB5, 0x76, 0x62, 0x5E, 0x7E, 0xC6, 0xF4, 0x4C, 0x42, 0xE9, + 0xA6, 0x37, 0xED, 0x6B, 0x0B, 0xFF, 0x5C, 0xB6, 0xF4, 0x06, 0xB7, 0xED, 0xEE, 0x38, + 0x6B, 0xFB, 0x5A, 0x89, 0x9F, 0xA5, 0xAE, 0x9F, 0x24, 0x11, 0x7C, 0x4B, 0x1F, 0xE6, + 0x49, 0x28, 0x66, 0x51, 0xEC, 0xE6, 0x53, 0x81, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF +) + + +DH_PRIME_1024 = int.from_bytes(DH_PRIME_1024_BYTES, 'big') + + +class Session: + def __init__(self) -> None: + self.object_path: str | None = None + self.aes_key: bytes | None = None + self.encrypted = True + # 128-bytes-long strong random number + self.my_private_key = int.from_bytes(os.urandom(0x80), 'big') + self.my_public_key = pow(2, self.my_private_key, DH_PRIME_1024) + + def set_server_public_key(self, server_public_key: int) -> None: + common_secret_int = pow(server_public_key, self.my_private_key, + DH_PRIME_1024) + common_secret = common_secret_int.to_bytes(128, 'big') + # HKDF with null salt, empty info and SHA-256 hash + salt = b'\x00' * 0x20 + pseudo_random_key = hmac.new(salt, common_secret, sha256).digest() + output_block = hmac.new(pseudo_random_key, b'\x01', sha256).digest() + # Resulting AES key should be 128-bit + self.aes_key = output_block[:0x10] diff --git a/lib/secretstorage/exceptions.py b/lib/secretstorage/exceptions.py new file mode 100644 index 0000000..c8c19d6 --- /dev/null +++ b/lib/secretstorage/exceptions.py @@ -0,0 +1,50 @@ +# SecretStorage module for Python +# Access passwords using the SecretService DBus API +# Author: Dmitry Shachnev, 2012-2018 +# License: 3-clause BSD, see LICENSE file + +"""All secretstorage functions may raise various exceptions when +something goes wrong. All exceptions derive from base +:exc:`SecretStorageException` class.""" + + +class SecretStorageException(Exception): + """All exceptions derive from this class.""" + + +class SecretServiceNotAvailableException(SecretStorageException): + """Raised by :class:`~secretstorage.item.Item` or + :class:`~secretstorage.collection.Collection` constructors, or by + other functions in the :mod:`secretstorage.collection` module, when + the Secret Service API is not available.""" + + +class LockedException(SecretStorageException): + """Raised when an action cannot be performed because the collection + is locked. Use :meth:`~secretstorage.collection.Collection.is_locked` + to check if the collection is locked, and + :meth:`~secretstorage.collection.Collection.unlock` to unlock it. + """ + + +class ItemNotFoundException(SecretStorageException): + """Raised when an item does not exist or has been deleted. Example of + handling: + + >>> import secretstorage + >>> connection = secretstorage.dbus_init() + >>> item_path = '/not/existing/path' + >>> try: + ... item = secretstorage.Item(connection, item_path) + ... except secretstorage.ItemNotFoundException: + ... print('Item not found!') + ... + Item not found! + """ + + +class PromptDismissedException(ItemNotFoundException): + """Raised when a prompt was dismissed by the user. + + .. versionadded:: 3.1 + """ diff --git a/lib/secretstorage/item.py b/lib/secretstorage/item.py new file mode 100644 index 0000000..ad3a8b0 --- /dev/null +++ b/lib/secretstorage/item.py @@ -0,0 +1,159 @@ +# SecretStorage module for Python +# Access passwords using the SecretService DBus API +# Author: Dmitry Shachnev, 2013-2025 +# License: 3-clause BSD, see LICENSE file + +"""SecretStorage item contains a *secret*, some *attributes* and a +*label* visible to user. Editing all these properties and reading the +secret is possible only when the :doc:`collection ` storing +the item is unlocked. The collection can be unlocked using collection's +:meth:`~secretstorage.collection.Collection.unlock` method.""" + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes +from jeepney.io.blocking import DBusConnection + +from secretstorage.defines import SS_PREFIX +from secretstorage.dhcrypto import Session +from secretstorage.exceptions import LockedException, PromptDismissedException +from secretstorage.util import ( + DBusAddressWrapper, + exec_prompt, + format_secret, + open_session, + unlock_objects, +) + +ITEM_IFACE = SS_PREFIX + 'Item' + + +class Item: + """Represents a secret item.""" + + def __init__(self, connection: DBusConnection, + item_path: str, session: Session | None = None) -> None: + self.item_path = item_path + self._item = DBusAddressWrapper(item_path, ITEM_IFACE, connection) + self._item.get_property('Label') + self.session = session + self.connection = connection + + def __eq__(self, other: "DBusConnection") -> bool: + assert isinstance(other.item_path, str) + return self.item_path == other.item_path + + def is_locked(self) -> bool: + """Returns :const:`True` if item is locked, otherwise + :const:`False`.""" + return bool(self._item.get_property('Locked')) + + def ensure_not_locked(self) -> None: + """If collection is locked, raises + :exc:`~secretstorage.exceptions.LockedException`.""" + if self.is_locked(): + raise LockedException('Item is locked!') + + def unlock(self, timeout: float | None = None) -> bool: + """Requests unlocking the item. Usually, this means that the + whole collection containing this item will be unlocked. + + Returns a boolean representing whether the prompt has been + dismissed; that means :const:`False` on successful unlocking + and :const:`True` if it has been dismissed. + + :raises: ``TimeoutError`` if `timeout` (in seconds) passed + and the prompt was neither accepted nor dismissed. + + .. versionadded:: 2.1.2 + + .. versionchanged:: 3.0 + No longer accepts the ``callback`` argument. + + .. versionchanged:: 3.5 + Added ``timeout`` argument. + """ + return unlock_objects(self.connection, [self.item_path], timeout=timeout) + + def get_attributes(self) -> dict[str, str]: + """Returns item attributes (dictionary).""" + attrs = self._item.get_property('Attributes') + return dict(attrs) + + def set_attributes(self, attributes: dict[str, str]) -> None: + """Sets item attributes to `attributes` (dictionary).""" + self._item.set_property('Attributes', 'a{ss}', attributes) + + def get_label(self) -> str: + """Returns item label (unicode string).""" + label = self._item.get_property('Label') + assert isinstance(label, str) + return label + + def set_label(self, label: str) -> None: + """Sets item label to `label`.""" + self.ensure_not_locked() + self._item.set_property('Label', 's', label) + + def delete(self) -> None: + """Deletes the item.""" + self.ensure_not_locked() + prompt, = self._item.call('Delete', '') + if prompt != "/": + dismissed, _result = exec_prompt(self.connection, prompt) + if dismissed: + raise PromptDismissedException('Prompt dismissed.') + + def get_secret(self) -> bytes: + """Returns item secret (bytestring).""" + self.ensure_not_locked() + if not self.session: + self.session = open_session(self.connection) + secret, = self._item.call('GetSecret', 'o', self.session.object_path) + if not self.session.encrypted: + return bytes(secret[2]) + assert self.session.aes_key is not None + aes = algorithms.AES(self.session.aes_key) + aes_iv = bytes(secret[1]) + decryptor = Cipher(aes, modes.CBC(aes_iv), default_backend()).decryptor() + encrypted_secret = secret[2] + padded_secret = decryptor.update(bytes(encrypted_secret)) + decryptor.finalize() + assert isinstance(padded_secret, bytes) + return padded_secret[:-padded_secret[-1]] + + def get_secret_content_type(self) -> str: + """Returns content type of item secret (string).""" + self.ensure_not_locked() + if not self.session: + self.session = open_session(self.connection) + secret, = self._item.call('GetSecret', 'o', self.session.object_path) + return str(secret[3]) + + def set_secret(self, secret: bytes, + content_type: str = 'text/plain') -> None: + """Sets item secret to `secret`. If `content_type` is given, + also sets the content type of the secret (``text/plain`` by + default).""" + self.ensure_not_locked() + if not self.session: + self.session = open_session(self.connection) + _secret = format_secret(self.session, secret, content_type) + self._item.call('SetSecret', '(oayays)', _secret) + + def get_created(self) -> int: + """Returns UNIX timestamp (integer) representing the time + when the item was created. + + .. versionadded:: 1.1""" + created = self._item.get_property('Created') + assert isinstance(created, int) + return created + + def get_modified(self) -> int: + """Returns UNIX timestamp (integer) representing the time + when the item was last modified.""" + modified = self._item.get_property('Modified') + assert isinstance(modified, int) + return modified + + def __repr__(self) -> str: + return f"" diff --git a/lib/secretstorage/py.typed b/lib/secretstorage/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/lib/secretstorage/util.py b/lib/secretstorage/util.py new file mode 100644 index 0000000..a7caf2c --- /dev/null +++ b/lib/secretstorage/util.py @@ -0,0 +1,227 @@ +# SecretStorage module for Python +# Access passwords using the SecretService DBus API +# Author: Dmitry Shachnev, 2013-2025 +# License: 3-clause BSD, see LICENSE file + +"""This module provides some utility functions, but these shouldn't +normally be used by external applications.""" + +import os +from typing import Any + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes +from jeepney import ( + DBusAddress, + DBusErrorResponse, + MatchRule, + Message, + MessageType, + Properties, + new_method_call, +) +from jeepney.io.blocking import DBusConnection + +from secretstorage.defines import ( + ALGORITHM_DH, + ALGORITHM_PLAIN, + DBUS_EXEC_FAILED, + DBUS_NO_REPLY, + DBUS_NO_SUCH_OBJECT, + DBUS_NOT_SUPPORTED, + DBUS_SERVICE_UNKNOWN, + DBUS_UNKNOWN_METHOD, + DBUS_UNKNOWN_OBJECT, + SS_PATH, + SS_PREFIX, +) +from secretstorage.dhcrypto import Session +from secretstorage.exceptions import ( + ItemNotFoundException, + SecretServiceNotAvailableException, +) + +BUS_NAME = 'org.freedesktop.secrets' +SERVICE_IFACE = SS_PREFIX + 'Service' +PROMPT_IFACE = SS_PREFIX + 'Prompt' + + +class DBusAddressWrapper(DBusAddress): # type: ignore + """A wrapper class around :class:`jeepney.wrappers.DBusAddress` + that adds some additional methods for calling and working with + properties, and converts error responses to SecretStorage + exceptions. + + .. versionadded:: 3.0 + """ + def __init__(self, path: str, interface: str, + connection: DBusConnection) -> None: + DBusAddress.__init__(self, path, BUS_NAME, interface) + self._connection = connection + + def send_and_get_reply(self, msg: Message) -> Any: + try: + resp_msg: Message = self._connection.send_and_get_reply(msg) + if resp_msg.header.message_type == MessageType.error: + raise DBusErrorResponse(resp_msg) + return resp_msg.body + except DBusErrorResponse as resp: + if resp.name in ( + DBUS_UNKNOWN_METHOD, + DBUS_NO_SUCH_OBJECT, + DBUS_UNKNOWN_OBJECT, + ): + raise ItemNotFoundException('Item does not exist!') from resp + elif resp.name in (DBUS_SERVICE_UNKNOWN, DBUS_EXEC_FAILED, + DBUS_NO_REPLY): + data = resp.data + if isinstance(data, tuple): + data = data[0] + raise SecretServiceNotAvailableException(data) from resp + raise + + def call(self, method: str, signature: str, *body: Any) -> Any: + msg = new_method_call(self, method, signature, body) + return self.send_and_get_reply(msg) + + def get_property(self, name: str) -> Any: + msg = Properties(self).get(name) + (signature, value), = self.send_and_get_reply(msg) + return value + + def set_property(self, name: str, signature: str, value: Any) -> None: + msg = Properties(self).set(name, signature, value) + self.send_and_get_reply(msg) + + +def open_session(connection: DBusConnection) -> Session: + """Returns a new Secret Service session.""" + service = DBusAddressWrapper(SS_PATH, SERVICE_IFACE, connection) + session = Session() + try: + output, result = service.call( + 'OpenSession', 'sv', + ALGORITHM_DH, + ('ay', session.my_public_key.to_bytes(128, 'big'))) + except DBusErrorResponse as resp: + if resp.name != DBUS_NOT_SUPPORTED: + raise + output, result = service.call( + 'OpenSession', 'sv', + ALGORITHM_PLAIN, + ('s', '')) + session.encrypted = False + else: + signature, value = output + assert signature == 'ay' + key = int.from_bytes(value, 'big') + session.set_server_public_key(key) + session.object_path = result + return session + + +def format_secret(session: Session, secret: bytes, + content_type: str) -> tuple[str, bytes, bytes, str]: + """Formats `secret` to make possible to pass it to the + Secret Service API.""" + if isinstance(secret, str): + secret = secret.encode('utf-8') + elif not isinstance(secret, bytes): + raise TypeError('secret must be bytes') + assert session.object_path is not None + if not session.encrypted: + return (session.object_path, b'', secret, content_type) + assert session.aes_key is not None + # PKCS-7 style padding + padding = 0x10 - (len(secret) & 0xf) + secret += bytes((padding,) * padding) + aes_iv = os.urandom(0x10) + aes = algorithms.AES(session.aes_key) + encryptor = Cipher(aes, modes.CBC(aes_iv), default_backend()).encryptor() + encrypted_secret = encryptor.update(secret) + encryptor.finalize() + return ( + session.object_path, + aes_iv, + encrypted_secret, + content_type + ) + + +def exec_prompt( + connection: DBusConnection, + prompt_path: str, + *, + timeout: float | None = None, +) -> tuple[bool, tuple[str, Any]]: + """Executes the prompt in a blocking mode. + + :returns: a two-element tuple: + + - The first element is a boolean value indicating whether the operation was + dismissed. + - The second element is a (signature, result) tuple. For creating items and + collections, ``signature`` is ``'o'`` and ``result`` is a single object + path. For unlocking, ``signature`` is ``'ao'`` and ``result`` is a list of + object paths. + + .. versionchanged:: 3.5 + Added ``timeout`` keyword argument. + """ + prompt = DBusAddressWrapper(prompt_path, PROMPT_IFACE, connection) + rule = MatchRule( + path=prompt_path, + interface=PROMPT_IFACE, + member='Completed', + type=MessageType.signal, + ) + with connection.filter(rule) as signals: + prompt.call('Prompt', 's', '') + message = connection.recv_until_filtered(signals, timeout=timeout) + dismissed, result = message.body + assert dismissed is not None + assert result is not None + return dismissed, result + + +def unlock_objects( + connection: DBusConnection, + paths: list[str], + *, + timeout: float | None = None, +) -> bool: + """Requests unlocking objects specified in `paths`. + Returns a boolean representing whether the operation was dismissed. + + .. versionadded:: 2.1.2 + + .. versionchanged:: 3.5 + Added ``timeout`` keyword argument. + """ + service = DBusAddressWrapper(SS_PATH, SERVICE_IFACE, connection) + unlocked_paths, prompt = service.call('Unlock', 'ao', paths) + if len(prompt) > 1: + dismissed, (signature, unlocked) = exec_prompt( + connection, + prompt, + timeout=timeout, + ) + assert signature == 'ao' + return dismissed + return False + + +def add_match_rules(connection: DBusConnection) -> None: + """Adds match rules for the given connection. + + Currently it matches all messages from the Prompt interface, as the + mock service (unlike GNOME Keyring) does not specify the signal + destination. + + .. versionadded:: 3.1 + """ + rule = MatchRule(sender=BUS_NAME, interface=PROMPT_IFACE) + dbus = DBusAddressWrapper(path='/org/freedesktop/DBus', + interface='org.freedesktop.DBus', + connection=connection) + dbus.bus_name = 'org.freedesktop.DBus' + dbus.call('AddMatch', 's', rule.serialise())