diff --git a/.gitignore b/.gitignore index 378fbcf..89c1f6e 100644 --- a/.gitignore +++ b/.gitignore @@ -3,7 +3,7 @@ __pycache__/ __pypackages__/ .mypy_cache/ .pytest_cache/ -*.py[cod] +*.py[cdio] *$py.class # BUILD ARTIFACTS diff --git a/CHANGELOG.md b/CHANGELOG.md index c22d009..9029e44 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,12 +15,22 @@ #
Changelog
+ + +## ... `v1.9.6` + +* The compiled version of the library now includes the type stub files (`.pyi`), so type checkers can properly check types. +* Made all type hints in the whole library way more strict and accurate. +* Removed leftover unnecessary runtime type-checks in several methods throughout the whole library. +* Renamed the `Spinner` class from the `console` module to `Throbber`, since that name is closer to what it's actually used for. + + ## 25.01.2026 `v1.9.5` -* Add new class property `Console.encoding`, which returns the encoding used by the console (*e.g.* `utf-8`*,* `cp1252`*, …*). -* Add multiple new class properties to the `System` class: +* Added a new class property `Console.encoding`, which returns the encoding used by the console (*e.g.* `utf-8`*,* `cp1252`*, …*). +* Added multiple new class properties to the `System` class: - `is_linux` Whether the current OS is Linux or not. - `is_mac` Whether the current OS is macOS or not. - `is_unix` Whether the current OS is a Unix-like OS (Linux, macOS, BSD, …) or not. diff --git a/pyproject.toml b/pyproject.toml index 28acabb..a164c7f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ build-backend = "setuptools.build_meta" [project] name = "xulbux" -version = "1.9.5" +version = "1.9.6" description = "A Python library to simplify common programming tasks." readme = "README.md" authors = [{ name = "XulbuX", email = "xulbux.real@gmail.com" }] @@ -130,6 +130,9 @@ package-dir = { "" = "src" } [tool.setuptools.packages.find] where = ["src"] +[tool.setuptools.package-data] +xulbux = ["py.typed", "*.pyi", "**/*.pyi"] + [tool.pytest.ini_options] minversion = "7.0" addopts = "-ra -q" diff --git a/setup.py b/setup.py index 5761a66..bd0ce5d 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,7 @@ from setuptools import setup from pathlib import Path +import subprocess +import sys import os @@ -10,14 +12,65 @@ def find_python_files(directory: str) -> list[str]: return python_files -# OPTIONALLY USE MYPYC COMPILATION +def generate_stubs_for_package(): + print("\nGenerating stub files with stubgen...\n") + + try: + skip_stubgen = { + Path("src/xulbux/base/types.py"), # COMPLEX TYPE DEFINITIONS + } + + src_dir = Path("src/xulbux") + generated_count = 0 + skipped_count = 0 + + for py_file in src_dir.rglob("*.py"): + pyi_file = py_file.with_suffix(".pyi") + rel_path = py_file.relative_to(src_dir.parent) + + if py_file in skip_stubgen: + pyi_file.write_text(py_file.read_text(encoding="utf-8"), encoding="utf-8") + print(f" copied {rel_path.with_suffix('.pyi')} (preserving type definitions)") + skipped_count += 1 + continue + + result = subprocess.run( + [sys.executable, "-m", "mypy.stubgen", + str(py_file), + "-o", "src", + "--include-private", + "--export-less"], + capture_output=True, + text=True + ) + + if result.returncode == 0: + print(f" generated {rel_path.with_suffix('.pyi')}") + generated_count += 1 + else: + print(f" failed {rel_path}") + if result.stderr: + print(f" {result.stderr.strip()}") + + print(f"\nStub generation complete! ({generated_count} generated, {skipped_count} copied)\n") + + except Exception as e: + fmt_error = "\n ".join(str(e).splitlines()) + print(f"[WARNING] Could not generate stubs:\n {fmt_error}\n") + + ext_modules = [] + +# OPTIONALLY USE MYPYC COMPILATION if os.environ.get("XULBUX_USE_MYPYC", "1") == "1": try: from mypyc.build import mypycify + print("\nCompiling with mypyc...\n") source_files = find_python_files("src/xulbux") - ext_modules = mypycify(source_files) + ext_modules = mypycify(source_files, opt_level="3") + + generate_stubs_for_package() except (ImportError, Exception) as e: fmt_error = "\n ".join(str(e).splitlines()) diff --git a/src/xulbux/__init__.py b/src/xulbux/__init__.py index 780b1c3..6e7086c 100644 --- a/src/xulbux/__init__.py +++ b/src/xulbux/__init__.py @@ -1,5 +1,5 @@ __package_name__ = "xulbux" -__version__ = "1.9.5" +__version__ = "1.9.6" __description__ = "A Python library to simplify common programming tasks." __status__ = "Production/Stable" diff --git a/src/xulbux/base/types.py b/src/xulbux/base/types.py index abd1562..c48e9fa 100644 --- a/src/xulbux/base/types.py +++ b/src/xulbux/base/types.py @@ -27,48 +27,48 @@ # ################################################## TypeAlias ################################################## -PathsList: TypeAlias = Union[list[Path], list[str], list[Path | str]] +PathsList: TypeAlias = Union[list[Path], list[str], list[Union[Path, str]]] """Union of all supported list types for a list of paths.""" -DataStructure: TypeAlias = Union[list, tuple, set, frozenset, dict] +DataStructure: TypeAlias = Union[list[Any], tuple[Any, ...], set[Any], frozenset[Any], dict[Any, Any]] """Union of supported data structures used in the `data` module.""" DataStructureTypes = (list, tuple, set, frozenset, dict) """Tuple of supported data structures used in the `data` module.""" -IndexIterable: TypeAlias = Union[list, tuple, set, frozenset] +IndexIterable: TypeAlias = Union[list[Any], tuple[Any, ...], set[Any], frozenset[Any]] """Union of all iterable types that support indexing operations.""" IndexIterableTypes = (list, tuple, set, frozenset) """Tuple of all iterable types that support indexing operations.""" Rgba: TypeAlias = Union[ tuple[Int_0_255, Int_0_255, Int_0_255], - tuple[Int_0_255, Int_0_255, Int_0_255, Float_0_1], + tuple[Int_0_255, Int_0_255, Int_0_255, Optional[Float_0_1]], list[Int_0_255], - list[Union[Int_0_255, Float_0_1]], - dict[str, Union[int, float]], + list[Union[Int_0_255, Optional[Float_0_1]]], + "RgbaDict", "rgba", str, ] """Matches all supported RGBA color value formats.""" Hsla: TypeAlias = Union[ tuple[Int_0_360, Int_0_100, Int_0_100], - tuple[Int_0_360, Int_0_100, Int_0_100, Float_0_1], + tuple[Int_0_360, Int_0_100, Int_0_100, Optional[Float_0_1]], list[Union[Int_0_360, Int_0_100]], - list[Union[Int_0_360, Int_0_100, Float_0_1]], - dict[str, Union[int, float]], + list[Union[Int_0_360, Int_0_100, Optional[Float_0_1]]], + "HslaDict", "hsla", str, ] """Matches all supported HSLA color value formats.""" Hexa: TypeAlias = Union[str, int, "hexa"] -"""Matches all supported hexadecimal color value formats.""" +"""Matches all supported HEXA color value formats.""" AnyRgba: TypeAlias = Any -"""Generic type alias for RGBA color values in any supported format (type checking disabled).""" +"""Generic type alias for RGBA color values in any format (type checking disabled).""" AnyHsla: TypeAlias = Any -"""Generic type alias for HSLA color values in any supported format (type checking disabled).""" +"""Generic type alias for HSLA color values in any format (type checking disabled).""" AnyHexa: TypeAlias = Any -"""Generic type alias for hexadecimal color values in any supported format (type checking disabled).""" +"""Generic type alias for HEXA color values in any format (type checking disabled).""" ArgParseConfig: TypeAlias = Union[set[str], "ArgConfigWithDefault", Literal["before", "after"]] """Matches the command-line-parsing configuration of a single argument.""" @@ -92,7 +92,6 @@ class ArgConfigWithDefault(TypedDict): flags: set[str] default: str - class ArgData(TypedDict): """Schema for the resulting data of parsing a single command-line argument.""" exists: bool @@ -101,6 +100,28 @@ class ArgData(TypedDict): flag: Optional[str] +class RgbaDict(TypedDict): + """Dictionary schema for RGBA color components.""" + r: Int_0_255 + g: Int_0_255 + b: Int_0_255 + a: Optional[Float_0_1] + +class HslaDict(TypedDict): + """Dictionary schema for HSLA color components.""" + h: Int_0_360 + s: Int_0_100 + l: Int_0_100 + a: Optional[Float_0_1] + +class HexaDict(TypedDict): + """Dictionary schema for HEXA color components.""" + r: str + g: str + b: str + a: Optional[str] + + class MissingLibsMsgs(TypedDict): """Configuration schema for custom messages in `System.check_libs()` when checking library dependencies.""" found_missing: str diff --git a/src/xulbux/cli/help.py b/src/xulbux/cli/help.py index 5828a16..0e80b48 100644 --- a/src/xulbux/cli/help.py +++ b/src/xulbux/cli/help.py @@ -72,6 +72,6 @@ def is_latest_version() -> Optional[bool]: def show_help() -> None: - FormatCodes._config_console() + FormatCodes._config_console() # type: ignore[protected-access] print(CLI_HELP) Console.pause_exit(pause=True, prompt=" [dim](Press any key to exit...)\n\n") diff --git a/src/xulbux/code.py b/src/xulbux/code.py index 1971176..431856e 100644 --- a/src/xulbux/code.py +++ b/src/xulbux/code.py @@ -6,6 +6,7 @@ from .regex import Regex from .data import Data +from typing import Any import regex as _rx @@ -48,7 +49,7 @@ def change_tab_size(cls, code: str, new_tab_size: int, remove_empty_lines: bool return "\n".join(code_lines) return code - result = [] + result: list[str] = [] for line in code_lines: indent_level = (len(line) - len(stripped := line.lstrip())) // tab_spaces result.append((" " * (indent_level * new_tab_size)) + stripped) @@ -56,11 +57,11 @@ def change_tab_size(cls, code: str, new_tab_size: int, remove_empty_lines: bool return "\n".join(result) @classmethod - def get_func_calls(cls, code: str) -> list: + def get_func_calls(cls, code: str) -> list[list[Any]]: """Will try to get all function calls and return them as a list.\n ------------------------------------------------------------------- - `code` -⠀the code to analyze""" - nested_func_calls = [] + nested_func_calls: list[list[Any]] = [] for _, func_attrs in (funcs := _rx.findall(r"(?i)" + Regex.func_call(), code)): if (nested_calls := _rx.findall(r"(?i)" + Regex.func_call(), func_attrs)): diff --git a/src/xulbux/color.py b/src/xulbux/color.py index 2a7eb4d..110439f 100644 --- a/src/xulbux/color.py +++ b/src/xulbux/color.py @@ -6,10 +6,10 @@ includes methods to work with colors in various formats. """ -from .base.types import AnyRgba, AnyHsla, AnyHexa, Rgba, Hsla, Hexa +from .base.types import RgbaDict, HslaDict, HexaDict, AnyRgba, AnyHsla, AnyHexa, Rgba, Hsla, Hexa from .regex import Regex -from typing import Iterator, Optional, Literal, cast +from typing import Iterator, Optional, Literal, Any, overload, cast import re as _re @@ -69,10 +69,18 @@ def __len__(self) -> int: """The number of components in the color (3 or 4).""" return 3 if self.a is None else 4 - def __iter__(self) -> Iterator: + def __iter__(self) -> Iterator[int | Optional[float]]: return iter((self.r, self.g, self.b) + (() if self.a is None else (self.a, ))) - def __getitem__(self, index: int) -> int | float: + @overload + def __getitem__(self, index: Literal[0, 1, 2]) -> int: + ... + + @overload + def __getitem__(self, index: Literal[3]) -> Optional[float]: + ... + + def __getitem__(self, index: int) -> int | Optional[float]: return ((self.r, self.g, self.b) + (() if self.a is None else (self.a, )))[index] def __eq__(self, other: object) -> bool: @@ -91,20 +99,20 @@ def __repr__(self) -> str: def __str__(self) -> str: return self.__repr__() - def dict(self) -> dict: + def dict(self) -> RgbaDict: """Returns the color components as a dictionary with keys `"r"`, `"g"`, `"b"` and optionally `"a"`.""" - return dict(r=self.r, g=self.g, b=self.b) if self.a is None else dict(r=self.r, g=self.g, b=self.b, a=self.a) + return {"r": self.r, "g": self.g, "b": self.b, "a": self.a} - def values(self) -> tuple: + def values(self) -> tuple[int, int, int, Optional[float]]: """Returns the color components as separate values `r, g, b, a`.""" return self.r, self.g, self.b, self.a - def to_hsla(self) -> "hsla": + def to_hsla(self) -> hsla: """Returns the color as `hsla()` color object.""" h, s, l = self._rgb_to_hsl(self.r, self.g, self.b) return hsla(h, s, l, self.a, _validate=False) - def to_hexa(self) -> "hexa": + def to_hexa(self) -> hexa: """Returns the color as `hexa()` color object.""" return hexa("", self.r, self.g, self.b, self.a) @@ -112,66 +120,51 @@ def has_alpha(self) -> bool: """Returns `True` if the color has an alpha channel and `False` otherwise.""" return self.a is not None - def lighten(self, amount: float) -> "rgba": + def lighten(self, amount: float) -> rgba: """Increases the colors lightness by the specified amount in range [0.0, 1.0] inclusive.""" - if not isinstance(amount, float): - raise TypeError(f"The 'amount' parameter must be a float, got {type(amount)}") - elif not (0.0 <= amount <= 1.0): + if not (0.0 <= amount <= 1.0): raise ValueError(f"The 'amount' parameter must be in range [0.0, 1.0] inclusive, got {amount!r}") self.r, self.g, self.b, self.a = self.to_hsla().lighten(amount).to_rgba().values() return rgba(self.r, self.g, self.b, self.a, _validate=False) - def darken(self, amount: float) -> "rgba": + def darken(self, amount: float) -> rgba: """Decreases the colors lightness by the specified amount in range [0.0, 1.0] inclusive.""" - if not isinstance(amount, float): - raise TypeError(f"The 'amount' parameter must be a float, got {type(amount)}") - elif not (0.0 <= amount <= 1.0): + if not (0.0 <= amount <= 1.0): raise ValueError(f"The 'amount' parameter must be in range [0.0, 1.0] inclusive, got {amount!r}") self.r, self.g, self.b, self.a = self.to_hsla().darken(amount).to_rgba().values() return rgba(self.r, self.g, self.b, self.a, _validate=False) - def saturate(self, amount: float) -> "rgba": + def saturate(self, amount: float) -> rgba: """Increases the colors saturation by the specified amount in range [0.0, 1.0] inclusive.""" - if not isinstance(amount, float): - raise TypeError(f"The 'amount' parameter must be a float, got {type(amount)}") - elif not (0.0 <= amount <= 1.0): + if not (0.0 <= amount <= 1.0): raise ValueError(f"The 'amount' parameter must be in range [0.0, 1.0] inclusive, got {amount!r}") self.r, self.g, self.b, self.a = self.to_hsla().saturate(amount).to_rgba().values() return rgba(self.r, self.g, self.b, self.a, _validate=False) - def desaturate(self, amount: float) -> "rgba": + def desaturate(self, amount: float) -> rgba: """Decreases the colors saturation by the specified amount in range [0.0, 1.0] inclusive.""" - if not isinstance(amount, float): - raise TypeError(f"The 'amount' parameter must be a float, got {type(amount)}") - elif not (0.0 <= amount <= 1.0): + if not (0.0 <= amount <= 1.0): raise ValueError(f"The 'amount' parameter must be in range [0.0, 1.0] inclusive, got {amount!r}") self.r, self.g, self.b, self.a = self.to_hsla().desaturate(amount).to_rgba().values() return rgba(self.r, self.g, self.b, self.a, _validate=False) - def rotate(self, degrees: int) -> "rgba": + def rotate(self, degrees: int) -> rgba: """Rotates the colors hue by the specified number of degrees.""" - if not isinstance(degrees, int): - raise TypeError(f"The 'degrees' parameter must be an integer, got {type(degrees)}") - self.r, self.g, self.b, self.a = self.to_hsla().rotate(degrees).to_rgba().values() return rgba(self.r, self.g, self.b, self.a, _validate=False) - def invert(self, invert_alpha: bool = False) -> "rgba": + def invert(self, invert_alpha: bool = False) -> rgba: """Inverts the color by rotating hue by 180 degrees and inverting lightness.""" - if not isinstance(invert_alpha, bool): - raise TypeError(f"The 'invert_alpha' parameter must be a boolean, got {type(invert_alpha)}") - self.r, self.g, self.b = 255 - self.r, 255 - self.g, 255 - self.b if invert_alpha and self.a is not None: self.a = 1 - self.a - return rgba(self.r, self.g, self.b, self.a, _validate=False) - def grayscale(self, method: Literal["wcag2", "wcag3", "simple", "bt601"] = "wcag2") -> "rgba": + def grayscale(self, method: Literal["wcag2", "wcag3", "simple", "bt601"] = "wcag2") -> rgba: """Converts the color to grayscale using the luminance formula.\n --------------------------------------------------------------------------- - `method` -⠀the luminance calculation method to use: @@ -183,7 +176,7 @@ def grayscale(self, method: Literal["wcag2", "wcag3", "simple", "bt601"] = "wcag self.r = self.g = self.b = int(Color.luminance(self.r, self.g, self.b, method=method)) return rgba(self.r, self.g, self.b, self.a, _validate=False) - def blend(self, other: Rgba, ratio: float = 0.5, additive_alpha: bool = False) -> "rgba": + def blend(self, other: Rgba, ratio: float = 0.5, additive_alpha: bool = False) -> rgba: """Blends the current color with another color using the specified ratio in range [0.0, 1.0] inclusive.\n ---------------------------------------------------------------------------------------------------------- - `other` -⠀the other RGBA color to blend with @@ -192,27 +185,20 @@ def blend(self, other: Rgba, ratio: float = 0.5, additive_alpha: bool = False) - * if `ratio` is `0.5` it means 50% of both colors (1:1 mixture) * if `ratio` is `1.0` it means 0% of the current color and 100% of the `other` color (0:2 mixture) - `additive_alpha` -⠀whether to blend the alpha channels additively or not""" - if not isinstance(other, rgba): - if Color.is_valid_rgba(other): - other = Color.to_rgba(other) - else: - raise TypeError(f"The 'other' parameter must be a valid RGBA color, got {type(other)}") - if not isinstance(ratio, float): - raise TypeError(f"The 'ratio' parameter must be a float, got {type(ratio)}") - elif not (0.0 <= ratio <= 1.0): + if not (0.0 <= ratio <= 1.0): raise ValueError(f"The 'ratio' parameter must be in range [0.0, 1.0] inclusive, got {ratio!r}") - if not isinstance(additive_alpha, bool): - raise TypeError(f"The 'additive_alpha' parameter must be a boolean, got {type(additive_alpha)}") + + other_rgba = Color.to_rgba(other) ratio *= 2 - self.r = int(max(0, min(255, int(round((self.r * (2 - ratio)) + (other.r * ratio)))))) - self.g = int(max(0, min(255, int(round((self.g * (2 - ratio)) + (other.g * ratio)))))) - self.b = int(max(0, min(255, int(round((self.b * (2 - ratio)) + (other.b * ratio)))))) - none_alpha = self.a is None and (len(other) <= 3 or other[3] is None) + self.r = int(max(0, min(255, int((self.r * (2 - ratio)) + (other_rgba.r * ratio) + 0.5)))) + self.g = int(max(0, min(255, int((self.g * (2 - ratio)) + (other_rgba.g * ratio) + 0.5)))) + self.b = int(max(0, min(255, int((self.b * (2 - ratio)) + (other_rgba.b * ratio) + 0.5)))) + none_alpha = self.a is None and (len(other_rgba) <= 3 or other_rgba[3] is None) if not none_alpha: - self_a = 1 if self.a is None else self.a - other_a = (other[3] if other[3] is not None else 1) if len(other) > 3 else 1 + self_a: float = 1.0 if self.a is None else self.a + other_a: float = cast(float, 1.0 if other_rgba[3] is None else other_rgba[3]) if len(other_rgba) > 3 else 1.0 if additive_alpha: self.a = max(0, min(1, (self_a * (2 - ratio)) + (other_a * ratio))) @@ -240,21 +226,19 @@ def is_opaque(self) -> bool: """Returns `True` if the color has no transparency.""" return self.a == 1 or self.a is None - def with_alpha(self, alpha: float) -> "rgba": + def with_alpha(self, alpha: float) -> rgba: """Returns a new color with the specified alpha value.""" - if not isinstance(alpha, float): - raise TypeError(f"The 'alpha' parameter must be a float, got {type(alpha)}") - elif not (0.0 <= alpha <= 1.0): + if not (0.0 <= alpha <= 1.0): raise ValueError(f"The 'alpha' parameter must be in range [0.0, 1.0] inclusive, got {alpha!r}") return rgba(self.r, self.g, self.b, alpha, _validate=False) - def complementary(self) -> "rgba": + def complementary(self) -> rgba: """Returns the complementary color (180 degrees on the color wheel).""" return self.to_hsla().complementary().to_rgba() @staticmethod - def _rgb_to_hsl(r: int, g: int, b: int) -> tuple: + def _rgb_to_hsl(r: int, g: int, b: int) -> tuple[int, int, int]: """Internal method to convert RGB to HSL color space.""" _r, _g, _b = r / 255.0, g / 255.0, b / 255.0 max_c, min_c = max(_r, _g, _b), min(_r, _g, _b) @@ -333,10 +317,18 @@ def __len__(self) -> int: """The number of components in the color (3 or 4).""" return 3 if self.a is None else 4 - def __iter__(self) -> Iterator: + def __iter__(self) -> Iterator[int | Optional[float]]: return iter((self.h, self.s, self.l) + (() if self.a is None else (self.a, ))) - def __getitem__(self, index: int) -> int | float: + @overload + def __getitem__(self, index: Literal[0, 1, 2]) -> int: + ... + + @overload + def __getitem__(self, index: Literal[3]) -> Optional[float]: + ... + + def __getitem__(self, index: int) -> int | Optional[float]: return ((self.h, self.s, self.l) + (() if self.a is None else (self.a, )))[index] def __eq__(self, other: object) -> bool: @@ -355,20 +347,20 @@ def __repr__(self) -> str: def __str__(self) -> str: return self.__repr__() - def dict(self) -> dict: + def dict(self) -> HslaDict: """Returns the color components as a dictionary with keys `"h"`, `"s"`, `"l"` and optionally `"a"`.""" - return dict(h=self.h, s=self.s, l=self.l) if self.a is None else dict(h=self.h, s=self.s, l=self.l, a=self.a) + return {"h": self.h, "s": self.s, "l": self.l, "a": self.a} - def values(self) -> tuple: + def values(self) -> tuple[int, int, int, Optional[float]]: """Returns the color components as separate values `h, s, l, a`.""" return self.h, self.s, self.l, self.a - def to_rgba(self) -> "rgba": + def to_rgba(self) -> rgba: """Returns the color as `rgba()` color object.""" r, g, b = self._hsl_to_rgb(self.h, self.s, self.l) return rgba(r, g, b, self.a, _validate=False) - def to_hexa(self) -> "hexa": + def to_hexa(self) -> hexa: """Returns the color as `hexa()` color object.""" r, g, b = self._hsl_to_rgb(self.h, self.s, self.l) return hexa("", r, g, b, self.a) @@ -377,59 +369,45 @@ def has_alpha(self) -> bool: """Returns `True` if the color has an alpha channel and `False` otherwise.""" return self.a is not None - def lighten(self, amount: float) -> "hsla": + def lighten(self, amount: float) -> hsla: """Increases the colors lightness by the specified amount in range [0.0, 1.0] inclusive.""" - if not isinstance(amount, float): - raise TypeError(f"The 'amount' parameter must be a float, got {type(amount)}") - elif not (0.0 <= amount <= 1.0): + if not (0.0 <= amount <= 1.0): raise ValueError(f"The 'amount' parameter must be in range [0.0, 1.0] inclusive, got {amount!r}") self.l = int(min(100, self.l + (100 - self.l) * amount)) return hsla(self.h, self.s, self.l, self.a, _validate=False) - def darken(self, amount: float) -> "hsla": + def darken(self, amount: float) -> hsla: """Decreases the colors lightness by the specified amount in range [0.0, 1.0] inclusive.""" - if not isinstance(amount, float): - raise TypeError(f"The 'amount' parameter must be a float, got {type(amount)}") - elif not (0.0 <= amount <= 1.0): + if not (0.0 <= amount <= 1.0): raise ValueError(f"The 'amount' parameter must be in range [0.0, 1.0] inclusive, got {amount!r}") self.l = int(max(0, self.l * (1 - amount))) return hsla(self.h, self.s, self.l, self.a, _validate=False) - def saturate(self, amount: float) -> "hsla": + def saturate(self, amount: float) -> hsla: """Increases the colors saturation by the specified amount in range [0.0, 1.0] inclusive.""" - if not isinstance(amount, float): - raise TypeError(f"The 'amount' parameter must be a float, got {type(amount)}") - elif not (0.0 <= amount <= 1.0): + if not (0.0 <= amount <= 1.0): raise ValueError(f"The 'amount' parameter must be in range [0.0, 1.0] inclusive, got {amount!r}") self.s = int(min(100, self.s + (100 - self.s) * amount)) return hsla(self.h, self.s, self.l, self.a, _validate=False) - def desaturate(self, amount: float) -> "hsla": + def desaturate(self, amount: float) -> hsla: """Decreases the colors saturation by the specified amount in range [0.0, 1.0] inclusive.""" - if not isinstance(amount, float): - raise TypeError(f"The 'amount' parameter must be a float, got {type(amount)}") - elif not (0.0 <= amount <= 1.0): + if not (0.0 <= amount <= 1.0): raise ValueError(f"The 'amount' parameter must be in range [0.0, 1.0] inclusive, got {amount!r}") self.s = int(max(0, self.s * (1 - amount))) return hsla(self.h, self.s, self.l, self.a, _validate=False) - def rotate(self, degrees: int) -> "hsla": + def rotate(self, degrees: int) -> hsla: """Rotates the colors hue by the specified number of degrees.""" - if not isinstance(degrees, int): - raise TypeError(f"The 'degrees' parameter must be an integer, got {type(degrees)}") - self.h = (self.h + degrees) % 360 return hsla(self.h, self.s, self.l, self.a, _validate=False) - def invert(self, invert_alpha: bool = False) -> "hsla": + def invert(self, invert_alpha: bool = False) -> hsla: """Inverts the color by rotating hue by 180 degrees and inverting lightness.""" - if not isinstance(invert_alpha, bool): - raise TypeError(f"The 'invert_alpha' parameter must be a boolean, got {type(invert_alpha)}") - self.h = (self.h + 180) % 360 self.l = 100 - self.l if invert_alpha and self.a is not None: @@ -437,7 +415,7 @@ def invert(self, invert_alpha: bool = False) -> "hsla": return hsla(self.h, self.s, self.l, self.a, _validate=False) - def grayscale(self, method: Literal["wcag2", "wcag3", "simple", "bt601"] = "wcag2") -> "hsla": + def grayscale(self, method: Literal["wcag2", "wcag3", "simple", "bt601"] = "wcag2") -> hsla: """Converts the color to grayscale using the luminance formula.\n --------------------------------------------------------------------------- - `method` -⠀the luminance calculation method to use: @@ -451,7 +429,7 @@ def grayscale(self, method: Literal["wcag2", "wcag3", "simple", "bt601"] = "wcag self.h, self.s, self.l, _ = rgba(l, l, l, _validate=False).to_hsla().values() return hsla(self.h, self.s, self.l, self.a, _validate=False) - def blend(self, other: Hsla, ratio: float = 0.5, additive_alpha: bool = False) -> "hsla": + def blend(self, other: Hsla, ratio: float = 0.5, additive_alpha: bool = False) -> hsla: """Blends the current color with another color using the specified ratio in range [0.0, 1.0] inclusive.\n ---------------------------------------------------------------------------------------------------------- - `other` -⠀the other HSLA color to blend with @@ -462,12 +440,8 @@ def blend(self, other: Hsla, ratio: float = 0.5, additive_alpha: bool = False) - - `additive_alpha` -⠀whether to blend the alpha channels additively or not""" if not Color.is_valid_hsla(other): raise TypeError(f"The 'other' parameter must be a valid HSLA color, got {type(other)}") - if not isinstance(ratio, float): - raise TypeError(f"The 'ratio' parameter must be a float, got {type(ratio)}") - elif not (0.0 <= ratio <= 1.0): + if not (0.0 <= ratio <= 1.0): raise ValueError(f"The 'ratio' parameter must be in range [0.0, 1.0] inclusive, got {ratio!r}") - if not isinstance(additive_alpha, bool): - raise TypeError(f"The 'additive_alpha' parameter must be a boolean, got {type(additive_alpha)}") self.h, self.s, self.l, self.a = self.to_rgba().blend(Color.to_rgba(other), ratio, additive_alpha).to_hsla().values() return hsla(self.h, self.s, self.l, self.a, _validate=False) @@ -488,7 +462,7 @@ def is_opaque(self) -> bool: """Returns `True` if the color has no transparency.""" return self.a == 1 or self.a is None - def with_alpha(self, alpha: float) -> "hsla": + def with_alpha(self, alpha: float) -> hsla: """Returns a new color with the specified alpha value.""" if not isinstance(alpha, float): raise TypeError(f"The 'alpha' parameter must be a float, got {type(alpha)}") @@ -497,12 +471,12 @@ def with_alpha(self, alpha: float) -> "hsla": return hsla(self.h, self.s, self.l, alpha, _validate=False) - def complementary(self) -> "hsla": + def complementary(self) -> hsla: """Returns the complementary color (180 degrees on the color wheel).""" return hsla((self.h + 180) % 360, self.s, self.l, self.a, _validate=False) @classmethod - def _hsl_to_rgb(cls, h: int, s: int, l: int) -> tuple: + def _hsl_to_rgb(cls, h: int, s: int, l: int) -> tuple[int, int, int]: """Internal method to convert HSL to RGB color space.""" _h, _s, _l = h / 360, s / 100, l / 100 @@ -621,20 +595,18 @@ def __init__( else: raise ValueError(f"Invalid HEXA color string '{color}'. Must be in formats RGB, RGBA, RRGGBB or RRGGBBAA.") - elif isinstance(color, int): - self.r, self.g, self.b, self.a = Color.hex_int_to_rgba(color).values() else: - raise TypeError(f"The 'color' parameter must be a string or integer, got {type(color)}") + self.r, self.g, self.b, self.a = Color.hex_int_to_rgba(color).values() def __len__(self) -> int: """The number of components in the color (3 or 4).""" return 3 if self.a is None else 4 - def __iter__(self) -> Iterator: + def __iter__(self) -> Iterator[str]: return iter((f"{self.r:02X}", f"{self.g:02X}", f"{self.b:02X}") + (() if self.a is None else (f"{int(self.a * 255):02X}", ))) - def __getitem__(self, index: int) -> str | int: + def __getitem__(self, index: int) -> str: return ((f"{self.r:02X}", f"{self.g:02X}", f"{self.b:02X}") \ + (() if self.a is None else (f"{int(self.a * 255):02X}", )))[index] @@ -654,29 +626,16 @@ def __repr__(self) -> str: def __str__(self) -> str: return f"#{self.r:02X}{self.g:02X}{self.b:02X}{'' if self.a is None else f'{int(self.a * 255):02X}'}" - def dict(self) -> dict: + def dict(self) -> HexaDict: """Returns the color components as a dictionary with hex string values for keys `"r"`, `"g"`, `"b"` and optionally `"a"`.""" - return ( - dict(r=f"{self.r:02X}", g=f"{self.g:02X}", b=f"{self.b:02X}") if self.a is None else dict( - r=f"{self.r:02X}", - g=f"{self.g:02X}", - b=f"{self.b:02X}", - a=f"{int(self.a * 255):02X}", - ) - ) + return {"r": f"{self.r:02X}", "g": f"{self.g:02X}", "b": f"{self.b:02X}", "a": None if self.a is None else f"{int(self.a * 255):02X}"} - def values(self, round_alpha: bool = True) -> tuple: + def values(self, round_alpha: bool = True) -> tuple[int, int, int, Optional[float]]: """Returns the color components as separate values `r, g, b, a`.""" - if not isinstance(round_alpha, bool): - raise TypeError(f"The 'round_alpha' parameter must be a boolean, got {type(round_alpha)}") - return self.r, self.g, self.b, None if self.a is None else (round(self.a, 2) if round_alpha else self.a) - def to_rgba(self, round_alpha: bool = True) -> "rgba": + def to_rgba(self, round_alpha: bool = True) -> rgba: """Returns the color as `rgba()` color object.""" - if not isinstance(round_alpha, bool): - raise TypeError(f"The 'round_alpha' parameter must be a boolean, got {type(round_alpha)}") - return rgba( self.r, self.g, @@ -685,77 +644,60 @@ def to_rgba(self, round_alpha: bool = True) -> "rgba": _validate=False, ) - def to_hsla(self, round_alpha: bool = True) -> "hsla": + def to_hsla(self, round_alpha: bool = True) -> hsla: """Returns the color as `hsla()` color object.""" - if not isinstance(round_alpha, bool): - raise TypeError(f"The 'round_alpha' parameter must be a boolean, got {type(round_alpha)}") - return self.to_rgba(round_alpha).to_hsla() def has_alpha(self) -> bool: """Returns `True` if the color has an alpha channel and `False` otherwise.""" return self.a is not None - def lighten(self, amount: float) -> "hexa": + def lighten(self, amount: float) -> hexa: """Increases the colors lightness by the specified amount in range [0.0, 1.0] inclusive.""" - if not isinstance(amount, float): - raise TypeError(f"The 'amount' parameter must be a float, got {type(amount)}") - elif not (0.0 <= amount <= 1.0): + if not (0.0 <= amount <= 1.0): raise ValueError(f"The 'amount' parameter must be in range [0.0, 1.0] inclusive, got {amount!r}") self.r, self.g, self.b, self.a = self.to_rgba(False).lighten(amount).values() return hexa("", self.r, self.g, self.b, self.a) - def darken(self, amount: float) -> "hexa": + def darken(self, amount: float) -> hexa: """Decreases the colors lightness by the specified amount in range [0.0, 1.0] inclusive.""" - if not isinstance(amount, float): - raise TypeError(f"The 'amount' parameter must be a float, got {type(amount)}") - elif not (0.0 <= amount <= 1.0): + if not (0.0 <= amount <= 1.0): raise ValueError(f"The 'amount' parameter must be in range [0.0, 1.0] inclusive, got {amount!r}") self.r, self.g, self.b, self.a = self.to_rgba(False).darken(amount).values() return hexa("", self.r, self.g, self.b, self.a) - def saturate(self, amount: float) -> "hexa": + def saturate(self, amount: float) -> hexa: """Increases the colors saturation by the specified amount in range [0.0, 1.0] inclusive.""" - if not isinstance(amount, float): - raise TypeError(f"The 'amount' parameter must be a float, got {type(amount)}") - elif not (0.0 <= amount <= 1.0): + if not (0.0 <= amount <= 1.0): raise ValueError(f"The 'amount' parameter must be in range [0.0, 1.0] inclusive, got {amount!r}") self.r, self.g, self.b, self.a = self.to_rgba(False).saturate(amount).values() return hexa("", self.r, self.g, self.b, self.a) - def desaturate(self, amount: float) -> "hexa": + def desaturate(self, amount: float) -> hexa: """Decreases the colors saturation by the specified amount in range [0.0, 1.0] inclusive.""" - if not isinstance(amount, float): - raise TypeError(f"The 'amount' parameter must be a float, got {type(amount)}") - elif not (0.0 <= amount <= 1.0): + if not (0.0 <= amount <= 1.0): raise ValueError(f"The 'amount' parameter must be in range [0.0, 1.0] inclusive, got {amount!r}") self.r, self.g, self.b, self.a = self.to_rgba(False).desaturate(amount).values() return hexa("", self.r, self.g, self.b, self.a) - def rotate(self, degrees: int) -> "hexa": + def rotate(self, degrees: int) -> hexa: """Rotates the colors hue by the specified number of degrees.""" - if not isinstance(degrees, int): - raise TypeError(f"The 'degrees' parameter must be an integer, got {type(degrees)}") - self.r, self.g, self.b, self.a = self.to_rgba(False).rotate(degrees).values() return hexa("", self.r, self.g, self.b, self.a) - def invert(self, invert_alpha: bool = False) -> "hexa": + def invert(self, invert_alpha: bool = False) -> hexa: """Inverts the color by rotating hue by 180 degrees and inverting lightness.""" - if not isinstance(invert_alpha, bool): - raise TypeError(f"The 'invert_alpha' parameter must be a boolean, got {type(invert_alpha)}") - self.r, self.g, self.b, self.a = self.to_rgba(False).invert().values() if invert_alpha and self.a is not None: self.a = 1 - self.a return hexa("", self.r, self.g, self.b, self.a) - def grayscale(self, method: Literal["wcag2", "wcag3", "simple", "bt601"] = "wcag2") -> "hexa": + def grayscale(self, method: Literal["wcag2", "wcag3", "simple", "bt601"] = "wcag2") -> hexa: """Converts the color to grayscale using the luminance formula.\n --------------------------------------------------------------------------- - `method` -⠀the luminance calculation method to use: @@ -767,7 +709,7 @@ def grayscale(self, method: Literal["wcag2", "wcag3", "simple", "bt601"] = "wcag self.r = self.g = self.b = int(Color.luminance(self.r, self.g, self.b, method=method)) return hexa("", self.r, self.g, self.b, self.a) - def blend(self, other: Hexa, ratio: float = 0.5, additive_alpha: bool = False) -> "hexa": + def blend(self, other: Hexa, ratio: float = 0.5, additive_alpha: bool = False) -> hexa: """Blends the current color with another color using the specified ratio in range [0.0, 1.0] inclusive.\n ---------------------------------------------------------------------------------------------------------- - `other` -⠀the other HEXA color to blend with @@ -778,12 +720,8 @@ def blend(self, other: Hexa, ratio: float = 0.5, additive_alpha: bool = False) - - `additive_alpha` -⠀whether to blend the alpha channels additively or not""" if not Color.is_valid_hexa(other): raise TypeError(f"The 'other' parameter must be a valid HEXA color, got {type(other)}") - if not isinstance(ratio, float): - raise TypeError(f"The 'ratio' parameter must be a float, got {type(ratio)}") - elif not (0.0 <= ratio <= 1.0): + if not (0.0 <= ratio <= 1.0): raise ValueError(f"The 'ratio' parameter must be in range [0.0, 1.0] inclusive, got {ratio!r}") - if not isinstance(additive_alpha, bool): - raise TypeError(f"The 'additive_alpha' parameter must be a boolean, got {type(additive_alpha)}") self.r, self.g, self.b, self.a = self.to_rgba(False).blend(Color.to_rgba(other), ratio, additive_alpha).values() return hexa("", self.r, self.g, self.b, self.a) @@ -804,7 +742,7 @@ def is_opaque(self) -> bool: """Returns `True` if the color has no transparency (`alpha == 1.0`).""" return self.a == 1 or self.a is None - def with_alpha(self, alpha: float) -> "hexa": + def with_alpha(self, alpha: float) -> hexa: """Returns a new color with the specified alpha value.""" if not isinstance(alpha, float): raise TypeError(f"The 'alpha' parameter must be a float, got {type(alpha)}") @@ -813,7 +751,7 @@ def with_alpha(self, alpha: float) -> "hexa": return hexa("", self.r, self.g, self.b, alpha) - def complementary(self) -> "hexa": + def complementary(self) -> hexa: """Returns the complementary color (180 degrees on the color wheel).""" return self.to_hsla(False).complementary().to_hexa() @@ -827,32 +765,41 @@ def is_valid_rgba(cls, color: AnyRgba, allow_alpha: bool = True) -> bool: ----------------------------------------------------------------- - `color` -⠀the color to check (can be in any supported format) - `allow_alpha` -⠀whether to allow alpha channel in the color""" - if not isinstance(allow_alpha, bool): - raise TypeError(f"The 'new_tab_size' parameter must be an boolean, got {type(allow_alpha)}") - try: if isinstance(color, rgba): return True elif isinstance(color, (list, tuple)): - if allow_alpha and cls.has_alpha(color): + array_color = cast(list[Any] | tuple[Any, ...], color) + + if (allow_alpha \ + and len(array_color) == 4 + and all(isinstance(val, int) for val in array_color[:3]) + and isinstance(array_color[3], (float, type(None))) + ): return ( - 0 <= color[0] <= 255 and 0 <= color[1] <= 255 and 0 <= color[2] <= 255 - and (0 <= color[3] <= 1 or color[3] is None) + 0 <= array_color[0] <= 255 and 0 <= array_color[1] <= 255 and 0 <= array_color[2] <= 255 + and (array_color[3] is None or 0 <= array_color[3] <= 1) ) - elif len(color) == 3: - return 0 <= color[0] <= 255 and 0 <= color[1] <= 255 and 0 <= color[2] <= 255 + elif len(array_color) == 3 and all(isinstance(val, int) for val in array_color): + return 0 <= array_color[0] <= 255 and 0 <= array_color[1] <= 255 and 0 <= array_color[2] <= 255 else: return False elif isinstance(color, dict): - if allow_alpha and cls.has_alpha(color): + dict_color = cast(dict[str, Any], color) + + if (allow_alpha \ + and len(dict_color) == 4 + and all(isinstance(dict_color.get(ch), int) for ch in ("r", "g", "b")) + and isinstance(dict_color.get("a", "no alpha"), (float, type(None))) + ): return ( - 0 <= color["r"] <= 255 and 0 <= color["g"] <= 255 and 0 <= color["b"] <= 255 - and (0 <= color["a"] <= 1 or color["a"] is None) + 0 <= dict_color["r"] <= 255 and 0 <= dict_color["g"] <= 255 and 0 <= dict_color["b"] <= 255 + and (dict_color["a"] is None or 0 <= dict_color["a"] <= 1) ) - elif len(color) == 3: - return 0 <= color["r"] <= 255 and 0 <= color["g"] <= 255 and 0 <= color["b"] <= 255 + elif len(dict_color) == 3 and all(isinstance(dict_color.get(ch), int) for ch in ("r", "g", "b")): + return 0 <= dict_color["r"] <= 255 and 0 <= dict_color["g"] <= 255 and 0 <= dict_color["b"] <= 255 else: return False @@ -874,24 +821,36 @@ def is_valid_hsla(cls, color: AnyHsla, allow_alpha: bool = True) -> bool: return True elif isinstance(color, (list, tuple)): - if allow_alpha and cls.has_alpha(color): + array_color = cast(list[Any] | tuple[Any, ...], color) + + if (allow_alpha \ + and len(array_color) == 4 + and all(isinstance(val, int) for val in array_color[:3]) + and isinstance(array_color[3], (float, type(None))) + ): return ( - 0 <= color[0] <= 360 and 0 <= color[1] <= 100 and 0 <= color[2] <= 100 - and (0 <= color[3] <= 1 or color[3] is None) + 0 <= array_color[0] <= 360 and 0 <= array_color[1] <= 100 and 0 <= array_color[2] <= 100 + and (array_color[3] is None or 0 <= array_color[3] <= 1) ) - elif len(color) == 3: - return 0 <= color[0] <= 360 and 0 <= color[1] <= 100 and 0 <= color[2] <= 100 + elif len(array_color) == 3 and all(isinstance(val, int) for val in array_color): + return 0 <= array_color[0] <= 360 and 0 <= array_color[1] <= 100 and 0 <= array_color[2] <= 100 else: return False elif isinstance(color, dict): - if allow_alpha and cls.has_alpha(color): + dict_color = cast(dict[str, Any], color) + + if (allow_alpha \ + and len(dict_color) == 4 + and all(isinstance(dict_color.get(ch), int) for ch in ("h", "s", "l")) + and isinstance(dict_color.get("a", "no alpha"), (float, type(None))) + ): return ( - 0 <= color["h"] <= 360 and 0 <= color["s"] <= 100 and 0 <= color["l"] <= 100 - and (0 <= color["a"] <= 1 or color["a"] is None) + 0 <= dict_color["h"] <= 360 and 0 <= dict_color["s"] <= 100 and 0 <= dict_color["l"] <= 100 + and (dict_color["a"] is None or 0 <= dict_color["a"] <= 1) ) - elif len(color) == 3: - return 0 <= color["h"] <= 360 and 0 <= color["s"] <= 100 and 0 <= color["l"] <= 100 + elif len(dict_color) == 3 and all(isinstance(dict_color.get(ch), int) for ch in ("h", "s", "l")): + return 0 <= dict_color["h"] <= 360 and 0 <= dict_color["s"] <= 100 and 0 <= dict_color["l"] <= 100 else: return False @@ -972,9 +931,9 @@ def has_alpha(cls, color: Rgba | Hsla | Hexa) -> bool: if parsed_hsla := cls.str_to_hsla(color, only_first=True): return cast(hsla, parsed_hsla).has_alpha() - elif isinstance(color, (list, tuple)) and len(color) == 4 and color[3] is not None: + elif isinstance(color, (list, tuple)) and len(color) == 4: return True - elif isinstance(color, dict) and len(color) == 4 and color["a"] is not None: + elif isinstance(color, dict) and len(color) == 4: return True return False @@ -987,11 +946,11 @@ def to_rgba(cls, color: Rgba | Hsla | Hexa) -> rgba: if isinstance(color, (hsla, hexa)): return color.to_rgba() elif cls.is_valid_hsla(color): - return cls._parse_hsla(color).to_rgba() + return cls._parse_hsla(cast(Hsla, color)).to_rgba() elif cls.is_valid_hexa(color): return hexa(cast(str | int, color)).to_rgba() elif cls.is_valid_rgba(color): - return cls._parse_rgba(color) + return cls._parse_rgba(cast(Rgba, color)) raise ValueError(f"Could not convert color {color!r} to RGBA.") @classmethod @@ -1002,11 +961,11 @@ def to_hsla(cls, color: Rgba | Hsla | Hexa) -> hsla: if isinstance(color, (rgba, hexa)): return color.to_hsla() elif cls.is_valid_rgba(color): - return cls._parse_rgba(color).to_hsla() + return cls._parse_rgba(cast(Rgba, color)).to_hsla() elif cls.is_valid_hexa(color): return hexa(cast(str | int, color)).to_hsla() elif cls.is_valid_hsla(color): - return cls._parse_hsla(color) + return cls._parse_hsla(cast(Hsla, color)) raise ValueError(f"Could not convert color {color!r} to HSLA.") @classmethod @@ -1017,9 +976,9 @@ def to_hexa(cls, color: Rgba | Hsla | Hexa) -> hexa: if isinstance(color, (rgba, hsla)): return color.to_hexa() elif cls.is_valid_rgba(color): - return cls._parse_rgba(color).to_hexa() + return cls._parse_rgba(cast(Rgba, color)).to_hexa() elif cls.is_valid_hsla(color): - return cls._parse_hsla(color).to_hexa() + return cls._parse_hsla(cast(Hsla, color)).to_hexa() elif cls.is_valid_hexa(color): return color if isinstance(color, hexa) else hexa(cast(str | int, color)) raise ValueError(f"Could not convert color {color!r} to HEXA") @@ -1220,8 +1179,8 @@ def text_color_for_on_bg(cls, text_bg_color: Rgba | Hexa) -> rgba | hexa | int: - `text_bg_color` -⠀the background color (can be in RGBA or HEXA format)""" was_hexa, was_int = cls.is_valid_hexa(text_bg_color), isinstance(text_bg_color, int) - text_bg_color = cls.to_rgba(text_bg_color) - brightness = 0.2126 * text_bg_color[0] + 0.7152 * text_bg_color[1] + 0.0722 * text_bg_color[2] + text_bg_rgba = cls.to_rgba(text_bg_color) + brightness = 0.2126 * text_bg_rgba[0] + 0.7152 * text_bg_rgba[1] + 0.0722 * text_bg_rgba[2] return ( (0xFFFFFF if was_int else hexa("", 255, 255, 255)) if was_hexa \ @@ -1238,18 +1197,17 @@ def adjust_lightness(cls, color: Rgba | Hexa, lightness_change: float) -> rgba | - `color` -⠀the color to adjust (can be in RGBA or HEXA format) - `lightness_change` -⠀the amount to change the lightness by, in range `-1.0` (darken by 100%) and `1.0` (lighten by 100%)""" - was_hexa = cls.is_valid_hexa(color) - if not (-1.0 <= lightness_change <= 1.0): raise ValueError( f"The 'lightness_change' parameter must be in range [-1.0, 1.0] inclusive, got {lightness_change!r}" ) - hsla_color: hsla = cls.to_hsla(color) + was_hexa = cls.is_valid_hexa(color) + hsla_color = cls.to_hsla(color) h, s, l, a = ( int(hsla_color[0]), int(hsla_color[1]), int(hsla_color[2]), \ - hsla_color[3] if cls.has_alpha(hsla_color) else None + hsla_color[3] if hsla_color.has_alpha() else None ) l = int(max(0, min(100, l + lightness_change * 100))) @@ -1265,18 +1223,17 @@ def adjust_saturation(cls, color: Rgba | Hexa, saturation_change: float) -> rgba - `color` -⠀the color to adjust (can be in RGBA or HEXA format) - `saturation_change` -⠀the amount to change the saturation by, in range `-1.0` (saturate by 100%) and `1.0` (desaturate by 100%)""" - was_hexa = cls.is_valid_hexa(color) - if not (-1.0 <= saturation_change <= 1.0): raise ValueError( f"The 'saturation_change' parameter must be in range [-1.0, 1.0] inclusive, got {saturation_change!r}" ) - hsla_color: hsla = cls.to_hsla(color) + was_hexa = cls.is_valid_hexa(color) + hsla_color = cls.to_hsla(color) h, s, l, a = ( int(hsla_color[0]), int(hsla_color[1]), int(hsla_color[2]), \ - hsla_color[3] if cls.has_alpha(hsla_color) else None + hsla_color[3] if hsla_color.has_alpha() else None ) s = int(max(0, min(100, s + saturation_change * 100))) @@ -1286,38 +1243,50 @@ def adjust_saturation(cls, color: Rgba | Hexa, saturation_change: float) -> rgba ) @classmethod - def _parse_rgba(cls, color: AnyRgba) -> rgba: + def _parse_rgba(cls, color: Rgba) -> rgba: """Internal method to parse a color to an RGBA object.""" if isinstance(color, rgba): return color + elif isinstance(color, (list, tuple)): - if len(color) == 4: - return rgba(color[0], color[1], color[2], color[3], _validate=False) - elif len(color) == 3: - return rgba(color[0], color[1], color[2], None, _validate=False) + array_color = cast(list[Any] | tuple[Any, ...], color) + if len(array_color) == 4: + return rgba(int(array_color[0]), int(array_color[1]), int(array_color[2]), float(array_color[3]), _validate=False) + elif len(array_color) == 3: + return rgba(int(array_color[0]), int(array_color[1]), int(array_color[2]), None, _validate=False) + raise ValueError(f"Could not parse RGBA color: {color!r}") + elif isinstance(color, dict): - return rgba(color["r"], color["g"], color["b"], color.get("a"), _validate=False) - elif isinstance(color, str): + dict_color = cast(dict[str, Any], color) + return rgba(int(dict_color["r"]), int(dict_color["g"]), int(dict_color["b"]), dict_color.get("a"), _validate=False) + + else: if parsed := cls.str_to_rgba(color, only_first=True): return cast(rgba, parsed) - raise ValueError(f"Could not parse RGBA color: {color!r}") + raise ValueError(f"Could not parse RGBA color: {color!r}") @classmethod - def _parse_hsla(cls, color: AnyHsla) -> hsla: + def _parse_hsla(cls, color: Hsla) -> hsla: """Internal method to parse a color to an HSLA object.""" if isinstance(color, hsla): return color + elif isinstance(color, (list, tuple)): + array_color = cast(list[Any] | tuple[Any, ...], color) if len(color) == 4: - return hsla(color[0], color[1], color[2], color[3], _validate=False) + return hsla(int(array_color[0]), int(array_color[1]), int(array_color[2]), float(array_color[3]), _validate=False) elif len(color) == 3: - return hsla(color[0], color[1], color[2], None, _validate=False) + return hsla(int(array_color[0]), int(array_color[1]), int(array_color[2]), None, _validate=False) + raise ValueError(f"Could not parse HSLA color: {color!r}") + elif isinstance(color, dict): - return hsla(color["h"], color["s"], color["l"], color.get("a"), _validate=False) - elif isinstance(color, str): + dict_color = cast(dict[str, Any], color) + return hsla(int(dict_color["h"]), int(dict_color["s"]), int(dict_color["l"]), dict_color.get("a"), _validate=False) + + else: if parsed := cls.str_to_hsla(color, only_first=True): return cast(hsla, parsed) - raise ValueError(f"Could not parse HSLA color: {color!r}") + raise ValueError(f"Could not parse HSLA color: {color!r}") @staticmethod def _linearize_srgb(c: float) -> float: diff --git a/src/xulbux/console.py b/src/xulbux/console.py index 44f898c..a7c324d 100644 --- a/src/xulbux/console.py +++ b/src/xulbux/console.py @@ -1,5 +1,5 @@ """ -This module provides the `Console`, `ProgressBar`, and `Spinner` classes +This module provides the `Console`, `ProgressBar`, and `Throbber` classes which offer methods for logging and other actions within the console. """ @@ -7,7 +7,7 @@ from .base.decorators import mypyc_attr from .base.consts import COLOR, CHARS, ANSI -from .format_codes import _PATTERNS as _FC_PATTERNS, FormatCodes +from .format_codes import _PATTERNS as _FC_PATTERNS, FormatCodes # type: ignore[private-access] from .string import String from .color import Color, hexa from .regex import LazyRegex @@ -15,6 +15,7 @@ from typing import Generator, Callable, Optional, Literal, TypeVar, TextIO, Any, overload, cast from prompt_toolkit.key_binding import KeyPressEvent, KeyBindings from prompt_toolkit.validation import ValidationError, Validator +from prompt_toolkit.document import Document from prompt_toolkit.styles import Style from prompt_toolkit.keys import Keys from contextlib import contextmanager @@ -115,7 +116,7 @@ def __len__(self): """The number of arguments stored in the `ParsedArgs` object.""" return len(vars(self)) - def __contains__(self, key): + def __contains__(self, key: str) -> bool: """Checks if an argument with the given alias exists in the `ParsedArgs` object.""" return key in vars(self) @@ -126,9 +127,9 @@ def __bool__(self) -> bool: def __getattr__(self, name: str) -> ParsedArgData: raise AttributeError(f"'{type(self).__name__}' object has no attribute {name}") - def __getitem__(self, key): + def __getitem__(self, key: str | int) -> ParsedArgData: if isinstance(key, int): - return list(self.__iter__())[key] + return list(self.values())[key] return getattr(self, key) def __iter__(self) -> Generator[tuple[str, ParsedArgData], None, None]: @@ -387,7 +388,7 @@ def log( information about formatting codes, see `format_codes` module documentation.""" has_title_bg: bool = False if title_bg_color is not None and (Color.is_valid_rgba(title_bg_color) or Color.is_valid_hexa(title_bg_color)): - title_bg_color, has_title_bg = Color.to_hexa(cast(Rgba | Hexa, title_bg_color)), True + title_bg_color, has_title_bg = Color.to_hexa(title_bg_color), True if tab_size < 0: raise ValueError("The 'tab_size' parameter must be a non-negative integer.") if title_px < 0: @@ -412,11 +413,9 @@ def log( String.split_count(line, cls.w - (title_len + len(tab) + 2 * len(mx))) \ for line in str(clean_prompt).splitlines() ) - for item in ([""] if lst == [] else (lst if isinstance(lst, list) else [lst])) + for item in ([""] if lst == [] else lst) ] - prompt = f"\n{mx}{' ' * title_len}{mx}{tab}".join( - cls._add_back_removed_parts(prompt_lst, cast(tuple[tuple[int, str], ...], removals)) - ) + prompt = f"\n{mx}{' ' * title_len}{mx}{tab}".join(cls._add_back_removed_parts(prompt_lst, removals)) if title == "": FormatCodes.print( @@ -709,7 +708,7 @@ def log_box_bordered( if not all(len(char) == 1 for char in _border_chars): raise ValueError("The '_border_chars' parameter must only contain single-character strings.") - if border_style is not None and Color.is_valid(border_style): + if Color.is_valid(border_style): border_style = Color.to_hexa(border_style) borders = { @@ -915,7 +914,7 @@ def input( kb.add(Keys.Any)(helper.handle_any) custom_style = Style.from_dict({"bottom-toolbar": "noreverse"}) - session: _pt.PromptSession = _pt.PromptSession( + session: _pt.PromptSession[str] = _pt.PromptSession( message=_pt.formatted_text.ANSI(FormatCodes.to_ansi(str(prompt), default_color=default_color)), validator=_ConsoleInputValidator( get_text=helper.get_text, @@ -954,7 +953,7 @@ def input( def _add_back_removed_parts(cls, split_string: list[str], removals: tuple[tuple[int, str], ...]) -> list[str]: """Adds back the removed parts into the split string parts at their original positions.""" cumulative_pos = [0] - for length in (len(s) for s in split_string): + for length in (len(part) for part in split_string): cumulative_pos.append(cumulative_pos[-1] + length) result, offset_adjusts = split_string.copy(), [0] * len(split_string) @@ -995,9 +994,12 @@ def _prepare_log_box( ) -> tuple[list[str], list[str], int]: """Prepares the log box content and returns it along with the max line length.""" if has_rules: - lines = [] + lines: list[str] = [] + for val in values: - val_str, result_parts, current_pos = str(val), [], 0 + result_parts: list[str] = [] + val_str, current_pos = str(val), 0 + for match in _PATTERNS.hr.finditer(val_str): start, end = match.span() should_split_before = start > 0 and val_str[start - 1] != "\n" @@ -1110,25 +1112,19 @@ def _parse_arg_config(self, alias: str, config: ArgParseConfig) -> Optional[set[ return config # SET OF FLAGS WITH SPECIFIED DEFAULT VALUE - elif isinstance(config, dict): - if not config.get("flags"): + else: + if not config["flags"]: raise ValueError( f"No flags provided under alias '{alias}'.\n" "The 'flags'-key set must contain at least one flag to search for." ) self.parsed_args[alias] = ParsedArgData( exists=False, - values=[default] if (default := config.get("default")) is not None else [], + values=[config["default"]], is_pos=False, ) return config["flags"] - else: - raise TypeError( - f"Invalid configuration type under alias '{alias}'.\n" - "Must be a set, dict, literal 'before' or literal 'after'." - ) - def find_flag_positions(self) -> None: """Find positions of first and last flags for positional argument collection.""" i = 0 @@ -1273,7 +1269,7 @@ def __init__(self, box_bg_color: str | Rgba | Hexa) -> None: self.box_bg_color = box_bg_color def __call__(self, m: _rx.Match[str]) -> str: - return f"{cast(str, m.group(0))}[bg:{self.box_bg_color}]" + return f"{m.group(0)}[bg:{self.box_bg_color}]" class _ConsoleInputHelper: @@ -1443,7 +1439,7 @@ def __init__( self.min_len = min_len self.validator = validator - def validate(self, document) -> None: + def validate(self, document: Document) -> None: text_to_validate = self.get_text() if self.mask_char else document.text if self.min_len and len(text_to_validate) < self.min_len: raise ValidationError(message="", cursor_position=len(document.text)) @@ -1546,13 +1542,13 @@ def set_bar_format( The bar format (also limited) can additionally be formatted with special formatting codes. For more detailed information about formatting codes, see the `format_codes` module documentation.""" if bar_format is not None: - if not any(_PATTERNS.bar.search(s) for s in bar_format): + if not any(_PATTERNS.bar.search(part) for part in bar_format): raise ValueError("The 'bar_format' parameter value must contain the '{bar}' or '{b}' placeholder.") self.bar_format = bar_format if limited_bar_format is not None: - if not any(_PATTERNS.bar.search(s) for s in limited_bar_format): + if not any(_PATTERNS.bar.search(part) for part in limited_bar_format): raise ValueError("The 'limited_bar_format' parameter value must contain the '{bar}' or '{b}' placeholder.") self.limited_bar_format = limited_bar_format @@ -1569,7 +1565,7 @@ def set_chars(self, chars: tuple[str, ...]) -> None: empty sections. If None, uses default Unicode block characters.""" if len(chars) < 2: raise ValueError("The 'chars' parameter must contain at least two characters (full and empty).") - elif not all(isinstance(c, str) and len(c) == 1 for c in chars): + elif not all(len(char) == 1 for char in chars): raise ValueError("All elements of 'chars' must be single-character strings.") self.chars = chars @@ -1676,10 +1672,10 @@ def _get_formatted_info_and_bar_width( percentage: float, label: Optional[str] = None, ) -> tuple[str, int]: - fmt_parts = [] + fmt_parts: list[str] = [] - for s in bar_format: - fmt_part = _PATTERNS.label.sub(label or "", s) + for part in bar_format: + fmt_part = _PATTERNS.label.sub(label or "", part) fmt_part = _PATTERNS.current.sub(_ProgressBarCurrentReplacer(current), fmt_part) fmt_part = _PATTERNS.total.sub(_ProgressBarTotalReplacer(total), fmt_part) fmt_part = _PATTERNS.percentage.sub(_ProgressBarPercentageReplacer(percentage), fmt_part) @@ -1696,7 +1692,7 @@ def _get_formatted_info_and_bar_width( def _create_bar(self, current: int, total: int, bar_width: int) -> str: progress = current / total if total > 0 else 0 - bar = [] + bar: list[str] = [] for i in range(bar_width): pos_progress = (i + 1) / bar_width @@ -1825,32 +1821,32 @@ def __call__(self, match: _rx.Match[str]) -> str: return f"{self.percentage:.{match.group(1) if match.group(1) else '1'}f}" -class Spinner: - """A console spinner for indeterminate processes with customizable appearance. +class Throbber: + """A console throbber for indeterminate processes with customizable appearance. This class intercepts stdout to allow printing while the animation is active.\n --------------------------------------------------------------------------------------------- - `label` -⠀the current label text - - `spinner_format` -⠀the format string used to render the spinner, containing placeholders: + - `throbber_format` -⠀the format string used to render the throbber, containing placeholders: * `{label}` `{l}` * `{animation}` `{a}` - `frames` -⠀a tuple of strings representing the animation frames - `interval` -⠀the time in seconds between each animation frame --------------------------------------------------------------------------------------------- - The `spinner_format` can additionally be formatted with special formatting codes. For more + The `throbber_format` can additionally be formatted with special formatting codes. For more detailed information about formatting codes, see the `format_codes` module documentation.""" def __init__( self, label: Optional[str] = None, - spinner_format: list[str] | tuple[str, ...] = ["{l}", "[b]({a}) "], + throbber_format: list[str] | tuple[str, ...] = ["{l}", "[b]({a}) "], sep: str = " ", frames: tuple[str, ...] = ("· ", "·· ", "···", " ··", " ·", " ·", " ··", "···", "·· ", "· "), interval: float = 0.2, ): - self.spinner_format: list[str] | tuple[str, ...] - """The format strings used to render the spinner (joined by `sep`).""" + self.throbber_format: list[str] | tuple[str, ...] + """The format strings used to render the throbber (joined by `sep`).""" self.sep: str - """The separator string used to join multiple spinner-format strings.""" + """The separator string used to join multiple throbber-format strings.""" self.frames: tuple[str, ...] """A tuple of strings representing the animation frames.""" self.interval: float @@ -1858,10 +1854,10 @@ def __init__( self.label: Optional[str] """The current label text.""" self.active: bool = False - """Whether the spinner is currently active (intercepting stdout) or not.""" + """Whether the throbber is currently active (intercepting stdout) or not.""" self.update_label(label) - self.set_format(spinner_format, sep) + self.set_format(throbber_format, sep) self.set_frames(frames) self.set_interval(interval) @@ -1873,23 +1869,23 @@ def __init__( self._stop_event: Optional[_threading.Event] = None self._animation_thread: Optional[_threading.Thread] = None - def set_format(self, spinner_format: list[str] | tuple[str, ...], sep: Optional[str] = None) -> None: - """Set the format string used to render the spinner.\n + def set_format(self, throbber_format: list[str] | tuple[str, ...], sep: Optional[str] = None) -> None: + """Set the format string used to render the throbber.\n --------------------------------------------------------------------------------------------- - - `spinner_format` -⠀the format strings used to render the spinner, containing placeholders: + - `throbber_format` -⠀the format strings used to render the throbber, containing placeholders: * `{label}` `{l}` * `{animation}` `{a}` - `sep` -⠀the separator string used to join multiple format strings""" - if not any(_PATTERNS.animation.search(fmt) for fmt in spinner_format): + if not any(_PATTERNS.animation.search(fmt) for fmt in throbber_format): raise ValueError( - "At least one format string in 'spinner_format' must contain the '{animation}' or '{a}' placeholder." + "At least one format string in 'throbber_format' must contain the '{animation}' or '{a}' placeholder." ) - self.spinner_format = spinner_format + self.throbber_format = throbber_format self.sep = sep or self.sep def set_frames(self, frames: tuple[str, ...]) -> None: - """Set the frames used for the spinner animation.\n + """Set the frames used for the throbber animation.\n --------------------------------------------------------------------- - `frames` -⠀a tuple of strings representing the animation frames""" if len(frames) < 2: @@ -1907,9 +1903,9 @@ def set_interval(self, interval: int | float) -> None: self.interval = interval def start(self, label: Optional[str] = None) -> None: - """Start the spinner animation and intercept stdout.\n + """Start the throbber animation and intercept stdout.\n ---------------------------------------------------------- - - `label` -⠀the label to display alongside the spinner""" + - `label` -⠀the label to display alongside the throbber""" if self.active: return @@ -1920,7 +1916,7 @@ def start(self, label: Optional[str] = None) -> None: self._animation_thread.start() def stop(self) -> None: - """Stop and hide the spinner and restore normal console output.""" + """Stop and hide the throbber and restore normal console output.""" if self.active: if self._stop_event: self._stop_event.set() @@ -1931,11 +1927,11 @@ def stop(self) -> None: self._animation_thread = None self._frame_index = 0 - self._clear_spinner_line() + self._clear_throbber_line() self._stop_intercepting() def update_label(self, label: Optional[str]) -> None: - """Update the spinner's label text.\n + """Update the throbber's label text.\n -------------------------------------- - `new_label` -⠀the new label text""" self.label = label @@ -1944,14 +1940,14 @@ def update_label(self, label: Optional[str]) -> None: def context(self, label: Optional[str] = None) -> Generator[Callable[[str], None], None, None]: """Context manager for automatic cleanup. Returns a function to update the label.\n ---------------------------------------------------------------------------------------------- - - `label` -⠀the label to display alongside the spinner + - `label` -⠀the label to display alongside the throbber ----------------------------------------------------------------------------------------------- The returned callable accepts a single parameter: - `new_label` -⠀the new label text\n #### Example usage: ```python - with Spinner().context("Starting...") as update_label: + with Throbber().context("Starting...") as update_label: time.sleep(2) update_label("Processing...") time.sleep(3) @@ -1979,10 +1975,8 @@ def _animation_loop(self) -> None: frame = FormatCodes.to_ansi(f"{self.frames[self._frame_index % len(self.frames)]}[*]") formatted = FormatCodes.to_ansi(self.sep.join( - s for s in ( \ - _PATTERNS.animation.sub(frame, _PATTERNS.label.sub(self.label or "", s)) - for s in self.spinner_format - ) if s + fmt_part for part in self.throbber_format if \ + (fmt_part := _PATTERNS.animation.sub(frame, _PATTERNS.label.sub(self.label or "", part))) )) self._current_animation_str = formatted @@ -2018,14 +2012,14 @@ def _emergency_cleanup(self) -> None: except Exception: pass - def _clear_spinner_line(self) -> None: + def _clear_throbber_line(self) -> None: if self._last_line_len > 0 and self._original_stdout: self._original_stdout.write(f"{ANSI.CHAR}[2K\r") self._original_stdout.flush() def _flush_buffer(self) -> None: if self._buffer and self._original_stdout: - self._clear_spinner_line() + self._clear_throbber_line() for content in self._buffer: self._original_stdout.write(content) self._original_stdout.flush() @@ -2041,28 +2035,28 @@ def _redraw_display(self) -> None: class _InterceptedOutput: """Custom StringIO that captures output and stores it in the progress bar buffer.""" - def __init__(self, progress_bar: ProgressBar | Spinner): - self.progress_bar = progress_bar + def __init__(self, status_indicator: ProgressBar | Throbber): + self.status_indicator = status_indicator self.string_io = StringIO() def write(self, content: str) -> int: self.string_io.write(content) try: if content and content != "\r": - self.progress_bar._buffer.append(content) + cast(ProgressBar | Throbber, self.status_indicator)._buffer.append(content) # type: ignore[protected-access] return len(content) except Exception: - self.progress_bar._emergency_cleanup() + self.status_indicator._emergency_cleanup() # type: ignore[protected-access] raise def flush(self) -> None: self.string_io.flush() try: - if self.progress_bar.active and self.progress_bar._buffer: - self.progress_bar._flush_buffer() - self.progress_bar._redraw_display() + if self.status_indicator.active and self.status_indicator._buffer: # type: ignore[protected-access] + self.status_indicator._flush_buffer() # type: ignore[protected-access] + self.status_indicator._redraw_display() # type: ignore[protected-access] except Exception: - self.progress_bar._emergency_cleanup() + self.status_indicator._emergency_cleanup() # type: ignore[protected-access] raise def __getattr__(self, name: str) -> Any: diff --git a/src/xulbux/data.py b/src/xulbux/data.py index 06d1c81..683129b 100644 --- a/src/xulbux/data.py +++ b/src/xulbux/data.py @@ -71,13 +71,20 @@ def chars_count(cls, data: DataStructure) -> int: chars_count = 0 if isinstance(data, dict): - for k, v in data.items(): - chars_count += len(str(k)) + (cls.chars_count(v) if isinstance(v, DataStructureTypes) else len(str(v))) - - elif isinstance(data, IndexIterableTypes): + for key, val in data.items(): + chars_count += len(str(key)) + ( + cls.chars_count(cast(DataStructure, val)) \ + if isinstance(val, DataStructureTypes) + else len(str(val)) + ) + else: for item in data: - chars_count += cls.chars_count(item) if isinstance(item, DataStructureTypes) else len(str(item)) - + chars_count += ( + cls.chars_count(cast(DataStructure, item)) \ + if isinstance(item, DataStructureTypes) + else len(str(item)) + ) + return chars_count @classmethod @@ -86,12 +93,18 @@ def strip(cls, data: DataStructure) -> DataStructure: ------------------------------------------------------------------------------- - `data` -⠀the data structure to strip the items from""" if isinstance(data, dict): - return {k.strip(): cls.strip(v) if isinstance(v, DataStructureTypes) else v.strip() for k, v in data.items()} - - if isinstance(data, IndexIterableTypes): - return type(data)(cls.strip(item) if isinstance(item, DataStructureTypes) else item.strip() for item in data) + return {key.strip(): ( + cls.strip(cast(DataStructure, val)) \ + if isinstance(val, DataStructureTypes) + else val.strip() + ) for key, val in data.items()} - raise TypeError(f"Unsupported data structure type: {type(data)}") + else: + return type(data)(( + cls.strip(cast(DataStructure, item)) \ + if isinstance(item, DataStructureTypes) + else item.strip() + ) for item in data) @classmethod def remove_empty_items(cls, data: DataStructure, spaces_are_empty: bool = False) -> DataStructure: @@ -101,34 +114,30 @@ def remove_empty_items(cls, data: DataStructure, spaces_are_empty: bool = False) - `spaces_are_empty` -⠀if true, it will count items with only spaces as empty""" if isinstance(data, dict): return { - k: (v if not isinstance(v, DataStructureTypes) else cls.remove_empty_items(v, spaces_are_empty)) - for k, v in data.items() if not String.is_empty(v, spaces_are_empty) + key: (val if not isinstance(val, DataStructureTypes) else cls.remove_empty_items(cast(DataStructure, val), spaces_are_empty)) + for key, val in data.items() if not String.is_empty(val, spaces_are_empty) } - if isinstance(data, IndexIterableTypes): + else: return type(data)( - item for item in - ( - (item if not isinstance(item, DataStructureTypes) else cls.remove_empty_items(item, spaces_are_empty)) \ + item for item in ( + (item if not isinstance(item, DataStructureTypes) else cls.remove_empty_items(cast(DataStructure, item), spaces_are_empty)) \ for item in data if not (isinstance(item, (str, type(None))) and String.is_empty(item, spaces_are_empty)) - ) - if item not in ([], (), {}, set(), frozenset()) + ) if item not in ([], (), {}, set(), frozenset()) ) - raise TypeError(f"Unsupported data structure type: {type(data)}") - @classmethod def remove_duplicates(cls, data: DataStructure) -> DataStructure: """Removes all duplicates from the data structure.\n ----------------------------------------------------------- - `data` -⠀the data structure to remove duplicates from""" if isinstance(data, dict): - return {k: cls.remove_duplicates(v) if isinstance(v, DataStructureTypes) else v for k, v in data.items()} + return {key: cls.remove_duplicates(cast(DataStructure, val)) if isinstance(val, DataStructureTypes) else val for key, val in data.items()} - if isinstance(data, (list, tuple)): + elif isinstance(data, (list, tuple)): result: list[Any] = [] for item in data: - processed_item = cls.remove_duplicates(item) if isinstance(item, DataStructureTypes) else item + processed_item = cls.remove_duplicates(cast(DataStructure, item)) if isinstance(item, DataStructureTypes) else item is_duplicate: bool = False for existing_item in result: @@ -141,15 +150,13 @@ def remove_duplicates(cls, data: DataStructure) -> DataStructure: return type(data)(result) - if isinstance(data, (set, frozenset)): - processed_elements = set() + else: + processed_elements: set[Any] = set() for item in data: - processed_item = cls.remove_duplicates(item) if isinstance(item, DataStructureTypes) else item + processed_item = cls.remove_duplicates(cast(DataStructure, item)) if isinstance(item, DataStructureTypes) else item processed_elements.add(processed_item) return type(data)(processed_elements) - raise TypeError(f"Unsupported data structure type: {type(data)}") - @classmethod def remove_comments( cls, @@ -309,19 +316,21 @@ def get_value_by_path_id(cls, data: DataStructure, path_id: str, get_key: bool = for i, path_idx in enumerate(path): if isinstance(current_data, dict): - keys = list(current_data.keys()) + dict_data = cast(dict[Any, Any], current_data) + keys: list[str] = list(dict_data.keys()) if i == len(path) - 1 and get_key: return keys[path_idx] - parent = current_data - current_data = current_data[keys[path_idx]] + parent = dict_data + current_data = dict_data[keys[path_idx]] elif isinstance(current_data, IndexIterableTypes): + idx_iterable_data = cast(IndexIterable, current_data) if i == len(path) - 1 and get_key: if parent is None or not isinstance(parent, dict): raise ValueError(f"Cannot get key from a non-dict parent at path '{path[:i + 1]}'") - return next(key for key, value in parent.items() if value is current_data) - parent = current_data - current_data = list(current_data)[path_idx] # CONVERT TO LIST FOR INDEXING + return next(key for key, value in parent.items() if value is idx_iterable_data) + parent = idx_iterable_data + current_data = list(idx_iterable_data)[path_idx] # CONVERT TO LIST FOR INDEXING else: raise TypeError(f"Unsupported type '{type(current_data)}' at path '{path[:i + 1]}'") @@ -474,24 +483,26 @@ def _compare_nested( return False if isinstance(data1, dict) and isinstance(data2, dict): - if set(data1.keys()) != set(data2.keys()): + dict_data1, dict_data2 = cast(dict[Any, Any], data1), cast(dict[Any, Any], data2) + if set(dict_data1.keys()) != set(dict_data2.keys()): return False return all(cls._compare_nested( \ - data1=data1[key], - data2=data2[key], + data1=dict_data1[key], + data2=dict_data2[key], ignore_paths=ignore_paths, current_path=current_path + [key], - ) for key in data1) + ) for key in dict_data1) - elif isinstance(data1, (list, tuple)): - if len(data1) != len(data2): + elif isinstance(data1, (list, tuple)) and isinstance(data2, (list, tuple)): + array_data1, array_data2 = cast(IndexIterable, data1), cast(IndexIterable, data2) + if len(array_data1) != len(array_data2): return False return all(cls._compare_nested( \ data1=item1, data2=item2, ignore_paths=ignore_paths, current_path=current_path + [str(i)], - ) for i, (item1, item2) in enumerate(zip(data1, data2))) + ) for i, (item1, item2) in enumerate(zip(array_data1, array_data2))) elif isinstance(data1, (set, frozenset)): return data1 == data2 @@ -519,23 +530,27 @@ def _set_nested_val(cls, data: DataStructure, id_path: list[int], value: Any) -> if len(id_path) == 1: if isinstance(current_data, dict): - keys, data_dict = list(current_data.keys()), dict(current_data) - data_dict[keys[id_path[0]]] = value - return data_dict + dict_data = cast(dict[Any, Any], current_data) + keys, dict_data = list(dict_data.keys()), dict(dict_data) + dict_data[keys[id_path[0]]] = value + return dict_data elif isinstance(current_data, IndexIterableTypes): - was_t, data_list = type(current_data), list(current_data) - data_list[id_path[0]] = value - return was_t(data_list) + idx_iterable_data = cast(IndexIterable, current_data) + was_t, idx_iterable_data = type(idx_iterable_data), list(idx_iterable_data) + idx_iterable_data[id_path[0]] = value + return was_t(idx_iterable_data) else: if isinstance(current_data, dict): - keys, data_dict = list(current_data.keys()), dict(current_data) - data_dict[keys[id_path[0]]] = cls._set_nested_val(data_dict[keys[id_path[0]]], id_path[1:], value) - return data_dict + dict_data = cast(dict[Any, Any], current_data) + keys, dict_data = list(dict_data.keys()), dict(dict_data) + dict_data[keys[id_path[0]]] = cls._set_nested_val(dict_data[keys[id_path[0]]], id_path[1:], value) + return dict_data elif isinstance(current_data, IndexIterableTypes): - was_t, data_list = type(current_data), list(current_data) - data_list[id_path[0]] = cls._set_nested_val(data_list[id_path[0]], id_path[1:], value) - return was_t(data_list) + idx_iterable_data = cast(IndexIterable, current_data) + was_t, idx_iterable_data = type(idx_iterable_data), list(idx_iterable_data) + idx_iterable_data[id_path[0]] = cls._set_nested_val(idx_iterable_data[id_path[0]], id_path[1:], value) + return was_t(idx_iterable_data) return current_data @@ -549,7 +564,7 @@ def __init__(self, data: DataStructure, comment_start: str, comment_end: str, co self.comment_end = comment_end self.comment_sep = comment_sep - self.pattern = _re.compile(Regex._clean( \ + self.pattern = _re.compile(Regex._clean( # type: ignore[protected-access] rf"""^( (?:(?!{_re.escape(comment_start)}).)* ) @@ -564,16 +579,18 @@ def __call__(self) -> DataStructure: def remove_nested_comments(self, item: Any) -> Any: if isinstance(item, dict): + dict_item = cast(dict[Any, Any], item) return { key: val for key, val in ( \ - (self.remove_nested_comments(k), self.remove_nested_comments(v)) for k, v in item.items() + (self.remove_nested_comments(k), self.remove_nested_comments(v)) for k, v in dict_item.items() ) if key is not None } if isinstance(item, IndexIterableTypes): - processed = (v for v in map(self.remove_nested_comments, item) if v is not None) - return type(item)(processed) + idx_iterable_item = cast(IndexIterable, item) + processed = (val for val in map(self.remove_nested_comments, idx_iterable_item) if val is not None) + return type(idx_iterable_item)(processed) if isinstance(item, str): if self.pattern: @@ -689,8 +706,8 @@ def __init__( raise TypeError(f"Expected 'syntax_highlighting' to be a dict or bool. Got: {type(syntax_highlighting)}") self.syntax_hl.update({ - k: (f"[{v}]", "[_]") if k in self.syntax_hl and v not in {"", None} else ("", "") - for k, v in syntax_highlighting.items() + key: (f"[{val}]", "[_]") if key in self.syntax_hl and val not in {"", None} else ("", "") + for key, val in syntax_highlighting.items() }) sep = f"{self.syntax_hl['punctuation'][0]}{sep}{self.syntax_hl['punctuation'][1]}" @@ -699,10 +716,10 @@ def __init__( punct_map: dict[str, str | tuple[str, str]] = {"(": ("/(", "("), **{c: c for c in "'\":)[]{}"}} self.punct: dict[str, str] = { - k: ((f"{self.syntax_hl['punctuation'][0]}{v[0]}{self.syntax_hl['punctuation'][1]}" if self.do_syntax_hl else v[1]) - if isinstance(v, (list, tuple)) else - (f"{self.syntax_hl['punctuation'][0]}{v}{self.syntax_hl['punctuation'][1]}" if self.do_syntax_hl else v)) - for k, v in punct_map.items() + key: ((f"{self.syntax_hl['punctuation'][0]}{val[0]}{self.syntax_hl['punctuation'][1]}" if self.do_syntax_hl else val[1]) + if isinstance(val, (list, tuple)) else + (f"{self.syntax_hl['punctuation'][0]}{val}{self.syntax_hl['punctuation'][1]}" if self.do_syntax_hl else val)) + for key, val in punct_map.items() } def __call__(self) -> str: @@ -713,19 +730,19 @@ def __call__(self) -> str: def format_value(self, value: Any, current_indent: Optional[int] = None) -> str: if current_indent is not None and isinstance(value, dict): - return self.format_dict(value, current_indent + self.indent) + return self.format_dict(cast(dict[Any, Any], value), current_indent + self.indent) elif current_indent is not None and hasattr(value, "__dict__"): return self.format_dict(value.__dict__, current_indent + self.indent) elif current_indent is not None and isinstance(value, IndexIterableTypes): - return self.format_sequence(value, current_indent + self.indent) + return self.format_sequence(cast(IndexIterable, value), current_indent + self.indent) elif current_indent is not None and isinstance(value, (bytes, bytearray)): obj_dict = self.cls.serialize_bytes(value) return ( self.format_dict(obj_dict, current_indent + self.indent) if self.as_json else ( - f"{self.syntax_hl['type'][0]}{(k := next(iter(obj_dict)))}{self.syntax_hl['type'][1]}" - + self.format_sequence((obj_dict[k], obj_dict["encoding"]), current_indent + self.indent) - if self.do_syntax_hl else (k := next(iter(obj_dict))) - + self.format_sequence((obj_dict[k], obj_dict["encoding"]), current_indent + self.indent) + f"{self.syntax_hl['type'][0]}{(key := next(iter(obj_dict)))}{self.syntax_hl['type'][1]}" + + self.format_sequence((obj_dict[key], obj_dict["encoding"]), current_indent + self.indent) + if self.do_syntax_hl else (key := next(iter(obj_dict))) + + self.format_sequence((obj_dict[key], obj_dict["encoding"]), current_indent + self.indent) ) ) elif isinstance(value, bool): @@ -770,20 +787,20 @@ def should_expand(self, seq: IndexIterable) -> bool: or (complex_items == 1 and len(seq) > 1) \ or self.cls.chars_count(seq) + (len(seq) * len(self.sep)) > self.max_width - def format_dict(self, d: dict, current_indent: int) -> str: - if self.compactness == 2 or not d or not self.should_expand(list(d.values())): + def format_dict(self, data_dict: dict[Any, Any], current_indent: int) -> str: + if self.compactness == 2 or not data_dict or not self.should_expand(list(data_dict.values())): return self.punct["{"] + self.sep.join( - f"{self.format_value(k)}{self.punct[':']} {self.format_value(v, current_indent)}" for k, v in d.items() + f"{self.format_value(key)}{self.punct[':']} {self.format_value(val, current_indent)}" for key, val in data_dict.items() ) + self.punct["}"] - items = [] - for k, val in d.items(): + items: list[str] = [] + for key, val in data_dict.items(): formatted_value = self.format_value(val, current_indent) - items.append(f"{' ' * (current_indent + self.indent)}{self.format_value(k)}{self.punct[':']} {formatted_value}") + items.append(f"{' ' * (current_indent + self.indent)}{self.format_value(key)}{self.punct[':']} {formatted_value}") return self.punct["{"] + "\n" + f"{self.sep}\n".join(items) + f"\n{' ' * current_indent}" + self.punct["}"] - def format_sequence(self, seq, current_indent: int) -> str: + def format_sequence(self, seq: IndexIterable, current_indent: int) -> str: if self.as_json: seq = list(seq) diff --git a/src/xulbux/file_sys.py b/src/xulbux/file_sys.py index f5ef031..1ed0edb 100644 --- a/src/xulbux/file_sys.py +++ b/src/xulbux/file_sys.py @@ -89,12 +89,8 @@ def extend_path( if search_in is not None: if isinstance(search_in, (str, Path)): search_dirs.extend([Path(search_in)]) - elif isinstance(search_in, list): - search_dirs.extend([Path(path) for path in search_in]) else: - raise TypeError( - f"The 'search_in' parameter must be a string, Path, or a list of strings/Paths, got {type(search_in)}" - ) + search_dirs.extend([Path(path) for path in search_in]) return _ExtendPathHelper( cls, diff --git a/src/xulbux/format_codes.py b/src/xulbux/format_codes.py index 62ea6c4..5564110 100644 --- a/src/xulbux/format_codes.py +++ b/src/xulbux/format_codes.py @@ -407,7 +407,7 @@ def _config_console(cls) -> None: kernel32.SetConsoleMode(h, mode.value | 0x0004) except Exception: pass - _CONSOLE_ANSI_CONFIGURED = True + _CONSOLE_ANSI_CONFIGURED = True # type: ignore[assignment] @staticmethod def _validate_default_color(default_color: Optional[Rgba | Hexa]) -> tuple[bool, Optional[rgba]]: @@ -417,14 +417,14 @@ def _validate_default_color(default_color: Optional[Rgba | Hexa]) -> tuple[bool, if Color.is_valid_hexa(default_color, False): return True, hexa(cast(str | int, default_color)).to_rgba() elif Color.is_valid_rgba(default_color, False): - return True, Color._parse_rgba(default_color) + return True, Color._parse_rgba(cast(Rgba, default_color)) # type: ignore[protected-access] raise TypeError("The 'default_color' parameter must be either a valid RGBA or HEXA color, or None.") @staticmethod def _formats_to_keys(formats: str) -> list[str]: """Internal method to convert a string of multiple format keys to a list of individual, stripped format keys.""" - return [k.strip() for k in formats.split("|") if k.strip()] + return [key.strip() for key in formats.split("|") if key.strip()] @classmethod def _get_replacement(cls, format_key: str, default_color: Optional[rgba], brightness_steps: int = 20) -> str: @@ -438,7 +438,8 @@ def _get_replacement(cls, format_key: str, default_color: Optional[rgba], bright if (isinstance(map_key, tuple) and format_key in map_key) or format_key == map_key: return _ANSI_SEQ_1.format( next(( - v for k, v in ANSI.CODES_MAP.items() if format_key == k or (isinstance(k, tuple) and format_key in k) + val for key, val in ANSI.CODES_MAP.items() \ + if format_key == key or (isinstance(key, tuple) and format_key in key) ), None) ) rgb_match = _PATTERNS.rgb.match(format_key) @@ -468,9 +469,7 @@ def _get_default_ansi( _modifiers: tuple[str, str] = (_DEFAULT_COLOR_MODS["lighten"], _DEFAULT_COLOR_MODS["darken"]), ) -> Optional[str]: """Internal method to get the `default_color` and lighter/darker versions of it as ANSI code.""" - if not isinstance(default_color, rgba): - return None - _default_color: tuple[int, int, int] = tuple(default_color)[:3] + _default_color: tuple[int, int, int] = (default_color[0], default_color[1], default_color[2]) if brightness_steps is None or (format_key and _PATTERNS.bg_opt_default.search(format_key)): return (ANSI.SEQ_BG_COLOR if format_key and _PATTERNS.bg_default.search(format_key) else ANSI.SEQ_COLOR).format( *_default_color @@ -536,7 +535,10 @@ def __call__(self, match: _rx.Match[str]) -> str: else: _formats = _PATTERNS.star_reset_inside.sub(r"\1_\2", formats) - if all((self.cls._get_replacement(k, self.default_color) != k) for k in self.cls._formats_to_keys(_formats)): + if all( + self.cls._get_replacement(format_key, self.default_color) != format_key # type: ignore[protected-access] + for format_key in self.cls._formats_to_keys(_formats) # type: ignore[protected-access] + ): # ESCAPE THE FORMATTING CODE escaped = f"[{self.escape_char}{formats}]" if auto_reset_txt: @@ -635,11 +637,12 @@ def process_formats_and_auto_reset(self) -> None: def convert_to_ansi(self) -> None: """Convert format keys to ANSI codes and generate resets if needed.""" - self.format_keys = self.cls._formats_to_keys(self.formats) - self.ansi_formats = [ - r if (r := self.cls._get_replacement(k, self.default_color, self.brightness_steps)) != k else f"[{k}]" - for k in self.format_keys - ] + self.format_keys = self.cls._formats_to_keys(self.formats) # type: ignore[protected-access] + self.ansi_formats = [( + ansi_code \ + if (ansi_code := self.cls._get_replacement(format_key, self.default_color, self.brightness_steps)) != format_key # type: ignore[protected-access] + else f"[{format_key}]" + ) for format_key in self.format_keys] # GENERATE RESET CODES IF AUTO-RESET IS ACTIVE if self.auto_reset_txt and not self.auto_reset_escaped: @@ -652,40 +655,40 @@ def gen_reset_codes(self) -> None: default_color_resets = ("_bg", "default") if self.use_default else ("_bg", "_c") reset_keys: list[str] = [] - for k in self.format_keys: - k_lower = k.lower() + for format_key in self.format_keys: + k_lower = format_key.lower() k_set = set(k_lower.split(":")) # BACKGROUND COLOR FORMAT if _PREFIX["BG"] & k_set and len(k_set) <= 3: if k_set & _PREFIX["BR"]: # BRIGHT BACKGROUND COLOR - RESET BOTH BG AND COLOR - for i in range(len(k)): - if self.is_valid_color(k[i:]): + for i in range(len(format_key)): + if self.is_valid_color(format_key[i:]): reset_keys.extend(default_color_resets) break else: # REGULAR BACKGROUND COLOR - RESET ONLY BG - for i in range(len(k)): - if self.is_valid_color(k[i:]): + for i in range(len(format_key)): + if self.is_valid_color(format_key[i:]): reset_keys.append("_bg") break # TEXT COLOR FORMAT - elif self.is_valid_color(k) or any( - k_lower.startswith(pref_colon := f"{prefix}:") and self.is_valid_color(k[len(pref_colon):]) \ + elif self.is_valid_color(format_key) or any( + k_lower.startswith(pref_colon := f"{prefix}:") and self.is_valid_color(format_key[len(pref_colon):]) \ for prefix in _PREFIX["BR"] ): reset_keys.append(default_color_resets[1]) # TEXT STYLE FORMAT else: - reset_keys.append(f"_{k}") + reset_keys.append(f"_{format_key}") # CONVERT RESET KEYS TO ANSI CODES self.ansi_resets = [ - r for k in reset_keys if ( \ - r := self.cls._get_replacement(k, self.default_color, self.brightness_steps) + ansi_code for reset_key in reset_keys if ( \ + ansi_code := self.cls._get_replacement(reset_key, self.default_color, self.brightness_steps) # type: ignore[protected-access] ).startswith(f"{ANSI.CHAR}{ANSI.START}") ] diff --git a/src/xulbux/json.py b/src/xulbux/json.py index c60bca7..bc2d55b 100644 --- a/src/xulbux/json.py +++ b/src/xulbux/json.py @@ -3,6 +3,7 @@ create and update JSON files, with support for comments inside the JSON data. """ +from .base.types import DataStructure from .file_sys import FileSys from .data import Data from .file import File @@ -23,7 +24,7 @@ def read( comment_start: str = ">>", comment_end: str = "<<", return_original: bool = False, - ) -> dict | tuple[dict, dict]: + ) -> dict[str, Any] | tuple[dict[str, Any], dict[str, Any]]: """Read JSON files, ignoring comments.\n ------------------------------------------------------------------------------------ - `json_file` -⠀the path (relative or absolute) to the JSON file to read @@ -58,7 +59,7 @@ def read( def create( cls, json_file: Path | str, - data: dict, + data: dict[str, Any], indent: int = 2, compactness: Literal[0, 1, 2] = 1, force: bool = False, @@ -141,17 +142,24 @@ def update( If you don't know that the first list item is `"apples"`, you can use the items list index inside the value-path, so `healthy->fruits->0`.\n ⇾ If the given value-path doesn't exist, it will be created.""" - processed_data, data = cls.read( - json_file=json_file, - comment_start=comment_start, - comment_end=comment_end, - return_original=True, + processed_data, data = cast( + tuple[dict[str, Any], dict[str, Any]], + cls.read( + json_file=json_file, + comment_start=comment_start, + comment_end=comment_end, + return_original=True, + ), ) update: dict[str, Any] = {} for val_path, new_val in update_values.items(): try: - if (path_id := Data.get_path_id(data=processed_data, value_paths=val_path, path_sep=path_sep)) is not None: + if (path_id := Data.get_path_id( + data=cast(DataStructure, processed_data), + value_paths=val_path, + path_sep=path_sep, + )) is not None: update[cast(str, path_id)] = new_val else: data = cls._create_nested_path(data, val_path.split(path_sep), new_val) @@ -164,7 +172,7 @@ def update( cls.create(json_file=json_file, data=dict(data), force=True) @staticmethod - def _create_nested_path(data_obj: dict, path_keys: list[str], value: Any) -> dict: + def _create_nested_path(data_obj: dict[str, Any], path_keys: list[str], value: Any) -> dict[str, Any]: """Internal method that creates nested dictionaries/lists based on the given path keys and sets the specified value at the end of the path.""" last_idx, current = len(path_keys) - 1, data_obj @@ -175,11 +183,11 @@ def _create_nested_path(data_obj: dict, path_keys: list[str], value: Any) -> dic current[key] = value elif isinstance(current, list) and key.isdigit(): idx = int(key) - while len(current) <= idx: - current.append(None) + while len(cast(list[Any], current)) <= idx: + cast(list[Any], current).append(None) current[idx] = value else: - raise TypeError(f"Cannot set key '{key}' on {type(current)}") + raise TypeError(f"Cannot set key '{key}' on {type(cast(Any, current))}") else: next_key = path_keys[i + 1] @@ -189,12 +197,12 @@ def _create_nested_path(data_obj: dict, path_keys: list[str], value: Any) -> dic current = current[key] elif isinstance(current, list) and key.isdigit(): idx = int(key) - while len(current) <= idx: - current.append(None) + while len(cast(list[Any], current)) <= idx: + cast(list[Any], current).append(None) if current[idx] is None: current[idx] = [] if next_key.isdigit() else {} - current = current[idx] + current = cast(list[Any], current)[idx] else: - raise TypeError(f"Cannot navigate through {type(current)}") + raise TypeError(f"Cannot navigate through {type(cast(Any, current))}") return data_obj diff --git a/src/xulbux/py.typed b/src/xulbux/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/src/xulbux/regex.py b/src/xulbux/regex.py index a9ef866..1daa129 100644 --- a/src/xulbux/regex.py +++ b/src/xulbux/regex.py @@ -239,7 +239,7 @@ class LazyRegex: def __init__(self, **patterns: str): self._patterns = patterns - def __getattr__(self, name: str) -> _rx.Pattern: + def __getattr__(self, name: str) -> _rx.Pattern[str]: if name in self._patterns: setattr(self, name, compiled := _rx.compile(self._patterns[name])) return compiled diff --git a/src/xulbux/system.py b/src/xulbux/system.py index 500dffd..b118942 100644 --- a/src/xulbux/system.py +++ b/src/xulbux/system.py @@ -97,8 +97,7 @@ def architecture(cls) -> str: def cpu_count(cls) -> int: """The number of CPU cores available.""" try: - count = _multiprocessing.cpu_count() - return count if count is not None else 1 + return _multiprocessing.cpu_count() except (NotImplementedError, AttributeError): return 1 @@ -148,7 +147,7 @@ def check_libs( return _SystemCheckLibsHelper(lib_names, install_missing, missing_libs_msgs, confirm_install)() @classmethod - def elevate(cls, win_title: Optional[str] = None, args: Optional[list] = None) -> bool: + def elevate(cls, win_title: Optional[str] = None, args: Optional[list[str]] = None) -> bool: """Attempts to start a new process with elevated privileges.\n --------------------------------------------------------------------------------- - `win_title` -⠀the window title of the elevated process (only on Windows) @@ -285,7 +284,7 @@ def __call__(self) -> Optional[list[str]]: def find_missing_libs(self) -> list[str]: """Find which libraries are missing.""" - missing = [] + missing: list[str] = [] for lib in self.lib_names: try: __import__(lib) diff --git a/tests/test_console.py b/tests/test_console.py index bd9c705..46b3775 100644 --- a/tests/test_console.py +++ b/tests/test_console.py @@ -1,5 +1,5 @@ from xulbux.console import ParsedArgData, ParsedArgs -from xulbux.console import Spinner, ProgressBar +from xulbux.console import Throbber, ProgressBar from xulbux.console import Console from xulbux import console @@ -1029,128 +1029,128 @@ def test_progressbar_redraw_progress_bar(): mock_stdout.flush.assert_called_once() -################################################## Spinner TESTS ################################################## +################################################## Throbber TESTS ################################################## -def test_spinner_init_defaults(): - spinner = Spinner() - assert spinner.label is None - assert spinner.interval == 0.2 - assert spinner.active is False - assert spinner.sep == " " - assert len(spinner.frames) > 0 +def test_throbber_init_defaults(): + throbber = Throbber() + assert throbber.label is None + assert throbber.interval == 0.2 + assert throbber.active is False + assert throbber.sep == " " + assert len(throbber.frames) > 0 -def test_spinner_init_custom(): - spinner = Spinner(label="Loading", interval=0.5, sep="-") - assert spinner.label == "Loading" - assert spinner.interval == 0.5 - assert spinner.sep == "-" +def test_throbber_init_custom(): + throbber = Throbber(label="Loading", interval=0.5, sep="-") + assert throbber.label == "Loading" + assert throbber.interval == 0.5 + assert throbber.sep == "-" -def test_spinner_set_format_valid(): - spinner = Spinner() - spinner.set_format(["{l}", "{a}"]) - assert spinner.spinner_format == ["{l}", "{a}"] +def test_throbber_set_format_valid(): + throbber = Throbber() + throbber.set_format(["{l}", "{a}"]) + assert throbber.throbber_format == ["{l}", "{a}"] -def test_spinner_set_format_invalid(): - spinner = Spinner() +def test_throbber_set_format_invalid(): + throbber = Throbber() with pytest.raises(ValueError): - spinner.set_format(["{l}"]) # MISSING {a} + throbber.set_format(["{l}"]) # MISSING {a} -def test_spinner_set_frames_valid(): - spinner = Spinner() - spinner.set_frames(("a", "b")) - assert spinner.frames == ("a", "b") +def test_throbber_set_frames_valid(): + throbber = Throbber() + throbber.set_frames(("a", "b")) + assert throbber.frames == ("a", "b") -def test_spinner_set_frames_invalid(): - spinner = Spinner() +def test_throbber_set_frames_invalid(): + throbber = Throbber() with pytest.raises(ValueError): - spinner.set_frames(("a", )) # LESS THAN 2 FRAMES + throbber.set_frames(("a", )) # LESS THAN 2 FRAMES -def test_spinner_set_interval_valid(): - spinner = Spinner() - spinner.set_interval(1.0) - assert spinner.interval == 1.0 +def test_throbber_set_interval_valid(): + throbber = Throbber() + throbber.set_interval(1.0) + assert throbber.interval == 1.0 -def test_spinner_set_interval_invalid(): - spinner = Spinner() +def test_throbber_set_interval_invalid(): + throbber = Throbber() with pytest.raises(ValueError): - spinner.set_interval(0) + throbber.set_interval(0) with pytest.raises(ValueError): - spinner.set_interval(-1) + throbber.set_interval(-1) @patch("xulbux.console._threading.Thread") @patch("xulbux.console._threading.Event") @patch("sys.stdout", new_callable=MagicMock) -def test_spinner_start(mock_stdout, mock_event, mock_thread): +def test_throbber_start(mock_stdout, mock_event, mock_thread): mock_thread.return_value.start.return_value = None - spinner = Spinner() - spinner.start("Test") + throbber = Throbber() + throbber.start("Test") - assert spinner.active is True - assert spinner.label == "Test" + assert throbber.active is True + assert throbber.label == "Test" mock_event.assert_called_once() mock_thread.assert_called_once() # TEST CALLING START AGAIN DOESN'T DO ANYTHING - spinner.start("Test2") + throbber.start("Test2") assert mock_event.call_count == 1 @patch("xulbux.console._threading.Thread") @patch("xulbux.console._threading.Event") -def test_spinner_stop(mock_event, mock_thread): - spinner = Spinner() +def test_throbber_stop(mock_event, mock_thread): + throbber = Throbber() # MANUALLY SET ACTIVE TO SIMULATE RUNNING - spinner.active = True + throbber.active = True mock_stop_event = MagicMock() mock_stop_event.set.return_value = None - spinner._stop_event = mock_stop_event + throbber._stop_event = mock_stop_event mock_animation_thread = MagicMock() mock_animation_thread.join.return_value = None - spinner._animation_thread = mock_animation_thread + throbber._animation_thread = mock_animation_thread - spinner.stop() + throbber.stop() - assert spinner.active is False + assert throbber.active is False mock_stop_event.set.assert_called_once() mock_animation_thread.join.assert_called_once() -def test_spinner_update_label(): - spinner = Spinner() - spinner.update_label("New Label") - assert spinner.label == "New Label" +def test_throbber_update_label(): + throbber = Throbber() + throbber.update_label("New Label") + assert throbber.label == "New Label" -def test_spinner_context_manager(): - spinner = Spinner() +def test_throbber_context_manager(): + throbber = Throbber() # TEST CONTEXT MANAGER BEHAVIOR BY CHECKING ACTUAL EFFECTS - with spinner.context("Test") as update: - assert spinner.active is True - assert spinner.label == "Test" + with throbber.context("Test") as update: + assert throbber.active is True + assert throbber.label == "Test" update("New Label") - assert spinner.label == "New Label" + assert throbber.label == "New Label" - # AFTER CONTEXT EXITS, SPINNER SHOULD BE STOPPED - assert spinner.active is False + # AFTER CONTEXT EXITS, THROBBER SHOULD BE STOPPED + assert throbber.active is False -def test_spinner_context_manager_exception(): - spinner = Spinner() +def test_throbber_context_manager_exception(): + throbber = Throbber() # TEST THAT CLEANUP HAPPENS EVEN WITH EXCEPTIONS with pytest.raises(ValueError): - with spinner.context("Test"): + with throbber.context("Test"): raise ValueError("Oops") - # AFTER EXCEPTION, SPINNER SHOULD STILL BE CLEANED UP - assert spinner.active is False + # AFTER EXCEPTION, THROBBER SHOULD STILL BE CLEANED UP + assert throbber.active is False