diff --git a/.coveragerc b/.coveragerc
new file mode 100644
index 0000000..b0d4dbf
--- /dev/null
+++ b/.coveragerc
@@ -0,0 +1,15 @@
+[run]
+omit =
+ */config.py
+ */config-*.py
+ capybara/cpuinfo.py
+ capybara/utils/system_info.py
+ capybara/vision/ipcam/app.py
+
+[report]
+omit =
+ */config.py
+ */config-*.py
+ capybara/cpuinfo.py
+ capybara/utils/system_info.py
+ capybara/vision/ipcam/app.py
diff --git a/.github/scripts/coverage_gate.py b/.github/scripts/coverage_gate.py
new file mode 100644
index 0000000..3c2d32e
--- /dev/null
+++ b/.github/scripts/coverage_gate.py
@@ -0,0 +1,173 @@
+from __future__ import annotations
+
+import argparse
+import sys
+from pathlib import Path
+from xml.etree import ElementTree
+
+
+def _as_float(value: str | None, default: float | None = None) -> float | None:
+ if value is None:
+ return default
+ text = value.strip()
+ if not text:
+ return default
+ try:
+ return float(text)
+ except ValueError:
+ return default
+
+
+def _as_bool(value: str | None) -> bool:
+ if value is None:
+ return False
+ return str(value).strip().lower() in {"1", "true", "yes", "on", "y"}
+
+
+def _percent(value: float | None) -> str:
+ if value is None or value < 0:
+ return "N/A"
+ pct = value * 100.0
+ if pct.is_integer():
+ return f"{int(pct)}%"
+ return f"{pct:.2f}%"
+
+
+def _load_coverage_root(path: Path) -> ElementTree.Element | None:
+ if not path.exists():
+ return None
+ try:
+ return ElementTree.parse(path).getroot()
+ except ElementTree.ParseError:
+ return None
+
+
+def evaluate_coverage(
+ coverage_path: Path,
+ min_line: float | None,
+ min_branch: float | None,
+) -> tuple[bool, list[str], float | None, float | None]:
+ """Return a tuple describing coverage gate result."""
+ root = _load_coverage_root(coverage_path)
+ if root is None:
+ return False, [f"coverage XML '{coverage_path}' 無法讀取."], None, None
+
+ line_rate = _as_float(root.attrib.get("line-rate"))
+ branch_rate = _as_float(root.attrib.get("branch-rate"))
+
+ messages: list[str] = []
+ passed = True
+
+ if min_line is not None:
+ if line_rate is None:
+ passed = False
+ messages.append("行覆蓋率資料不存在.")
+ elif line_rate + 1e-9 < min_line:
+ passed = False
+ messages.append(
+ f"行覆蓋率 {_percent(line_rate)} 低於門檻 {_percent(min_line)}."
+ )
+ else:
+ messages.append(
+ f"行覆蓋率 {_percent(line_rate)} >= 門檻 {_percent(min_line)}."
+ )
+
+ if min_branch is not None:
+ if branch_rate is None:
+ passed = False
+ messages.append("分支覆蓋率資料不存在.")
+ elif branch_rate + 1e-9 < min_branch:
+ passed = False
+ messages.append(
+ f"分支覆蓋率 {_percent(branch_rate)} 低於門檻 {_percent(min_branch)}."
+ )
+ else:
+ messages.append(
+ f"分支覆蓋率 {_percent(branch_rate)} >= 門檻 {_percent(min_branch)}."
+ )
+
+ if min_line is None and min_branch is None:
+ messages.append("未設定覆蓋率門檻, 跳過檢查.")
+
+ return passed, messages, line_rate, branch_rate
+
+
+def parse_args(argv: list[str]) -> argparse.Namespace:
+ parser = argparse.ArgumentParser(description="Coverage gate checker")
+ parser.add_argument(
+ "--file",
+ required=True,
+ help="Path to coverage XML report generated by coverage.py",
+ )
+ parser.add_argument(
+ "--min-line",
+ dest="min_line",
+ default=None,
+ help="Minimum line coverage required (0.0 - 1.0)",
+ )
+ parser.add_argument(
+ "--min-branch",
+ dest="min_branch",
+ default=None,
+ help="Minimum branch coverage required (0.0 - 1.0)",
+ )
+ parser.add_argument(
+ "--enforce",
+ dest="enforce",
+ default="0",
+ help="Set to 1/true to enforce the gate (exit with non-zero on failure)",
+ )
+ return parser.parse_args(argv)
+
+
+def main(argv: list[str] | None = None) -> int:
+ args = parse_args(argv or sys.argv[1:])
+
+ coverage_path = Path(args.file)
+ min_line = _as_float(args.min_line)
+ min_branch = _as_float(args.min_branch)
+ enforce_gate = _as_bool(args.enforce)
+
+ if not coverage_path.exists():
+ message = f"找不到覆蓋率報告: {coverage_path}"
+ if enforce_gate:
+ print(f"::error::{message}")
+ print("::error::Coverage Gate Result: FAIL (N/A)")
+ return 2
+ print(f"::warning::{message}")
+ print("::warning::Coverage Gate Result: SKIP (N/A)")
+ return 0
+
+ passed, messages, line_rate, branch_rate = evaluate_coverage(
+ coverage_path, min_line, min_branch
+ )
+ prefix = "::notice::" if passed else "::error::"
+ for line in messages:
+ print(f"{prefix}{line}")
+
+ summary_parts: list[str] = []
+ if line_rate is not None:
+ summary_parts.append(f"行 {_percent(line_rate)}")
+ if branch_rate is not None:
+ summary_parts.append(f"分支 {_percent(branch_rate)}")
+ summary_text = ", ".join(summary_parts) if summary_parts else "N/A"
+ summary_prefix = (
+ "::notice::"
+ if passed
+ else ("::warning::" if not enforce_gate else "::error::")
+ )
+ result_text = "PASS" if passed else "FAIL"
+ print(
+ f"{summary_prefix}Coverage Gate Result: {result_text} ({summary_text})"
+ )
+
+ if passed or not enforce_gate:
+ if not passed:
+ print("::warning::覆蓋率未達門檻, 但目前未強制執行 (enforce=0).")
+ return 0
+
+ return 1
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/.github/scripts/generate_pytest_summary.py b/.github/scripts/generate_pytest_summary.py
new file mode 100644
index 0000000..1a377da
--- /dev/null
+++ b/.github/scripts/generate_pytest_summary.py
@@ -0,0 +1,548 @@
+import os
+import re
+import textwrap
+from contextlib import suppress
+from dataclasses import dataclass
+from pathlib import Path
+from xml.etree import ElementTree
+
+MARKER = ""
+ARTIFACT_PLACEHOLDER = ""
+DEFAULT_LOG_TAIL_LINES = 200
+DEFAULT_LOG_CHAR_LIMIT = 6000
+SNIPPET_CHAR_LIMIT = 4000
+TOP_SLOW_N = 20
+
+
+@dataclass
+class CoverageSummary:
+ line_rate: float | None = None
+ branch_rate: float | None = None
+ line_pct: str | None = None
+ branch_pct: str | None = None
+ short_text: str = "覆蓋率資料不可用"
+
+
+REPORT_DIR = Path(".ci-reports/pytest")
+RAW_REPORT_DIR = Path(os.getenv("PYTEST_REPORT_DIR", ".ci-reports/pytest/raw"))
+LEGACY_REPORT_DIR = Path(".pytest-reports")
+COVERAGE_REPORT_FILE = "coverage_report.txt"
+
+
+def _first_existing(paths: list[Path]) -> Path | None:
+ for candidate in paths:
+ if candidate.exists():
+ return candidate
+ return paths[0] if paths else None
+
+
+def _as_float(value: str | float | None) -> float | None:
+ if value is None:
+ return None
+ try:
+ return float(value)
+ except (TypeError, ValueError):
+ return None
+
+
+def _threshold_value(value: str | None) -> float | None:
+ if value is None:
+ return None
+ try:
+ return float(str(value))
+ except (TypeError, ValueError):
+ return None
+
+
+def fmt_outcome(value: str) -> str:
+ mapping = {
+ "success": "✅ 通過",
+ "failure": "❌ 失敗",
+ "cancelled": "⚪ 取消",
+ "skipped": "⚪ 未執行",
+ }
+ value = (value or "").strip()
+ return mapping.get(value, f"⚪ {value or '未知'}")
+
+
+def fmt_percent(raw: str | float | int | None) -> str | None:
+ if raw is None:
+ return None
+ if isinstance(raw, int | float):
+ value = float(raw)
+ else:
+ try:
+ value = float(raw)
+ except (TypeError, ValueError):
+ return None
+ pct = value * 100.0
+ if pct.is_integer():
+ return f"{int(pct)}%"
+ return f"{pct:.2f}%"
+
+
+def cleanse_snippet(snippet: str, limit: int = SNIPPET_CHAR_LIMIT) -> str:
+ snippet = textwrap.dedent(snippet).strip()
+ if not snippet:
+ return "(無訊息)"
+ snippet = "\n".join(line.rstrip() for line in snippet.splitlines())
+ if len(snippet) > limit:
+ snippet = f"{snippet[:limit]}\n... (截斷)"
+ return snippet
+
+
+def load_log_tail(
+ path: Path,
+ *,
+ limit: int = DEFAULT_LOG_CHAR_LIMIT,
+ tail_lines: int = DEFAULT_LOG_TAIL_LINES,
+) -> str | None:
+ if not path.exists():
+ return None
+ try:
+ raw = path.read_text(encoding="utf-8", errors="ignore").strip()
+ except OSError:
+ return None
+ if not raw:
+ return None
+ lines = raw.splitlines()
+ tail = "\n".join(lines[-tail_lines:])
+ if len(tail) > limit:
+ tail = tail[-limit:]
+ return tail
+
+
+def _read_text_if_exists(path: Path) -> str | None:
+ if not path.exists():
+ return None
+ try:
+ return path.read_text(encoding="utf-8", errors="ignore")
+ except OSError:
+ return None
+
+
+def append_coverage_report(
+ lines: list[str], report_path: Path, max_lines: int = 80
+) -> None:
+ raw = _read_text_if_exists(report_path)
+ if raw is None:
+ return
+ content = [line.rstrip() for line in raw.splitlines() if line.strip()]
+ if not content:
+ return
+ lines.append("")
+ lines.append("### 覆蓋率缺口")
+ lines.append("")
+ lines.append("未覆蓋的檔案列表
")
+ lines.append("")
+ lines.append("```")
+ for line in content[:max_lines]:
+ lines.append(line)
+ if len(content) > max_lines:
+ lines.append("... (截斷)")
+ lines.append("```")
+ lines.append(" ")
+
+
+def _iter_testcases(root: ElementTree.Element) -> list[ElementTree.Element]:
+ return list(root.findall(".//testcase"))
+
+
+def _parse_root(xml_path: Path) -> ElementTree.Element | None:
+ if not xml_path.exists():
+ return None
+ try:
+ return ElementTree.parse(xml_path).getroot()
+ except ElementTree.ParseError:
+ return None
+
+
+def append_pytest_summary(lines: list[str], junit_path: Path) -> None:
+ root = _parse_root(junit_path)
+ if root is None:
+ return
+
+ attrs = root.attrib
+ tests = attrs.get("tests")
+ failures = attrs.get("failures") or attrs.get("failed")
+ errors = attrs.get("errors")
+ skipped = attrs.get("skipped") or attrs.get("disabled")
+ time_spent = attrs.get("time")
+
+ def _safe_int(x: str | int | float | None, default: int | str = 0) -> int:
+ if x is None:
+ return int(default)
+ try:
+ return int(float(x))
+ except (TypeError, ValueError):
+ try:
+ return int(x)
+ except (TypeError, ValueError):
+ return int(default)
+
+ if tests is None:
+ tests = 0
+ failures = 0
+ errors = 0
+ skipped = 0
+ total_time = 0.0
+ for suite in root.findall(".//testsuite"):
+ a = suite.attrib
+ tests += _safe_int(a.get("tests"), 0)
+ failures += _safe_int(a.get("failures") or a.get("failed"), 0)
+ errors += _safe_int(a.get("errors"), 0)
+ skipped += _safe_int(a.get("skipped") or a.get("disabled"), 0)
+ with suppress(Exception):
+ total_time += float(a.get("time") or 0.0)
+ time_spent = f"{total_time:.3f}"
+ else:
+ tests = _safe_int(tests, 0)
+ failures = _safe_int(failures, 0)
+ errors = _safe_int(errors, 0)
+ skipped = _safe_int(skipped, 0)
+
+ passed = max(tests - failures - errors - skipped, 0)
+
+ lines.append(
+ f"- 測試統計: 共 {tests}; ✅ 通過 {passed}; ❌ 失敗 {failures}; ⚠️ 錯誤 {errors}; ⏭️ 跳過 {skipped}; 耗時 {time_spent or 'N/A'} 秒"
+ )
+
+ failures_block: list[str] = []
+ for case in _iter_testcases(root):
+ failure = case.find("failure")
+ error = case.find("error")
+ target = failure or error
+ if target is None:
+ continue
+ status = "failure" if failure is not None else "error"
+ classname = case.attrib.get("classname") or ""
+ name = case.attrib.get("name") or ""
+ elapsed = case.attrib.get("time") or ""
+ header = f"- `{classname}.{name}` ({status}"
+ if elapsed:
+ header += f", {elapsed} 秒"
+ header += ")"
+ snippet = target.attrib.get("message", "") or ""
+ text_body = target.text or ""
+ combined = "\n".join(part for part in (snippet, text_body) if part)
+ failures_block.append(header)
+ formatted = cleanse_snippet(combined)
+ indented = "\n".join(f" {line}" for line in formatted.splitlines())
+ failures_block.append(" ```")
+ failures_block.append(indented or " (無訊息)")
+ failures_block.append(" ```")
+
+ if failures_block:
+ lines.append("")
+ lines.append("#### 失敗與錯誤測試")
+ lines.append("")
+ lines.extend(failures_block)
+
+
+def append_top_slowest(
+ lines: list[str], junit_path: Path, top_n: int = TOP_SLOW_N
+) -> None:
+ root = _parse_root(junit_path)
+ if root is None:
+ return
+ records: list[tuple[float, str]] = []
+ for case in _iter_testcases(root):
+ try:
+ t = float(case.attrib.get("time") or 0.0)
+ except Exception:
+ t = 0.0
+ if t <= 0.0:
+ continue
+ classname = case.attrib.get("classname") or ""
+ name = case.attrib.get("name") or ""
+ fqname = f"{classname}.{name}" if classname else name
+ records.append((t, fqname))
+ if not records:
+ return
+ records.sort(key=lambda x: x[0], reverse=True)
+ lines.append("")
+ lines.append(f"#### 最慢測試 Top {min(top_n, len(records))}")
+ lines.append("")
+ lines.append("| 測試 | 耗時 (秒) |")
+ lines.append("| --- | ---: |")
+ for t, fqname in records[:top_n]:
+ lines.append(f"| `{fqname}` | {t:.3f} |")
+
+
+def append_warnings_summary(
+ lines: list[str], candidate_logs: list[Path]
+) -> None:
+ text: str | None = None
+ for p in candidate_logs:
+ text = _read_text_if_exists(p)
+ if text:
+ break
+ if not text:
+ return
+
+ start = None
+ end = None
+ all_lines = text.splitlines()
+ for i, line in enumerate(all_lines):
+ if re.search(r"=+\s*warnings summary\s*=+", line, re.IGNORECASE):
+ start = i
+ break
+ if start is None:
+ return
+ for j in range(start + 1, len(all_lines)):
+ if re.match(r"=+\s", all_lines[j]):
+ end = j
+ break
+ block = (
+ "\n".join(all_lines[start:end]).strip()
+ if end is not None
+ else "\n".join(all_lines[start:]).strip()
+ )
+ block = cleanse_snippet(block, limit=SNIPPET_CHAR_LIMIT)
+
+ if block:
+ lines.append("")
+ lines.append("#### Warnings 摘要")
+ lines.append("")
+ lines.append("```")
+ lines.append(block)
+ lines.append("```")
+
+
+def append_session_meta(lines: list[str], candidate_logs: list[Path]) -> None:
+ text: str | None = None
+ for p in candidate_logs:
+ text = _read_text_if_exists(p)
+ if text:
+ break
+ if not text:
+ return
+ platform_line = re.search(r"^platform .+$", text, re.MULTILINE)
+ rootdir_line = re.search(r"^rootdir: .+$", text, re.MULTILINE)
+ plugins_line = re.search(r"^plugins: .+$", text, re.MULTILINE)
+
+ entries = []
+ if platform_line:
+ entries.append(f"- {platform_line.group(0)}")
+ if rootdir_line:
+ entries.append(f"- {rootdir_line.group(0)}")
+ if plugins_line:
+ entries.append(f"- {plugins_line.group(0)}")
+
+ if entries:
+ lines.append("")
+ lines.append("### 測試環境")
+ lines.append("")
+ lines.extend(entries)
+
+
+def append_coverage_summary(
+ lines: list[str], coverage_path: Path
+) -> CoverageSummary:
+ summary = CoverageSummary()
+
+ lines.append("")
+ lines.append("### 覆蓋率")
+ lines.append("")
+
+ root = _parse_root(coverage_path)
+ if root is None:
+ lines.append("- 未產生覆蓋率報告 (可能因測試失敗或覆蓋率未啟用)。")
+ summary.short_text = "覆蓋率資料不可用"
+ return summary
+
+ line_rate = _as_float(root.attrib.get("line-rate"))
+ branch_rate = _as_float(root.attrib.get("branch-rate"))
+ overall_line = fmt_percent(line_rate)
+ overall_branch = fmt_percent(branch_rate)
+ summary.line_rate = line_rate
+ summary.branch_rate = branch_rate
+ summary.line_pct = overall_line
+ summary.branch_pct = overall_branch
+
+ lines_valid = root.attrib.get("lines-valid")
+ lines_covered = root.attrib.get("lines-covered")
+ branches_valid = root.attrib.get("branches-valid")
+ branches_covered = root.attrib.get("branches-covered")
+ overall_parts: list[str] = []
+ if overall_line:
+ detail = overall_line
+ if lines_covered and lines_valid:
+ detail += f" ({lines_covered}/{lines_valid})"
+ overall_parts.append(f"行 {detail}")
+ if overall_branch:
+ detail = overall_branch
+ if branches_covered and branches_valid:
+ detail += f" ({branches_covered}/{branches_valid})"
+ overall_parts.append(f"分支 {detail}")
+ if overall_parts:
+ lines.append(f"- 總覆蓋率: {', '.join(overall_parts)}")
+ else:
+ lines.append("- 總覆蓋率資料不可用。")
+
+ records: list[tuple[float, str, str, str]] = []
+ for cls in root.findall(".//class"):
+ filename = cls.attrib.get("filename") or cls.attrib.get("name")
+ if not filename:
+ continue
+ line_rate = cls.attrib.get("line-rate")
+ branch_rate = cls.attrib.get("branch-rate")
+ try:
+ line_value = float(line_rate) if line_rate is not None else 1.0
+ except ValueError:
+ line_value = 1.0
+ records.append(
+ (
+ line_value,
+ filename,
+ fmt_percent(line_rate) or "N/A",
+ fmt_percent(branch_rate) or "N/A",
+ )
+ )
+
+ if records:
+ records.sort(key=lambda item: item[0])
+ lines.append("")
+ lines.append("#### 覆蓋率最低的檔案 (前 10 名)")
+ lines.append("")
+ lines.append("| 檔案 | 行覆蓋 | 分支覆蓋 |")
+ lines.append("| --- | --- | --- |")
+ for _, filename, line_pct, branch_pct in records[:10]:
+ lines.append(f"| `{filename}` | {line_pct} | {branch_pct} |")
+
+ summary_parts: list[str] = []
+ if overall_line:
+ summary_parts.append(f"行 {overall_line}")
+ if overall_branch:
+ summary_parts.append(f"分支 {overall_branch}")
+ summary.short_text = (
+ " / ".join(summary_parts) if summary_parts else "覆蓋率資料不可用"
+ )
+
+ return summary
+
+
+def build_coverage_kpi(
+ info: CoverageSummary,
+ min_line: float | None,
+ min_branch: float | None,
+) -> str:
+ line_pct = info.line_pct
+ branch_pct = info.branch_pct
+ line_rate = info.line_rate
+ branch_rate = info.branch_rate
+
+ actual_parts: list[str] = []
+ if line_pct:
+ actual_parts.append(f"行 {line_pct}")
+ if branch_pct:
+ actual_parts.append(f"分支 {branch_pct}")
+ actual_text = (
+ " / ".join(actual_parts) if actual_parts else "覆蓋率資料不可用"
+ )
+
+ threshold_parts: list[str] = []
+ if min_line is not None:
+ min_line_pct = fmt_percent(min_line) or f"{min_line * 100:.2f}%"
+ threshold_parts.append(f"行 {min_line_pct}")
+ if min_branch is not None:
+ min_branch_pct = fmt_percent(min_branch) or f"{min_branch * 100:.2f}%"
+ threshold_parts.append(f"分支 {min_branch_pct}")
+ threshold_text = ""
+ if threshold_parts:
+ threshold_text = f"(門檻 {' / '.join(threshold_parts)})"
+
+ thresholds_provided = bool(threshold_parts)
+ status_ok = True
+ if min_line is not None and (
+ line_rate is None or line_rate + 1e-9 < min_line
+ ):
+ status_ok = False
+ if min_branch is not None and (
+ branch_rate is None or branch_rate + 1e-9 < min_branch
+ ):
+ status_ok = False
+
+ if info.short_text == "覆蓋率資料不可用":
+ status_icon = "⚪"
+ elif thresholds_provided:
+ status_icon = "✅" if status_ok else "❌"
+ else:
+ status_icon = "✅"
+
+ return f"{actual_text}{threshold_text} → {status_icon}"
+
+
+def main() -> None:
+ report_dir = REPORT_DIR
+ report_dir.mkdir(parents=True, exist_ok=True)
+ summary_path = report_dir / "summary.md"
+
+ outcome = os.getenv("PYTEST_OUTCOME", "")
+ exit_code = os.getenv("PYTEST_EXIT_CODE", "")
+
+ lines: list[str] = [MARKER, "### Pytest 結果"]
+ status_line = f"- 狀態: {fmt_outcome(outcome)}"
+ if exit_code:
+ status_line += f" (exit={exit_code})"
+ lines.append(status_line)
+
+ junit_candidates = [
+ RAW_REPORT_DIR / "pytest.xml",
+ report_dir / "pytest.xml",
+ LEGACY_REPORT_DIR / "pytest.xml",
+ ]
+ coverage_candidates = [
+ RAW_REPORT_DIR / "coverage.xml",
+ report_dir / "coverage.xml",
+ LEGACY_REPORT_DIR / "coverage.xml",
+ ]
+ junit_path = _first_existing(junit_candidates) or junit_candidates[0]
+ coverage_xml = (
+ _first_existing(coverage_candidates) or coverage_candidates[0]
+ )
+ log_candidates = [
+ report_dir / "pytest.log",
+ report_dir / "run.log",
+ RAW_REPORT_DIR / "pytest.log",
+ LEGACY_REPORT_DIR / "pytest.log",
+ ]
+
+ append_session_meta(lines, log_candidates)
+ append_pytest_summary(lines, junit_path)
+ append_top_slowest(lines, junit_path, top_n=TOP_SLOW_N)
+ append_warnings_summary(lines, log_candidates)
+ coverage_info = append_coverage_summary(lines, coverage_xml)
+ append_coverage_report(lines, report_dir / COVERAGE_REPORT_FILE)
+
+ min_line = _threshold_value(os.getenv("COVERAGE_MIN_LINE"))
+ min_branch = _threshold_value(os.getenv("COVERAGE_MIN_BRANCH"))
+ coverage_kpi = build_coverage_kpi(coverage_info, min_line, min_branch)
+
+ lines.append("")
+ lines.append(f"- 產物: {ARTIFACT_PLACEHOLDER}")
+
+ if outcome.strip() == "failure":
+ tail = None
+ for p in log_candidates:
+ tail = load_log_tail(p)
+ if tail:
+ break
+ if tail:
+ lines.append("")
+ lines.append("")
+ lines.append("Pytest 輸出 (最後 200 行)
")
+ lines.append("")
+ lines.append("```")
+ lines.append(tail)
+ lines.append("```")
+ lines.append(" ")
+
+ summary_path.write_text("\n".join(lines) + "\n", encoding="utf-8")
+
+ coverage_txt_path = report_dir / "coverage.txt"
+ coverage_txt_path.write_text(f"{coverage_kpi}\n", encoding="utf-8")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
new file mode 100644
index 0000000..131b08c
--- /dev/null
+++ b/.github/workflows/ci.yml
@@ -0,0 +1,428 @@
+name: Capybara CI
+
+on:
+ pull_request:
+ branches: [main]
+ types: [opened, synchronize, reopened, ready_for_review]
+ push:
+ branches: [main]
+ workflow_dispatch:
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
+ cancel-in-progress: true
+
+jobs:
+ qa:
+ name: QA (ruff + pyright + pytest)
+ runs-on: [self-hosted, unicorn]
+ timeout-minutes: 45
+ permissions:
+ actions: read
+ contents: read
+ packages: read
+ pull-requests: write
+ checks: write
+ env:
+ PYTHON_VERSION: "3.10"
+ REPORT_DIR: .ci-reports
+ PYTEST_REPORT_DIR: .ci-reports/pytest/raw
+ RUFF_REPORT_DIR: .ci-reports/ruff
+ PYRIGHT_REPORT_DIR: .ci-reports/pyright
+ COVERAGE_MIN_LINE: "0.99"
+ COVERAGE_MIN_BRANCH: "0.00"
+ COVERAGE_ENFORCE: "1"
+ ARTIFACT_RETENTION_DAYS: 14
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v4
+
+ - name: Set up Python
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ env.PYTHON_VERSION }}
+ cache-dependency-path: |
+ pyproject.toml
+ **/requirements*.txt
+
+ - name: Install dependencies
+ run: |
+ PYTORCH_INDEX_URL="https://download.pytorch.org/whl/cpu"
+ python -m pip install --upgrade pip wheel
+ python -m pip install torch --index-url "$PYTORCH_INDEX_URL"
+ python -m pip install -e .
+ python -m pip install pytest pytest-cov pytest-html coverage ruff pyright onnx onnxslim
+
+ - name: Prepare report directories
+ run: |
+ mkdir -p "$REPORT_DIR"
+ mkdir -p "$PYTEST_REPORT_DIR"
+ mkdir -p "$RUFF_REPORT_DIR"
+ mkdir -p "$PYRIGHT_REPORT_DIR"
+
+ - name: Ruff - Lint
+ id: ruff_lint
+ continue-on-error: true
+ run: |
+ set -o pipefail
+ set +e
+ python -m ruff version
+ python -m ruff check capybara tests | tee "$RUFF_REPORT_DIR/lint.log"
+ status=${PIPESTATUS[0]}
+ set -e
+ echo "exit_code=$status" >> "$GITHUB_OUTPUT"
+ exit $status
+
+ - name: Ruff - Format Check
+ id: ruff_format
+ continue-on-error: true
+ run: |
+ set -o pipefail
+ set +e
+ python -m ruff format --check capybara tests | tee "$RUFF_REPORT_DIR/format.log"
+ status=${PIPESTATUS[0]}
+ set -e
+ echo "exit_code=$status" >> "$GITHUB_OUTPUT"
+ exit $status
+
+ - name: Pyright
+ id: pyright
+ continue-on-error: true
+ run: |
+ set -o pipefail
+ set +e
+ python -m pyright --version
+ python -m pyright | tee "$PYRIGHT_REPORT_DIR/pyright.log"
+ status=${PIPESTATUS[0]}
+ set -e
+ echo "exit_code=$status" >> "$GITHUB_OUTPUT"
+ exit $status
+
+ - name: Pytest (XML + HTML + Coverage)
+ id: pytest
+ continue-on-error: true
+ env:
+ PYTEST_ADDOPTS: "-vv -ra"
+ run: |
+ set -o pipefail
+ raw_dir="$PYTEST_REPORT_DIR"
+ final_dir="$REPORT_DIR/pytest"
+ mkdir -p "$final_dir" "$raw_dir"
+ set +e
+ python -m pytest \
+ --junitxml="$raw_dir/pytest.xml" \
+ --cov=capybara \
+ --cov-report=term-missing:skip-covered \
+ --cov-report=xml:"$raw_dir/coverage.xml" \
+ --cov-report=html:"$raw_dir/htmlcov" \
+ --html="$raw_dir/pytest.html" \
+ --self-contained-html \
+ --durations=25 \
+ tests 2>&1 | tee "$raw_dir/pytest.log"
+ status=${PIPESTATUS[0]}
+ set -e
+ if [[ -f .coverage ]]; then
+ python -m coverage report -m --skip-empty --skip-covered > "$raw_dir/coverage_report.txt" || true
+ else
+ : > "$raw_dir/coverage_report.txt"
+ fi
+ for artifact in pytest.log pytest.xml coverage.xml pytest.html coverage_report.txt; do
+ if [[ -f "$raw_dir/$artifact" ]]; then
+ cp "$raw_dir/$artifact" "$final_dir/$artifact"
+ fi
+ done
+ if [[ -f "$raw_dir/pytest.log" ]]; then
+ cp "$raw_dir/pytest.log" "$final_dir/run.log"
+ fi
+ if [[ -d "$raw_dir/htmlcov" ]]; then
+ rm -rf "$final_dir/htmlcov"
+ cp -R "$raw_dir/htmlcov" "$final_dir/htmlcov"
+ fi
+ echo "exit_code=$status" >> "$GITHUB_OUTPUT"
+ exit $status
+
+ - name: Generate Pytest summary
+ if: always()
+ env:
+ PYTEST_OUTCOME: ${{ steps.pytest.outcome }}
+ PYTEST_EXIT_CODE: ${{ steps.pytest.outputs.exit_code }}
+ run: |
+ python .github/scripts/generate_pytest_summary.py
+ summary_path="$REPORT_DIR/pytest/summary.md"
+ if [[ ! -s "$summary_path" ]]; then
+ echo "::error::Pytest summary was not generated"
+ exit 1
+ fi
+ {
+ echo "PYTEST_SUMMARY_BODY<> "$GITHUB_ENV"
+
+ - name: Coverage Gate (optional)
+ id: coverage_gate
+ if: always() && hashFiles('.ci-reports/pytest/raw/coverage.xml') != ''
+ continue-on-error: true
+ run: |
+ set -o pipefail
+ set +e
+ python .github/scripts/coverage_gate.py \
+ --file "${{ github.workspace }}/${{ env.PYTEST_REPORT_DIR }}/coverage.xml" \
+ --min-line "$COVERAGE_MIN_LINE" \
+ --min-branch "$COVERAGE_MIN_BRANCH" \
+ --enforce "$COVERAGE_ENFORCE" | tee "$REPORT_DIR/pytest/coverage_gate.log"
+ status=${PIPESTATUS[0]}
+ set -e
+ echo "exit_code=$status" >> "$GITHUB_OUTPUT"
+ exit $status
+
+ - name: Upload CI report artifact
+ if: always()
+ uses: actions/upload-artifact@v4
+ with:
+ name: ci-report-${{ github.run_id }}
+ if-no-files-found: warn
+ retention-days: ${{ env.ARTIFACT_RETENTION_DAYS }}
+ compression-level: 6
+ path: .ci-reports/
+
+ - name: Publish CI summary (artifact + comment)
+ if: always()
+ uses: actions/github-script@v7
+ with:
+ script: |
+ const fs = require('fs');
+ const path = require('path');
+ const marker = '';
+ const baseDir = path.join(process.cwd(), '.ci-reports');
+ const pytestDir = path.join(baseDir, 'pytest');
+ const ruffDir = path.join(baseDir, 'ruff');
+ const artifactName = `ci-report-${context.runId}`;
+ fs.mkdirSync(baseDir, { recursive: true });
+
+ const fmtOutcome = (value) => {
+ const mapping = {
+ success: '✅ 通過',
+ failure: '❌ 失敗',
+ cancelled: '⚪ 取消',
+ skipped: '⚪ 未執行',
+ };
+ return mapping[value] || `⚪ ${value || '未知'}`;
+ };
+
+ const readIfExists = (file) => {
+ try {
+ return fs.readFileSync(file, 'utf8').trim();
+ } catch {
+ return null;
+ }
+ };
+
+ const readTail = (file) => {
+ try {
+ const raw = fs.readFileSync(file, 'utf8').trim();
+ if (!raw) return null;
+ const lines = raw.split(/\r?\n/);
+ const tail = lines.slice(-200).join('\n');
+ return tail.length > 6000 ? tail.slice(-6000) : tail;
+ } catch {
+ return null;
+ }
+ };
+
+ const artifactPlaceholder = '';
+ const coverageSummary = readIfExists(path.join(pytestDir, 'coverage.txt')) || 'N/A';
+ const pytestSummaryRaw = readIfExists(path.join(pytestDir, 'summary.md'));
+ const pytestSummary = (() => {
+ if (!pytestSummaryRaw) return null;
+ const marker = '';
+ if (pytestSummaryRaw.includes(marker)) {
+ return pytestSummaryRaw.split(marker).pop().trim();
+ }
+ return pytestSummaryRaw;
+ })();
+ const lintTail = readTail(path.join(ruffDir, 'lint.log'));
+ const formatTail = readTail(path.join(ruffDir, 'format.log'));
+
+ const ruffLintOutcome = '${{ steps.ruff_lint.outcome }}';
+ const ruffFormatOutcome = '${{ steps.ruff_format.outcome }}';
+ const pytestOutcome = '${{ steps.pytest.outcome }}';
+ const pytestExitCode = '${{ steps.pytest.outputs.exit_code || '' }}';
+ const coverageGateOutcome = '${{ steps.coverage_gate.outcome || '' }}';
+ const coverageGateExit = '${{ steps.coverage_gate.outputs.exit_code || '' }}';
+ const coverageGateEnforce = '${{ env.COVERAGE_ENFORCE }}';
+
+ const coverageGateStatus = (() => {
+ if (!coverageGateOutcome) {
+ return '⚪ 未執行';
+ }
+ const suffix = coverageGateExit && coverageGateExit !== '0' && coverageGateEnforce === '0'
+ ? '(未強制)'
+ : '';
+ return `${fmtOutcome(coverageGateOutcome)}${suffix}`;
+ })();
+
+ const pytestStatus = (() => {
+ let status = fmtOutcome(pytestOutcome);
+ if (pytestExitCode) {
+ status += ` (exit=${pytestExitCode})`;
+ }
+ return status;
+ })();
+
+ const kpiTable = [
+ '| 項目 | 狀態 |',
+ '| --- | --- |',
+ `| Ruff Lint | ${fmtOutcome(ruffLintOutcome)} |`,
+ `| Ruff Format | ${fmtOutcome(ruffFormatOutcome)} |`,
+ `| Pytest | ${pytestStatus} |`,
+ `| Coverage | ${coverageSummary} |`,
+ `| Coverage Gate | ${coverageGateStatus} |`,
+ ].join('\n');
+
+ let artifactLink = null;
+ try {
+ const { owner, repo } = context.repo;
+ const { data } = await github.rest.actions.listWorkflowRunArtifacts({
+ owner,
+ repo,
+ run_id: context.runId,
+ per_page: 100,
+ });
+ const target = (data.artifacts || []).find((item) => item.name === artifactName);
+ if (target) {
+ artifactLink = `[${target.name}](${target.archive_download_url})`;
+ }
+ } catch (error) {
+ core.warning(`無法取得 artifacts: ${error.message}`);
+ }
+
+ const quickLinks = [];
+ const runSummaryUrl = `${process.env.GITHUB_SERVER_URL}/${process.env.GITHUB_REPOSITORY}/actions/runs/${context.runId}`;
+ quickLinks.push(`- [Workflow 執行](${runSummaryUrl})`);
+ if (artifactLink) {
+ quickLinks.push(`- 測試 HTML:${artifactLink}(內含 \`pytest/pytest.html\`)`);
+ quickLinks.push(`- JUnit XML:${artifactLink}(內含 \`pytest/pytest.xml\`)`);
+ quickLinks.push(`- Coverage HTML:${artifactLink}(內含 \`pytest/htmlcov/index.html\`)`);
+ } else {
+ quickLinks.push('- 測試 HTML:不可用');
+ quickLinks.push('- JUnit XML:不可用');
+ quickLinks.push('- Coverage HTML:不可用');
+ }
+
+ const lines = [
+ marker,
+ '## CI 概覽',
+ '',
+ kpiTable,
+ '',
+ '### 快速連結',
+ '',
+ ...quickLinks,
+ '',
+ ];
+
+ if (pytestSummary) {
+ const artifactText = artifactLink || '不可用';
+ const processedSummary = pytestSummary.replace(new RegExp(artifactPlaceholder, 'g'), artifactText);
+ lines.push(processedSummary);
+ lines.push('');
+ } else {
+ lines.push('_(Pytest 摘要不可用)_', '');
+ }
+
+ if (ruffLintOutcome === 'failure' && lintTail) {
+ lines.push('### Ruff Lint(最後 200 行)', '', '展開
', '', '```');
+ lines.push(lintTail);
+ lines.push('```', ' ', '');
+ }
+
+ if (ruffFormatOutcome === 'failure' && formatTail) {
+ lines.push('### Ruff Format(最後 200 行)', '', '展開
', '', '```');
+ lines.push(formatTail);
+ lines.push('```', ' ', '');
+ }
+
+ const body = lines.join('\n');
+ const indexPath = path.join(baseDir, 'index.md');
+ fs.writeFileSync(indexPath, `${body}\n`, { encoding: 'utf8' });
+
+ if (context.eventName !== 'pull_request') {
+ return;
+ }
+
+ const { owner, repo } = context.repo;
+ const issue_number = context.payload.pull_request.number;
+ const { data: comments } = await github.rest.issues.listComments({
+ owner,
+ repo,
+ issue_number,
+ per_page: 100,
+ });
+
+ const existing = comments.find((comment) => comment.body && comment.body.includes(marker));
+ if (existing) {
+ await github.rest.issues.updateComment({
+ owner,
+ repo,
+ comment_id: existing.id,
+ body,
+ });
+ } else {
+ await github.rest.issues.createComment({
+ owner,
+ repo,
+ issue_number,
+ body,
+ });
+ }
+
+ - name: Append CI summary to Job Summary
+ if: always()
+ run: |
+ if [[ -f "$REPORT_DIR/index.md" ]]; then
+ cat "$REPORT_DIR/index.md" >> "$GITHUB_STEP_SUMMARY"
+ else
+ echo "## CI 概覽" >> "$GITHUB_STEP_SUMMARY"
+ echo "" >> "$GITHUB_STEP_SUMMARY"
+ echo "_(summary 未產生)_" >> "$GITHUB_STEP_SUMMARY"
+ fi
+
+ - name: Finalize CI status
+ if: always()
+ run: |
+ failed=0
+ ruff_lint_exit="${{ steps.ruff_lint.outputs.exit_code || '0' }}"
+ ruff_format_exit="${{ steps.ruff_format.outputs.exit_code || '0' }}"
+ pyright_exit="${{ steps.pyright.outputs.exit_code || '0' }}"
+ pytest_exit="${{ steps.pytest.outputs.exit_code || '0' }}"
+ coverage_gate_exit="${{ steps.coverage_gate.outputs.exit_code || '0' }}"
+
+ if [[ "$ruff_lint_exit" -ne 0 ]]; then
+ echo "::error::Ruff Lint failed"
+ failed=1
+ fi
+ if [[ "$ruff_format_exit" -ne 0 ]]; then
+ echo "::error::Ruff Format failed"
+ failed=1
+ fi
+ if [[ "$pyright_exit" -ne 0 ]]; then
+ echo "::error::Pyright failed"
+ failed=1
+ fi
+ if [[ "$pytest_exit" -ne 0 ]]; then
+ echo "::error::Pytest failed"
+ failed=1
+ fi
+ if [[ "$coverage_gate_exit" -ne 0 ]]; then
+ if [[ "$COVERAGE_ENFORCE" != "0" ]]; then
+ echo "::error::Coverage gate failed"
+ failed=1
+ else
+ echo "::warning::Coverage gate reported failure but enforcement is disabled (COVERAGE_ENFORCE=$COVERAGE_ENFORCE)."
+ fi
+ fi
+
+ if [[ $failed -ne 0 ]]; then
+ exit 1
+ fi
diff --git a/.github/workflows/cpu-ci.yml b/.github/workflows/deprecated/cpu-ci.yml
similarity index 100%
rename from .github/workflows/cpu-ci.yml
rename to .github/workflows/deprecated/cpu-ci.yml
diff --git a/.github/workflows/cuda-ci.yml b/.github/workflows/deprecated/cuda-ci.yml
similarity index 100%
rename from .github/workflows/cuda-ci.yml
rename to .github/workflows/deprecated/cuda-ci.yml
diff --git a/.gitignore b/.gitignore
index 79f632a..b651d68 100644
--- a/.gitignore
+++ b/.gitignore
@@ -166,3 +166,8 @@ temp_image.jpg
.DS_Store
.python-version
tmp.jpg
+tmp*
+.venv_min
+.vscode
+.ruff_cache
+.benchmarks
\ No newline at end of file
diff --git a/README.md b/README.md
index ceb36d1..fc7ef05 100644
--- a/README.md
+++ b/README.md
@@ -1,170 +1,472 @@
-[**English**](./README.md) | [中文](./README_tw.md)
+**[English](./README.md)** | [Chinese](./README_tw.md)
# Capybara
-**An Integrated Python Package for Image Processing and Deep Learning.**
-
-
-
+
+
-## Introduction
-

-This project is an image processing and deep learning toolkit, mainly consisting of the following parts:
+---
-- **Vision**: Provides functionalities related to computer vision, such as image and video processing.
-- **Structures**: Modules for handling structured data, such as BoundingBox and Polygon.
-- **ONNXEngine**: Provides ONNX inference functionalities, supporting ONNX format models.
-- **Utils**: Contains utility functions that do not belong to other modules.
-- **Tests**: Includes test code for various functions to verify their correctness.
+## Introduction
-## Technical Documentation
+Capybara is designed with three goals:
-For more detailed information on installation and usage, please refer to the [**Capybara Documents**](https://docsaid.org/en/docs/capybara).
+1. **Lightweight default install**: `pip install capybara-docsaid` installs only the core `utils/structures/vision` modules, without forcing heavy inference dependencies.
+2. **Inference backends as opt-in extras**: install ONNX Runtime / OpenVINO / TorchScript only when you need them via extras.
+3. **Lower risk**: enforce quality gates with ruff/pyright/pytest and target **90%** line coverage for the core codebase.
-The document provides a detailed explanation of this project and answers to frequently asked questions.
+What you get:
-## Prerequisites
+- **Image tools** (`capybara.vision`): I/O, color conversion, resize/rotate/pad/crop, and video frame extraction.
+- **Geometry structures** (`capybara.structures`): `Box/Boxes`, `Polygon/Polygons`, `Keypoints`, plus helper functions like IoU.
+- **Inference wrappers (optional)**: `capybara.onnxengine` / `capybara.openvinoengine` / `capybara.torchengine`.
+- **Feature extras (optional)**: `visualization` (drawing tools), `ipcam` (simple web demo), `system` (system info tools).
+- **Utilities** (`capybara.utils`): `PowerDict`, `Timer`, `make_batch`, `download_from_google`, and other common helpers.
-Before the installation of Capybara, ensure that your system meets the following requirements:
+## Quick Start
-### Python Version
+### Install and verify
-3.10+
+```bash
+pip install capybara-docsaid
+python -c "import capybara; print(capybara.__version__)"
+```
-### Dependency Packages
+## Documentation
-Please install the necessary system packages according to your operating system:
+To learn more about installation and usage, see [**Capybara Documents**](https://docsaid.org/docs/capybara).
-#### Ubuntu
+The documentation includes detailed guides and common FAQs for this project.
+
+## Installation
+
+### Core install (lightweight)
```bash
-sudo apt install libturbojpeg exiftool ffmpeg libheif-dev poppler-utils
+pip install capybara-docsaid
```
-##### GPU Dependencies
+### Enable inference backends (optional)
+
+```bash
+# ONNX Runtime (CPU)
+pip install "capybara-docsaid[onnxruntime]"
+
+# ONNX Runtime (GPU)
+pip install "capybara-docsaid[onnxruntime-gpu]"
+
+# OpenVINO runtime
+pip install "capybara-docsaid[openvino]"
-To use ONNX Runtime with GPU acceleration, ensure that you install a compatible version, which can be found on the official ONNX Runtime CUDA Execution Provider requirements page.
+# TorchScript runtime
+pip install "capybara-docsaid[torchscript]"
+
+# Install everything
+pip install "capybara-docsaid[all]"
+```
-Here's an example to install cuda-12.8:
+### Feature extras (optional)
```bash
-wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb
-sudo dpkg -i cuda-keyring_1.1-1_all.deb
-sudo apt-get update
-sudo apt-get -y install cuda-toolkit-12-8
-# Post installation, add cuda path to .bashrc or .zshrc
-export shellrc="~/.zshrc"
-echo 'export PATH=/usr/local/cuda-12.8/bin${PATH:+:${PATH}}' >> $shellrc
-echo 'export LD_LIBRARY_PATH=/usr/local/cuda-12.8/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}' >> $shellrc
+# Visualization (matplotlib/pillow)
+pip install "capybara-docsaid[visualization]"
+
+# IPCam app (flask)
+pip install "capybara-docsaid[ipcam]"
+
+# System info (psutil)
+pip install "capybara-docsaid[system]"
```
-For more details, please see [Nvidia CUDA](https://developer.nvidia.com/cuda-toolkit).
+### Combine multiple extras
-#### MacOS
+If you want OpenVINO inference and the IPCam features, install:
```bash
-brew install jpeg-turbo exiftool ffmpeg libheif poppler
+# OpenVINO + IPCam
+pip install "capybara-docsaid[openvino,ipcam]"
```
-## Installation
+### Install from Git
+
+```bash
+pip install git+https://github.com/DocsaidLab/Capybara.git
+```
+
+## System Dependencies (Install as needed)
-### PyPI
+Some features require OS-level codecs / image I/O / PDF tools (install as needed):
+
+- `PyTurboJPEG` (faster JPEG I/O): requires the TurboJPEG library.
+- `pillow-heif` (HEIC/HEIF support): requires libheif.
+- `pdf2image` (PDF to images): requires Poppler.
+- Video frame extraction: installing `ffmpeg` is recommended (more stable OpenCV video decoding).
+
+### Ubuntu
```bash
-pip install capybara-docsaid
+sudo apt install ffmpeg libturbojpeg libheif-dev poppler-utils
```
-### Git
+### macOS
```bash
-pip install git+https://github.com/DocsaidLab/Capybara.git
+brew install jpeg-turbo ffmpeg libheif poppler
+```
+
+### GPU Notes (ONNX Runtime CUDA)
+
+If you're using `onnxruntime-gpu`, install the compatible CUDA/cuDNN version for your ORT version:
+
+- See [**the ONNX Runtime documentation**](https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#requirements)
+
+## Usage
+
+### Image data conventions
+
+- Capybara images are represented as `numpy.ndarray`. By default, they follow OpenCV conventions: **BGR**, and shape is typically `(H, W, 3)`.
+- If you prefer working in RGB, use `imread(..., color_base="RGB")` or convert with `imcvtcolor(img, "BGR2RGB")`.
+
+### Image I/O
+
+```python
+from capybara import imread, imwrite
+
+img = imread("your_image.jpg")
+if img is None:
+ raise RuntimeError("Failed to read image.")
+
+imwrite(img, "out.jpg")
+```
+
+Notes:
+
+- `imread` returns `None` when it fails to decode an image (if the path doesn't exist, it raises `FileExistsError`).
+- `imread` also supports `.heic` (requires `pillow-heif` + OS-level libheif).
+
+### Resize / pad
+
+With `imresize`, you can pass `None` in `size` to keep the aspect ratio and have the other dimension inferred automatically.
+
+```python
+import numpy as np
+from capybara import BORDER, imresize, pad
+
+img = np.zeros((480, 640, 3), dtype=np.uint8)
+img = imresize(img, (320, None)) # (height, width)
+img = pad(img, pad_size=(8, 8), pad_mode=BORDER.REPLICATE)
+```
+
+### Color conversion
+
+```python
+import numpy as np
+from capybara import imcvtcolor
+
+img = np.zeros((240, 320, 3), dtype=np.uint8) # BGR
+gray = imcvtcolor(img, "BGR2GRAY") # grayscale
+rgb = imcvtcolor(img, "BGR2RGB") # RGB
+```
+
+### Rotation / perspective correction
+
+```python
+import numpy as np
+from capybara import Polygon, imrotate, imwarp_quadrangle
+
+img = np.zeros((240, 320, 3), dtype=np.uint8)
+rot = imrotate(img, angle=15, expand=True) # Angle definition matches OpenCV: positive values rotate counterclockwise
+
+poly = Polygon([[10, 10], [200, 20], [190, 120], [20, 110]])
+patch = imwarp_quadrangle(img, poly) # 4-point perspective warp
+```
+
+### Cropping (Box / Boxes)
+
+```python
+import numpy as np
+from capybara import Box, Boxes, imcropbox, imcropboxes
+
+img = np.zeros((240, 320, 3), dtype=np.uint8)
+crop1 = imcropbox(img, Box([10, 20, 110, 120]), use_pad=True)
+crop_list = imcropboxes(
+ img,
+ Boxes([[0, 0, 10, 10], [100, 100, 400, 300]]),
+ use_pad=True,
+)
+```
+
+### Binarization + morphology
+
+Morphology operators live in `capybara.vision.morphology` (not in the top-level `capybara` namespace).
+
+```python
+import numpy as np
+from capybara import imbinarize
+from capybara.vision.morphology import imopen
+
+img = np.zeros((240, 320, 3), dtype=np.uint8)
+mask = imbinarize(img) # OTSU + binary
+mask = imopen(mask, ksize=3) # Opening to remove small noise
+```
+
+### Boxes / IoU
+
+```python
+import numpy as np
+from capybara import Box, Boxes, pairwise_iou
+
+boxes_a = Boxes([[10, 10, 20, 20], [30, 30, 60, 60]])
+boxes_b = Boxes(np.array([[12, 12, 18, 18]], dtype=np.float32))
+print(pairwise_iou(boxes_a, boxes_b))
+
+box = Box([0.1, 0.2, 0.9, 0.8], is_normalized=True).convert("XYWH")
+print(box.numpy())
+```
+
+### Polygons / IoU
+
+```python
+from capybara import Polygon, polygon_iou
+
+p1 = Polygon([[0, 0], [10, 0], [10, 10], [0, 10]])
+p2 = Polygon([[5, 5], [15, 5], [15, 15], [5, 15]])
+print(polygon_iou(p1, p2))
+```
+
+### Base64 (image / ndarray)
+
+```python
+import numpy as np
+from capybara import img_to_b64str, npy_to_b64str
+from capybara.vision.improc import b64str_to_img, b64str_to_npy
+
+img = np.zeros((32, 32, 3), dtype=np.uint8)
+b64_img = img_to_b64str(img) # JPEG bytes -> base64 string
+if b64_img is None:
+ raise RuntimeError("Failed to encode image into base64.")
+img2 = b64str_to_img(b64_img) # base64 string -> numpy image
+
+vec = np.arange(8, dtype=np.float32)
+b64_vec = npy_to_b64str(vec)
+vec2 = b64str_to_npy(b64_vec, dtype="float32")
+```
+
+### PDF to images
+
+```python
+from capybara.vision.improc import pdf2imgs
+
+pages = pdf2imgs("file.pdf") # list[np.ndarray], each page is BGR image
+if pages is None:
+ raise RuntimeError("Failed to decode PDF.")
+print(len(pages))
+```
+
+### Visualization (optional)
+
+Install first: `pip install "capybara-docsaid[visualization]"`.
+
+```python
+import numpy as np
+from capybara import Box
+from capybara.vision.visualization.draw import draw_box
+
+img = np.zeros((240, 320, 3), dtype=np.uint8)
+img = draw_box(img, Box([10, 20, 100, 120]))
+```
+
+### IPCam (optional)
+
+`IpcamCapture` itself does not depend on Flask; you only need the `ipcam` extra to use `WebDemo`.
+
+```python
+from capybara.vision.ipcam.camera import IpcamCapture
+
+cap = IpcamCapture(url=0, color_base="BGR") # or provide an RTSP/HTTP URL
+frame = next(cap)
+```
+
+Web demo (install first: `pip install "capybara-docsaid[ipcam]"`):
+
+```python
+from capybara.vision.ipcam.app import WebDemo
+
+WebDemo("rtsp://").run(port=5001)
+```
+
+### System info (optional)
+
+Install first: `pip install "capybara-docsaid[system]"`.
+
+```python
+from capybara.utils.system_info import get_system_info
+
+print(get_system_info())
```
-## Docker for Deployment
+### Video frame extraction
-We provide a Docker script for convenient deployment, ensuring a consistent environment. Below are the steps to build the image with Capybara installed.
+```python
+from capybara import video2frames_v2
-1. Clone this repository:
+frames = video2frames_v2("demo.mp4", frame_per_sec=2, max_size=1280)
+print(len(frames))
+```
- ```bash
- git clone https://github.com/DocsaidLab/Capybara.git
- ```
+## Inference Backends
-2. Enter the project directory and run the build script:
+Inference backends are optional; install the corresponding extras before importing the relevant engine modules.
- ```bash
- cd Capybara
- bash docker/build.bash
- ```
+### Runtime / backend matrix
- This will build an image using the [**Dockerfile**](docker/Dockerfile) in the project. The image is based on `nvidia/cuda:12.8.1-cudnn-runtime-ubuntu24.04` by default, providing the CUDA environment required for ONNXRuntime inference.
+Note: TorchScript runtime is named `Runtime.pt` in code (corresponding extra: `torchscript`).
-3. After the build is complete, mount the working directory and run the program:
+| Runtime (`capybara.runtime.Runtime`) | Backend name | Provider / device |
+| ------------------------------------ | --------------- | ----------------------------------------------------------------------------------------------------------- |
+| `onnx` | `cpu` | `["CPUExecutionProvider"]` |
+| `onnx` | `cuda` | `["CUDAExecutionProvider"(device_id), "CPUExecutionProvider"]` |
+| `onnx` | `tensorrt` | `["TensorrtExecutionProvider"(device_id), "CUDAExecutionProvider"(device_id), "CPUExecutionProvider"]` |
+| `onnx` | `tensorrt_rtx` | `["NvTensorRTRTXExecutionProvider"(device_id), "CUDAExecutionProvider"(device_id), "CPUExecutionProvider"]` |
+| `openvino` | `cpu` | `device="CPU"` |
+| `openvino` | `gpu` | `device="GPU"` |
+| `openvino` | `npu` | `device="NPU"` |
+| `pt` | `cpu` | `torch.device("cpu")` |
+| `pt` | `cuda` | `torch.device("cuda")` |
- ```bash
- docker run --gpus all -it --rm capybara_docsaid:latest bash
- ```
+### Runtime registry (auto backend selection)
-**PS: If you want to compile cuda or cudnn for developing, please change the base image to `nvidia/cuda:12.8.1-cudnn-devel-ubuntu24.04`.**
+```python
+from capybara.runtime import Runtime
-### gosu Permissions Issues
+print(Runtime.onnx.auto_backend_name()) # Priority: cuda -> tensorrt_rtx -> tensorrt -> cpu
+print(Runtime.openvino.auto_backend_name()) # Priority: gpu -> npu -> cpu
+print(Runtime.pt.auto_backend_name()) # Priority: cuda -> cpu
+```
-If you encounter issues with file ownership as root when running scripts inside the container, causing permission problems, you can use `gosu` to switch users in the Dockerfile. Specify `USER_ID` and `GROUP_ID` when starting the container to avoid frequent permission adjustments in collaborative development.
+### ONNX Runtime (`capybara.onnxengine`)
-For details, refer to the technical documentation: [**Integrating gosu Configuration**](https://docsaid.org/en/docs/capybara/advance/#integrating-gosu-configuration)
+```python
+import numpy as np
+from capybara.onnxengine import EngineConfig, ONNXEngine
-1. Install `gosu`:
+engine = ONNXEngine(
+ "model.onnx",
+ backend="cpu",
+ config=EngineConfig(enable_io_binding=False),
+)
+outputs = engine.run({"input": np.ones((1, 3, 224, 224), dtype=np.float32)})
+print(outputs.keys())
+print(engine.summary())
+```
+
+### OpenVINO (`capybara.openvinoengine`)
- ```dockerfile
- RUN apt-get update && apt-get install -y gosu
- ```
+```python
+import numpy as np
+from capybara.openvinoengine import OpenVINOConfig, OpenVINODevice, OpenVINOEngine
+
+engine = OpenVINOEngine(
+ "model.xml",
+ device=OpenVINODevice.cpu,
+ config=OpenVINOConfig(num_requests=2),
+)
+outputs = engine.run({"input": np.ones((1, 3), dtype=np.float32)})
+print(outputs.keys())
+```
-2. Use `gosu` in the container start command to switch to a non-root user for file read/write operations.
+### TorchScript (`capybara.torchengine`)
+
+```python
+import numpy as np
+from capybara.torchengine import TorchEngine
+
+engine = TorchEngine("model.pt", device="cpu")
+outputs = engine.run({"image": np.zeros((1, 3, 224, 224), dtype=np.float32)})
+print(outputs.keys())
+```
+
+### Benchmark (depends on hardware)
+
+All engines provide `benchmark(...)` for quick throughput/latency measurements.
+
+```python
+import numpy as np
+from capybara.onnxengine import ONNXEngine
+
+engine = ONNXEngine("model.onnx", backend="cpu")
+dummy = np.zeros((1, 3, 224, 224), dtype=np.float32)
+print(engine.benchmark({"input": dummy}, repeat=50, warmup=5))
+```
- ```dockerfile
- # Create the entrypoint script
- RUN printf '#!/bin/bash\n\
- if [ ! -z "$USER_ID" ] && [ ! -z "$GROUP_ID" ]; then\n\
- groupadd -g "$GROUP_ID" -o usergroup\n\
- useradd --shell /bin/bash -u "$USER_ID" -g "$GROUP_ID" -o -c "" -m user\n\
- export HOME=/home/user\n\
- chown -R "$USER_ID":"$GROUP_ID" /home/user\n\
- chown -R "$USER_ID":"$GROUP_ID" /code\n\
- fi\n\
- \n\
- # Check for parameters\n\
- if [ $# -gt 0 ]; then\n\
- exec gosu ${USER_ID:-0}:${GROUP_ID:-0} python "$@"\n\
- else\n\
- exec gosu ${USER_ID:-0}:${GROUP_ID:-0} bash\n\
- fi' > "$ENTRYPOINT_SCRIPT"
+### Advanced: Custom options (optional)
- RUN chmod +x "$ENTRYPOINT_SCRIPT"
+`EngineConfig` / `OpenVINOConfig` / `TorchEngineConfig` are passed through to the underlying runtime as-is.
- ENTRYPOINT ["/bin/bash", "/entrypoint.sh"]
- ```
+```python
+from capybara.onnxengine import EngineConfig, ONNXEngine
-For more advanced configuration, refer to [**NVIDIA Container Toolkit**](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) and the official [**docker**](https://docs.docker.com/) documentation.
+engine = ONNXEngine(
+ "model.onnx",
+ backend="cuda",
+ config=EngineConfig(
+ provider_options={
+ "CUDAExecutionProvider": {
+ "enable_cuda_graph": True,
+ },
+ },
+ ),
+)
+```
+
+## Quality Gates (Contributors)
+
+Before merging, this project requires:
+
+```bash
+ruff check .
+ruff format --check .
+pyright
+python -m pytest --cov=capybara --cov-config=.coveragerc --cov-report=term
+```
+
+Notes:
+
+- Coverage gate is **90% line coverage** (rules defined in `.coveragerc`).
+- Heavy / environment-dependent modules are excluded from the default coverage gate to keep CI reproducible and maintainable.
+
+## Docker (optional)
+
+```bash
+git clone https://github.com/DocsaidLab/Capybara.git
+cd Capybara
+bash docker/build.bash
+```
+
+Run:
+
+```bash
+docker run --rm -it capybara_docsaid bash
+```
-## Testing
+If you need GPU access inside the container, use the NVIDIA container runtime (e.g. `--gpus all`).
-This project uses `pytest` for unit testing, and users can run the tests themselves to verify the correctness of the functionalities. To install and run the tests, use the following commands:
+## Testing (local)
```bash
-pip install pytest
-python -m pytest -vv tests
+python -m pytest -vv
```
-Once completed, you can check if all modules are functioning properly. If any issues arise, first check the environment settings and package versions.
+## License
-If the problem persists, please report it in the Issue section.
+Apache-2.0, see `LICENSE`.
## Citation
@@ -174,7 +476,7 @@ If the problem persists, please report it in the Issue section.
title = {Capybara: An Integrated Python Package for Image Processing and Deep Learning.},
year = {2025},
publisher = {GitHub},
- howpublished = {\url{https://github.com/DocsaidLab/Capybara}},
+ howpublished = {\\url{https://github.com/DocsaidLab/Capybara}},
note = {* equal contribution}
}
```
diff --git a/README_tw.md b/README_tw.md
index c5cf4cd..9e56622 100644
--- a/README_tw.md
+++ b/README_tw.md
@@ -6,21 +6,38 @@
-
-
+
+
+
+
+---
+
## 介紹
-
+Capybara 的設計目標聚焦三個方向:
-本專案是一個影像處理與深度學習的工具箱,主要包括以下幾個部分:
+1. **預設安裝輕量化**:`pip install capybara-docsaid` 僅安裝核心 utils/structures/vision,不強迫安裝重型推論依賴。
+2. **推論後端改為 opt-in extras**:需要 ONNX Runtime / OpenVINO / TorchScript 時再用 extras 安裝。
+3. **降低風險**:導入 ruff/pyright/pytest 品質門檻,並以核心程式碼 **90%** 行覆蓋率為維護目標。
-- **Vision**:提供與電腦視覺相關的功能,例如圖像和影片處理。
-- **Structures**:用於處理結構化數據的模組,例如 BoundingBox 和 Polygon。
-- **ONNXEngine**:提供 ONNX 推理功能,支援 ONNX 格式的模型。
-- **Utils**:放置無法歸類到其他模組的工具函式。
-- **Tests**:包含各類功能的測試程式碼,用於驗證函式的正確性。
+你會得到:
+
+- **影像工具**(`capybara.vision`):讀寫、轉色、縮放/旋轉/補邊/裁切,以及影片抽幀工具。
+- **幾何結構**(`capybara.structures`):`Box/Boxes`、`Polygon/Polygons`、`Keypoints`,以及 IoU 等輔助函數。
+- **推論封裝(可選)**:`capybara.onnxengine` / `capybara.openvinoengine` / `capybara.torchengine`。
+- **功能 extras(可選)**:`visualization`(繪圖工具)、`ipcam`(簡易 Web demo)、`system`(系統資訊工具)。
+- **小工具**(`capybara.utils`):`PowerDict`、`Timer`、`make_batch`、`download_from_google` 等常用 helper。
+
+## 快速開始
+
+### 安裝與驗證
+
+```bash
+pip install capybara-docsaid
+python -c "import capybara; print(capybara.__version__)"
+```
## 技術文件
@@ -30,185 +47,436 @@
## 安裝
-在開始安裝 Capybara 之前,請先確保系統符合以下需求:
+### 核心安裝(輕量)
-### Python 版本
+```bash
+pip install capybara-docsaid
+```
-- 需要 Python 3.10 或以上版本。
+### 啟用推論後端(可選)
-### 依賴套件
+```bash
+# ONNXRuntime(CPU)
+pip install "capybara-docsaid[onnxruntime]"
-請依照作業系統,安裝下列必要的系統套件:
+# ONNXRuntime(GPU)
+pip install "capybara-docsaid[onnxruntime-gpu]"
-- **Ubuntu**
+# OpenVINO runtime
+pip install "capybara-docsaid[openvino]"
- ```bash
- sudo apt install libturbojpeg exiftool ffmpeg libheif-dev
- ```
+# TorchScript runtime
+pip install "capybara-docsaid[torchscript]"
-- **MacOS**
+# 全部一起裝
+pip install "capybara-docsaid[all]"
+```
- ```bash
- brew install jpeg-turbo exiftool ffmpeg
- ```
+### 選用功能 extras(可選)
- - **特別注意**:經過測試,在 macOS 上使用 libheif 時,存在一些已知問題,主要包括:
+```bash
+# 視覺化(matplotlib/pillow)
+pip install "capybara-docsaid[visualization]"
- 1. **生成的 HEIC 檔案無法打開**:在 macOS 上,libheif 生成的 HEIC 檔案可能無法被某些程式打開。這可能與圖像尺寸有關,特別是當圖像的寬度或高度為奇數時,可能會導致相容性問題。
+# IPCam app(flask)
+pip install "capybara-docsaid[ipcam]"
- 2. **編譯錯誤**:在 macOS 上編譯 libheif 時,可能會遇到與 ffmpeg 解碼器相關的未定義符號錯誤。這可能是由於編譯選項或相依性設定不正確所致。
+# 系統資訊(psutil)
+pip install "capybara-docsaid[system]"
+```
- 3. **範例程式無法執行**:在 macOS Sonoma 上,libheif 的範例程式可能無法正常運行,出現動態鏈接錯誤,提示找不到 `libheif.1.dylib`,這可能與動態庫的路徑設定有關。
+### 挑選多個功能
- 由於問題不少,因此我們目前只在 Ubuntu 才會運行 libheif,至於 macOS 的部分則留給未來的版本。
+假設你想使用 openvino 推論,並搭配 ipcam 相關的功能,可以這樣安裝:
-### pdf2image 依賴套件
+```bash
+# 選用 OpenVINO 和 IPCam
+pip install "capybara-docsaid[openvino,ipcam]"
+```
-pdf2image 是用於將 PDF 文件轉換成影像的 Python 模組,請確保系統已安裝下列工具:
+### 從 Git 安裝
-- MacOS:需要安裝 poppler
+```bash
+pip install git+https://github.com/DocsaidLab/Capybara.git
+```
+
+## 系統相依套件(依功能需求安裝)
- ```bash
- brew install poppler
- ```
+有些功能需要 OS 層級的 codec / image IO / PDF 工具(依功能需求安裝):
-- Linux:大多數發行版已內建 `pdftoppm` 與 `pdftocairo`。如未安裝,請執行:
+- `PyTurboJPEG`(JPEG 讀寫加速):需要 TurboJPEG library。
+- `pillow-heif`(HEIC/HEIF 支援):需要 libheif。
+- `pdf2image`(PDF 轉圖):需要 Poppler。
+- 影片抽幀:建議安裝 `ffmpeg`(讓 OpenCV 影片讀取更穩定)。
- ```bash
- sudo apt install poppler-utils
- ```
+### Ubuntu
-### ONNXRuntime GPU 依賴
+```bash
+sudo apt install ffmpeg libturbojpeg libheif-dev poppler-utils
+```
-若需使用 ONNXRuntime 進行 GPU 加速推理,請確保已安裝相容版本的 CUDA,如下示範:
+### macOS
```bash
-sudo apt install cuda-12-4
-# 假設要加入至 .bashrc
-echo 'export PATH=/usr/local/cuda-12.4/bin${PATH:+:${PATH}}' >> ~/.bashrc
-echo 'export LD_LIBRARY_PATH=/usr/local/cuda-12.4/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}' >> ~/.bashrc
+brew install jpeg-turbo ffmpeg libheif poppler
+```
+
+### GPU 注意事項(ONNXRuntime CUDA)
+
+若使用 `onnxruntime-gpu`,請依 ORT 的版本安裝相容的 CUDA/cuDNN:
+
+- 請參考 [**onnxruntime 官方網站**](https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#requirements)
+
+## 使用方式
+
+### 影像資料格式約定
+
+- Capybara 的影像以 `numpy.ndarray` 表示,預設遵循 OpenCV 慣例:**BGR**、shape 通常為 `(H, W, 3)`。
+- 若你希望以 RGB 工作,可用 `imread(..., color_base="RGB")` 或 `imcvtcolor(img, "BGR2RGB")` 轉換。
+
+### 影像 I/O
+
+```python
+from capybara import imread, imwrite
+
+img = imread("your_image.jpg")
+if img is None:
+ raise RuntimeError("Failed to read image.")
+
+imwrite(img, "out.jpg")
+```
+
+補充:
+
+- `imread` 讀不到圖時會回傳 `None`(路徑不存在則會直接丟 `FileExistsError`)。
+- `imread` 也支援 `.heic`(需 `pillow-heif` + OS 層級 libheif)。
+
+### Resize / pad
+
+`imresize` 支援在 `size` 中用 `None` 表示「維持長寬比自動推算另一邊」。
+
+```python
+import numpy as np
+from capybara import BORDER, imresize, pad
+
+img = np.zeros((480, 640, 3), dtype=np.uint8)
+img = imresize(img, (320, None)) # (height, width)
+img = pad(img, pad_size=(8, 8), pad_mode=BORDER.REPLICATE)
+```
+
+### 轉色(Color Conversion)
+
+```python
+import numpy as np
+from capybara import imcvtcolor
+
+img = np.zeros((240, 320, 3), dtype=np.uint8) # BGR
+gray = imcvtcolor(img, "BGR2GRAY") # grayscale
+rgb = imcvtcolor(img, "BGR2RGB") # RGB
+```
+
+### 旋轉 / 透視校正
+
+```python
+import numpy as np
+from capybara import Polygon, imrotate, imwarp_quadrangle
+
+img = np.zeros((240, 320, 3), dtype=np.uint8)
+rot = imrotate(img, angle=15, expand=True) # 角度定義與 OpenCV 相同:正值為逆時針
+
+poly = Polygon([[10, 10], [200, 20], [190, 120], [20, 110]])
+patch = imwarp_quadrangle(img, poly) # 4 點透視校正
+```
+
+### 裁切(Box / Boxes)
+
+```python
+import numpy as np
+from capybara import Box, Boxes, imcropbox, imcropboxes
+
+img = np.zeros((240, 320, 3), dtype=np.uint8)
+crop1 = imcropbox(img, Box([10, 20, 110, 120]), use_pad=True)
+crop_list = imcropboxes(
+ img,
+ Boxes([[0, 0, 10, 10], [100, 100, 400, 300]]),
+ use_pad=True,
+)
+```
+
+### 二值化 + 形態學(Morphology)
+
+形態學操作位於 `capybara.vision.morphology`(不在頂層 `capybara` namespace)。
+
+```python
+import numpy as np
+from capybara import imbinarize
+from capybara.vision.morphology import imopen
+
+img = np.zeros((240, 320, 3), dtype=np.uint8)
+mask = imbinarize(img) # OTSU + binary
+mask = imopen(mask, ksize=3) # 開運算去除雜點
```
-### 透過 PyPI 安裝
+### Boxes / IoU
-1. 透過 PyPI 安裝套件:
+```python
+import numpy as np
+from capybara import Box, Boxes, pairwise_iou
- ```bash
- pip install capybara-docsaid
- ```
+boxes_a = Boxes([[10, 10, 20, 20], [30, 30, 60, 60]])
+boxes_b = Boxes(np.array([[12, 12, 18, 18]], dtype=np.float32))
+print(pairwise_iou(boxes_a, boxes_b))
-2. 驗證安裝:
+box = Box([0.1, 0.2, 0.9, 0.8], is_normalized=True).convert("XYWH")
+print(box.numpy())
+```
- ```bash
- python -c "import capybara; print(capybara.__version__)"
- ```
+### Polygons(多邊形)/ IoU
-3. 若顯示版本號,則安裝成功。
+```python
+from capybara import Polygon, polygon_iou
-### 透過 git clone 安裝
+p1 = Polygon([[0, 0], [10, 0], [10, 10], [0, 10]])
+p2 = Polygon([[5, 5], [15, 5], [15, 15], [5, 15]])
+print(polygon_iou(p1, p2))
+```
-1. 下載本專案:
+### Base64(影像 / ndarray)
- ```bash
- git clone https://github.com/DocsaidLab/Capybara.git
- ```
+```python
+import numpy as np
+from capybara import img_to_b64str, npy_to_b64str
+from capybara.vision.improc import b64str_to_img, b64str_to_npy
-2. 安裝 wheel 套件:
+img = np.zeros((32, 32, 3), dtype=np.uint8)
+b64_img = img_to_b64str(img) # JPEG bytes -> base64 string
+if b64_img is None:
+ raise RuntimeError("Failed to encode image into base64.")
+img2 = b64str_to_img(b64_img) # base64 string -> numpy image
- ```bash
- pip install wheel
- ```
+vec = np.arange(8, dtype=np.float32)
+b64_vec = npy_to_b64str(vec)
+vec2 = b64str_to_npy(b64_vec, dtype="float32")
+```
-3. 建構 wheel 檔案:
+### PDF 轉影像
- ```bash
- cd Capybara
- python setup.py bdist_wheel
- ```
+```python
+from capybara.vision.improc import pdf2imgs
-4. 安裝建置完成的 wheel 檔:
+pages = pdf2imgs("file.pdf") # list[np.ndarray], each page is BGR image
+if pages is None:
+ raise RuntimeError("Failed to decode PDF.")
+print(len(pages))
+```
- ```bash
- pip install dist/capybara_docsaid-*-py3-none-any.whl
- ```
+### 視覺化(可選)
-### 透過 docker 安裝(建議)
+需要先安裝:`pip install "capybara-docsaid[visualization]"`。
-若想在部署或協同開發時避免環境衝突,建議使用 Docker,以下為簡要示範流程:
+```python
+import numpy as np
+from capybara import Box
+from capybara.vision.visualization.draw import draw_box
-1. 下載本專案:
+img = np.zeros((240, 320, 3), dtype=np.uint8)
+img = draw_box(img, Box([10, 20, 100, 120]))
+```
- ```bash
- git clone https://github.com/DocsaidLab/Capybara.git
- ```
+### IPCam(可選)
-2. 進入專案資料夾,執行建置腳本:
+`IpcamCapture` 本身不依賴 Flask;若要使用 `WebDemo` 才需要安裝 `ipcam` extra。
- ```bash
- cd Capybara
- bash docker/build.bash
- ```
+```python
+from capybara.vision.ipcam.camera import IpcamCapture
- 這會使用專案中的 [**Dockerfile**](https://github.com/DocsaidLab/Capybara/blob/main/docker/Dockerfile) 來建立映像檔;映像檔預設以 `nvcr.io/nvidia/cuda:12.4.1-cudnn-runtime-ubuntu22.04` 為基底,提供 ONNXRuntime 推理所需的 CUDA 環境。
+cap = IpcamCapture(url=0, color_base="BGR") # 或填入 RTSP/HTTP URL
+frame = next(cap)
+```
-3. 建置完成後,使用指令掛載工作目錄並執行程式:
+Web demo(需要先安裝:`pip install "capybara-docsaid[ipcam]"`):
- ```bash
- docker run -v ${PWD}:/code -it capybara_infer_image your_scripts.py
- ```
+```python
+from capybara.vision.ipcam.app import WebDemo
- 若需 GPU 加速,可於執行時加入 `--gpus all`。
+WebDemo("rtsp://").run(port=5001)
+```
+
+### 系統資訊(可選)
-#### gosu 權限問題
+需要先安裝:`pip install "capybara-docsaid[system]"`。
-若在容器內執行腳本時,遇到輸出檔案歸屬為 root,導致檔案權限不便的情況,可在 Dockerfile 中加入 `gosu` 進行使用者切換,並在容器啟動時指定 `USER_ID` 與 `GROUP_ID`。
-這樣可避免在多位開發者協作時,需要頻繁調整檔案權限的問題。
+```python
+from capybara.utils.system_info import get_system_info
+
+print(get_system_info())
+```
-具體作法可參考技術文件:[**Integrating gosu Configuration**](https://docsaid.org/docs/capybara/advance/#integrating-gosu-configuration)
+### 影片抽幀
+
+```python
+from capybara import video2frames_v2
+
+frames = video2frames_v2("demo.mp4", frame_per_sec=2, max_size=1280)
+print(len(frames))
+```
-1. 安裝 `gosu`:
+## 推論後端(Inference Backends)
- ```dockerfile
- RUN apt-get update && apt-get install -y gosu
- ```
+推論後端為可選功能;請先用 extras 安裝後再 import 對應 engine 模組。
-2. 在容器啟動指令中使用 `gosu` 切換至容器內的非 root 帳號,以利檔案的讀寫。
+### Runtime / Backend 搭配表
- ```dockerfile
- # Create the entrypoint script
- RUN printf '#!/bin/bash\n\
- if [ ! -z "$USER_ID" ] && [ ! -z "$GROUP_ID" ]; then\n\
- groupadd -g "$GROUP_ID" -o usergroup\n\
- useradd --shell /bin/bash -u "$USER_ID" -g "$GROUP_ID" -o -c "" -m user\n\
- export HOME=/home/user\n\
- chown -R "$USER_ID":"$GROUP_ID" /home/user\n\
- chown -R "$USER_ID":"$GROUP_ID" /code\n\
- fi\n\
- \n\
- # Check for parameters\n\
- if [ $# -gt 0 ]; then\n\
- exec gosu ${USER_ID:-0}:${GROUP_ID:-0} python "$@"\n\
- else\n\
- exec gosu ${USER_ID:-0}:${GROUP_ID:-0} bash\n\
- fi' > "$ENTRYPOINT_SCRIPT"
+注意:TorchScript runtime 在程式內以 `Runtime.pt` 命名(對應安裝 extra:`torchscript`)。
- RUN chmod +x "$ENTRYPOINT_SCRIPT"
+| Runtime (`capybara.runtime.Runtime`) | Backend 名稱 | Provider / device |
+| ------------------------------------ | -------------- | ----------------------------------------------------------------------------------------------------------- |
+| `onnx` | `cpu` | `["CPUExecutionProvider"]` |
+| `onnx` | `cuda` | `["CUDAExecutionProvider"(device_id), "CPUExecutionProvider"]` |
+| `onnx` | `tensorrt` | `["TensorrtExecutionProvider"(device_id), "CUDAExecutionProvider"(device_id), "CPUExecutionProvider"]` |
+| `onnx` | `tensorrt_rtx` | `["NvTensorRTRTXExecutionProvider"(device_id), "CUDAExecutionProvider"(device_id), "CPUExecutionProvider"]` |
+| `openvino` | `cpu` | `device="CPU"` |
+| `openvino` | `gpu` | `device="GPU"` |
+| `openvino` | `npu` | `device="NPU"` |
+| `pt` | `cpu` | `torch.device("cpu")` |
+| `pt` | `cuda` | `torch.device("cuda")` |
- ENTRYPOINT ["/bin/bash", "/entrypoint.sh"]
- ```
+### Runtime registry(auto 後端選擇)
-更多進階配置請參考 [**NVIDIA Container Toolkit**](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) 及 [**docker**](https://docs.docker.com/) 官方文件。
+```python
+from capybara.runtime import Runtime
+
+print(Runtime.onnx.auto_backend_name()) # 優先順序:cuda -> tensorrt_rtx -> tensorrt -> cpu
+print(Runtime.openvino.auto_backend_name()) # 優先順序:gpu -> npu -> cpu
+print(Runtime.pt.auto_backend_name()) # 優先順序:cuda -> cpu
+```
+
+### ONNX Runtime(`capybara.onnxengine`)
+
+```python
+import numpy as np
+from capybara.onnxengine import EngineConfig, ONNXEngine
+
+engine = ONNXEngine(
+ "model.onnx",
+ backend="cpu",
+ config=EngineConfig(enable_io_binding=False),
+)
+outputs = engine.run({"input": np.ones((1, 3, 224, 224), dtype=np.float32)})
+print(outputs.keys())
+print(engine.summary())
+```
+
+### OpenVINO(`capybara.openvinoengine`)
+
+```python
+import numpy as np
+from capybara.openvinoengine import OpenVINOConfig, OpenVINODevice, OpenVINOEngine
+
+engine = OpenVINOEngine(
+ "model.xml",
+ device=OpenVINODevice.cpu,
+ config=OpenVINOConfig(num_requests=2),
+)
+outputs = engine.run({"input": np.ones((1, 3), dtype=np.float32)})
+print(outputs.keys())
+```
+
+### TorchScript(`capybara.torchengine`)
+
+```python
+import numpy as np
+from capybara.torchengine import TorchEngine
+
+engine = TorchEngine("model.pt", device="cpu")
+outputs = engine.run({"image": np.zeros((1, 3, 224, 224), dtype=np.float32)})
+print(outputs.keys())
+```
-## 測試
+### Benchmark(依硬體而異)
-本專案使用 `pytest` 進行單元測試,用戶可自行運行測試以驗證功能的正確性。
-安裝並執行測試的方式如下:
+所有 engines 都提供 `benchmark(...)`,用於快速量測吞吐/延遲。
+
+```python
+import numpy as np
+from capybara.onnxengine import ONNXEngine
+
+engine = ONNXEngine("model.onnx", backend="cpu")
+dummy = np.zeros((1, 3, 224, 224), dtype=np.float32)
+print(engine.benchmark({"input": dummy}, repeat=50, warmup=5))
+```
+
+### 進階:自訂參數(可選)
+
+`EngineConfig` / `OpenVINOConfig` / `TorchEngineConfig` 會原樣傳遞到底層 runtime。
+
+```python
+from capybara.onnxengine import EngineConfig, ONNXEngine
+
+engine = ONNXEngine(
+ "model.onnx",
+ backend="cuda",
+ config=EngineConfig(
+ provider_options={
+ "CUDAExecutionProvider": {
+ "enable_cuda_graph": True,
+ },
+ },
+ ),
+)
+```
+
+## 品質門檻(Quality Gates / 開發者)
+
+本專案在合併前會強制通過:
+
+```bash
+ruff check .
+ruff format --check .
+pyright
+python -m pytest --cov=capybara --cov-config=.coveragerc --cov-report=term
+```
+
+備註:
+
+- 覆蓋率門檻為 **90% 覆蓋率**(規則定義於 `.coveragerc`)。
+- 重型/環境相依模組不納入預設 coverage gate,以維持 CI 可重現與可維護。
+
+## Docker(可選)
+
+```bash
+git clone https://github.com/DocsaidLab/Capybara.git
+cd Capybara
+bash docker/build.bash
+```
+
+執行:
+
+```bash
+docker run --rm -it capybara_docsaid bash
+```
+
+若你需要在容器內使用 GPU,請使用 NVIDIA container runtime(例如 `--gpus all`)。
+
+## 測試(本地)
```bash
-pip install pytest
-python -m pytest -vv tests
+python -m pytest -vv
```
-完成後即可確認各模組運作是否正常。若遇到功能異常,請先檢查環境設定與套件版本。
+## 授權
-若仍無法解決,可至 Issue 區回報。
+Apache-2.0,見 `LICENSE`。
+
+## 引用
+
+```bibtex
+@misc{lin2025capybara,
+ author = {Kun-Hsiang Lin*, Ze Yuan*},
+ title = {Capybara: An Integrated Python Package for Image Processing and Deep Learning.},
+ year = {2025},
+ publisher = {GitHub},
+ howpublished = {\\url{https://github.com/DocsaidLab/Capybara}},
+ note = {* equal contribution}
+}
+```
diff --git a/capybara/__init__.py b/capybara/__init__.py
index a304b45..f28e399 100644
--- a/capybara/__init__.py
+++ b/capybara/__init__.py
@@ -1,8 +1,100 @@
-from .enums import *
-from .mixins import *
-from .onnxengine import *
-from .structures import *
-from .utils import *
-from .vision import *
+from __future__ import annotations
+
+from .enums import BORDER, COLORSTR, FORMATSTR, IMGTYP, INTER, MORPH, ROTATE
+from .mixins import (
+ DataclassCopyMixin,
+ DataclassToJsonMixin,
+ EnumCheckMixin,
+ dict_to_jsonable,
+)
+from .structures.boxes import Box, Boxes, BoxMode
+from .structures.functionals import (
+ jaccard_index,
+ pairwise_ioa,
+ pairwise_iou,
+ polygon_iou,
+)
+from .structures.keypoints import Keypoints, KeypointsList
+from .structures.polygons import (
+ JOIN_STYLE,
+ Polygon,
+ Polygons,
+ order_points_clockwise,
+)
+from .utils.custom_path import get_curdir
+from .utils.powerdict import PowerDict
+from .utils.utils import colorstr, make_batch
+from .vision.functionals import (
+ gaussianblur,
+ imbinarize,
+ imcropbox,
+ imcropboxes,
+ imcvtcolor,
+ imresize_and_pad_if_need,
+ meanblur,
+ medianblur,
+ pad,
+)
+from .vision.geometric import (
+ imresize,
+ imrotate,
+ imrotate90,
+ imwarp_quadrangle,
+ imwarp_quadrangles,
+)
+from .vision.improc import img_to_b64str, imread, imwrite, npy_to_b64str
+from .vision.videotools.video2frames import video2frames
+from .vision.videotools.video2frames_v2 import video2frames_v2
+
+__all__ = [
+ "BORDER",
+ "COLORSTR",
+ "FORMATSTR",
+ "IMGTYP",
+ "INTER",
+ "JOIN_STYLE",
+ "MORPH",
+ "ROTATE",
+ "Box",
+ "BoxMode",
+ "Boxes",
+ "DataclassCopyMixin",
+ "DataclassToJsonMixin",
+ "EnumCheckMixin",
+ "Keypoints",
+ "KeypointsList",
+ "Polygon",
+ "Polygons",
+ "PowerDict",
+ "colorstr",
+ "dict_to_jsonable",
+ "gaussianblur",
+ "get_curdir",
+ "imbinarize",
+ "imcropbox",
+ "imcropboxes",
+ "imcvtcolor",
+ "img_to_b64str",
+ "imread",
+ "imresize",
+ "imresize_and_pad_if_need",
+ "imrotate",
+ "imrotate90",
+ "imwarp_quadrangle",
+ "imwarp_quadrangles",
+ "imwrite",
+ "jaccard_index",
+ "make_batch",
+ "meanblur",
+ "medianblur",
+ "npy_to_b64str",
+ "order_points_clockwise",
+ "pad",
+ "pairwise_ioa",
+ "pairwise_iou",
+ "polygon_iou",
+ "video2frames",
+ "video2frames_v2",
+]
__version__ = "0.12.0"
diff --git a/capybara/cpuinfo.py b/capybara/cpuinfo.py
index 3c08fad..4d56b8f 100644
--- a/capybara/cpuinfo.py
+++ b/capybara/cpuinfo.py
@@ -21,7 +21,7 @@
Usage:
>>> from cpuinfo import cpuinfo
- >>> info = cpuinfo() # len(info) equals to num of cpus.
+ >>> info = cpuinfo() # len(info) equals to num of cpus.
>>> print(list(info[0].keys()))
>>> {
'processor',
@@ -54,7 +54,12 @@
}
"""
-__all__ = ['cpuinfo']
+# ruff: noqa: N802, N815
+#
+# This module is vendored from numexpr (see link above). It keeps upstream
+# helper names that intentionally use non-PEP8 casing.
+
+__all__ = ["cpuinfo"]
import inspect
import os
@@ -64,7 +69,9 @@
import sys
import warnings
-is_cpu_amd_intel = False # DEPRECATION WARNING: WILL BE REMOVED IN FUTURE RELEASE
+is_cpu_amd_intel = (
+ False # DEPRECATION WARNING: WILL BE REMOVED IN FUTURE RELEASE
+)
def getoutput(cmd, successful_status=(0,), stacklevel=1):
@@ -72,9 +79,9 @@ def getoutput(cmd, successful_status=(0,), stacklevel=1):
p = subprocess.Popen(cmd, stdout=subprocess.PIPE)
output, _ = p.communicate()
status = p.returncode
- except EnvironmentError as e:
+ except OSError as e:
warnings.warn(str(e), UserWarning, stacklevel=stacklevel)
- return False, ''
+ return False, ""
if os.WIFEXITED(status) and os.WEXITSTATUS(status) in successful_status:
return True, output
return False, output
@@ -83,38 +90,42 @@ def getoutput(cmd, successful_status=(0,), stacklevel=1):
def command_info(successful_status=(0,), stacklevel=1, **kw):
info = {}
for key in kw:
- ok, output = getoutput(kw[key], successful_status=successful_status,
- stacklevel=stacklevel + 1)
+ ok, output = getoutput(
+ kw[key],
+ successful_status=successful_status,
+ stacklevel=stacklevel + 1,
+ )
if ok:
info[key] = output.strip()
return info
def command_by_line(cmd, successful_status=(0,), stacklevel=1):
- ok, output = getoutput(cmd, successful_status=successful_status,
- stacklevel=stacklevel + 1)
+ ok, output = getoutput(
+ cmd, successful_status=successful_status, stacklevel=stacklevel + 1
+ )
if not ok:
return
# XXX: check
- output = output.decode('ascii')
+ output = output.decode("ascii")
for line in output.splitlines():
yield line.strip()
-def key_value_from_command(cmd, sep, successful_status=(0,),
- stacklevel=1):
+def key_value_from_command(cmd, sep, successful_status=(0,), stacklevel=1):
d = {}
- for line in command_by_line(cmd, successful_status=successful_status,
- stacklevel=stacklevel + 1):
- l = [s.strip() for s in line.split(sep, 1)]
- if len(l) == 2:
- d[l[0]] = l[1]
+ for line in command_by_line(
+ cmd, successful_status=successful_status, stacklevel=stacklevel + 1
+ ):
+ parts = [s.strip() for s in line.split(sep, 1)]
+ if len(parts) == 2:
+ d[parts[0]] = parts[1]
return d
-class CPUInfoBase(object):
+class CPUInfoBase:
"""Holds CPU information and provides methods for requiring
the availability of various CPU features.
"""
@@ -122,13 +133,13 @@ class CPUInfoBase(object):
def _try_call(self, func):
try:
return func()
- except:
- pass
+ except Exception:
+ return None
def __getattr__(self, name):
- if not name.startswith('_'):
- if hasattr(self, '_' + name):
- attr = getattr(self, '_' + name)
+ if not name.startswith("_"):
+ if hasattr(self, "_" + name):
+ attr = getattr(self, "_" + name)
if inspect.ismethod(attr):
return lambda func=self._try_call, attr=attr: func(attr)
else:
@@ -140,14 +151,14 @@ def _getNCPUs(self):
def __get_nbits(self):
abits = platform.architecture()[0]
- nbits = re.compile(r'(\d+)bit').search(abits).group(1)
+ nbits = re.compile(r"(\d+)bit").search(abits).group(1)
return nbits
def _is_32bit(self):
- return self.__get_nbits() == '32'
+ return self.__get_nbits() == "32"
def _is_64bit(self):
- return self.__get_nbits() == '64'
+ return self.__get_nbits() == "64"
class LinuxCPUInfo(CPUInfoBase):
@@ -157,23 +168,21 @@ def __init__(self):
if self.info is not None:
return
info = [{}]
- ok, output = getoutput(['uname', '-m'])
+ ok, output = getoutput(["uname", "-m"])
if ok:
- info[0]['uname_m'] = output.strip()
+ info[0]["uname_m"] = output.strip()
try:
- fo = open('/proc/cpuinfo')
- except EnvironmentError as e:
- warnings.warn(str(e), UserWarning)
- else:
- for line in fo:
- name_value = [s.strip() for s in line.split(':', 1)]
- if len(name_value) != 2:
- continue
- name, value = name_value
- if not info or name in info[-1]: # next processor
- info.append({})
- info[-1][name] = value
- fo.close()
+ with open("/proc/cpuinfo") as fo:
+ for line in fo:
+ name_value = [s.strip() for s in line.split(":", 1)]
+ if len(name_value) != 2:
+ continue
+ name, value = name_value
+ if not info or name in info[-1]: # next processor
+ info.append({})
+ info[-1][name] = value
+ except OSError as e:
+ warnings.warn(str(e), UserWarning, stacklevel=2)
self.__class__.info = info
def _not_impl(self):
@@ -182,59 +191,62 @@ def _not_impl(self):
# Athlon
def _is_AMD(self):
- return self.info[0]['vendor_id'] == 'AuthenticAMD'
+ return self.info[0]["vendor_id"] == "AuthenticAMD"
def _is_AthlonK6_2(self):
- return self._is_AMD() and self.info[0]['model'] == '2'
+ return self._is_AMD() and self.info[0]["model"] == "2"
def _is_AthlonK6_3(self):
- return self._is_AMD() and self.info[0]['model'] == '3'
+ return self._is_AMD() and self.info[0]["model"] == "3"
def _is_AthlonK6(self):
- return re.match(r'.*?AMD-K6', self.info[0]['model name']) is not None
+ return re.match(r".*?AMD-K6", self.info[0]["model name"]) is not None
def _is_AthlonK7(self):
- return re.match(r'.*?AMD-K7', self.info[0]['model name']) is not None
+ return re.match(r".*?AMD-K7", self.info[0]["model name"]) is not None
def _is_AthlonMP(self):
- return re.match(r'.*?Athlon\(tm\) MP\b',
- self.info[0]['model name']) is not None
+ return (
+ re.match(r".*?Athlon\(tm\) MP\b", self.info[0]["model name"])
+ is not None
+ )
def _is_AMD64(self):
- return self.is_AMD() and self.info[0]['family'] == '15'
+ return self.is_AMD() and self.info[0]["family"] == "15"
def _is_Athlon64(self):
- return re.match(r'.*?Athlon\(tm\) 64\b',
- self.info[0]['model name']) is not None
+ return (
+ re.match(r".*?Athlon\(tm\) 64\b", self.info[0]["model name"])
+ is not None
+ )
def _is_AthlonHX(self):
- return re.match(r'.*?Athlon HX\b',
- self.info[0]['model name']) is not None
+ return (
+ re.match(r".*?Athlon HX\b", self.info[0]["model name"]) is not None
+ )
def _is_Opteron(self):
- return re.match(r'.*?Opteron\b',
- self.info[0]['model name']) is not None
+ return re.match(r".*?Opteron\b", self.info[0]["model name"]) is not None
def _is_Hammer(self):
- return re.match(r'.*?Hammer\b',
- self.info[0]['model name']) is not None
+ return re.match(r".*?Hammer\b", self.info[0]["model name"]) is not None
# Alpha
def _is_Alpha(self):
- return self.info[0]['cpu'] == 'Alpha'
+ return self.info[0]["cpu"] == "Alpha"
def _is_EV4(self):
- return self.is_Alpha() and self.info[0]['cpu model'] == 'EV4'
+ return self.is_Alpha() and self.info[0]["cpu model"] == "EV4"
def _is_EV5(self):
- return self.is_Alpha() and self.info[0]['cpu model'] == 'EV5'
+ return self.is_Alpha() and self.info[0]["cpu model"] == "EV5"
def _is_EV56(self):
- return self.is_Alpha() and self.info[0]['cpu model'] == 'EV56'
+ return self.is_Alpha() and self.info[0]["cpu model"] == "EV56"
def _is_PCA56(self):
- return self.is_Alpha() and self.info[0]['cpu model'] == 'PCA56'
+ return self.is_Alpha() and self.info[0]["cpu model"] == "PCA56"
# Intel
@@ -242,94 +254,107 @@ def _is_PCA56(self):
_is_i386 = _not_impl
def _is_Intel(self):
- return self.info[0]['vendor_id'] == 'GenuineIntel'
+ return self.info[0]["vendor_id"] == "GenuineIntel"
def _is_i486(self):
- return self.info[0]['cpu'] == 'i486'
+ return self.info[0]["cpu"] == "i486"
def _is_i586(self):
- return self.is_Intel() and self.info[0]['cpu family'] == '5'
+ return self.is_Intel() and self.info[0]["cpu family"] == "5"
def _is_i686(self):
- return self.is_Intel() and self.info[0]['cpu family'] == '6'
+ return self.is_Intel() and self.info[0]["cpu family"] == "6"
def _is_Celeron(self):
- return re.match(r'.*?Celeron',
- self.info[0]['model name']) is not None
+ return re.match(r".*?Celeron", self.info[0]["model name"]) is not None
def _is_Pentium(self):
- return re.match(r'.*?Pentium',
- self.info[0]['model name']) is not None
+ return re.match(r".*?Pentium", self.info[0]["model name"]) is not None
def _is_PentiumII(self):
- return re.match(r'.*?Pentium.*?II\b',
- self.info[0]['model name']) is not None
+ return (
+ re.match(r".*?Pentium.*?II\b", self.info[0]["model name"])
+ is not None
+ )
def _is_PentiumPro(self):
- return re.match(r'.*?PentiumPro\b',
- self.info[0]['model name']) is not None
+ return (
+ re.match(r".*?PentiumPro\b", self.info[0]["model name"]) is not None
+ )
def _is_PentiumMMX(self):
- return re.match(r'.*?Pentium.*?MMX\b',
- self.info[0]['model name']) is not None
+ return (
+ re.match(r".*?Pentium.*?MMX\b", self.info[0]["model name"])
+ is not None
+ )
def _is_PentiumIII(self):
- return re.match(r'.*?Pentium.*?III\b',
- self.info[0]['model name']) is not None
+ return (
+ re.match(r".*?Pentium.*?III\b", self.info[0]["model name"])
+ is not None
+ )
def _is_PentiumIV(self):
- return re.match(r'.*?Pentium.*?(IV|4)\b',
- self.info[0]['model name']) is not None
+ return (
+ re.match(r".*?Pentium.*?(IV|4)\b", self.info[0]["model name"])
+ is not None
+ )
def _is_PentiumM(self):
- return re.match(r'.*?Pentium.*?M\b',
- self.info[0]['model name']) is not None
+ return (
+ re.match(r".*?Pentium.*?M\b", self.info[0]["model name"])
+ is not None
+ )
def _is_Prescott(self):
return self.is_PentiumIV() and self.has_sse3()
def _is_Nocona(self):
- return (self.is_Intel() and
- self.info[0]['cpu family'] in ('6', '15') and
- # two s sse3; three s ssse3 not the same thing, this is fine
- (self.has_sse3() and not self.has_ssse3()) and
- re.match(r'.*?\blm\b', self.info[0]['flags']) is not None)
+ return (
+ self.is_Intel()
+ and self.info[0]["cpu family"] in ("6", "15")
+ and
+ # two s sse3; three s ssse3 not the same thing, this is fine
+ (self.has_sse3() and not self.has_ssse3())
+ and re.match(r".*?\blm\b", self.info[0]["flags"]) is not None
+ )
def _is_Core2(self):
- return (self.is_64bit() and self.is_Intel() and
- re.match(r'.*?Core\(TM\)2\b',
- self.info[0]['model name']) is not None)
+ return (
+ self.is_64bit()
+ and self.is_Intel()
+ and re.match(r".*?Core\(TM\)2\b", self.info[0]["model name"])
+ is not None
+ )
def _is_Itanium(self):
- return re.match(r'.*?Itanium\b',
- self.info[0]['family']) is not None
+ return re.match(r".*?Itanium\b", self.info[0]["family"]) is not None
def _is_XEON(self):
- return re.match(r'.*?XEON\b',
- self.info[0]['model name'], re.IGNORECASE) is not None
+ return (
+ re.match(r".*?XEON\b", self.info[0]["model name"], re.IGNORECASE)
+ is not None
+ )
_is_Xeon = _is_XEON
# Power
def _is_Power(self):
- return re.match(r'.*POWER.*',
- self.info[0]['cpu']) is not None
+ return re.match(r".*POWER.*", self.info[0]["cpu"]) is not None
def _is_Power7(self):
- return re.match(r'.*POWER7.*',
- self.info[0]['cpu']) is not None
+ return re.match(r".*POWER7.*", self.info[0]["cpu"]) is not None
def _is_Power8(self):
- return re.match(r'.*POWER8.*',
- self.info[0]['cpu']) is not None
+ return re.match(r".*POWER8.*", self.info[0]["cpu"]) is not None
def _is_Power9(self):
- return re.match(r'.*POWER9.*',
- self.info[0]['cpu']) is not None
+ return re.match(r".*POWER9.*", self.info[0]["cpu"]) is not None
def _has_Altivec(self):
- return re.match(r'.*altivec\ supported.*',
- self.info[0]['cpu']) is not None
+ return (
+ re.match(r".*altivec\ supported.*", self.info[0]["cpu"]) is not None
+ )
# Varia
@@ -340,31 +365,31 @@ def _getNCPUs(self):
return len(self.info)
def _has_fdiv_bug(self):
- return self.info[0]['fdiv_bug'] == 'yes'
+ return self.info[0]["fdiv_bug"] == "yes"
def _has_f00f_bug(self):
- return self.info[0]['f00f_bug'] == 'yes'
+ return self.info[0]["f00f_bug"] == "yes"
def _has_mmx(self):
- return re.match(r'.*?\bmmx\b', self.info[0]['flags']) is not None
+ return re.match(r".*?\bmmx\b", self.info[0]["flags"]) is not None
def _has_sse(self):
- return re.match(r'.*?\bsse\b', self.info[0]['flags']) is not None
+ return re.match(r".*?\bsse\b", self.info[0]["flags"]) is not None
def _has_sse2(self):
- return re.match(r'.*?\bsse2\b', self.info[0]['flags']) is not None
+ return re.match(r".*?\bsse2\b", self.info[0]["flags"]) is not None
def _has_sse3(self):
- return re.match(r'.*?\bpni\b', self.info[0]['flags']) is not None
+ return re.match(r".*?\bpni\b", self.info[0]["flags"]) is not None
def _has_ssse3(self):
- return re.match(r'.*?\bssse3\b', self.info[0]['flags']) is not None
+ return re.match(r".*?\bssse3\b", self.info[0]["flags"]) is not None
def _has_3dnow(self):
- return re.match(r'.*?\b3dnow\b', self.info[0]['flags']) is not None
+ return re.match(r".*?\b3dnow\b", self.info[0]["flags"]) is not None
def _has_3dnowext(self):
- return re.match(r'.*?\b3dnowext\b', self.info[0]['flags']) is not None
+ return re.match(r".*?\b3dnowext\b", self.info[0]["flags"]) is not None
class IRIXCPUInfo(CPUInfoBase):
@@ -373,21 +398,22 @@ class IRIXCPUInfo(CPUInfoBase):
def __init__(self):
if self.info is not None:
return
- info = key_value_from_command('sysconf', sep=' ',
- successful_status=(0, 1))
+ info = key_value_from_command(
+ "sysconf", sep=" ", successful_status=(0, 1)
+ )
self.__class__.info = info
def _not_impl(self):
pass
def _is_singleCPU(self):
- return self.info.get('NUM_PROCESSORS') == '1'
+ return self.info.get("NUM_PROCESSORS") == "1"
def _getNCPUs(self):
- return int(self.info.get('NUM_PROCESSORS', 1))
+ return int(self.info.get("NUM_PROCESSORS", 1))
def __cputype(self, n):
- return self.info.get('PROCESSORS').split()[0].lower() == 'r%s' % (n)
+ return self.info.get("PROCESSORS").split()[0].lower() == f"r{n}"
def _is_r2000(self):
return self.__cputype(2000)
@@ -432,16 +458,16 @@ def _is_r12000(self):
return self.__cputype(12000)
def _is_rorion(self):
- return self.__cputype('orion')
+ return self.__cputype("orion")
def get_ip(self):
try:
- return self.info.get('MACHINE')
- except:
- pass
+ return self.info.get("MACHINE")
+ except Exception:
+ return None
def __machine(self, n):
- return self.info.get('MACHINE').lower() == 'ip%s' % (n)
+ return self.info.get("MACHINE").lower() == f"ip{n}"
def _is_IP19(self):
return self.__machine(19)
@@ -495,63 +521,81 @@ class DarwinCPUInfo(CPUInfoBase):
def __init__(self):
if self.info is not None:
return
- info = command_info(arch='arch',
- machine='machine')
- info['sysctl_hw'] = key_value_from_command(['sysctl', 'hw'], sep='=')
+ info = command_info(arch="arch", machine="machine")
+ info["sysctl_hw"] = key_value_from_command(["sysctl", "hw"], sep="=")
self.__class__.info = info
- def _not_impl(self): pass
+ def _not_impl(self):
+ pass
def _getNCPUs(self):
- return int(self.info['sysctl_hw'].get('hw.ncpu', 1))
+ return int(self.info["sysctl_hw"].get("hw.ncpu", 1))
def _is_Power_Macintosh(self):
- return self.info['sysctl_hw']['hw.machine'] == 'Power Macintosh'
+ return self.info["sysctl_hw"]["hw.machine"] == "Power Macintosh"
def _is_i386(self):
- return self.info['arch'] == 'i386'
+ return self.info["arch"] == "i386"
def _is_ppc(self):
- return self.info['arch'] == 'ppc'
+ return self.info["arch"] == "ppc"
def __machine(self, n):
- return self.info['machine'] == 'ppc%s' % n
+ return self.info["machine"] == f"ppc{n}"
- def _is_ppc601(self): return self.__machine(601)
+ def _is_ppc601(self):
+ return self.__machine(601)
- def _is_ppc602(self): return self.__machine(602)
+ def _is_ppc602(self):
+ return self.__machine(602)
- def _is_ppc603(self): return self.__machine(603)
+ def _is_ppc603(self):
+ return self.__machine(603)
- def _is_ppc603e(self): return self.__machine('603e')
+ def _is_ppc603e(self):
+ return self.__machine("603e")
- def _is_ppc604(self): return self.__machine(604)
+ def _is_ppc604(self):
+ return self.__machine(604)
- def _is_ppc604e(self): return self.__machine('604e')
+ def _is_ppc604e(self):
+ return self.__machine("604e")
- def _is_ppc620(self): return self.__machine(620)
+ def _is_ppc620(self):
+ return self.__machine(620)
- def _is_ppc630(self): return self.__machine(630)
+ def _is_ppc630(self):
+ return self.__machine(630)
- def _is_ppc740(self): return self.__machine(740)
+ def _is_ppc740(self):
+ return self.__machine(740)
- def _is_ppc7400(self): return self.__machine(7400)
+ def _is_ppc7400(self):
+ return self.__machine(7400)
- def _is_ppc7450(self): return self.__machine(7450)
+ def _is_ppc7450(self):
+ return self.__machine(7450)
- def _is_ppc750(self): return self.__machine(750)
+ def _is_ppc750(self):
+ return self.__machine(750)
- def _is_ppc403(self): return self.__machine(403)
+ def _is_ppc403(self):
+ return self.__machine(403)
- def _is_ppc505(self): return self.__machine(505)
+ def _is_ppc505(self):
+ return self.__machine(505)
- def _is_ppc801(self): return self.__machine(801)
+ def _is_ppc801(self):
+ return self.__machine(801)
- def _is_ppc821(self): return self.__machine(821)
+ def _is_ppc821(self):
+ return self.__machine(821)
- def _is_ppc823(self): return self.__machine(823)
+ def _is_ppc823(self):
+ return self.__machine(823)
- def _is_ppc860(self): return self.__machine(860)
+ def _is_ppc860(self):
+ return self.__machine(860)
class NetBSDCPUInfo(CPUInfoBase):
@@ -561,25 +605,22 @@ def __init__(self):
if self.info is not None:
return
info = {}
- info['sysctl_hw'] = key_value_from_command(['sysctl', 'hw'], sep='=')
- info['arch'] = info['sysctl_hw'].get('hw.machine_arch', 1)
- info['machine'] = info['sysctl_hw'].get('hw.machine', 1)
+ info["sysctl_hw"] = key_value_from_command(["sysctl", "hw"], sep="=")
+ info["arch"] = info["sysctl_hw"].get("hw.machine_arch", 1)
+ info["machine"] = info["sysctl_hw"].get("hw.machine", 1)
self.__class__.info = info
- def _not_impl(self): pass
+ def _not_impl(self):
+ pass
def _getNCPUs(self):
- return int(self.info['sysctl_hw'].get('hw.ncpu', 1))
+ return int(self.info["sysctl_hw"].get("hw.ncpu", 1))
def _is_Intel(self):
- if self.info['sysctl_hw'].get('hw.model', "")[0:5] == 'Intel':
- return True
- return False
+ return self.info["sysctl_hw"].get("hw.model", "")[0:5] == "Intel"
def _is_AMD(self):
- if self.info['sysctl_hw'].get('hw.model', "")[0:3] == 'AMD':
- return True
- return False
+ return self.info["sysctl_hw"].get("hw.model", "")[0:3] == "AMD"
class SunOSCPUInfo(CPUInfoBase):
@@ -588,17 +629,18 @@ class SunOSCPUInfo(CPUInfoBase):
def __init__(self):
if self.info is not None:
return
- info = command_info(arch='arch',
- mach='mach',
- uname_i=['uname', '-i'],
- isainfo_b=['isainfo', '-b'],
- isainfo_n=['isainfo', '-n'],
- )
- info['uname_X'] = key_value_from_command(['uname', '-X'], sep='=')
- for line in command_by_line(['psrinfo', '-v', '0']):
- m = re.match(r'\s*The (?P[\w\d]+) processor operates at', line)
+ info = command_info(
+ arch="arch",
+ mach="mach",
+ uname_i=["uname", "-i"],
+ isainfo_b=["isainfo", "-b"],
+ isainfo_n=["isainfo", "-n"],
+ )
+ info["uname_X"] = key_value_from_command(["uname", "-X"], sep="=")
+ for line in command_by_line(["psrinfo", "-v", "0"]):
+ m = re.match(r"\s*The (?P
[\w\d]+) processor operates at", line)
if m:
- info['processor'] = m.group('p')
+ info["processor"] = m.group("p")
break
self.__class__.info = info
@@ -606,73 +648,76 @@ def _not_impl(self):
pass
def _is_i386(self):
- return self.info['isainfo_n'] == 'i386'
+ return self.info["isainfo_n"] == "i386"
def _is_sparc(self):
- return self.info['isainfo_n'] == 'sparc'
+ return self.info["isainfo_n"] == "sparc"
def _is_sparcv9(self):
- return self.info['isainfo_n'] == 'sparcv9'
+ return self.info["isainfo_n"] == "sparcv9"
def _getNCPUs(self):
- return int(self.info['uname_X'].get('NumCPU', 1))
+ return int(self.info["uname_X"].get("NumCPU", 1))
def _is_sun4(self):
- return self.info['arch'] == 'sun4'
+ return self.info["arch"] == "sun4"
def _is_SUNW(self):
- return re.match(r'SUNW', self.info['uname_i']) is not None
+ return re.match(r"SUNW", self.info["uname_i"]) is not None
def _is_sparcstation5(self):
- return re.match(r'.*SPARCstation-5', self.info['uname_i']) is not None
+ return re.match(r".*SPARCstation-5", self.info["uname_i"]) is not None
def _is_ultra1(self):
- return re.match(r'.*Ultra-1', self.info['uname_i']) is not None
+ return re.match(r".*Ultra-1", self.info["uname_i"]) is not None
def _is_ultra250(self):
- return re.match(r'.*Ultra-250', self.info['uname_i']) is not None
+ return re.match(r".*Ultra-250", self.info["uname_i"]) is not None
def _is_ultra2(self):
- return re.match(r'.*Ultra-2', self.info['uname_i']) is not None
+ return re.match(r".*Ultra-2", self.info["uname_i"]) is not None
def _is_ultra30(self):
- return re.match(r'.*Ultra-30', self.info['uname_i']) is not None
+ return re.match(r".*Ultra-30", self.info["uname_i"]) is not None
def _is_ultra4(self):
- return re.match(r'.*Ultra-4', self.info['uname_i']) is not None
+ return re.match(r".*Ultra-4", self.info["uname_i"]) is not None
def _is_ultra5_10(self):
- return re.match(r'.*Ultra-5_10', self.info['uname_i']) is not None
+ return re.match(r".*Ultra-5_10", self.info["uname_i"]) is not None
def _is_ultra5(self):
- return re.match(r'.*Ultra-5', self.info['uname_i']) is not None
+ return re.match(r".*Ultra-5", self.info["uname_i"]) is not None
def _is_ultra60(self):
- return re.match(r'.*Ultra-60', self.info['uname_i']) is not None
+ return re.match(r".*Ultra-60", self.info["uname_i"]) is not None
def _is_ultra80(self):
- return re.match(r'.*Ultra-80', self.info['uname_i']) is not None
+ return re.match(r".*Ultra-80", self.info["uname_i"]) is not None
def _is_ultraenterprice(self):
- return re.match(r'.*Ultra-Enterprise', self.info['uname_i']) is not None
+ return re.match(r".*Ultra-Enterprise", self.info["uname_i"]) is not None
def _is_ultraenterprice10k(self):
- return re.match(r'.*Ultra-Enterprise-10000', self.info['uname_i']) is not None
+ return (
+ re.match(r".*Ultra-Enterprise-10000", self.info["uname_i"])
+ is not None
+ )
def _is_sunfire(self):
- return re.match(r'.*Sun-Fire', self.info['uname_i']) is not None
+ return re.match(r".*Sun-Fire", self.info["uname_i"]) is not None
def _is_ultra(self):
- return re.match(r'.*Ultra', self.info['uname_i']) is not None
+ return re.match(r".*Ultra", self.info["uname_i"]) is not None
def _is_cpusparcv7(self):
- return self.info['processor'] == 'sparcv7'
+ return self.info["processor"] == "sparcv7"
def _is_cpusparcv8(self):
- return self.info['processor'] == 'sparcv8'
+ return self.info["processor"] == "sparcv8"
def _is_cpusparcv9(self):
- return self.info['processor'] == 'sparcv9'
+ return self.info["processor"] == "sparcv9"
class Win32CPUInfo(CPUInfoBase):
@@ -694,8 +739,11 @@ def __init__(self):
try:
# XXX: Bad style to use so long `try:...except:...`. Fix it!
- prgx = re.compile(r"family\s+(?P\d+)\s+model\s+(?P\d+)"
- r"\s+stepping\s+(?P\d+)", re.IGNORECASE)
+ prgx = re.compile(
+ r"family\s+(?P\d+)\s+model\s+(?P\d+)"
+ r"\s+stepping\s+(?P\d+)",
+ re.IGNORECASE,
+ )
chnd = _winreg.OpenKey(_winreg.HKEY_LOCAL_MACHINE, self.pkey)
pnum = 0
while 1:
@@ -721,9 +769,11 @@ def __init__(self):
if srch:
info[-1]["Family"] = int(srch.group("FML"))
info[-1]["Model"] = int(srch.group("MDL"))
- info[-1]["Stepping"] = int(srch.group("STP"))
- except:
- print(sys.exc_value, '(ignoring)')
+ info[-1]["Stepping"] = int(
+ srch.group("STP")
+ )
+ except Exception as exc:
+ warnings.warn(f"{exc} (ignoring)", RuntimeWarning, stacklevel=2)
self.__class__.info = info
def _not_impl(self):
@@ -732,86 +782,116 @@ def _not_impl(self):
# Athlon
def _is_AMD(self):
- return self.info[0]['VendorIdentifier'] == 'AuthenticAMD'
+ return self.info[0]["VendorIdentifier"] == "AuthenticAMD"
def _is_Am486(self):
- return self.is_AMD() and self.info[0]['Family'] == 4
+ return self.is_AMD() and self.info[0]["Family"] == 4
def _is_Am5x86(self):
- return self.is_AMD() and self.info[0]['Family'] == 4
+ return self.is_AMD() and self.info[0]["Family"] == 4
def _is_AMDK5(self):
- return (self.is_AMD() and self.info[0]['Family'] == 5 and
- self.info[0]['Model'] in [0, 1, 2, 3])
+ return (
+ self.is_AMD()
+ and self.info[0]["Family"] == 5
+ and self.info[0]["Model"] in [0, 1, 2, 3]
+ )
def _is_AMDK6(self):
- return (self.is_AMD() and self.info[0]['Family'] == 5 and
- self.info[0]['Model'] in [6, 7])
+ return (
+ self.is_AMD()
+ and self.info[0]["Family"] == 5
+ and self.info[0]["Model"] in [6, 7]
+ )
def _is_AMDK6_2(self):
- return (self.is_AMD() and self.info[0]['Family'] == 5 and
- self.info[0]['Model'] == 8)
+ return (
+ self.is_AMD()
+ and self.info[0]["Family"] == 5
+ and self.info[0]["Model"] == 8
+ )
def _is_AMDK6_3(self):
- return (self.is_AMD() and self.info[0]['Family'] == 5 and
- self.info[0]['Model'] == 9)
+ return (
+ self.is_AMD()
+ and self.info[0]["Family"] == 5
+ and self.info[0]["Model"] == 9
+ )
def _is_AMDK7(self):
- return self.is_AMD() and self.info[0]['Family'] == 6
+ return self.is_AMD() and self.info[0]["Family"] == 6
# To reliably distinguish between the different types of AMD64 chips
# (Athlon64, Operton, Athlon64 X2, Semperon, Turion 64, etc.) would
# require looking at the 'brand' from cpuid
def _is_AMD64(self):
- return self.is_AMD() and self.info[0]['Family'] == 15
+ return self.is_AMD() and self.info[0]["Family"] == 15
# Intel
def _is_Intel(self):
- return self.info[0]['VendorIdentifier'] == 'GenuineIntel'
+ return self.info[0]["VendorIdentifier"] == "GenuineIntel"
def _is_i386(self):
- return self.info[0]['Family'] == 3
+ return self.info[0]["Family"] == 3
def _is_i486(self):
- return self.info[0]['Family'] == 4
+ return self.info[0]["Family"] == 4
def _is_i586(self):
- return self.is_Intel() and self.info[0]['Family'] == 5
+ return self.is_Intel() and self.info[0]["Family"] == 5
def _is_i686(self):
- return self.is_Intel() and self.info[0]['Family'] == 6
+ return self.is_Intel() and self.info[0]["Family"] == 6
def _is_Pentium(self):
- return self.is_Intel() and self.info[0]['Family'] == 5
+ return self.is_Intel() and self.info[0]["Family"] == 5
def _is_PentiumMMX(self):
- return (self.is_Intel() and self.info[0]['Family'] == 5 and
- self.info[0]['Model'] == 4)
+ return (
+ self.is_Intel()
+ and self.info[0]["Family"] == 5
+ and self.info[0]["Model"] == 4
+ )
def _is_PentiumPro(self):
- return (self.is_Intel() and self.info[0]['Family'] == 6 and
- self.info[0]['Model'] == 1)
+ return (
+ self.is_Intel()
+ and self.info[0]["Family"] == 6
+ and self.info[0]["Model"] == 1
+ )
def _is_PentiumII(self):
- return (self.is_Intel() and self.info[0]['Family'] == 6 and
- self.info[0]['Model'] in [3, 5, 6])
+ return (
+ self.is_Intel()
+ and self.info[0]["Family"] == 6
+ and self.info[0]["Model"] in [3, 5, 6]
+ )
def _is_PentiumIII(self):
- return (self.is_Intel() and self.info[0]['Family'] == 6 and
- self.info[0]['Model'] in [7, 8, 9, 10, 11])
+ return (
+ self.is_Intel()
+ and self.info[0]["Family"] == 6
+ and self.info[0]["Model"] in [7, 8, 9, 10, 11]
+ )
def _is_PentiumIV(self):
- return self.is_Intel() and self.info[0]['Family'] == 15
+ return self.is_Intel() and self.info[0]["Family"] == 15
def _is_PentiumM(self):
- return (self.is_Intel() and self.info[0]['Family'] == 6 and
- self.info[0]['Model'] in [9, 13, 14])
+ return (
+ self.is_Intel()
+ and self.info[0]["Family"] == 6
+ and self.info[0]["Model"] in [9, 13, 14]
+ )
def _is_Core2(self):
- return (self.is_Intel() and self.info[0]['Family'] == 6 and
- self.info[0]['Model'] in [15, 16, 17])
+ return (
+ self.is_Intel()
+ and self.info[0]["Family"] == 6
+ and self.info[0]["Model"] in [15, 16, 17]
+ )
# Varia
@@ -823,23 +903,25 @@ def _getNCPUs(self):
def _has_mmx(self):
if self.is_Intel():
- return ((self.info[0]['Family'] == 5 and
- self.info[0]['Model'] == 4) or
- (self.info[0]['Family'] in [6, 15]))
+ return (
+ self.info[0]["Family"] == 5 and self.info[0]["Model"] == 4
+ ) or (self.info[0]["Family"] in [6, 15])
elif self.is_AMD():
- return self.info[0]['Family'] in [5, 6, 15]
+ return self.info[0]["Family"] in [5, 6, 15]
else:
return False
def _has_sse(self):
if self.is_Intel():
- return ((self.info[0]['Family'] == 6 and
- self.info[0]['Model'] in [7, 8, 9, 10, 11]) or
- self.info[0]['Family'] == 15)
+ return (
+ self.info[0]["Family"] == 6
+ and self.info[0]["Model"] in [7, 8, 9, 10, 11]
+ ) or self.info[0]["Family"] == 15
elif self.is_AMD():
- return ((self.info[0]['Family'] == 6 and
- self.info[0]['Model'] in [6, 7, 8, 10]) or
- self.info[0]['Family'] == 15)
+ return (
+ self.info[0]["Family"] == 6
+ and self.info[0]["Model"] in [6, 7, 8, 10]
+ ) or self.info[0]["Family"] == 15
else:
return False
@@ -852,25 +934,27 @@ def _has_sse2(self):
return False
def _has_3dnow(self):
- return self.is_AMD() and self.info[0]['Family'] in [5, 6, 15]
+ return self.is_AMD() and self.info[0]["Family"] in [5, 6, 15]
def _has_3dnowext(self):
- return self.is_AMD() and self.info[0]['Family'] in [6, 15]
+ return self.is_AMD() and self.info[0]["Family"] in [6, 15]
-if sys.platform.startswith('linux'): # variations: linux2,linux-i386 (any others?)
+if sys.platform.startswith(
+ "linux"
+): # variations: linux2,linux-i386 (any others?)
cpuinfo = LinuxCPUInfo
-elif sys.platform.startswith('irix'):
+elif sys.platform.startswith("irix"):
cpuinfo = IRIXCPUInfo
-elif sys.platform == 'darwin':
+elif sys.platform == "darwin":
cpuinfo = DarwinCPUInfo
-elif sys.platform[0:6] == 'netbsd':
+elif sys.platform[0:6] == "netbsd":
cpuinfo = NetBSDCPUInfo
-elif sys.platform.startswith('sunos'):
+elif sys.platform.startswith("sunos"):
cpuinfo = SunOSCPUInfo
-elif sys.platform.startswith('win32'):
+elif sys.platform.startswith("win32"):
cpuinfo = Win32CPUInfo
-elif sys.platform.startswith('cygwin'):
+elif sys.platform.startswith("cygwin"):
cpuinfo = LinuxCPUInfo
# XXX: other OS's. Eg. use _winreg on Win32. Or os.uname on unices.
else:
diff --git a/capybara/enums.py b/capybara/enums.py
index d958ffa..15bc85e 100644
--- a/capybara/enums.py
+++ b/capybara/enums.py
@@ -5,7 +5,13 @@
from .mixins import EnumCheckMixin
__all__ = [
- 'INTER', 'ROTATE', 'BORDER', 'MORPH', 'COLORSTR', 'FORMATSTR', 'IMGTYP'
+ "BORDER",
+ "COLORSTR",
+ "FORMATSTR",
+ "IMGTYP",
+ "INTER",
+ "MORPH",
+ "ROTATE",
]
diff --git a/capybara/mixins.py b/capybara/mixins.py
index 6116d61..a8a99d1 100644
--- a/capybara/mixins.py
+++ b/capybara/mixins.py
@@ -1,8 +1,9 @@
import json
from collections import OrderedDict
+from collections.abc import Callable, Mapping, MutableMapping
from dataclasses import asdict
from enum import Enum
-from typing import Any, Callable, Dict, Mapping, Optional
+from typing import Any, TypeVar, cast
from warnings import warn
import numpy as np
@@ -11,32 +12,39 @@
from .structures import Box, Boxes, Keypoints, KeypointsList, Polygon, Polygons
__all__ = [
- "EnumCheckMixin",
"DataclassCopyMixin",
"DataclassToJsonMixin",
+ "EnumCheckMixin",
"dict_to_jsonable",
]
def dict_to_jsonable(
- d: Mapping,
- jsonable_func: Optional[Dict[str, Callable]] = None,
- dict_factory: Mapping = OrderedDict,
-) -> Any:
+ d: Mapping[str, Any],
+ jsonable_func: Mapping[str, Callable[[Any], Any]] | None = None,
+ dict_factory: Callable[[], MutableMapping[str, Any]] = OrderedDict,
+) -> MutableMapping[str, Any]:
out = dict_factory()
for k, v in d.items():
if jsonable_func is not None and k in jsonable_func:
out[k] = jsonable_func[k](v)
else:
if isinstance(v, (Box, Boxes)):
- out[k] = v.convert("XYXY").numpy().astype(float).round().tolist()
+ out[k] = (
+ v.convert("XYXY").numpy().astype(float).round().tolist()
+ )
elif isinstance(v, (Keypoints, KeypointsList, Polygon, Polygons)):
out[k] = v.numpy().astype(float).round().tolist()
elif isinstance(v, (np.ndarray, np.generic)):
# include array and scalar, if you want jsonable image please use jsonable_func
out[k] = v.tolist()
elif isinstance(v, (list, tuple)):
- out[k] = [dict_to_jsonable(x, jsonable_func) if isinstance(x, dict) else x for x in v]
+ out[k] = [
+ dict_to_jsonable(x, jsonable_func)
+ if isinstance(x, dict)
+ else x
+ for x in v
+ ]
elif isinstance(v, Enum):
out[k] = v.name
elif isinstance(v, Mapping):
@@ -47,24 +55,27 @@ def dict_to_jsonable(
try:
json.dumps(out)
except Exception as e:
- warn(e)
+ warn(str(e), stacklevel=2)
return out
+_EnumT = TypeVar("_EnumT", bound="EnumCheckMixin")
+
+
class EnumCheckMixin:
@classmethod
- def obj_to_enum(cls: Enum, obj: Any):
+ def obj_to_enum(cls: type[_EnumT], obj: Any) -> _EnumT:
if isinstance(obj, str):
try:
- return getattr(cls, obj)
+ return cast(_EnumT, getattr(cls, obj))
except AttributeError:
pass
elif isinstance(obj, cls):
return obj
elif isinstance(obj, int):
try:
- return cls(obj)
+ return cast(_EnumT, cast(Any, cls)(obj))
except ValueError:
pass
@@ -73,10 +84,18 @@ def obj_to_enum(cls: Enum, obj: Any):
class DataclassCopyMixin:
def __copy__(self):
- return self.__class__(**{field: getattr(self, field) for field in self.__dataclass_fields__})
+ dataclass_fields = getattr(self, "__dataclass_fields__", None)
+ if dataclass_fields is None:
+ raise TypeError(
+ f"{self.__class__.__name__} is not a dataclass instance."
+ )
+ field_names = cast(dict[str, Any], dataclass_fields).keys()
+ return self.__class__(
+ **{field: getattr(self, field) for field in field_names}
+ )
def __deepcopy__(self, memo):
- out = asdict(self, dict_factory=OrderedDict)
+ out = asdict(cast(Any, self), dict_factory=OrderedDict)
return from_dict(data_class=self.__class__, data=out)
@@ -84,5 +103,7 @@ class DataclassToJsonMixin:
jsonable_func = None
def be_jsonable(self, dict_factory=OrderedDict):
- d = asdict(self, dict_factory=dict_factory)
- return dict_to_jsonable(d, jsonable_func=self.jsonable_func, dict_factory=dict_factory)
+ d = asdict(cast(Any, self), dict_factory=dict_factory)
+ return dict_to_jsonable(
+ d, jsonable_func=self.jsonable_func, dict_factory=dict_factory
+ )
diff --git a/capybara/onnxengine/__init__.py b/capybara/onnxengine/__init__.py
index 18c1434..0b07016 100644
--- a/capybara/onnxengine/__init__.py
+++ b/capybara/onnxengine/__init__.py
@@ -1,8 +1,24 @@
-from .engine import ONNXEngine
-from .engine_io_binding import ONNXEngineIOBinding
-from .enum import Backend
-from .metadata import get_onnx_metadata, parse_metadata_from_onnx, write_metadata_into_onnx
-from .tools import get_onnx_input_infos, get_onnx_output_infos, get_recommended_backend, make_onnx_dynamic_axes
+from ..runtime import Backend
+from .engine import EngineConfig, ONNXEngine
+from .metadata import (
+ get_onnx_metadata,
+ parse_metadata_from_onnx,
+ write_metadata_into_onnx,
+)
+from .utils import (
+ get_onnx_input_infos,
+ get_onnx_output_infos,
+ make_onnx_dynamic_axes,
+)
-# 暫時無法使用
-# from .quantize import quantize, quantize_static
+__all__ = [
+ "Backend",
+ "EngineConfig",
+ "ONNXEngine",
+ "get_onnx_input_infos",
+ "get_onnx_metadata",
+ "get_onnx_output_infos",
+ "make_onnx_dynamic_axes",
+ "parse_metadata_from_onnx",
+ "write_metadata_into_onnx",
+]
diff --git a/capybara/onnxengine/engine.py b/capybara/onnxengine/engine.py
index bbed318..4be5202 100644
--- a/capybara/onnxengine/engine.py
+++ b/capybara/onnxengine/engine.py
@@ -1,178 +1,432 @@
+from __future__ import annotations
+
+import json
+import time
+from collections.abc import Mapping
+from dataclasses import dataclass
from pathlib import Path
-from typing import Any, Dict, Union
+from typing import Any
-import colored
import numpy as np
-import onnxruntime as ort
-from .enum import Backend
-from .metadata import parse_metadata_from_onnx
-from .tools import get_onnx_input_infos, get_onnx_output_infos
+try:
+ _BFLOAT16 = np.dtype("bfloat16")
+except TypeError: # pragma: no cover - older numpy
+ _BFLOAT16 = np.float16
+
+try: # pragma: no cover - optional dependency handled at runtime
+ import onnxruntime as ort # type: ignore
+except Exception as exc: # pragma: no cover
+ raise ImportError(
+ "onnxruntime is required. Install 'onnxruntime' or 'onnxruntime-gpu'."
+ ) from exc
+
+from ..runtime import Backend, ProviderSpec
+
+
+@dataclass(slots=True)
+class EngineConfig:
+ """High level configuration knobs for the ONNX engine."""
+
+ graph_optimization: str | ort.GraphOptimizationLevel = "all"
+ execution_mode: str | ort.ExecutionMode | None = None
+ intra_op_num_threads: int | None = None
+ inter_op_num_threads: int | None = None
+ log_severity_level: int | None = None
+ session_config_entries: dict[str, str] | None = None
+ provider_options: dict[str, dict[str, Any]] | None = None
+ fallback_to_cpu: bool = True
+ enable_io_binding: bool = False
+ run_config_entries: dict[str, str] | None = None
+ enable_profiling: bool = False
+
+
+@dataclass(slots=True)
+class _InputSpec:
+ name: str
+ dtype: np.dtype[Any] | None
+ dtype_str: str
+ shape: tuple[int | None, ...]
+
+
+_GRAPH_OPT_MAP = {
+ "disable": ort.GraphOptimizationLevel.ORT_DISABLE_ALL,
+ "basic": ort.GraphOptimizationLevel.ORT_ENABLE_BASIC,
+ "extended": ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED,
+ "all": ort.GraphOptimizationLevel.ORT_ENABLE_ALL,
+}
+
+_EXECUTION_MODE_MAP = {
+ "sequential": ort.ExecutionMode.ORT_SEQUENTIAL,
+ "parallel": ort.ExecutionMode.ORT_PARALLEL,
+}
+
+_TYPE_MAP = {
+ "tensor(float)": np.float32,
+ "tensor(float16)": np.float16,
+ "tensor(double)": np.float64,
+ "tensor(bfloat16)": _BFLOAT16,
+ "tensor(int64)": np.int64,
+ "tensor(int32)": np.int32,
+ "tensor(int16)": np.int16,
+ "tensor(int8)": np.int8,
+ "tensor(uint64)": np.uint64,
+ "tensor(uint32)": np.uint32,
+ "tensor(uint16)": np.uint16,
+ "tensor(uint8)": np.uint8,
+ "tensor(bool)": np.bool_,
+}
+
+
+# NOTE: Benchmarks are intentionally strict about negative/zero loops to avoid
+# reporting misleading throughput/latency statistics.
+def _validate_benchmark_params(*, repeat: Any, warmup: Any) -> tuple[int, int]:
+ repeat_i = int(repeat)
+ warmup_i = int(warmup)
+ if repeat_i < 1:
+ raise ValueError("repeat must be >= 1.")
+ if warmup_i < 0:
+ raise ValueError("warmup must be >= 0.")
+ return repeat_i, warmup_i
+
+
+# Provider-specific defaults that enable CUDA graph and TensorRT caching tweaks.
+_DEFAULT_PROVIDER_OPTIONS: dict[str, dict[str, Any]] = {
+ "CUDAExecutionProvider": {
+ "cudnn_conv_algo_search": "HEURISTIC",
+ "cudnn_conv_use_max_workspace": "1",
+ "do_copy_in_default_stream": True,
+ "arena_extend_strategy": "kSameAsRequested",
+ "tunable_op_enable": False,
+ "enable_cuda_graph": False, # 試驗性功能, 推論效果不穩定
+ },
+ "TensorrtExecutionProvider": {
+ "trt_max_workspace_size": 4 * 1024 * 1024 * 1024,
+ "trt_int8_enable": False,
+ "trt_fp16_enable": False,
+ "trt_engine_cache_enable": True,
+ "trt_engine_cache_path": "trt_engine_cache",
+ "trt_timing_cache_enable": True,
+ },
+}
class ONNXEngine:
+ """A thin wrapper around onnxruntime.InferenceSession with ergonomic defaults."""
+
def __init__(
self,
- model_path: Union[str, Path],
+ model_path: str | Path,
gpu_id: int = 0,
- backend: Union[str, int, Backend] = Backend.cpu,
- session_option: Dict[str, Any] = {},
- provider_option: Dict[str, Any] = {},
- ):
- """
- Initialize an ONNX model inference engine.
-
- Args:
- model_path (Union[str, Path]):
- Filename or serialized ONNX or ORT format model in a byte string.
- gpu_id (int, optional):
- GPU ID. Defaults to 0.
- backend (Union[str, int, Backend], optional):
- Backend. Defaults to Backend.cuda.
- session_option (Dict[str, Any], optional):
- Session options. Defaults to {}.
- provider_option (Dict[str, Any], optional):
- Provider options. Defaults to {}.
- """
- # setting device info
- backend = Backend.obj_to_enum(backend)
- self.device_id = 0 if backend.name == "cpu" else gpu_id
-
- # setting provider options
- providers = self._get_providers(backend, provider_option)
-
- # setting session options
- sess_options = self._get_session_info(session_option)
-
- # setting onnxruntime session
- model_path = str(model_path) if isinstance(model_path, Path) else model_path
- self.sess = ort.InferenceSession(model_path, sess_options=sess_options, providers=providers)
-
- # setting onnxruntime session info
- self.model_path = model_path
- self.metadata = parse_metadata_from_onnx(model_path)
- self.providers = self.sess.get_providers()
- self.provider_options = self.sess.get_provider_options()
-
- self.input_infos = get_onnx_input_infos(model_path)
- self.output_infos = get_onnx_output_infos(model_path)
-
- def __call__(self, **xs) -> Dict[str, np.ndarray]:
- output_names = list(self.output_infos.keys())
- outs = self.sess.run(output_names, {k: v for k, v in xs.items()})
- outs = {k: v for k, v in zip(output_names, outs)}
- return outs
-
- def _get_session_info(self, session_option: Dict[str, Any] = {}) -> ort.SessionOptions:
- """
- Ref: https://onnxruntime.ai/docs/api/python/api_summary.html#sessionoptions
- """
- sess_opt = ort.SessionOptions()
- session_option_default = {
- "graph_optimization_level": ort.GraphOptimizationLevel.ORT_ENABLE_ALL,
- "log_severity_level": 2,
+ backend: str | Backend = Backend.cpu,
+ session_option: Mapping[str, Any] | None = None,
+ provider_option: Mapping[str, Any] | None = None,
+ config: EngineConfig | None = None,
+ ) -> None:
+ self.model_path = str(model_path)
+ self.backend = Backend.from_any(backend, runtime="onnx")
+ self.device_id = int(gpu_id)
+ self._session_overrides = dict(session_option or {})
+ self._provider_override = dict(provider_option or {})
+ self._cfg = config or EngineConfig()
+
+ self._session = self._create_session()
+ self.providers = self._session.get_providers()
+ self.provider_options = self._session.get_provider_options()
+ self.metadata = self._extract_metadata()
+
+ self._output_names = [node.name for node in self._session.get_outputs()]
+ self._input_specs = self._inspect_inputs()
+ self._binding = (
+ self._session.io_binding() if self._cfg.enable_io_binding else None
+ )
+ self._run_options = self._build_run_options()
+
+ def __call__(self, **inputs: np.ndarray) -> dict[str, np.ndarray]:
+ if len(inputs) == 1 and isinstance(
+ next(iter(inputs.values())), Mapping
+ ):
+ feed_dict = dict(next(iter(inputs.values())))
+ else:
+ feed_dict = inputs
+ feed = self._prepare_feed(feed_dict)
+ return self._run(feed)
+
+ def run(
+ self,
+ feed: Mapping[str, np.ndarray],
+ ) -> dict[str, np.ndarray]:
+ return self._run(self._prepare_feed(feed))
+
+ def summary(self) -> dict[str, Any]:
+ inputs = [
+ {
+ "name": spec.name,
+ "dtype": spec.dtype_str,
+ "shape": list(spec.shape),
+ }
+ for spec in self._input_specs
+ ]
+ outputs = [
+ {
+ "name": out.name,
+ "dtype": getattr(out, "type", ""),
+ "shape": list(getattr(out, "shape", [])),
+ }
+ for out in self._session.get_outputs()
+ ]
+ return {
+ "model": self.model_path,
+ "providers": self.providers,
+ "inputs": inputs,
+ "outputs": outputs,
}
- session_option_default.update(session_option)
- for k, v in session_option_default.items():
- setattr(sess_opt, k, v)
- return sess_opt
-
- def _get_providers(self, backend: Union[str, int, Backend], provider_option: Dict[str, Any] = {}) -> Backend:
- """
- Ref: https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#configuration-options
- """
- if backend == Backend.cuda:
- providers = [
- (
- "CUDAExecutionProvider",
- {
- "device_id": self.device_id,
- "cudnn_conv_use_max_workspace": "1",
- **provider_option,
- },
- )
- ]
- elif backend == Backend.coreml:
- providers = [
- (
- "CoreMLExecutionProvider",
- {
- "ModelFormat": "MLProgram",
- "MLComputeUnits": "ALL",
- "RequireStaticInputShapes": "1",
- **provider_option,
- },
+
+ def benchmark(
+ self,
+ inputs: Mapping[str, np.ndarray],
+ *,
+ repeat: int = 100,
+ warmup: int = 10,
+ ) -> dict[str, Any]:
+ repeat, warmup = _validate_benchmark_params(
+ repeat=repeat, warmup=warmup
+ )
+ feed = self._prepare_feed(inputs)
+ for _ in range(warmup):
+ self._session.run(self._output_names, feed)
+
+ latencies: list[float] = []
+ t0 = time.perf_counter()
+ for _ in range(repeat):
+ start = time.perf_counter()
+ self._session.run(self._output_names, feed)
+ latencies.append((time.perf_counter() - start) * 1e3)
+ total = time.perf_counter() - t0
+ arr = np.asarray(latencies, dtype=np.float64)
+
+ return {
+ "repeat": repeat,
+ "warmup": warmup,
+ "throughput_fps": repeat / total if total else None,
+ "latency_ms": {
+ "mean": float(arr.mean()) if arr.size else None,
+ "median": float(np.median(arr)) if arr.size else None,
+ "p90": float(np.percentile(arr, 90)) if arr.size else None,
+ "p95": float(np.percentile(arr, 95)) if arr.size else None,
+ "min": float(arr.min()) if arr.size else None,
+ "max": float(arr.max()) if arr.size else None,
+ },
+ }
+
+ def _run(self, feed: Mapping[str, np.ndarray]) -> dict[str, np.ndarray]:
+ if self._binding is not None:
+ binding = self._binding
+ binding.clear_binding_inputs()
+ binding.clear_binding_outputs()
+ for name, array in feed.items():
+ binding.bind_cpu_input(name, array)
+ for name in self._output_names:
+ binding.bind_output(name)
+ if self._run_options is not None:
+ self._session.run_with_iobinding(binding, self._run_options)
+ else:
+ self._session.run_with_iobinding(binding)
+ outputs = binding.copy_outputs_to_cpu()
+ else:
+ if self._run_options is not None:
+ outputs = self._session.run(
+ self._output_names, feed, self._run_options
)
+ else:
+ outputs = self._session.run(self._output_names, feed)
+ converted: list[np.ndarray] = []
+ for out in outputs:
+ toarray = getattr(out, "toarray", None)
+ if callable(toarray):
+ out = toarray()
+ converted.append(np.asarray(out))
+ return dict(zip(self._output_names, converted, strict=False))
+
+ def _prepare_feed(self, feed: Mapping[str, Any]) -> dict[str, np.ndarray]:
+ prepared: dict[str, np.ndarray] = {}
+ for spec in self._input_specs:
+ if spec.name not in feed:
+ raise KeyError(f"Missing required input '{spec.name}'.")
+ array = np.asarray(feed[spec.name])
+ if spec.dtype is not None and array.dtype != spec.dtype:
+ array = array.astype(spec.dtype, copy=False)
+ prepared[spec.name] = array
+ return prepared
+
+ def _create_session(self) -> ort.InferenceSession:
+ session_options = self._build_session_options()
+ provider_tuples = self._resolve_providers()
+ provider_names = [name for name, _ in provider_tuples]
+ provider_options = [opts for _, opts in provider_tuples]
+
+ available = set(ort.get_available_providers())
+ missing = [
+ name
+ for name in provider_names
+ if name != "CPUExecutionProvider" and name not in available
+ ]
+ if missing and self._cfg.fallback_to_cpu:
+ provider_names = ["CPUExecutionProvider"]
+ provider_options = [
+ self._build_provider_options("CPUExecutionProvider")
]
- elif backend == Backend.cpu:
- providers = [("CPUExecutionProvider", {})]
- # "CPUExecutionProvider" is different from everything else.
- # provider_option = None
- else:
- raise ValueError(f"backend={backend} is not supported.")
- return providers
-
- def __repr__(self) -> str:
- import re
-
- def strip_ansi_codes(text):
- """Remove ANSI escape codes from a string."""
- ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")
- return ansi_escape.sub("", text)
-
- def format_nested_dict(dict_data, indent=0):
- """Recursively format nested dictionaries with indentation."""
- info = []
- prefix = " " * indent
- for key, value in dict_data.items():
- if isinstance(value, dict):
- info.append(f"{prefix}{key}:")
- info.append(format_nested_dict(value, indent + 1))
- elif isinstance(value, str) and value.startswith("{") and value.endswith("}"):
- try:
- nested_dict = eval(value)
- if isinstance(nested_dict, dict):
- info.append(f"{prefix}{key}:")
- info.append(format_nested_dict(nested_dict, indent + 1))
- else:
- info.append(f"{prefix}{key}: {value}")
- except Exception:
- info.append(f"{prefix}{key}: {value}")
+
+ return ort.InferenceSession(
+ self.model_path,
+ sess_options=session_options,
+ providers=provider_names,
+ provider_options=provider_options,
+ )
+
+ def _build_session_options(self) -> ort.SessionOptions:
+ opts = ort.SessionOptions()
+ cfg = self._cfg
+
+ opts.enable_profiling = bool(cfg.enable_profiling)
+ opts.graph_optimization_level = self._resolve_graph_optimization(
+ cfg.graph_optimization
+ )
+ if cfg.execution_mode is not None:
+ opts.execution_mode = self._resolve_execution_mode(
+ cfg.execution_mode
+ )
+ if cfg.intra_op_num_threads is not None:
+ opts.intra_op_num_threads = int(cfg.intra_op_num_threads)
+ if cfg.inter_op_num_threads is not None:
+ opts.inter_op_num_threads = int(cfg.inter_op_num_threads)
+ if cfg.log_severity_level is not None:
+ opts.log_severity_level = int(cfg.log_severity_level)
+
+ for key, value in (cfg.session_config_entries or {}).items():
+ opts.add_session_config_entry(str(key), str(value))
+ for key, value in self._session_overrides.items():
+ if hasattr(opts, key):
+ setattr(opts, key, value)
+ else:
+ opts.add_session_config_entry(str(key), str(value))
+
+ return opts
+
+ def _resolve_graph_optimization(
+ self,
+ option: str | ort.GraphOptimizationLevel,
+ ) -> ort.GraphOptimizationLevel:
+ if isinstance(option, ort.GraphOptimizationLevel):
+ return option
+ normalized = str(option).lower()
+ return _GRAPH_OPT_MAP.get(
+ normalized, ort.GraphOptimizationLevel.ORT_ENABLE_ALL
+ )
+
+ def _resolve_execution_mode(
+ self,
+ option: str | ort.ExecutionMode,
+ ) -> ort.ExecutionMode:
+ if isinstance(option, ort.ExecutionMode):
+ return option
+ normalized = str(option).lower()
+ return _EXECUTION_MODE_MAP.get(
+ normalized, ort.ExecutionMode.ORT_SEQUENTIAL
+ )
+
+ def _resolve_providers(self) -> list[tuple[str, dict[str, str]]]:
+ cfg_providers = self._cfg.provider_options or {}
+ provider_specs = self.backend.providers or (
+ ProviderSpec("CPUExecutionProvider"),
+ )
+ merged: list[tuple[str, dict[str, str]]] = []
+ for spec in provider_specs:
+ opts = self._build_provider_options(
+ spec.name, include_device=spec.include_device
+ )
+ merged_opts = dict(opts)
+ for key, value in cfg_providers.get(spec.name, {}).items():
+ merged_opts[str(key)] = self._provider_value(value)
+ merged.append((spec.name, merged_opts))
+ return merged
+
+ def _build_provider_options(
+ self,
+ name: str,
+ *,
+ include_device: bool = False,
+ ) -> dict[str, str]:
+ opts: dict[str, Any] = dict(_DEFAULT_PROVIDER_OPTIONS.get(name, {}))
+ if include_device:
+ opts["device_id"] = self.device_id
+ opts.update(
+ self._provider_override
+ if (include_device or name == "CPUExecutionProvider")
+ else {}
+ )
+ return {str(k): self._provider_value(v) for k, v in opts.items()}
+
+ def _provider_value(self, value: Any) -> str:
+ if isinstance(value, bool):
+ return "True" if value else "False"
+ return str(value)
+
+ def _build_run_options(self) -> ort.RunOptions | None:
+ entries = {
+ **(self._cfg.run_config_entries or {}),
+ }
+ if not entries:
+ return None
+ opts = ort.RunOptions()
+ for key, value in entries.items():
+ opts.add_run_config_entry(str(key), str(value))
+ return opts
+
+ def _extract_metadata(self) -> dict[str, Any] | None:
+ try:
+ model_meta = self._session.get_modelmeta()
+ except Exception:
+ return None
+
+ custom = getattr(model_meta, "custom_metadata_map", None)
+ if not isinstance(custom, Mapping):
+ return None
+
+ parsed: dict[str, Any] = {}
+ for key, value in custom.items():
+ if isinstance(value, str):
+ try:
+ parsed[key] = json.loads(value)
+ except Exception:
+ parsed[key] = value
+ else:
+ parsed[key] = value
+ return parsed or None
+
+ def _inspect_inputs(self) -> list[_InputSpec]:
+ specs: list[_InputSpec] = []
+ for node in self._session.get_inputs():
+ dtype = _TYPE_MAP.get(getattr(node, "type", ""))
+ raw_shape = list(getattr(node, "shape", []))
+ shape: list[int | None] = []
+ for dim in raw_shape:
+ if isinstance(dim, (int, np.integer)):
+ shape.append(int(dim))
else:
- info.append(f"{prefix}{key}: {value}")
- return "\n".join(info)
-
- title = "DOCSAID X ONNXRUNTIME"
- divider_length = 50
- divider = f"+{'-' * divider_length}+"
- styled_title = colored.stylize(title, [colored.fg("blue"), colored.attr("bold")])
-
- def center_text(text, width):
- """Center text within a fixed width, handling ANSI escape codes."""
- plain_text = strip_ansi_codes(text)
- text_length = len(plain_text)
- left_padding = (width - text_length) // 2
- right_padding = width - text_length - left_padding
- return f"|{' ' * left_padding}{text}{' ' * right_padding}|"
-
- path = f"Model Path: {self.model_path}"
- input_info = format_nested_dict(self.input_infos)
- output_info = format_nested_dict(self.output_infos)
- metadata = format_nested_dict({"metadata": self.metadata})
- providers = f"Provider: {', '.join(self.providers)}"
- provider_options = format_nested_dict(self.provider_options)
-
- sections = [
- divider,
- center_text(styled_title, divider_length),
- divider,
- path,
- input_info,
- output_info,
- metadata,
- providers,
- provider_options,
- divider,
- ]
+ shape.append(None)
+ specs.append(
+ _InputSpec(
+ name=node.name,
+ dtype=dtype,
+ dtype_str=getattr(node, "type", ""),
+ shape=tuple(shape),
+ )
+ )
+ return specs
- return "\n\n".join(sections)
+ def __repr__(self) -> str: # pragma: no cover - human readable summary
+ return (
+ f"ONNXEngine(model='{self.model_path}', "
+ f"backend='{self.backend.name}', providers={self.providers})"
+ )
diff --git a/capybara/onnxengine/engine_io_binding.py b/capybara/onnxengine/engine_io_binding.py
deleted file mode 100644
index cc8f90c..0000000
--- a/capybara/onnxengine/engine_io_binding.py
+++ /dev/null
@@ -1,203 +0,0 @@
-from pathlib import Path
-from typing import Any, Dict, Union
-
-import colored
-import numpy as np
-import onnxruntime as ort
-
-from .metadata import get_onnx_metadata
-from .tools import get_onnx_input_infos, get_onnx_output_infos
-
-
-class ONNXEngineIOBinding:
- def __init__(
- self,
- model_path: Union[str, Path],
- input_initializer: Dict[str, np.ndarray],
- gpu_id: int = 0,
- session_option: Dict[str, Any] = {},
- provider_option: Dict[str, Any] = {},
- ):
- """
- Initialize an ONNX model inference engine.
-
- Args:
- model_path (Union[str, Path]):
- Filename or serialized ONNX or ORT format model in a byte string.
- gpu_id (int, optional):
- GPU ID. Defaults to 0.
- session_option (Dict[str, Any], optional):
- Session options. Defaults to {}.
- provider_option (Dict[str, Any], optional):
- Provider options. Defaults to {}.
- """
- self.device_id = gpu_id
- providers = ["CUDAExecutionProvider"]
- provider_options = [
- {
- "device_id": self.device_id,
- "cudnn_conv_use_max_workspace": "1",
- "enable_cuda_graph": "1",
- **provider_option,
- }
- ]
-
- # setting session options
- sess_options = self._get_session_info(session_option)
-
- # setting onnxruntime session
- model_path = str(model_path) if isinstance(model_path, Path) else model_path
- self.sess = ort.InferenceSession(
- model_path,
- sess_options=sess_options,
- providers=providers,
- provider_options=provider_options,
- )
- self.device = "cuda" if "CUDAExecutionProvider" in self.sess.get_providers() else "cpu"
-
- # setting onnxruntime session info
- self.model_path = model_path
- self.metadata = get_onnx_metadata(model_path)
- self.providers = self.sess.get_providers()
- self.provider_options = self.sess.get_provider_options()
-
- input_infos, output_infos = self._init_io_infos(model_path, input_initializer)
-
- io_binding, x_ortvalues, y_ortvalues = self._setup_io_binding(input_infos, output_infos)
- self.io_binding = io_binding
- self.x_ortvalues = x_ortvalues
- self.y_ortvalues = y_ortvalues
- self.input_infos = input_infos
- self.output_infos = output_infos
- # # Pass gpu_graph_id to RunOptions through RunConfigs
- # ro = ort.RunOptions()
- # # gpu_graph_id is optional if the session uses only one cuda graph
- # ro.add_run_config_entry("gpu_graph_id", "1")
- # self.run_option = ro
-
- def __call__(self, **xs) -> Dict[str, np.ndarray]:
- self._update_x_ortvalues(xs)
- # self.sess.run_with_iobinding(self.io_binding, self.run_option)
- self.sess.run_with_iobinding(self.io_binding)
- return {k: v.numpy() for k, v in self.y_ortvalues.items()}
-
- def _get_session_info(
- self,
- session_option: Dict[str, Any] = {},
- ) -> ort.SessionOptions:
- """
- Ref: https://onnxruntime.ai/docs/api/python/api_summary.html#sessionoptions
- """
- sess_opt = ort.SessionOptions()
- session_option_default = {
- "graph_optimization_level": ort.GraphOptimizationLevel.ORT_ENABLE_ALL,
- "log_severity_level": 2,
- }
- session_option_default.update(session_option)
- for k, v in session_option_default.items():
- setattr(sess_opt, k, v)
- return sess_opt
-
- def _init_io_infos(self, model_path, input_initializer: dict):
- sess = ort.InferenceSession(
- model_path,
- providers=["CPUExecutionProvider"],
- )
- outs = sess.run(None, input_initializer)
- input_shapes = {k: v.shape for k, v in input_initializer.items()}
- output_shapes = {x.name: o.shape for x, o in zip(sess.get_outputs(), outs)}
- input_infos = get_onnx_input_infos(model_path)
- output_infos = get_onnx_output_infos(model_path)
- for k, v in input_infos.items():
- v["shape"] = input_shapes[k]
- for k, v in output_infos.items():
- v["shape"] = output_shapes[k]
- del sess
- return input_infos, output_infos
-
- def _setup_io_binding(self, input_infos, output_infos):
- x_ortvalues = {}
- y_ortvalues = {}
- for k, v in input_infos.items():
- m = np.zeros(**v)
- x_ortvalues[k] = ort.OrtValue.ortvalue_from_numpy(m, device_type=self.device, device_id=self.device_id)
- for k, v in output_infos.items():
- m = np.zeros(**v)
- y_ortvalues[k] = ort.OrtValue.ortvalue_from_numpy(m, device_type=self.device, device_id=self.device_id)
-
- io_binding = self.sess.io_binding()
- for k, v in x_ortvalues.items():
- io_binding.bind_ortvalue_input(k, v)
- for k, v in y_ortvalues.items():
- io_binding.bind_ortvalue_output(k, v)
-
- return io_binding, x_ortvalues, y_ortvalues
-
- def _update_x_ortvalues(self, xs: dict):
- for k, v in self.x_ortvalues.items():
- v.update_inplace(xs[k])
-
- def __repr__(self) -> str:
- import re
-
- def strip_ansi_codes(text):
- """Remove ANSI escape codes from a string."""
- ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")
- return ansi_escape.sub("", text)
-
- def format_nested_dict(dict_data, indent=0):
- """Recursively format nested dictionaries with indentation."""
- info = []
- prefix = " " * indent
- for key, value in dict_data.items():
- if isinstance(value, dict):
- info.append(f"{prefix}{key}:")
- info.append(format_nested_dict(value, indent + 1))
- elif isinstance(value, str) and value.startswith("{") and value.endswith("}"):
- try:
- nested_dict = eval(value)
- if isinstance(nested_dict, dict):
- info.append(f"{prefix}{key}:")
- info.append(format_nested_dict(nested_dict, indent + 1))
- else:
- info.append(f"{prefix}{key}: {value}")
- except Exception:
- info.append(f"{prefix}{key}: {value}")
- else:
- info.append(f"{prefix}{key}: {value}")
- return "\n".join(info)
-
- title = "DOCSAID X ONNXRUNTIME"
- divider_length = 50
- divider = f"+{'-' * divider_length}+"
- styled_title = colored.stylize(title, [colored.fg("blue"), colored.attr("bold")])
-
- def center_text(text, width):
- """Center text within a fixed width, handling ANSI escape codes."""
- plain_text = strip_ansi_codes(text)
- text_length = len(plain_text)
- left_padding = (width - text_length) // 2
- right_padding = width - text_length - left_padding
- return f"|{' ' * left_padding}{text}{' ' * right_padding}|"
-
- path = f"Model Path: {self.model_path}"
- input_info = format_nested_dict(self.input_infos)
- output_info = format_nested_dict(self.output_infos)
- metadata = format_nested_dict({"metadata": self.metadata})
- providers = f"Provider: {', '.join(self.providers)}"
- provider_options = format_nested_dict(self.provider_options)
-
- sections = [
- divider,
- center_text(styled_title, divider_length),
- divider,
- path,
- input_info,
- output_info,
- metadata,
- providers,
- provider_options,
- divider,
- ]
-
- return "\n\n".join(sections)
diff --git a/capybara/onnxengine/enum.py b/capybara/onnxengine/enum.py
deleted file mode 100644
index a697e12..0000000
--- a/capybara/onnxengine/enum.py
+++ /dev/null
@@ -1,9 +0,0 @@
-from enum import Enum
-
-from ..enums import EnumCheckMixin
-
-
-class Backend(EnumCheckMixin, Enum):
- cpu = 0
- cuda = 1
- coreml = 2
diff --git a/capybara/onnxengine/metadata.py b/capybara/onnxengine/metadata.py
index a7d31dc..9e4f52f 100644
--- a/capybara/onnxengine/metadata.py
+++ b/capybara/onnxengine/metadata.py
@@ -1,48 +1,73 @@
+from __future__ import annotations
+
import json
-from typing import Union
+from pathlib import Path
+from typing import Any
import onnx
-import onnxruntime as ort
-from ..utils import Path, now
+from ..utils.time import now
+
+try: # pragma: no cover - optional dependency
+ import onnxruntime as ort # type: ignore
+except Exception: # pragma: no cover - handled lazily
+ ort = None # type: ignore[assignment]
+
+def _require_ort():
+ if ort is None: # pragma: no cover - depends on optional dep
+ raise ImportError(
+ "onnxruntime is required to read/write ONNX metadata. "
+ "Install 'onnxruntime' or 'onnxruntime-gpu'."
+ )
+ return ort
-def get_onnx_metadata(
- onnx_path: Union[str, Path],
-) -> dict:
- onnx_path = str(onnx_path)
- sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
- metadata = sess.get_modelmeta().custom_metadata_map
- del sess
- return metadata
+
+def get_onnx_metadata(onnx_path: str | Path) -> dict[str, Any]:
+ ort_mod = _require_ort()
+ sess = ort_mod.InferenceSession(
+ str(onnx_path), providers=["CPUExecutionProvider"]
+ )
+ try:
+ custom = sess.get_modelmeta().custom_metadata_map
+ return dict(custom)
+ finally:
+ del sess
def write_metadata_into_onnx(
- onnx_path: Union[str, Path],
- out_path: Union[str, Path],
+ onnx_path: str | Path,
+ out_path: str | Path,
drop_old_meta: bool = False,
- **kwargs,
-):
- onnx_path = str(onnx_path)
- onnx_model = onnx.load(onnx_path)
- meta_data = parse_metadata_from_onnx(onnx_path) if not drop_old_meta else {}
+ **kwargs: Any,
+) -> None:
+ onnx_model = onnx.load(str(onnx_path))
+ meta_data: dict[str, Any] = (
+ {} if drop_old_meta else parse_metadata_from_onnx(onnx_path)
+ )
meta_data.update({"Date": now(fmt="%Y-%m-%d %H:%M:%S"), **kwargs})
onnx.helper.set_model_props(
onnx_model,
- {k: json.dumps(v) for k, v in meta_data.items()},
+ {str(k): json.dumps(v) for k, v in meta_data.items()},
+ )
+ onnx.save(onnx_model, str(out_path))
+
+
+def parse_metadata_from_onnx(onnx_path: str | Path) -> dict[str, Any]:
+ ort_mod = _require_ort()
+ sess = ort_mod.InferenceSession(
+ str(onnx_path), providers=["CPUExecutionProvider"]
)
- onnx.save(onnx_model, out_path)
-
-
-def parse_metadata_from_onnx(
- onnx_path: Union[str, Path],
-) -> dict:
- onnx_path = str(onnx_path)
- sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
- metadata = {
- k: json.loads(v) for k, v in sess.get_modelmeta().custom_metadata_map.items()
- }
- del sess
- return metadata
+ try:
+ metadata_map = sess.get_modelmeta().custom_metadata_map
+ parsed: dict[str, Any] = {}
+ for key, raw in metadata_map.items():
+ if isinstance(raw, str):
+ parsed[key] = json.loads(raw)
+ else:
+ parsed[key] = raw
+ return parsed
+ finally:
+ del sess
diff --git a/capybara/onnxengine/quantize.py b/capybara/onnxengine/quantize.py
deleted file mode 100644
index 15e1aa0..0000000
--- a/capybara/onnxengine/quantize.py
+++ /dev/null
@@ -1,62 +0,0 @@
-# from enum import Enum
-# from typing import List, Union
-
-# import onnx
-# from onnxruntime.quantization import CalibrationDataReader, quantize_static
-# from onnxruntime.quantization.calibrate import CalibrationMethod
-# from onnxruntime.quantization.quant_utils import QuantFormat
-
-# from ..enums import Enum, EnumCheckMixin
-# from ..utils import Path
-
-
-# class DstDevice(EnumCheckMixin, Enum):
-# mobile = 0
-# x86 = 1
-
-
-# def _get_exclude_names_from_op_type(onnx_fpath, op_type):
-# onnx_model = onnx.load(str(onnx_fpath))
-# return [node.name for node in onnx_model.graph.node if node.op_type == op_type]
-
-
-# def quantize(
-# onnx_fpath: Union[str, Path],
-# calibration_data_reader: CalibrationDataReader,
-# dst_device: Union[str, DstDevice] = DstDevice.mobile,
-# op_type_to_exclude: List[str] = [],
-# **quant_cfg,
-# ) -> str:
-# dst_device = DstDevice.obj_to_enum(dst_device)
-# onnx_fpath = Path(onnx_fpath)
-
-# print(f'\nStart to quantize onnx model to {dst_device.name}')
-
-# quant_fpath = str(onnx_fpath).replace('fp32', '').replace(
-# '.onnx', f'_int8_{dst_device.name}.onnx')
-
-# quant_format = QuantFormat.QOperator \
-# if dst_device == DstDevice.mobile else QuantFormat.QDQ
-# quant_format = quant_cfg.get('quant_format', quant_format)
-# calibrate_method = quant_cfg.get(
-# 'calibrate_method', CalibrationMethod.MinMax)
-# per_channel = quant_cfg.get('per_channel', True)
-# reduce_range = quant_cfg.get('reduce_range', True)
-# nodes_to_exclude = quant_cfg.get('nodes_to_exclude', [])
-
-# for op_type in op_type_to_exclude:
-# nodes_to_exclude.extend(
-# _get_exclude_names_from_op_type(onnx_fpath, op_type))
-
-# quantize_static(
-# onnx_fpath,
-# quant_fpath,
-# calibration_data_reader,
-# quant_format=quant_format,
-# calibrate_method=calibrate_method,
-# per_channel=per_channel,
-# reduce_range=reduce_range,
-# nodes_to_exclude=nodes_to_exclude
-# )
-
-# return quant_fpath
diff --git a/capybara/onnxengine/tools.py b/capybara/onnxengine/tools.py
deleted file mode 100644
index 73bffaa..0000000
--- a/capybara/onnxengine/tools.py
+++ /dev/null
@@ -1,94 +0,0 @@
-from pathlib import Path
-from typing import Dict, List, Optional, Union
-
-import onnx
-import onnxruntime as ort
-import onnxslim
-from onnx.helper import make_graph, make_model, make_opsetid, tensor_dtype_to_np_dtype
-
-from .enum import Backend
-
-__all__ = [
- "get_onnx_input_infos",
- "get_onnx_output_infos",
- "make_onnx_dynamic_axes",
- "get_recommended_backend",
-]
-
-
-def get_onnx_input_infos(model: Union[str, Path, onnx.ModelProto]) -> Dict[str, List[int]]:
- if not isinstance(model, onnx.ModelProto):
- model = onnx.load(model)
- return {
- x.name: {
- "shape": [d.dim_value if d.dim_value != 0 else -1 for d in x.type.tensor_type.shape.dim],
- "dtype": tensor_dtype_to_np_dtype(x.type.tensor_type.elem_type),
- }
- for x in model.graph.input
- }
-
-
-def get_onnx_output_infos(model: Union[str, Path, onnx.ModelProto]) -> Dict[str, List[int]]:
- if not isinstance(model, onnx.ModelProto):
- model = onnx.load(model)
- return {
- x.name: {
- "shape": [d.dim_value if d.dim_value != 0 else -1 for d in x.type.tensor_type.shape.dim],
- "dtype": tensor_dtype_to_np_dtype(x.type.tensor_type.elem_type),
- }
- for x in model.graph.output
- }
-
-
-def make_onnx_dynamic_axes(
- model_fpath: Union[str, Path],
- output_fpath: Union[str, Path],
- input_dims: Dict[str, Dict[int, str]],
- output_dims: Dict[str, Dict[int, str]],
- opset_version: Optional[int] = None,
-) -> None:
- onnx_model = onnx.load(model_fpath)
-
- new_graph = make_graph(
- nodes=onnx_model.graph.node,
- name=onnx_model.graph.name,
- inputs=onnx_model.graph.input,
- outputs=onnx_model.graph.output,
- initializer=onnx_model.graph.initializer,
- value_info=None,
- )
-
- if not any(opset.domain == "" for opset in onnx_model.opset_import):
- onnx_model.opset_import.append(make_opsetid(domain="", version=opset_version))
-
- new_model = make_model(new_graph, opset_imports=onnx_model.opset_import, ir_version=onnx_model.ir_version)
-
- for x in new_model.graph.input:
- for name, v in input_dims.items():
- if x.name == name:
- for k, d in v.items():
- x.type.tensor_type.shape.dim[k].dim_param = d
-
- for x in new_model.graph.output:
- for name, v in output_dims.items():
- if x.name == name:
- for k, d in v.items():
- x.type.tensor_type.shape.dim[k].dim_param = d
-
- for x in new_model.graph.node:
- if x.op_type == "Reshape":
- raise ValueError("Reshape cannot be trasformed to dynamic axes")
-
- new_model = onnxslim.slim(new_model)
- onnx.save(new_model, output_fpath)
-
-
-def get_recommended_backend() -> Backend:
- providers = ort.get_available_providers()
- device = ort.get_device()
- if "CUDAExecutionProvider" in providers and device == "GPU":
- return Backend.cuda
- elif "CoreMLExecutionProvider" in providers:
- return Backend.coreml
- else:
- return Backend.cpu
diff --git a/capybara/onnxengine/utils.py b/capybara/onnxengine/utils.py
new file mode 100644
index 0000000..717b7e3
--- /dev/null
+++ b/capybara/onnxengine/utils.py
@@ -0,0 +1,113 @@
+from pathlib import Path
+from typing import Any, cast
+
+import onnx
+import onnxslim
+from onnx.helper import (
+ make_graph,
+ make_model,
+ make_opsetid,
+ tensor_dtype_to_np_dtype,
+)
+
+__all__ = [
+ "get_onnx_input_infos",
+ "get_onnx_output_infos",
+]
+
+
+def get_onnx_input_infos(
+ model: str | Path | onnx.ModelProto,
+) -> dict[str, dict[str, Any]]:
+ if not isinstance(model, onnx.ModelProto):
+ model = onnx.load(model)
+
+ def _dim_to_value(dim: Any) -> int | str:
+ dim_param = getattr(dim, "dim_param", "") or ""
+ if dim_param:
+ return str(dim_param)
+ dim_value = int(getattr(dim, "dim_value", 0) or 0)
+ return dim_value if dim_value != 0 else -1
+
+ return {
+ x.name: {
+ "shape": [_dim_to_value(d) for d in x.type.tensor_type.shape.dim],
+ "dtype": tensor_dtype_to_np_dtype(x.type.tensor_type.elem_type),
+ }
+ for x in model.graph.input
+ }
+
+
+def get_onnx_output_infos(
+ model: str | Path | onnx.ModelProto,
+) -> dict[str, dict[str, Any]]:
+ if not isinstance(model, onnx.ModelProto):
+ model = onnx.load(model)
+
+ def _dim_to_value(dim: Any) -> int | str:
+ dim_param = getattr(dim, "dim_param", "") or ""
+ if dim_param:
+ return str(dim_param)
+ dim_value = int(getattr(dim, "dim_value", 0) or 0)
+ return dim_value if dim_value != 0 else -1
+
+ return {
+ x.name: {
+ "shape": [_dim_to_value(d) for d in x.type.tensor_type.shape.dim],
+ "dtype": tensor_dtype_to_np_dtype(x.type.tensor_type.elem_type),
+ }
+ for x in model.graph.output
+ }
+
+
+def make_onnx_dynamic_axes(
+ model_fpath: str | Path,
+ output_fpath: str | Path,
+ input_dims: dict[str, dict[int, str]],
+ output_dims: dict[str, dict[int, str]],
+ opset_version: int | None = None,
+) -> None:
+ onnx_model = onnx.load(model_fpath)
+
+ new_graph = make_graph(
+ nodes=onnx_model.graph.node,
+ name=onnx_model.graph.name,
+ inputs=onnx_model.graph.input,
+ outputs=onnx_model.graph.output,
+ initializer=onnx_model.graph.initializer,
+ value_info=None,
+ )
+
+ if not any(opset.domain == "" for opset in onnx_model.opset_import):
+ if opset_version is None:
+ opset_version = int(onnx.defs.onnx_opset_version())
+ onnx_model.opset_import.append(
+ make_opsetid(domain="", version=opset_version)
+ )
+
+ new_model = make_model(new_graph, opset_imports=onnx_model.opset_import)
+
+ for x in new_model.graph.input:
+ for name, v in input_dims.items():
+ if x.name == name:
+ for k, d in v.items():
+ x.type.tensor_type.shape.dim[k].dim_param = d
+
+ for x in new_model.graph.output:
+ for name, v in output_dims.items():
+ if x.name == name:
+ for k, d in v.items():
+ x.type.tensor_type.shape.dim[k].dim_param = d
+
+ for x in new_model.graph.node:
+ if x.op_type == "Reshape":
+ raise ValueError("Reshape cannot be trasformed to dynamic axes")
+
+ simplify = getattr(onnxslim, "simplify", None)
+ if callable(simplify):
+ simplified = simplify(new_model)
+ if isinstance(simplified, tuple):
+ new_model = cast(onnx.ModelProto, simplified[0])
+ else:
+ new_model = cast(onnx.ModelProto, simplified)
+ onnx.save(new_model, output_fpath)
diff --git a/capybara/openvinoengine/__init__.py b/capybara/openvinoengine/__init__.py
new file mode 100644
index 0000000..d2ca439
--- /dev/null
+++ b/capybara/openvinoengine/__init__.py
@@ -0,0 +1,3 @@
+from .engine import OpenVINOConfig, OpenVINODevice, OpenVINOEngine
+
+__all__ = ["OpenVINOConfig", "OpenVINODevice", "OpenVINOEngine"]
diff --git a/capybara/openvinoengine/engine.py b/capybara/openvinoengine/engine.py
new file mode 100644
index 0000000..febd359
--- /dev/null
+++ b/capybara/openvinoengine/engine.py
@@ -0,0 +1,626 @@
+from __future__ import annotations
+
+import contextlib
+import queue
+import threading
+import time
+import warnings
+from collections import deque
+from collections.abc import Mapping
+from concurrent.futures import Future
+from dataclasses import dataclass
+from enum import Enum
+from pathlib import Path
+from typing import Any
+
+import numpy as np
+
+__all__ = ["OpenVINOConfig", "OpenVINODevice", "OpenVINOEngine"]
+
+try:
+ _BFLOAT16 = np.dtype("bfloat16")
+except TypeError: # pragma: no cover
+ _BFLOAT16 = np.float16
+
+
+class OpenVINODevice(str, Enum):
+ auto = "AUTO"
+ cpu = "CPU"
+ gpu = "GPU"
+ npu = "NPU"
+ hetero = "HETERO"
+ auto_batch = "AUTO_BATCH"
+
+ @classmethod
+ def from_any(cls, value: Any) -> OpenVINODevice:
+ if isinstance(value, cls):
+ return value
+ if isinstance(value, str):
+ normalized = value.upper()
+ for member in cls:
+ if member.value == normalized:
+ return member
+ raise ValueError(
+ f"Unsupported OpenVINO device '{value}'. "
+ f"Available: {[member.value for member in cls]}"
+ )
+
+
+@dataclass(slots=True)
+class OpenVINOConfig:
+ compile_properties: dict[str, Any] | None = None
+ core_properties: dict[str, Any] | None = None
+ cache_dir: str | Path | None = None
+ num_streams: int | None = None
+ num_threads: int | None = None
+ # None => use engine defaults
+ # 0 => OpenVINO auto (async queue only; sync pool falls back to 1)
+ # >=1 => fixed number of requests
+ num_requests: int | None = None
+ # When disabled, outputs may share OpenVINO-owned buffers and can be
+ # overwritten by subsequent inference calls.
+ copy_outputs: bool = True
+
+
+@dataclass(slots=True)
+class _InputSpec:
+ name: str
+ dtype: np.dtype[Any] | None
+ shape: tuple[int | None, ...]
+
+
+def _lazy_import_openvino():
+ try:
+ import openvino.runtime as ov_runtime # type: ignore
+ except Exception as exc: # pragma: no cover - missing optional dep
+ raise ImportError(
+ "OpenVINO is required. Install it via 'pip install openvino-dev'."
+ ) from exc
+ return ov_runtime
+
+
+def _normalize_device(device: str | OpenVINODevice) -> str:
+ if isinstance(device, OpenVINODevice):
+ return device.value
+ return str(device).upper()
+
+
+def _validate_benchmark_params(*, repeat: Any, warmup: Any) -> tuple[int, int]:
+ repeat_i = int(repeat)
+ warmup_i = int(warmup)
+ if repeat_i < 1:
+ raise ValueError("repeat must be >= 1.")
+ if warmup_i < 0:
+ raise ValueError("warmup must be >= 0.")
+ return repeat_i, warmup_i
+
+
+class OpenVINOEngine:
+ def __init__(
+ self,
+ model_path: str | Path,
+ *,
+ device: str | OpenVINODevice = OpenVINODevice.auto,
+ config: OpenVINOConfig | None = None,
+ core: Any | None = None,
+ input_shapes: dict[str, Any] | None = None,
+ ) -> None:
+ ov = _lazy_import_openvino()
+
+ self.model_path = str(model_path)
+ self.device = _normalize_device(device)
+ self._cfg = config or OpenVINOConfig()
+ self._core = core or ov.Core()
+ self._input_shapes = input_shapes
+ self._ov = ov
+
+ if self._cfg.core_properties:
+ self._core.set_property(self._cfg.core_properties)
+
+ self._type_map = self._build_type_map(ov)
+ self._compiled_model = self._compile_model()
+ self._input_specs = self._inspect_inputs()
+ self._output_ports = list(self._compiled_model.outputs)
+ self._output_names = [
+ port.get_any_name() for port in self._output_ports
+ ]
+ self._copy_outputs = bool(self._cfg.copy_outputs)
+ self._request_pool = self._create_request_pool()
+
+ def __call__(self, **inputs: np.ndarray) -> dict[str, np.ndarray]:
+ if len(inputs) == 1 and isinstance(
+ next(iter(inputs.values())), Mapping
+ ):
+ feed_dict = dict(next(iter(inputs.values())))
+ else:
+ feed_dict = inputs
+ return self.run(feed_dict)
+
+ def run(self, feed: Mapping[str, np.ndarray]) -> dict[str, np.ndarray]:
+ prepared = self._prepare_feed(feed)
+ infer_request = self._request_pool.get()
+ try:
+ infer_request.infer(prepared)
+ except Exception:
+ try:
+ replacement = self._compiled_model.create_infer_request()
+ except Exception: # pragma: no cover - unexpected runtime failure
+ replacement = infer_request
+ self._request_pool.put(replacement)
+ raise
+ else:
+ try:
+ return self._collect_outputs(
+ infer_request, copy_outputs=self._copy_outputs
+ )
+ finally:
+ self._request_pool.put(infer_request)
+
+ def create_async_queue(
+ self,
+ *,
+ num_requests: int | None = None,
+ copy_outputs: bool | None = None,
+ ) -> OpenVINOAsyncQueue:
+ """Create an async inference queue backed by AsyncInferQueue.
+
+ This enables pipelining: while the device is running inference for one
+ request, your Python code can prepare the next input(s).
+
+ `num_requests` semantics:
+ - None: use `OpenVINOConfig.num_requests` (or a default)
+ - 0: let OpenVINO decide the number of requests (AUTO)
+ - >=1: fixed number of requests
+ """
+ return OpenVINOAsyncQueue(
+ self,
+ num_requests=num_requests,
+ copy_outputs=copy_outputs,
+ )
+
+ def summary(self) -> dict[str, Any]:
+ inputs = [
+ {
+ "name": spec.name,
+ "dtype": str(spec.dtype),
+ "shape": list(spec.shape),
+ }
+ for spec in self._input_specs
+ ]
+ outputs = [
+ {
+ "name": port.get_any_name(),
+ "dtype": str(port.get_element_type()),
+ "shape": list(
+ self._partial_shape_to_tuple(port.get_partial_shape())
+ ),
+ }
+ for port in self._output_ports
+ ]
+ return {
+ "model": self.model_path,
+ "device": self.device,
+ "inputs": inputs,
+ "outputs": outputs,
+ }
+
+ def benchmark(
+ self,
+ feed: Mapping[str, np.ndarray],
+ *,
+ repeat: int = 100,
+ warmup: int = 10,
+ ) -> dict[str, Any]:
+ repeat, warmup = _validate_benchmark_params(
+ repeat=repeat, warmup=warmup
+ )
+ prepared = self._prepare_feed(feed)
+ infer_request = self._compiled_model.create_infer_request()
+
+ for _ in range(warmup):
+ infer_request.infer(prepared)
+
+ latencies: list[float] = []
+ t0 = time.perf_counter()
+ for _ in range(repeat):
+ start = time.perf_counter()
+ infer_request.infer(prepared)
+ latencies.append((time.perf_counter() - start) * 1e3)
+ total = time.perf_counter() - t0
+ arr = np.asarray(latencies, dtype=np.float64)
+ return {
+ "repeat": repeat,
+ "warmup": warmup,
+ "throughput_fps": repeat / total if total else None,
+ "latency_ms": {
+ "mean": float(arr.mean()) if arr.size else None,
+ "median": float(np.median(arr)) if arr.size else None,
+ "p90": float(np.percentile(arr, 90)) if arr.size else None,
+ "p95": float(np.percentile(arr, 95)) if arr.size else None,
+ "min": float(arr.min()) if arr.size else None,
+ "max": float(arr.max()) if arr.size else None,
+ },
+ }
+
+ def benchmark_async(
+ self,
+ feed: Mapping[str, np.ndarray],
+ *,
+ repeat: int = 100,
+ warmup: int = 10,
+ num_requests: int | None = None,
+ ) -> dict[str, Any]:
+ """Benchmark throughput using AsyncInferQueue."""
+ repeat, warmup = _validate_benchmark_params(
+ repeat=repeat, warmup=warmup
+ )
+ prepared = self._prepare_feed(feed)
+ default_jobs = self._resolve_async_jobs(
+ self._cfg.num_requests, default=2
+ )
+ jobs = self._resolve_async_jobs(num_requests, default=default_jobs)
+
+ # Warmup with synchronous inference to keep behaviour deterministic.
+ warm_req = self._compiled_model.create_infer_request()
+ for _ in range(warmup):
+ warm_req.infer(prepared)
+
+ latencies: list[float] = []
+ with self.create_async_queue(num_requests=jobs) as async_queue:
+ in_flight_limit = jobs if jobs > 0 else 2
+ in_flight_limit = max(1, in_flight_limit)
+ in_flight: deque[tuple[Future[dict[str, np.ndarray]], float]] = (
+ deque()
+ )
+
+ t0 = time.perf_counter()
+ for _ in range(repeat):
+ if len(in_flight) >= in_flight_limit:
+ fut, start = in_flight.popleft()
+ fut.result()
+ latencies.append((time.perf_counter() - start) * 1e3)
+
+ start = time.perf_counter()
+ fut = async_queue.submit(prepared)
+ in_flight.append((fut, start))
+
+ while in_flight:
+ fut, start = in_flight.popleft()
+ fut.result()
+ latencies.append((time.perf_counter() - start) * 1e3)
+ total = time.perf_counter() - t0
+
+ arr = np.asarray(latencies, dtype=np.float64)
+ return {
+ "repeat": repeat,
+ "warmup": warmup,
+ "num_requests": jobs,
+ "throughput_fps": repeat / total if total else None,
+ "latency_ms": {
+ "mean": float(arr.mean()) if arr.size else None,
+ "median": float(np.median(arr)) if arr.size else None,
+ "p90": float(np.percentile(arr, 90)) if arr.size else None,
+ "p95": float(np.percentile(arr, 95)) if arr.size else None,
+ "min": float(arr.min()) if arr.size else None,
+ "max": float(arr.max()) if arr.size else None,
+ },
+ }
+
+ def _compile_model(self):
+ model = self._core.read_model(self.model_path)
+ self._maybe_reshape_model(model)
+ properties = dict(self._cfg.compile_properties or {})
+ if self._cfg.cache_dir is not None:
+ properties["CACHE_DIR"] = str(self._cfg.cache_dir)
+ if self._cfg.num_streams is not None:
+ properties["NUM_STREAMS"] = str(self._cfg.num_streams)
+ if self._cfg.num_threads is not None:
+ properties["INFERENCE_NUM_THREADS"] = str(self._cfg.num_threads)
+ return self._core.compile_model(model, self.device, properties)
+
+ def _create_request_pool(self) -> queue.Queue[Any]:
+ pool_size = self._resolve_pool_size(self._cfg.num_requests)
+ pool: queue.Queue[Any] = queue.Queue(maxsize=pool_size)
+ for _ in range(pool_size):
+ pool.put(self._compiled_model.create_infer_request())
+ return pool
+
+ def _resolve_pool_size(self, value: Any | None) -> int:
+ if value is None:
+ return 1
+ requests = int(value)
+ if requests < 0:
+ raise ValueError("num_requests must be >= 0.")
+ if requests == 0:
+ return 1
+ return requests
+
+ def _resolve_async_jobs(self, value: Any | None, *, default: int) -> int:
+ if value is None:
+ return int(default)
+ requests = int(value)
+ if requests < 0:
+ raise ValueError("num_requests must be >= 0.")
+ return requests
+
+ def _collect_outputs(
+ self,
+ infer_request: Any,
+ *,
+ copy_outputs: bool,
+ ) -> dict[str, np.ndarray]:
+ """Return output tensors as numpy arrays.
+
+ When `copy_outputs=False`, the returned arrays may share OpenVINO-owned
+ buffers and can be overwritten by subsequent inference calls.
+ """
+ outputs: dict[str, np.ndarray] = {}
+ for port, name in zip(
+ self._output_ports, self._output_names, strict=True
+ ):
+ tensor = infer_request.get_tensor(port)
+ outputs[name] = np.array(tensor.data, copy=copy_outputs)
+ return outputs
+
+ def _maybe_reshape_model(self, model: Any) -> None:
+ if not self._input_shapes:
+ return
+ if not hasattr(model, "reshape"):
+ raise RuntimeError(
+ "OpenVINO model does not support reshape(), "
+ "but input_shapes were provided."
+ )
+
+ reshape_map: dict[str, Any] = {}
+ for name, shape in self._input_shapes.items():
+ reshape_map[name] = self._normalize_shape_dims(shape)
+
+ model.reshape(reshape_map)
+
+ def _normalize_shape_dims(self, shape: Any) -> Any:
+ if isinstance(shape, (list, tuple)):
+ dims: list[int] = []
+ for dim in shape:
+ if dim is None:
+ raise ValueError(
+ "input_shapes must use concrete dimensions; got None."
+ )
+ dims.append(int(dim))
+ return tuple(dims)
+ return shape
+
+ def _inspect_inputs(self) -> list[_InputSpec]:
+ specs: list[_InputSpec] = []
+ for port in self._compiled_model.inputs:
+ specs.append(
+ _InputSpec(
+ name=port.get_any_name(),
+ dtype=self._type_map.get(port.get_element_type()),
+ shape=self._partial_shape_to_tuple(
+ port.get_partial_shape()
+ ),
+ )
+ )
+ return specs
+
+ def _prepare_feed(self, feed: Mapping[str, Any]) -> dict[str, np.ndarray]:
+ prepared: dict[str, np.ndarray] = {}
+ for spec in self._input_specs:
+ if spec.name not in feed:
+ raise KeyError(f"Missing required input '{spec.name}'.")
+ array = np.asarray(feed[spec.name])
+ if spec.dtype is not None and array.dtype != spec.dtype:
+ array = array.astype(spec.dtype, copy=False)
+ prepared[spec.name] = array
+ return prepared
+
+ def _partial_shape_to_tuple(self, shape) -> tuple[int | None, ...]:
+ dims: list[int | None] = []
+ for dim in shape:
+ if hasattr(dim, "is_static") and dim.is_static:
+ dims.append(int(dim.get_length()))
+ elif isinstance(dim, int):
+ dims.append(int(dim))
+ else:
+ dims.append(None)
+ return tuple(dims)
+
+ def _build_type_map(self, ov_module) -> dict[Any, np.dtype[Any]]:
+ type_cls = getattr(ov_module, "Type", None)
+ if type_cls is None:
+ return {}
+ mapping: dict[Any, np.dtype[Any]] = {}
+ pairs = {
+ "f32": np.float32,
+ "f16": np.float16,
+ "bf16": _BFLOAT16,
+ "i64": np.int64,
+ "i32": np.int32,
+ "i16": np.int16,
+ "i8": np.int8,
+ "u64": np.uint64,
+ "u32": np.uint32,
+ "u16": np.uint16,
+ "u8": np.uint8,
+ "boolean": np.bool_,
+ }
+ for attr, dtype in pairs.items():
+ if hasattr(type_cls, attr):
+ mapping[getattr(type_cls, attr)] = dtype
+ return mapping
+
+ def __repr__(self) -> str: # pragma: no cover - helper for debugging
+ return (
+ f"OpenVINOEngine(model='{self.model_path}', device='{self.device}')"
+ )
+
+
+class OpenVINOAsyncQueue:
+ """A small wrapper around OpenVINO AsyncInferQueue.
+
+ Use this to overlap preprocessing/postprocessing with device execution.
+ """
+
+ def __init__(
+ self,
+ engine: OpenVINOEngine,
+ *,
+ num_requests: int | None = None,
+ copy_outputs: bool | None = None,
+ ) -> None:
+ self._engine = engine
+ self._ov = engine._ov
+ self._copy_outputs = (
+ engine._copy_outputs if copy_outputs is None else bool(copy_outputs)
+ )
+ self._completion_queue_warned = False
+ default_jobs = engine._resolve_async_jobs(
+ engine._cfg.num_requests, default=2
+ )
+ jobs = engine._resolve_async_jobs(num_requests, default=default_jobs)
+
+ infer_queue_cls = getattr(self._ov, "AsyncInferQueue", None)
+ if infer_queue_cls is None:
+ raise RuntimeError(
+ "Your OpenVINO installation does not expose AsyncInferQueue. "
+ "Please upgrade 'openvino' / 'openvino-dev' to a newer version."
+ )
+
+ self._queue = infer_queue_cls(engine._compiled_model, jobs)
+ self._queue.set_callback(self._callback)
+ self._closed = False
+ self._request_id_lock = threading.Lock()
+ self._request_id_seq = 0
+
+ def _allocate_request_id(self) -> int:
+ with self._request_id_lock:
+ request_id = self._request_id_seq
+ self._request_id_seq += 1
+ return request_id
+
+ def submit(
+ self,
+ feed: Mapping[str, np.ndarray],
+ *,
+ request_id: Any | None = None,
+ completion_queue: Any | None = None,
+ ) -> Future[dict[str, np.ndarray]]:
+ """Start an async request and return a Future of raw model outputs.
+
+ If ``completion_queue`` is provided, this enqueues ``(request_id, outputs)``
+ in the OpenVINO callback (non-blocking; events are dropped if the queue is
+ full or incompatible).
+
+ Note: ``outputs`` are the raw model outputs (before any decoding). When
+ ``copy_outputs=False``, the returned arrays may share OpenVINO-owned
+ buffers and can be overwritten by subsequent inference calls; consume them
+ immediately.
+
+ The returned Future includes a ``request_id`` attribute that matches the
+ ID used for submission. If ``request_id`` is omitted, an integer is
+ auto-generated. You can also pass any correlation ID (e.g. a string).
+ """
+ if self._closed:
+ raise RuntimeError("Async queue is closed.")
+ resolved_request_id = request_id
+ if resolved_request_id is None:
+ resolved_request_id = self._allocate_request_id()
+ future: Future[dict[str, np.ndarray]] = Future()
+ with contextlib.suppress(Exception): # pragma: no cover - defensive
+ future.request_id = resolved_request_id # type: ignore[attr-defined]
+ if completion_queue is not None:
+ with contextlib.suppress(Exception): # pragma: no cover - defensive
+ future.completion_queue = completion_queue # type: ignore[attr-defined]
+ prepared = self._engine._prepare_feed(feed)
+ self._queue.start_async(
+ prepared,
+ _AsyncUserdata(
+ future=future,
+ request_id=resolved_request_id,
+ completion_queue=completion_queue,
+ ),
+ )
+ return future
+
+ def wait_all(self) -> None:
+ self._queue.wait_all()
+
+ def close(self) -> None:
+ if self._closed:
+ return
+ self.wait_all()
+ self._closed = True
+
+ def __enter__(self) -> OpenVINOAsyncQueue:
+ return self
+
+ def __exit__(self, exc_type, exc, tb) -> None:
+ self.close()
+
+ def _callback(self, infer_request: Any, userdata: Any) -> None:
+ completion_queue = None
+ request_id = None
+ if isinstance(userdata, _AsyncUserdata):
+ future = userdata.future
+ completion_queue = userdata.completion_queue
+ request_id = userdata.request_id
+ elif isinstance(
+ userdata, Future
+ ): # pragma: no cover - legacy defensive
+ future = userdata
+ completion_queue = getattr(future, "completion_queue", None)
+ request_id = getattr(future, "request_id", None)
+ else: # pragma: no cover - defensive
+ return
+ try:
+ outputs = self._engine._collect_outputs(
+ infer_request, copy_outputs=self._copy_outputs
+ )
+ except Exception as exc: # pragma: no cover - surfaced via future
+ future.set_exception(exc)
+ else:
+ future.set_result(outputs)
+ if completion_queue is not None and request_id is not None:
+ item = (request_id, dict(outputs))
+ try:
+ put_nowait = getattr(completion_queue, "put_nowait", None)
+ if callable(put_nowait):
+ put_nowait(item)
+ else:
+ completion_queue.put(item, block=False)
+ except queue.Full:
+ if not self._completion_queue_warned:
+ warnings.warn(
+ "completion_queue is full; dropping async completion "
+ "events to avoid blocking the OpenVINO callback thread.",
+ RuntimeWarning,
+ stacklevel=2,
+ )
+ self._completion_queue_warned = True
+ except (TypeError, AttributeError):
+ if not self._completion_queue_warned:
+ warnings.warn(
+ "completion_queue does not support non-blocking put; "
+ "dropping async completion events to avoid blocking "
+ "the OpenVINO callback thread.",
+ RuntimeWarning,
+ stacklevel=2,
+ )
+ self._completion_queue_warned = True
+ except Exception: # pragma: no cover - defensive
+ if not self._completion_queue_warned:
+ warnings.warn(
+ "completion_queue enqueue failed; dropping async "
+ "completion events to avoid blocking the OpenVINO "
+ "callback thread.",
+ RuntimeWarning,
+ stacklevel=2,
+ )
+ self._completion_queue_warned = True
+
+
+@dataclass(slots=True)
+class _AsyncUserdata:
+ future: Future[dict[str, np.ndarray]]
+ request_id: Any | None
+ completion_queue: Any | None
diff --git a/capybara/runtime.py b/capybara/runtime.py
new file mode 100644
index 0000000..3ed4ab0
--- /dev/null
+++ b/capybara/runtime.py
@@ -0,0 +1,340 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+from importlib import import_module
+from typing import Any, ClassVar, cast
+
+__all__ = ["Backend", "ProviderSpec", "Runtime"]
+
+
+def _normalize_key(value: str | Runtime) -> str:
+ if isinstance(value, Runtime):
+ return value.name
+ return str(value).strip().lower()
+
+
+@dataclass(frozen=True)
+class ProviderSpec:
+ name: str
+ include_device: bool = False
+
+
+@dataclass(frozen=True)
+class Backend:
+ name: str
+ runtime_key: str
+ providers: tuple[ProviderSpec, ...] = ()
+ device: str | None = None
+ description: str | None = None
+
+ _REGISTRY: ClassVar[dict[str, dict[str, Backend]]] = {}
+ cpu: ClassVar[Backend]
+ cuda: ClassVar[Backend]
+ tensorrt: ClassVar[Backend]
+ tensorrt_rtx: ClassVar[Backend]
+ ov_cpu: ClassVar[Backend]
+ ov_gpu: ClassVar[Backend]
+ ov_npu: ClassVar[Backend]
+
+ def __post_init__(self) -> None:
+ normalized_runtime = self.runtime_key.strip().lower()
+ normalized_name = self.name.strip().lower()
+ object.__setattr__(self, "runtime_key", normalized_runtime)
+ object.__setattr__(self, "name", normalized_name)
+ namespace = self._REGISTRY.setdefault(normalized_runtime, {})
+ if normalized_name in namespace:
+ raise ValueError(
+ f"Backend '{self.name}' already registered for runtime '{self.runtime_key}'."
+ )
+ namespace[normalized_name] = self
+
+ @property
+ def runtime(self) -> str:
+ return self.runtime_key
+
+ @classmethod
+ def available(cls, runtime: str | Runtime) -> tuple[Backend, ...]:
+ runtime_key = _normalize_key(runtime)
+ namespace = cls._REGISTRY.get(runtime_key, {})
+ return tuple(namespace.values())
+
+ @classmethod
+ def from_any(
+ cls,
+ value: Any,
+ *,
+ runtime: str | Runtime | None = None,
+ ) -> Backend:
+ runtime_key = cls._resolve_runtime_key(runtime)
+ namespace = cls._REGISTRY.get(runtime_key, {})
+ if isinstance(value, Backend):
+ if value.runtime_key != runtime_key:
+ raise ValueError(
+ f"Backend '{value.name}' is not registered for runtime '{runtime_key}'."
+ )
+ return value
+ normalized = str(value).strip().lower()
+ try:
+ return namespace[normalized]
+ except KeyError: # pragma: no cover - defensive guard
+ options = ", ".join(namespace.keys()) or ""
+ raise ValueError(
+ f"Unsupported backend '{value}' for runtime '{runtime_key}'. "
+ f"Pick from [{options}]"
+ ) from None
+
+ @classmethod
+ def _resolve_runtime_key(cls, runtime: str | Runtime | None) -> str:
+ if runtime is None:
+ if len(cls._REGISTRY) == 1:
+ return next(iter(cls._REGISTRY))
+ raise ValueError(
+ "runtime must be specified when resolving backends."
+ )
+ if isinstance(runtime, Runtime):
+ return runtime.name
+ return str(runtime).strip().lower()
+
+
+@dataclass(frozen=True)
+class Runtime:
+ name: str
+ backend_names: tuple[str, ...]
+ default_backend_name: str
+ description: str | None = None
+
+ _REGISTRY: ClassVar[dict[str, Runtime]] = {}
+ onnx: ClassVar[Runtime]
+ openvino: ClassVar[Runtime]
+ pt: ClassVar[Runtime]
+
+ def __post_init__(self) -> None:
+ normalized = self.name.strip().lower()
+ object.__setattr__(self, "name", normalized)
+ normalized_backends = tuple(
+ name.strip().lower() for name in self.backend_names
+ )
+ object.__setattr__(self, "backend_names", normalized_backends)
+ default_backend = self.default_backend_name.strip().lower()
+ object.__setattr__(self, "default_backend_name", default_backend)
+
+ if normalized in self._REGISTRY:
+ raise ValueError(f"Runtime '{self.name}' is already registered.")
+ self._REGISTRY[normalized] = self
+
+ available = {backend.name for backend in Backend.available(normalized)}
+ requested = set(normalized_backends)
+ missing = requested - available
+ if missing:
+ raise ValueError(
+ f"Runtime '{self.name}' references unknown backend(s): {sorted(missing)}."
+ )
+ if default_backend not in requested:
+ raise ValueError(
+ f"Default backend '{self.default_backend_name}' is not tracked "
+ f"for runtime '{self.name}'."
+ )
+
+ @classmethod
+ def from_any(cls, value: Any) -> Runtime:
+ if isinstance(value, Runtime):
+ return value
+ normalized = str(value).strip().lower()
+ try:
+ return cls._REGISTRY[normalized]
+ except KeyError: # pragma: no cover - defensive
+ options = ", ".join(cls._REGISTRY) or ""
+ raise ValueError(
+ f"Unsupported runtime '{value}'. Pick from [{options}]"
+ ) from None
+
+ def available_backends(self) -> tuple[Backend, ...]:
+ namespace = {
+ backend.name: backend for backend in Backend.available(self)
+ }
+ order = []
+ for name in self.backend_names:
+ order.append(namespace[name])
+ return tuple(order)
+
+ def normalize_backend(self, backend: Backend | str | None) -> Backend:
+ if backend is None:
+ return Backend.from_any(self.default_backend_name, runtime=self)
+ resolved = Backend.from_any(backend, runtime=self)
+ return resolved
+
+ def auto_backend_name(self) -> str:
+ if self.name == "onnx":
+ providers = _get_available_onnx_providers()
+ for backend_name, provider_name in _ONNX_AUTO_PRIORITY:
+ if provider_name in providers:
+ return backend_name
+ return self.default_backend_name
+ if self.name == "pt":
+ has_torch, has_cuda = _get_torch_capabilities()
+ if has_torch and has_cuda:
+ return "cuda"
+ return self.default_backend_name
+ if self.name == "openvino":
+ devices = _get_openvino_devices()
+ for backend_name, device_prefix in _OPENVINO_AUTO_PRIORITY:
+ if any(str(dev).startswith(device_prefix) for dev in devices):
+ return backend_name
+ return self.default_backend_name
+ return self.default_backend_name
+
+ def __str__(self) -> str: # pragma: no cover
+ return self.name
+
+
+_ONNX_BACKENDS: tuple[Backend, ...] = (
+ Backend(
+ name="cpu",
+ runtime_key="onnx",
+ providers=(ProviderSpec("CPUExecutionProvider"),),
+ description="Pure CPU execution provider.",
+ ),
+ Backend(
+ name="cuda",
+ runtime_key="onnx",
+ providers=(
+ ProviderSpec("CUDAExecutionProvider", include_device=True),
+ ProviderSpec("CPUExecutionProvider"),
+ ),
+ description="CUDA with CPU fallback.",
+ ),
+ Backend(
+ name="tensorrt",
+ runtime_key="onnx",
+ providers=(
+ ProviderSpec("TensorrtExecutionProvider", include_device=True),
+ ProviderSpec("CUDAExecutionProvider", include_device=True),
+ ProviderSpec("CPUExecutionProvider"),
+ ),
+ description="TensorRT backed by CUDA and CPU providers.",
+ ),
+ Backend(
+ name="tensorrt_rtx",
+ runtime_key="onnx",
+ providers=(
+ ProviderSpec("NvTensorRTRTXExecutionProvider", include_device=True),
+ ProviderSpec("CUDAExecutionProvider", include_device=True),
+ ProviderSpec("CPUExecutionProvider"),
+ ),
+ description="TensorRT RTX provider chain.",
+ ),
+)
+
+_OPENVINO_BACKENDS: tuple[Backend, ...] = (
+ Backend(
+ name="cpu",
+ runtime_key="openvino",
+ device="CPU",
+ description="Intel CPU device for OpenVINO.",
+ ),
+ Backend(
+ name="gpu",
+ runtime_key="openvino",
+ device="GPU",
+ description="Intel GPU device for OpenVINO.",
+ ),
+ Backend(
+ name="npu",
+ runtime_key="openvino",
+ device="NPU",
+ description="Intel NPU device for OpenVINO.",
+ ),
+)
+
+_PT_BACKENDS: tuple[Backend, ...] = (
+ Backend(
+ name="cpu",
+ runtime_key="pt",
+ device="cpu",
+ description="PyTorch CPU execution.",
+ ),
+ Backend(
+ name="cuda",
+ runtime_key="pt",
+ device="cuda",
+ description="PyTorch CUDA execution.",
+ ),
+)
+
+Runtime.onnx = Runtime(
+ name="onnx",
+ backend_names=tuple(backend.name for backend in _ONNX_BACKENDS),
+ default_backend_name="cpu",
+ description="ONNXRuntime execution backend.",
+)
+
+Runtime.openvino = Runtime(
+ name="openvino",
+ backend_names=tuple(backend.name for backend in _OPENVINO_BACKENDS),
+ default_backend_name="cpu",
+ description="Intel OpenVINO runtime.",
+)
+
+Runtime.pt = Runtime(
+ name="pt",
+ backend_names=tuple(backend.name for backend in _PT_BACKENDS),
+ default_backend_name="cpu",
+ description="TorchScript/PT runtime.",
+)
+
+# Convenience handles for legacy-style access.
+Backend.cpu = Backend.from_any("cpu", runtime="onnx")
+Backend.cuda = Backend.from_any("cuda", runtime="onnx")
+Backend.tensorrt = Backend.from_any("tensorrt", runtime="onnx")
+Backend.tensorrt_rtx = Backend.from_any("tensorrt_rtx", runtime="onnx")
+Backend.ov_cpu = Backend.from_any("cpu", runtime="openvino")
+Backend.ov_gpu = Backend.from_any("gpu", runtime="openvino")
+Backend.ov_npu = Backend.from_any("npu", runtime="openvino")
+
+
+_ONNX_AUTO_PRIORITY: tuple[tuple[str, str], ...] = (
+ # Prefer pure CUDA when available to avoid TensorRT dependency issues.
+ ("cuda", "CUDAExecutionProvider"),
+ ("tensorrt_rtx", "NvTensorRTRTXExecutionProvider"),
+ ("tensorrt", "TensorrtExecutionProvider"),
+)
+
+_OPENVINO_AUTO_PRIORITY: tuple[tuple[str, str], ...] = (
+ # Prefer GPU when available, then NPU; fall back to CPU/default otherwise.
+ ("gpu", "GPU"),
+ ("npu", "NPU"),
+)
+
+
+def _get_available_onnx_providers() -> set[str]:
+ try: # pragma: no cover - optional dependency
+ import onnxruntime as ort # type: ignore
+ except Exception:
+ return set()
+ try:
+ return set(ort.get_available_providers())
+ except Exception: # pragma: no cover - runtime query failure
+ return set()
+
+
+def _get_torch_capabilities() -> tuple[bool, bool]:
+ try: # pragma: no cover - optional dependency
+ import torch # type: ignore
+ except Exception:
+ return False, False
+ try:
+ return True, bool(torch.cuda.is_available())
+ except Exception: # pragma: no cover - runtime query failure
+ return True, False
+
+
+def _get_openvino_devices() -> set[str]:
+ try: # pragma: no cover - optional dependency
+ ov = cast(Any, import_module("openvino.runtime"))
+ except Exception:
+ return set()
+ try:
+ core = ov.Core()
+ return {str(dev) for dev in getattr(core, "available_devices", [])}
+ except Exception: # pragma: no cover - runtime query failure
+ return set()
diff --git a/capybara/structures/__init__.py b/capybara/structures/__init__.py
index 513e8a4..499894f 100644
--- a/capybara/structures/__init__.py
+++ b/capybara/structures/__init__.py
@@ -1,4 +1,35 @@
-from .boxes import *
-from .functionals import *
-from .keypoints import *
-from .polygons import *
+from __future__ import annotations
+
+from .boxes import Box, Boxes, BoxMode
+from .functionals import (
+ calc_angle,
+ is_inside_box,
+ jaccard_index,
+ pairwise_intersection,
+ pairwise_ioa,
+ pairwise_iou,
+ poly_angle,
+ polygon_iou,
+)
+from .keypoints import Keypoints, KeypointsList
+from .polygons import JOIN_STYLE, Polygon, Polygons, order_points_clockwise
+
+__all__ = [
+ "JOIN_STYLE",
+ "Box",
+ "BoxMode",
+ "Boxes",
+ "Keypoints",
+ "KeypointsList",
+ "Polygon",
+ "Polygons",
+ "calc_angle",
+ "is_inside_box",
+ "jaccard_index",
+ "order_points_clockwise",
+ "pairwise_intersection",
+ "pairwise_ioa",
+ "pairwise_iou",
+ "poly_angle",
+ "polygon_iou",
+]
diff --git a/capybara/structures/boxes.py b/capybara/structures/boxes.py
index 2735ade..30bb57e 100644
--- a/capybara/structures/boxes.py
+++ b/capybara/structures/boxes.py
@@ -1,25 +1,17 @@
+from collections.abc import Sequence
from enum import Enum, unique
-from typing import Any, List, Tuple, Union
+from typing import Any, Union, overload
from warnings import warn
import numpy as np
from ..typing import _Number
-__all__ = ['BoxMode', 'Box', 'Boxes']
+__all__ = ["Box", "BoxMode", "Boxes"]
_BoxMode = Union["BoxMode", int, str]
-_Box = Union[
- np.ndarray,
- Tuple[_Number, _Number, _Number, _Number],
- "Box"
-]
-
-_Boxes = Union[
- np.ndarray,
- List[_Box],
- "Boxes"
-]
+_Box = Union[np.ndarray, Sequence[_Number], "Box"]
+_Boxes = Union[np.ndarray, Sequence[_Box], "Boxes"]
@unique
@@ -49,7 +41,9 @@ class BoxMode(Enum):
"""
@staticmethod
- def convert(box: np.ndarray, from_mode: _BoxMode, to_mode: _BoxMode) -> np.ndarray:
+ def convert(
+ box: np.ndarray, from_mode: _BoxMode, to_mode: _BoxMode
+ ) -> np.ndarray:
"""
Convert function for box format converting
@@ -82,8 +76,9 @@ def convert(box: np.ndarray, from_mode: _BoxMode, to_mode: _BoxMode) -> np.ndarr
elif from_mode == BoxMode.CXCYWH and to_mode == BoxMode.XYWH:
arr[..., :2] -= arr[..., 2:] / 2
else:
- raise NotImplementedError(
- f"Conversion from BoxMode {str(from_mode)} to {str(to_mode)} is not supported yet")
+ raise NotImplementedError( # pragma: no cover
+ f"Conversion from BoxMode {from_mode!s} to {to_mode!s} is not supported yet"
+ )
return arr
@staticmethod
@@ -95,16 +90,15 @@ def align_code(box_mode: _BoxMode):
elif isinstance(box_mode, BoxMode):
return box_mode
else:
- raise TypeError(f'Given `box_mode` is not int, str, or BoxMode.')
+ raise TypeError("Given `box_mode` is not int, str, or BoxMode.")
class Box:
-
def __init__(
self,
array: _Box,
box_mode: _BoxMode = BoxMode.XYXY,
- is_normalized: bool = False
+ is_normalized: bool = False,
):
"""
Args:
@@ -121,12 +115,20 @@ def __init__(
self._xywh = BoxMode.convert(self._array, self.box_mode, BoxMode.XYWH)
def __repr__(self) -> str:
- return f"{self.__class__.__name__}({str(self._array)}), {str(BoxMode(self.box_mode))}"
+ return f"{self.__class__.__name__}({self._array!s}), {BoxMode(self.box_mode)!s}"
def __len__(self):
return self._array.shape[0]
- def __getitem__(self, item) -> float:
+ @overload
+ def __getitem__(self, item: int) -> float: ...
+
+ @overload
+ def __getitem__(self, item: slice) -> np.ndarray: ...
+
+ def __getitem__(self, item: int | slice) -> float | np.ndarray:
+ if isinstance(item, int):
+ return float(self._array[item])
return self._array[item]
def __eq__(self, value: object) -> bool:
@@ -135,20 +137,20 @@ def __eq__(self, value: object) -> bool:
return np.allclose(self._array, value._array)
def _check_valid_array(self, array: Any) -> np.ndarray:
- cond1 = isinstance(array, tuple) and len(array) == 4
- cond2 = isinstance(array, list) and len(array) == 4
- cond3 = isinstance(
- array, np.ndarray) and array.ndim == 1 and len(array) == 4
- cond4 = isinstance(array, self.__class__)
- if not (cond1 or cond2 or cond3 or cond4):
- raise TypeError(f'Input array must be {_Box}, but got {type(array)}.')
- if cond3:
- array = array.astype('float32')
- elif cond4:
+ if isinstance(array, Box):
array = array.numpy()
- else:
- array = np.array(array, dtype='float32')
- return array
+
+ if isinstance(array, np.ndarray):
+ if array.ndim != 1 or len(array) != 4:
+ raise TypeError(
+ f"Input array must be {_Box}, but got shape {array.shape}."
+ )
+ return array.astype("float32")
+
+ if isinstance(array, (tuple, list)) and len(array) == 4:
+ return np.array(array, dtype="float32")
+
+ raise TypeError(f"Input array must be {_Box}, but got {type(array)}.")
def convert(self, to_mode: _BoxMode) -> "Box":
"""
@@ -161,21 +163,33 @@ def convert(self, to_mode: _BoxMode) -> "Box":
Converted Box object.
"""
transed = BoxMode.convert(self._array, self.box_mode, to_mode)
- return self.__class__(transed, to_mode)
+ return self.__class__(
+ transed,
+ to_mode,
+ is_normalized=self.is_normalized,
+ )
def copy(self) -> Any:
- """ Create a copy of the Box object. """
- return self.__class__(self._array, self.box_mode)
+ """Create a copy of the Box object."""
+ return self.__class__(
+ self._array,
+ self.box_mode,
+ is_normalized=self.is_normalized,
+ )
def numpy(self) -> np.ndarray:
- """ Convert the Box object to a numpy array. """
+ """Convert the Box object to a numpy array."""
return self._array.copy()
def square(self) -> "Box":
- """ Convert the box to a square box. """
- arr = self.convert('CXCYWH').numpy()
+ """Convert the box to a square box."""
+ arr = self.convert("CXCYWH").numpy()
arr[2:] = arr[2:].min()
- return self.__class__(arr, 'CXCYWH').convert(self.box_mode)
+ return self.__class__(
+ arr,
+ "CXCYWH",
+ is_normalized=self.is_normalized,
+ ).convert(self.box_mode)
def normalize(self, w: int, h: int) -> "Box":
"""
@@ -189,7 +203,7 @@ def normalize(self, w: int, h: int) -> "Box":
Normalized Box object.
"""
if self.is_normalized:
- warn(f'Normalized box is forced to do normalization.')
+ warn("Normalized box is forced to do normalization.", stacklevel=2)
arr = self._array.copy()
arr[::2] = arr[::2] / w
arr[1::2] = arr[1::2] / h
@@ -207,7 +221,10 @@ def denormalize(self, w: int, h: int) -> "Box":
Denormalized Box object.
"""
if not self.is_normalized:
- warn(f'Non-normalized box is forced to do denormalization.')
+ warn(
+ "Non-normalized box is forced to do denormalization.",
+ stacklevel=2,
+ )
arr = self._array.copy()
arr[::2] = arr[::2] * w
arr[1::2] = arr[1::2] * h
@@ -233,7 +250,12 @@ def clip(self, xmin: int, ymin: int, xmax: int, ymax: int) -> "Box":
arr = BoxMode.convert(self._array, self.box_mode, BoxMode.XYXY)
arr[0::2] = np.clip(arr[0::2], max(xmin, 0), xmax)
arr[1::2] = np.clip(arr[1::2], max(ymin, 0), ymax)
- return self.__class__(arr, self.box_mode)
+ clipped = self.__class__(
+ arr,
+ BoxMode.XYXY,
+ is_normalized=self.is_normalized,
+ )
+ return clipped.convert(self.box_mode)
def shift(self, shift_x: float, shift_y: float) -> "Box":
"""
@@ -248,9 +270,18 @@ def shift(self, shift_x: float, shift_y: float) -> "Box":
"""
arr = self._xywh.copy()
arr[:2] += (shift_x, shift_y)
- return self.__class__(arr, "XYWH").convert(self.box_mode)
+ return self.__class__(
+ arr,
+ "XYWH",
+ is_normalized=self.is_normalized,
+ ).convert(self.box_mode)
- def scale(self, dsize: Tuple[int, int] = None, fx: float = None, fy: float = None) -> "Box":
+ def scale(
+ self,
+ dsize: tuple[int, int] | None = None,
+ fx: float | None = None,
+ fy: float | None = None,
+ ) -> "Box":
"""
Method to scale Box with a given scale.
@@ -282,13 +313,13 @@ def scale(self, dsize: Tuple[int, int] = None, fx: float = None, fy: float = Non
arr[3] += dy
else:
if fx is not None:
- fx = arr[2] * (fx - 1)
- arr[0] -= fx / 2
- arr[2] += fx
+ delta_x = arr[2] * (fx - 1)
+ arr[0] -= delta_x / 2
+ arr[2] += delta_x
if fy is not None:
- fy = arr[3] * (fy - 1)
- arr[1] -= fy / 2
- arr[3] += fy
+ delta_y = arr[3] * (fy - 1)
+ arr[1] -= delta_y / 2
+ arr[3] += delta_y
return self.__class__(arr, "XYWH").convert(self.box_mode)
@@ -296,16 +327,17 @@ def to_list(self) -> list:
return self._array.tolist()
def tolist(self) -> list:
- """ Alias of `to_list` (numpy style) """
+ """Alias of `to_list` (numpy style)"""
return self.to_list()
def to_polygon(self):
from .polygons import Polygon
+
arr = self._xywh.copy()
if (arr[2:] <= 0).any():
raise ValueError(
- 'Some element in Box has invaild value, which width or '
- 'height is smaller than zero or other unexpected reasons.'
+ "Some element in Box has invaild value, which width or "
+ "height is smaller than zero or other unexpected reasons."
)
p1 = arr[:2]
p2 = np.stack([arr[0::2].sum(), arr[1]])
@@ -313,59 +345,60 @@ def to_polygon(self):
p4 = np.stack([arr[0], arr[1::2].sum()])
return Polygon(np.stack([p1, p2, p3, p4]), self.is_normalized)
- @ property
+ @property
def width(self) -> np.ndarray:
- """ Get width of the box. """
+ """Get width of the box."""
return self._xywh[2]
- @ property
+ @property
def height(self) -> np.ndarray:
- """ Get height of the box. """
+ """Get height of the box."""
return self._xywh[3]
- @ property
+ @property
def left_top(self) -> np.ndarray:
- """ Get the left-top point of the box. """
+ """Get the left-top point of the box."""
return self._xywh[0:2]
- @ property
+ @property
def right_bottom(self) -> np.ndarray:
- """ Get the right-bottom point of the box. """
+ """Get the right-bottom point of the box."""
return self._xywh[0:2] + self._xywh[2:4]
- @ property
+ @property
def left_bottom(self) -> np.ndarray:
- """ Get the left_bottom point of the box. """
- return self._xywh[0:2] + [0, self._xywh[3]]
+ """Get the left_bottom point of the box."""
+ xywh = np.asarray(self._xywh)
+ return xywh[0:2] + np.array([0, xywh[3]], dtype=xywh.dtype)
- @ property
+ @property
def right_top(self) -> np.ndarray:
- """ Get the right_top point of the box. """
- return self._xywh[0:2] + [self._xywh[2], 0]
+ """Get the right_top point of the box."""
+ xywh = np.asarray(self._xywh)
+ return xywh[0:2] + np.array([xywh[2], 0], dtype=xywh.dtype)
- @ property
+ @property
def area(self) -> np.ndarray:
- """ Get the area of the boxes. """
+ """Get the area of the boxes."""
return self._xywh[2] * self._xywh[3]
- @ property
+ @property
def aspect_ratio(self) -> np.ndarray:
- """ Compute the aspect ratios (widths / heights) of the boxes. """
+ """Compute the aspect ratios (widths / heights) of the boxes."""
return self._xywh[2] / self._xywh[3]
- @ property
+ @property
def center(self) -> np.ndarray:
- """ Compute the center of the box. """
+ """Compute the center of the box."""
return self._xywh[:2] + self._xywh[2:] / 2
class Boxes:
-
def __init__(
self,
array: _Boxes,
box_mode: _BoxMode = BoxMode.XYXY,
- is_normalized: bool = False
+ is_normalized: bool = False,
):
self.box_mode = BoxMode.align_code(box_mode)
self.is_normalized = is_normalized
@@ -373,16 +406,33 @@ def __init__(
self._xywh = BoxMode.convert(self._array, box_mode, BoxMode.XYWH)
def __repr__(self) -> str:
- return f"{self.__class__.__name__}({str(self._array)}), {str(BoxMode(self.box_mode))}"
+ return f"{self.__class__.__name__}({self._array!s}), {BoxMode(self.box_mode)!s}"
def __len__(self):
return self._array.shape[0]
+ @overload
+ def __getitem__(self, item: int) -> "Box": ...
+
+ @overload
+ def __getitem__(self, item: list[int] | slice | np.ndarray) -> "Boxes": ...
+
def __getitem__(self, item) -> Union["Box", "Boxes"]:
if isinstance(item, int):
- return Box(self._array[item], self.box_mode, is_normalized=self.is_normalized)
+ return Box(
+ self._array[item],
+ self.box_mode,
+ is_normalized=self.is_normalized,
+ )
if isinstance(item, (list, slice, np.ndarray)):
- return self.__class__(self._array[item], self.box_mode, is_normalized=self.is_normalized)
+ return self.__class__(
+ self._array[item],
+ self.box_mode,
+ is_normalized=self.is_normalized,
+ )
+ raise TypeError(
+ "Boxes indices must be int, slice, list[int], or numpy array."
+ )
def __iter__(self) -> Any:
for i in range(len(self)):
@@ -395,22 +445,34 @@ def __eq__(self, value: object) -> bool:
def _check_valid_array(self, array: Any) -> np.ndarray:
cond1 = isinstance(array, list)
- cond2 = isinstance(
- array, np.ndarray) and array.ndim == 2 and array.shape[-1] == 4
- cond3 = isinstance(
- array, np.ndarray) and array.ndim == 1 and len(array) == 0
+ cond2 = (
+ isinstance(array, np.ndarray)
+ and array.ndim == 2
+ and array.shape[-1] == 4
+ )
+ cond3 = (
+ isinstance(array, np.ndarray)
+ and array.ndim == 1
+ and len(array) == 0
+ )
cond4 = isinstance(array, self.__class__)
if not (cond1 or cond2 or cond3 or cond4):
- raise TypeError(f'Input array must be {_Boxes}.')
+ raise TypeError(f"Input array must be {_Boxes}.")
if cond1:
for i, x in enumerate(array):
try:
- array[i] = Box(x, box_mode=self.box_mode, is_normalized=self.is_normalized).numpy()
- except TypeError:
- raise TypeError(f'Input array[{i}] must be {_Box}.')
+ array[i] = Box(
+ x,
+ box_mode=self.box_mode,
+ is_normalized=self.is_normalized,
+ ).numpy()
+ except TypeError as exc:
+ raise TypeError(
+ f"Input array[{i}] must be {_Box}."
+ ) from exc
if cond4:
array = [box.convert(self.box_mode).numpy() for box in array]
- array = np.array(array, dtype='float32').copy()
+ array = np.array(array, dtype="float32").copy()
return array
def convert(self, to_mode: _BoxMode) -> "Boxes":
@@ -424,20 +486,34 @@ def convert(self, to_mode: _BoxMode) -> "Boxes":
Converted Box object.
"""
transed = BoxMode.convert(self._array, self.box_mode, to_mode)
- return self.__class__(transed, to_mode)
+ return self.__class__(
+ transed,
+ to_mode,
+ is_normalized=self.is_normalized,
+ )
def copy(self) -> Any:
- """ Create a copy of the Box object. """
- return self.__class__(self._array, self.box_mode)
+ """Create a copy of the Box object."""
+ return self.__class__(
+ self._array,
+ self.box_mode,
+ is_normalized=self.is_normalized,
+ )
def numpy(self) -> np.ndarray:
- """ Convert the Box object to a numpy array. """
+ """Convert the Box object to a numpy array."""
return self._array.copy()
def square(self) -> "Boxes":
- arr = self.convert('CXCYWH').numpy()
- arr[..., 2:] = arr[..., 2:].max(1)
- return self.__class__(arr, 'CXCYWH').convert(self.box_mode)
+ arr = self.convert("CXCYWH").numpy()
+ # Use per-box maximum side length, keeping each box centered.
+ side = arr[..., 2:].max(axis=1, keepdims=True)
+ arr[..., 2:] = side
+ return self.__class__(
+ arr,
+ "CXCYWH",
+ is_normalized=self.is_normalized,
+ ).convert(self.box_mode)
def normalize(self, w: int, h: int) -> "Boxes":
"""
@@ -451,7 +527,7 @@ def normalize(self, w: int, h: int) -> "Boxes":
Normalized Box object.
"""
if self.is_normalized:
- warn(f'Normalized box is forced to do normalization.')
+ warn("Normalized box is forced to do normalization.", stacklevel=2)
arr = self._array.copy()
arr[:, ::2] = arr[:, ::2] / w
arr[:, 1::2] = arr[:, 1::2] / h
@@ -469,13 +545,16 @@ def denormalize(self, w: int, h: int) -> "Boxes":
Denormalized Boxes object.
"""
if not self.is_normalized:
- warn(f'Non-normalized box is forced to do denormalization.')
+ warn(
+ "Non-normalized box is forced to do denormalization.",
+ stacklevel=2,
+ )
arr = self._array.copy()
arr[:, ::2] = arr[:, ::2] * w
arr[:, 1::2] = arr[:, 1::2] * h
return self.__class__(arr, self.box_mode, is_normalized=False)
- def clip(self, xmin: int, ymin: int, xmax: int, ymax: int) -> "Box":
+ def clip(self, xmin: int, ymin: int, xmax: int, ymax: int) -> "Boxes":
"""
Method to clip the box by limiting x coordinates to the range [xmin, xmax]
and y coordinates to the range [ymin, ymax].
@@ -495,7 +574,12 @@ def clip(self, xmin: int, ymin: int, xmax: int, ymax: int) -> "Box":
arr = BoxMode.convert(self._array, self.box_mode, BoxMode.XYXY)
arr[:, 0::2] = np.clip(arr[:, 0::2], max(xmin, 0), xmax)
arr[:, 1::2] = np.clip(arr[:, 1::2], max(ymin, 0), ymax)
- return self.__class__(arr, self.box_mode)
+ clipped = self.__class__(
+ arr,
+ BoxMode.XYXY,
+ is_normalized=self.is_normalized,
+ )
+ return clipped.convert(self.box_mode)
def shift(self, shift_x: float, shift_y: float) -> "Boxes":
"""
@@ -510,9 +594,18 @@ def shift(self, shift_x: float, shift_y: float) -> "Boxes":
"""
arr = self._xywh.copy()
arr[:, :2] += (shift_x, shift_y)
- return self.__class__(arr, "XYWH").convert(self.box_mode)
+ return self.__class__(
+ arr,
+ "XYWH",
+ is_normalized=self.is_normalized,
+ ).convert(self.box_mode)
- def scale(self, dsize: Tuple[int, int] = None, fx: float = None, fy: float = None) -> "Boxes":
+ def scale(
+ self,
+ dsize: tuple[int, int] | None = None,
+ fx: float | None = None,
+ fy: float | None = None,
+ ) -> "Boxes":
"""
Method to scale Box with a given scale.
@@ -544,38 +637,42 @@ def scale(self, dsize: Tuple[int, int] = None, fx: float = None, fy: float = Non
arr[:, 3] += dy
else:
if fx is not None:
- fx = arr[:, 2] * (fx - 1)
- arr[:, 0] -= fx / 2
- arr[:, 2] += fx
+ delta_x = arr[:, 2] * (fx - 1)
+ arr[:, 0] -= delta_x / 2
+ arr[:, 2] += delta_x
if fy is not None:
- fy = arr[3] * (fy - 1)
- arr[:, 1] -= fy / 2
- arr[:, 3] += fy
+ delta_y = arr[:, 3] * (fy - 1)
+ arr[:, 1] -= delta_y / 2
+ arr[:, 3] += delta_y
return self.__class__(arr, "XYWH").convert(self.box_mode)
def get_empty_index(self) -> np.ndarray:
- """ Get the index of empty boxes. """
+ """Get the index of empty boxes."""
return np.where((self._xywh[:, 2] <= 0) | (self._xywh[:, 3] <= 0))[0]
def drop_empty(self) -> "Boxes":
- """ Drop the empty boxes. """
- return self.__class__(self._array[(self._xywh[:, 2] > 0) & (self._xywh[:, 3] > 0)], self.box_mode)
+ """Drop the empty boxes."""
+ return self.__class__(
+ self._array[(self._xywh[:, 2] > 0) & (self._xywh[:, 3] > 0)],
+ self.box_mode,
+ )
def to_list(self) -> list:
return self._array.tolist()
def tolist(self) -> list:
- """ Alias of `to_list` (numpy style) """
+ """Alias of `to_list` (numpy style)"""
return self.to_list()
def to_polygons(self):
from .polygons import Polygons
+
arr = self._xywh.copy()
if (arr[:, 2:] <= 0).any():
raise ValueError(
- 'Some element in Boxes has invaild value, which width or '
- 'height is smaller than zero or other unexpected reasons.'
+ "Some element in Boxes has invaild value, which width or "
+ "height is smaller than zero or other unexpected reasons."
)
p1 = arr[:, :2]
@@ -584,37 +681,37 @@ def to_polygons(self):
p4 = np.stack([arr[:, 0], arr[:, 1::2].sum(1)], axis=1)
return Polygons(np.stack([p1, p2, p3, p4], axis=1), self.is_normalized)
- @ property
+ @property
def width(self) -> np.ndarray:
- """ Get width of the box. """
+ """Get width of the box."""
return self._xywh[:, 2]
- @ property
+ @property
def height(self) -> np.ndarray:
- """ Get height of the box. """
+ """Get height of the box."""
return self._xywh[:, 3]
- @ property
+ @property
def left_top(self) -> np.ndarray:
- """ Get the left-top point of the box. """
+ """Get the left-top point of the box."""
return self._xywh[:, :2]
- @ property
+ @property
def right_bottom(self) -> np.ndarray:
- """ Get the right-bottom point of the box. """
+ """Get the right-bottom point of the box."""
return self._xywh[:, :2] + self._xywh[:, 2:4]
- @ property
+ @property
def area(self) -> np.ndarray:
- """ Get the area of the boxes. """
+ """Get the area of the boxes."""
return self._xywh[:, 2] * self._xywh[:, 3]
- @ property
+ @property
def aspect_ratio(self) -> np.ndarray:
- """ Compute the aspect ratios (widths / heights) of the boxes. """
+ """Compute the aspect ratios (widths / heights) of the boxes."""
return self._xywh[:, 2] / self._xywh[:, 3]
- @ property
+ @property
def center(self) -> np.ndarray:
- """ Compute the center of the box. """
+ """Compute the center of the box."""
return self._xywh[:, :2] + self._xywh[:, 2:] / 2
diff --git a/capybara/structures/functionals.py b/capybara/structures/functionals.py
index acc880b..091495c 100644
--- a/capybara/structures/functionals.py
+++ b/capybara/structures/functionals.py
@@ -1,5 +1,3 @@
-from typing import Optional, Tuple, Union
-
import cv2
import numpy as np
from shapely.geometry import Polygon as ShapelyPolygon
@@ -9,9 +7,14 @@
from .polygons import Polygon
__all__ = [
- 'pairwise_intersection', 'pairwise_iou', 'pairwise_ioa',
- 'jaccard_index', 'polygon_iou', 'is_inside_box', 'calc_angle',
- 'poly_angle',
+ "calc_angle",
+ "is_inside_box",
+ "jaccard_index",
+ "pairwise_intersection",
+ "pairwise_ioa",
+ "pairwise_iou",
+ "poly_angle",
+ "polygon_iou",
]
@@ -27,10 +30,10 @@ def pairwise_intersection(boxes1: Boxes, boxes2: Boxes) -> np.ndarray:
ndarray: intersection, sized [N, M].
"""
if not isinstance(boxes1, Boxes) or not isinstance(boxes2, Boxes):
- raise TypeError(f'Input type of boxes1 and boxes2 must be Boxes.')
+ raise TypeError("Input type of boxes1 and boxes2 must be Boxes.")
- boxes1_ = boxes1.convert('XYXY').numpy()
- boxes2_ = boxes2.convert('XYXY').numpy()
+ boxes1_ = boxes1.convert("XYXY").numpy()
+ boxes2_ = boxes2.convert("XYXY").numpy()
lt = np.maximum(boxes1_[:, None, :2], boxes2_[:, :2])
rb = np.minimum(boxes1_[:, None, 2:], boxes2_[:, 2:])
width_height = (rb - lt).clip(min=0)
@@ -51,12 +54,12 @@ def pairwise_iou(boxes1: Boxes, boxes2: Boxes) -> np.ndarray:
ndarray: IoU, sized [N,M].
"""
if not isinstance(boxes1, Boxes) or not isinstance(boxes2, Boxes):
- raise TypeError(f'Input type of boxes1 and boxes2 must be Boxes.')
+ raise TypeError("Input type of boxes1 and boxes2 must be Boxes.")
if np.any(boxes1._xywh[:, 2:] <= 0) or np.any(boxes2._xywh[:, 2:] <= 0):
raise ValueError(
- 'Some boxes in Boxes has invaild value, which width or '
- 'height is smaller than zero or other unexpected reasons, '
+ "Some boxes in Boxes has invaild value, which width or "
+ "height is smaller than zero or other unexpected reasons, "
'try to run "drop_empty()" at first.'
)
@@ -77,12 +80,12 @@ def pairwise_ioa(boxes1: Boxes, boxes2: Boxes) -> np.ndarray:
ndarray: IoA, sized [N,M].
"""
if not isinstance(boxes1, Boxes) or not isinstance(boxes2, Boxes):
- raise TypeError(f'Input type of boxes1 and boxes2 must be Boxes.')
+ raise TypeError("Input type of boxes1 and boxes2 must be Boxes.")
if np.any(boxes1._xywh[:, 2:] <= 0) or np.any(boxes2._xywh[:, 2:] <= 0):
raise ValueError(
- 'Some boxes in Boxes has invaild value, which width or '
- 'height is smaller than zero or other unexpected reasons, '
+ "Some boxes in Boxes has invaild value, which width or "
+ "height is smaller than zero or other unexpected reasons, "
'try to run "drop_empty()" at first.'
)
@@ -95,7 +98,7 @@ def pairwise_ioa(boxes1: Boxes, boxes2: Boxes) -> np.ndarray:
def jaccard_index(
pred_poly: np.ndarray,
gt_poly: np.ndarray,
- image_size: Tuple[int, int],
+ image_size: tuple[int, int],
) -> float:
"""
Reference : https://github.com/jchazalon/smartdoc15-ch1-eval
@@ -114,29 +117,26 @@ def jaccard_index(
"""
if pred_poly.shape != (4, 2) or gt_poly.shape != (4, 2):
- raise ValueError(f'Input polygon must be 4-point polygon.')
+ raise ValueError("Input polygon must be 4-point polygon.")
if image_size is None:
- raise ValueError(f'Input image size must be provided.')
+ raise ValueError("Input image size must be provided.")
pred_poly = pred_poly.astype(np.float32)
gt_poly = gt_poly.astype(np.float32)
height, width = image_size
- object_coord_target = np.array([
- [0, 0],
- [width, 0],
- [width, height],
- [0, height]]
+ object_coord_target = np.array(
+ [[0, 0], [width, 0], [width, height], [0, height]]
).astype(np.float32)
- M = cv2.getPerspectiveTransform(
+ matrix = cv2.getPerspectiveTransform(
gt_poly.reshape(-1, 1, 2),
object_coord_target[None, ...],
)
transformed_pred_coords = cv2.perspectiveTransform(
- pred_poly.reshape(-1, 1, 2), M
+ pred_poly.reshape(-1, 1, 2), matrix
)
try:
@@ -157,10 +157,10 @@ def jaccard_index(
area_inter = area_min
jaccard_index = area_inter / area_union
- except:
- # 通常錯誤來自於:
+ except Exception:
+ # 通常錯誤來自於:
# TopologyException: Input geom 1 is invalid: Ring Self-intersection
- # 表示多邊形自己交叉了,這時候就直接給 0
+ # 表示多邊形自己交叉了, 這時候就直接給 0
jaccard_index = 0
return jaccard_index
@@ -177,18 +177,18 @@ def polygon_iou(poly1: Polygon, poly2: Polygon):
float: IoU.
"""
if not isinstance(poly1, Polygon) or not isinstance(poly2, Polygon):
- raise TypeError(f'Input type of poly1 and poly2 must be Polygon.')
+ raise TypeError("Input type of poly1 and poly2 must be Polygon.")
- poly1 = poly1.numpy().astype(np.float32)
- poly2 = poly2.numpy().astype(np.float32)
+ poly1_arr = poly1.numpy().astype(np.float32)
+ poly2_arr = poly2.numpy().astype(np.float32)
try:
- poly1 = ShapelyPolygon(poly1)
- poly2 = ShapelyPolygon(poly2)
- poly_inter = poly1 & poly2
+ poly1_shape = ShapelyPolygon(poly1_arr)
+ poly2_shape = ShapelyPolygon(poly2_arr)
+ poly_inter = poly1_shape.intersection(poly2_shape)
- area_target = poly1.area
- area_test = poly2.area
+ area_target = poly1_shape.area
+ area_test = poly2_shape.area
area_inter = poly_inter.area
area_union = area_test + area_target - area_inter
@@ -200,16 +200,16 @@ def polygon_iou(poly1: Polygon, poly2: Polygon):
area_inter = area_min
iou = area_inter / area_union
- except:
- # 通常錯誤來自於:
+ except Exception:
+ # 通常錯誤來自於:
# TopologyException: Input geom 1 is invalid: Ring Self-intersection
- # 表示多邊形自己交叉了,這時候就直接給 0
+ # 表示多邊形自己交叉了, 這時候就直接給 0
iou = 0
return iou
-def is_inside_box(x: Union[Box, Keypoints, Polygon], box: Box) -> np.bool_:
+def is_inside_box(x: Box | Keypoints | Polygon, box: Box) -> np.bool_:
cond1 = x._array >= box.left_top
cond2 = x._array <= box.right_bottom
return np.concatenate((cond1, cond2), axis=-1).all()
@@ -238,8 +238,8 @@ def calc_angle(v1, v2):
def poly_angle(
poly1: Polygon,
- poly2: Optional[Polygon] = None,
- base_vector: Tuple[int, int] = (0, 1)
+ poly2: Polygon | None = None,
+ base_vector: tuple[int, int] = (0, 1),
) -> float:
"""
Calculate the angle between two polygons or a polygon and a base vector.
@@ -252,7 +252,10 @@ def _get_angle(poly):
return vector1 + vector2
v1 = _get_angle(poly1)
- v2 = _get_angle(poly2) if poly2 is not None else np.array(
- base_vector, dtype='float32')
+ v2 = (
+ _get_angle(poly2)
+ if poly2 is not None
+ else np.array(base_vector, dtype="float32")
+ )
return calc_angle(v1, v2)
diff --git a/capybara/structures/keypoints.py b/capybara/structures/keypoints.py
index ae8e5d9..f621e45 100644
--- a/capybara/structures/keypoints.py
+++ b/capybara/structures/keypoints.py
@@ -1,7 +1,8 @@
-from typing import Any, List, Tuple, Union
+import colorsys
+from collections.abc import Sequence
+from typing import Any, TypeAlias, Union
from warnings import warn
-import matplotlib
import numpy as np
from ..typing import _Number
@@ -9,17 +10,40 @@
__all__ = ["Keypoints", "KeypointsList"]
-_Keypoints = Union[
+def _colormap_bytes(cmap: str, steps: np.ndarray) -> np.ndarray:
+ try: # pragma: no cover - optional dependency
+ import matplotlib # type: ignore
+
+ try:
+ color_map = matplotlib.colormaps[cmap]
+ except Exception:
+ color_map = matplotlib.colormaps["rainbow"]
+ return np.asarray(color_map(steps, bytes=True), dtype=np.uint8)
+ except Exception:
+ out = np.empty((len(steps), 4), dtype=np.uint8)
+ for i, t in enumerate(steps):
+ hue = float(t)
+ if len(steps) > 1 and hue >= 1.0:
+ hue = 1.0 - (1.0 / len(steps))
+ r, g, b = colorsys.hsv_to_rgb(hue, 1.0, 1.0)
+ out[i, 0] = int(r * 255)
+ out[i, 1] = int(g * 255)
+ out[i, 2] = int(b * 255)
+ out[i, 3] = 255
+ return out
+
+
+_Keypoints: TypeAlias = Union[
np.ndarray,
- List[np.ndarray],
- List[Tuple[_Number, _Number]],
- List[Tuple[_Number, _Number, _Number]],
+ Sequence[np.ndarray],
+ Sequence[tuple[_Number, _Number]],
+ Sequence[tuple[_Number, _Number, _Number]],
"Keypoints",
]
-
-_KeypointsList = Union[
+_KeypointsList: TypeAlias = Union[
np.ndarray,
- List[_Keypoints],
+ Sequence[_Keypoints],
+ "KeypointsList",
]
@@ -32,18 +56,19 @@ class Keypoints:
* v=2: labeled and visible
"""
- def __init__(self, array: _Keypoints, cmap="rainbow", is_normalized: bool = False):
+ def __init__(
+ self, array: _Keypoints, cmap="rainbow", is_normalized: bool = False
+ ):
self._array = self._check_valid_array(array)
steps = np.linspace(0.0, 1.0, self._array.shape[-2])
- color_map = matplotlib.colormaps[cmap]
- self._point_colors = np.array(color_map(steps, bytes=True))[..., :3].tolist()
+ self._point_colors = _colormap_bytes(str(cmap), steps)[..., :3].tolist()
self._is_normalized = is_normalized
def __len__(self) -> int:
return self._array.shape[0]
def __repr__(self) -> str:
- return f"{self.__class__.__name__}({str(self._array)})"
+ return f"{self.__class__.__name__}({self._array!s})"
def __eq__(self, value: object) -> bool:
if not isinstance(value, self.__class__):
@@ -52,25 +77,34 @@ def __eq__(self, value: object) -> bool:
def _check_valid_array(self, array: Any) -> np.ndarray:
cond1 = isinstance(array, np.ndarray)
- cond2 = isinstance(array, list) and all(isinstance(x, (tuple, np.ndarray)) for x in array)
+ cond2 = isinstance(array, list) and all(
+ isinstance(x, (tuple, np.ndarray)) for x in array
+ )
cond3 = isinstance(array, self.__class__)
if not (cond1 or cond2 or cond3):
- raise TypeError(f"Input array is not {_Keypoints}, but got {type(array)}.")
+ raise TypeError(
+ f"Input array is not {_Keypoints}, but got {type(array)}."
+ )
- if cond3:
- array = array.numpy()
- else:
- array = np.array(array, dtype="float32")
+ array = array.numpy() if cond3 else np.array(array, dtype="float32")
if not array.ndim == 2:
- raise ValueError(f"Input array ndim = {array.ndim} is not 2, which is invalid.")
+ raise ValueError(
+ f"Input array ndim = {array.ndim} is not 2, which is invalid."
+ )
if array.shape[-1] not in [2, 3]:
- raise ValueError(f"Input array's shape[-1] = {array.shape[-1]} is not in [2, 3], which is invalid.")
+ raise ValueError(
+ f"Input array's shape[-1] = {array.shape[-1]} is not in [2, 3], which is invalid."
+ )
- if array.shape[-1] == 3 and not ((array[..., 2] <= 2).all() and (array[..., 2] >= 0).all()):
- raise ValueError("Given array is invalid because of its labels. (array[..., 2])")
+ if array.shape[-1] == 3 and not (
+ (array[..., 2] <= 2).all() and (array[..., 2] >= 0).all()
+ ):
+ raise ValueError(
+ "Given array is invalid because of its labels. (array[..., 2])"
+ )
return array.copy()
def numpy(self) -> np.ndarray:
@@ -91,7 +125,10 @@ def scale(self, fx: float, fy: float) -> "Keypoints":
def normalize(self, w: float, h: float) -> "Keypoints":
if self.is_normalized:
- warn("Normalized keypoints are forced to do normalization.")
+ warn(
+ "Normalized keypoints are forced to do normalization.",
+ stacklevel=2,
+ )
arr = self._array.copy()
arr[..., :2] = arr[..., :2] / (w, h)
kpts = self.__class__(arr)
@@ -100,7 +137,10 @@ def normalize(self, w: float, h: float) -> "Keypoints":
def denormalize(self, w: float, h: float) -> "Keypoints":
if not self.is_normalized:
- warn("Non-normalized keypoints is forced to do denormalization.")
+ warn(
+ "Non-normalized keypoints is forced to do denormalization.",
+ stacklevel=2,
+ )
arr = self._array.copy()
arr[..., :2] = arr[..., :2] * (w, h)
kpts = self.__class__(arr)
@@ -112,22 +152,29 @@ def is_normalized(self) -> bool:
return self._is_normalized
@property
- def point_colors(self) -> List[Tuple[int, int, int]]:
- return [tuple([int(x) for x in cs]) for cs in self._point_colors]
+ def point_colors(self) -> list[tuple[int, int, int]]:
+ return [
+ (int(cs[0]), int(cs[1]), int(cs[2])) for cs in self._point_colors
+ ]
- @point_colors.setter
- def set_point_colors(self, cmap: str):
+ def set_point_colors(self, cmap: str) -> None:
steps = np.linspace(0.0, 1.0, self._array.shape[-2])
- self._point_colors = matplotlib.colormaps[cmap](steps, bytes=True)
+ self._point_colors = _colormap_bytes(str(cmap), steps)[..., :3].tolist()
+
+ @point_colors.setter
+ def point_colors(self, cmap: str) -> None:
+ self.set_point_colors(cmap)
class KeypointsList:
- def __init__(self, array: _KeypointsList, cmap="rainbow", is_normalized: bool = False) -> None:
+ def __init__(
+ self, array: _KeypointsList, cmap="rainbow", is_normalized: bool = False
+ ) -> None:
self._array = self._check_valid_array(array).copy()
self._is_normalized = is_normalized
if len(self._array):
steps = np.linspace(0.0, 1.0, self._array.shape[-2])
- self._point_colors = matplotlib.colormaps[cmap](steps, bytes=True)
+ self._point_colors = _colormap_bytes(str(cmap), steps)
else:
self._point_colors = None
@@ -136,8 +183,12 @@ def __len__(self) -> int:
def __getitem__(self, item) -> Any:
if isinstance(item, int):
- return Keypoints(self._array[item], is_normalized=self.is_normalized)
- return KeypointsList(self._array[item], is_normalized=self.is_normalized)
+ return Keypoints(
+ self._array[item], is_normalized=self.is_normalized
+ )
+ return KeypointsList(
+ self._array[item], is_normalized=self.is_normalized
+ )
def __setitem__(self, item, value):
if not isinstance(value, (Keypoints, KeypointsList)):
@@ -151,7 +202,7 @@ def __iter__(self) -> Any:
yield self[i]
def __repr__(self) -> str:
- return f"{self.__class__.__name__}({str(self._array)})"
+ return f"{self.__class__.__name__}({self._array!s})"
def __eq__(self, value: object) -> bool:
if not isinstance(value, self.__class__):
@@ -161,17 +212,16 @@ def __eq__(self, value: object) -> bool:
def _check_valid_array(self, array: Any) -> np.ndarray:
cond1 = isinstance(array, np.ndarray)
cond2 = isinstance(array, list) and len(array) == 0
- cond3 = (
- isinstance(array, list)
- and (
- all(isinstance(x, (np.ndarray, Keypoints)) for x in array)
- or all(isinstance(y, tuple) for x in array for y in x)
- )
+ cond3 = isinstance(array, list) and (
+ all(isinstance(x, (np.ndarray, Keypoints)) for x in array)
+ or all(isinstance(y, tuple) for x in array for y in x)
)
cond4 = isinstance(array, self.__class__)
if not (cond1 or cond2 or cond3 or cond4):
- raise TypeError(f"Input array is not {_KeypointsList}, but got {type(array)}.")
+ raise TypeError(
+ f"Input array is not {_KeypointsList}, but got {type(array)}."
+ )
if cond4:
array = array.numpy()
@@ -184,13 +234,21 @@ def _check_valid_array(self, array: Any) -> np.ndarray:
return array
if array.ndim != 3:
- raise ValueError(f"Input array's ndim = {array.ndim} is not 3, which is invalid.")
+ raise ValueError(
+ f"Input array's ndim = {array.ndim} is not 3, which is invalid."
+ )
if array.shape[-1] not in [2, 3]:
- raise ValueError(f"Input array's shape[-1] = {array.shape[-1]} is not 2 or 3, which is invalid.")
+ raise ValueError(
+ f"Input array's shape[-1] = {array.shape[-1]} is not 2 or 3, which is invalid."
+ )
- if array.shape[-1] == 3 and not ((array[..., 2] <= 2).all() and (array[..., 2] >= 0).all()):
- raise ValueError("Given array is invalid because of its labels. (array[..., 2])")
+ if array.shape[-1] == 3 and not (
+ (array[..., 2] <= 2).all() and (array[..., 2] >= 0).all()
+ ):
+ raise ValueError(
+ "Given array is invalid because of its labels. (array[..., 2])"
+ )
return array
@@ -212,7 +270,10 @@ def scale(self, fx: float, fy: float) -> Any:
def normalize(self, w: float, h: float) -> "KeypointsList":
if self.is_normalized:
- warn("Normalized keypoints_list is forced to do normalization.")
+ warn(
+ "Normalized keypoints_list is forced to do normalization.",
+ stacklevel=2,
+ )
arr = self._array.copy()
arr[..., :2] = arr[..., :2] / (w, h)
kpts_list = self.__class__(arr)
@@ -221,7 +282,10 @@ def normalize(self, w: float, h: float) -> "KeypointsList":
def denormalize(self, w: float, h: float) -> "KeypointsList":
if not self.is_normalized:
- warn("Non-normalized box is forced to do denormalization.")
+ warn(
+ "Non-normalized box is forced to do denormalization.",
+ stacklevel=2,
+ )
arr = self._array.copy()
arr[..., :2] = arr[..., :2] * (w, h)
kpts_list = self.__class__(arr)
@@ -233,16 +297,24 @@ def is_normalized(self) -> bool:
return self._is_normalized
@property
- def point_colors(self):
- return [tuple(c) for c in self._point_colors[..., :3].tolist()]
+ def point_colors(self) -> list[tuple[int, int, int]]:
+ if self._point_colors is None:
+ return []
+ colors = self._point_colors[..., :3].tolist()
+ return [(int(c[0]), int(c[1]), int(c[2])) for c in colors]
+
+ def set_point_colors(self, cmap: str) -> None:
+ if self._point_colors is None:
+ return
+ steps = np.linspace(0.0, 1.0, self._array.shape[-2])
+ self._point_colors = _colormap_bytes(str(cmap), steps)
@point_colors.setter
- def set_point_colors(self, cmap: str):
- steps = np.linspace(0.0, 1.0, self._array.shape[-2])
- self._point_colors = matplotlib.colormaps[cmap](steps, bytes=True)
+ def point_colors(self, cmap: str) -> None:
+ self.set_point_colors(cmap)
@classmethod
- def cat(cls, keypoints_lists: List["KeypointsList"]) -> "KeypointsList":
+ def cat(cls, keypoints_lists: list["KeypointsList"]) -> "KeypointsList":
"""
Concatenates a list of KeypointsList into a single KeypointsList
@@ -260,7 +332,17 @@ def cat(cls, keypoints_lists: List["KeypointsList"]) -> "KeypointsList":
if len(keypoints_lists) == 0:
raise ValueError("Given keypoints_list is empty.")
- if not all(isinstance(keypoints_list, KeypointsList) for keypoints_list in keypoints_lists):
- raise TypeError("All type of elements in keypoints_lists must be KeypointsList.")
+ if not all(
+ isinstance(keypoints_list, KeypointsList)
+ for keypoints_list in keypoints_lists
+ ):
+ raise TypeError(
+ "All type of elements in keypoints_lists must be KeypointsList."
+ )
- return cls(np.concatenate([keypoints_list.numpy() for keypoints_list in keypoints_lists], axis=0))
+ return cls(
+ np.concatenate(
+ [keypoints_list.numpy() for keypoints_list in keypoints_lists],
+ axis=0,
+ )
+ )
diff --git a/capybara/structures/polygons.py b/capybara/structures/polygons.py
index 5105912..3365e2c 100644
--- a/capybara/structures/polygons.py
+++ b/capybara/structures/polygons.py
@@ -1,5 +1,6 @@
+from collections.abc import Sequence
from copy import deepcopy
-from typing import Any, List, Tuple, Union
+from typing import Any, Union, cast, overload
from warnings import warn
import cv2
@@ -10,26 +11,20 @@
from ..typing import _Number
__all__ = [
- 'Polygon', 'Polygons', 'order_points_clockwise', 'JOIN_STYLE',
+ "JOIN_STYLE",
+ "Polygon",
+ "Polygons",
+ "order_points_clockwise",
]
-_Polygon = Union[
- np.ndarray,
- List[Tuple[_Number, _Number]],
- "Polygon"
-]
-
-_Polygons = Union[
- np.ndarray,
- List["Polygon"],
- List[np.ndarray],
- List[List[Tuple[_Number, _Number]]],
- "Polygons"
-]
+_Polygon = Union[np.ndarray, Sequence[Sequence[_Number]], "Polygon"]
+_Polygons = Union[np.ndarray, Sequence[_Polygon], "Polygons"]
-def order_points_clockwise(pts: np.ndarray, inverse: bool = False) -> np.ndarray:
- """ Order the 4 points clockwise.
+def order_points_clockwise(
+ pts: np.ndarray, inverse: bool = False
+) -> np.ndarray:
+ """Order the 4 points clockwise.
Args:
pts (np.ndarray):
@@ -45,7 +40,7 @@ def order_points_clockwise(pts: np.ndarray, inverse: bool = False) -> np.ndarray
"""
if pts.shape != (4, 2):
- raise ValueError('Input array `pts` must be of shape (4, 2).')
+ raise ValueError("Input array `pts` must be of shape (4, 2).")
x_sorted = pts[np.argsort(pts[:, 0]), :]
left_most = x_sorted[:2, :]
@@ -69,11 +64,7 @@ class Polygon:
has shape (K, 2) where K is the number of points per instance.
"""
- def __init__(
- self,
- array: _Polygon,
- is_normalized: bool = False
- ):
+ def __init__(self, array: _Polygon, is_normalized: bool = False):
"""
Args:
array (Union[np.ndarray, list]): A Nx2 or Nx1x2 matrix.
@@ -82,12 +73,18 @@ def __init__(
self.is_normalized = is_normalized
def __repr__(self) -> str:
- return f"{self.__class__.__name__}({str(self._array)})"
+ return f"{self.__class__.__name__}({self._array!s})"
def __len__(self) -> int:
return self._array.shape[0]
- def __getitem__(self, item) -> float:
+ @overload
+ def __getitem__(self, item: int) -> np.ndarray: ...
+
+ @overload
+ def __getitem__(self, item: slice) -> np.ndarray: ...
+
+ def __getitem__(self, item):
return self._array[item]
def __eq__(self, value: object) -> bool:
@@ -97,32 +94,27 @@ def __eq__(self, value: object) -> bool:
def _check_valid_array(self, array: _Polygon) -> np.ndarray:
if isinstance(array, (list, tuple)):
- array = np.array(array, dtype='float32')
+ array = np.array(array, dtype="float32")
if isinstance(array, Polygon):
array = array.numpy()
- cond1 = isinstance(
- array, np.ndarray) and array.ndim == 3 and array.shape[1] == 1
- cond2 = isinstance(
- array, np.ndarray) and array.ndim == 2 and array.shape[1] == 2
- cond3 = isinstance(
- array, np.ndarray) and array.ndim == 1 and len(array) == 0
- cond4 = isinstance(
- array, np.ndarray) and array.ndim == 1 and len(array) == 2
- cond5 = isinstance(array, self.__class__)
- if not (cond1 or cond2 or cond3 or cond4 or cond5):
- raise TypeError(f'Input array must be {_Polygon}.')
- if cond3 or cond4:
+ if not isinstance(array, np.ndarray):
+ raise TypeError(f"Input array must be {_Polygon}.")
+ if array.ndim == 1 and len(array) in {0, 2}:
array = array[None]
- if cond1:
+ if array.ndim == 3 and array.shape[1] == 1:
array = np.squeeze(array, axis=1)
- return array.astype('float32')
+ if array.ndim == 2 and array.shape[1] == 2:
+ return array.astype("float32")
+ if array.ndim == 2 and array.shape == (1, 0):
+ return array.astype("float32")
+ raise TypeError(f"Input array must be {_Polygon}.")
def copy(self) -> "Polygon":
- """ Create a copy of the Polygon object. """
+ """Create a copy of the Polygon object."""
return self.__class__(self._array)
def numpy(self) -> np.ndarray:
- """ Convert the Polygon object to a numpy array. """
+ """Convert the Polygon object to a numpy array."""
return self._array.copy()
def normalize(self, w: float, h: float) -> "Polygon":
@@ -137,7 +129,10 @@ def normalize(self, w: float, h: float) -> "Polygon":
Normalized Polygon object.
"""
if self.is_normalized:
- warn(f'Normalized polygon is forced to do normalization.')
+ warn(
+ "Normalized polygon is forced to do normalization.",
+ stacklevel=2,
+ )
arr = self._array.copy()
arr = arr / (w, h)
return self.__class__(arr, is_normalized=True)
@@ -154,7 +149,10 @@ def denormalize(self, w: float, h: float) -> "Polygon":
Denormalized Polygon object.
"""
if not self.is_normalized:
- warn(f'Non-normalized polygon is forced to do denormalization.')
+ warn(
+ "Non-normalized polygon is forced to do denormalization.",
+ stacklevel=2,
+ )
arr = self._array.copy()
arr = arr * (w, h)
return self.__class__(arr, is_normalized=False)
@@ -198,7 +196,7 @@ def shift(self, shift_x: float, shift_y: float) -> "Polygon":
def scale(
self,
distance: int,
- join_style: JOIN_STYLE = JOIN_STYLE.mitre
+ join_style: int = JOIN_STYLE.mitre,
) -> "Polygon":
"""
Returns an approximate representation of all points within a given distance
@@ -211,14 +209,15 @@ def scale(
These values are also enumerated by the object shapely.geometry.JOIN_STYLE
"""
poly = _Polygon_shapely(self._array).buffer(
- distance, join_style=join_style)
+ distance, join_style=cast(Any, join_style)
+ )
if isinstance(poly, MultiPolygon):
poly = max(poly.geoms, key=lambda p: p.area)
if isinstance(poly, _Polygon_shapely) and not poly.exterior.is_empty:
pts = np.zeros_like(self._array)
- for x, y in zip(*poly.exterior.xy):
+ for x, y in zip(*poly.exterior.xy, strict=True):
pt = np.array([x, y])
dist = np.linalg.norm(pt - self._array, axis=1)
pts[dist.argmin()] = pt
@@ -237,15 +236,18 @@ def to_convexhull(self) -> "Polygon":
return self.__class__(hull)
def to_min_boxpoints(self) -> "Polygon":
- """ Converts polygon to the min area bounding box. """
+ """Converts polygon to the min area bounding box."""
min_box = cv2.boxPoints(self.min_box).round(4)
min_box = order_points_clockwise(np.array(min_box))
return self.__class__(min_box)
- def to_box(self, box_mode: str = 'xyxy'):
- """ Converts polygon to the bounding box. """
+ def to_box(self, box_mode: str = "xyxy"):
+ """Converts polygon to the bounding box."""
from .boxes import Box
- return Box(self.boundingbox, "xywh", self.is_normalized).convert(box_mode)
+
+ return Box(self.boundingbox, "xywh", self.is_normalized).convert(
+ box_mode
+ )
def to_list(self, flatten: bool = False) -> list:
if flatten:
@@ -254,7 +256,7 @@ def to_list(self, flatten: bool = False) -> list:
return self._array.tolist()
def tolist(self, flatten: bool = False) -> list:
- """ Alias of `to_list` (numpy style) """
+ """Alias of `to_list` (numpy style)"""
return self.to_list(flatten=flatten)
def is_empty(self, threshold: int = 3) -> bool:
@@ -264,86 +266,95 @@ def is_empty(self, threshold: int = 3) -> bool:
"""
if not isinstance(threshold, int):
raise TypeError(
- f'Input threshold type error, expected "int", got "{type(threshold)}".')
+ f'Input threshold type error, expected "int", got "{type(threshold)}".'
+ )
return len(self) < threshold
@property
def moments(self) -> dict:
- """ Get the moment of area. """
+ """Get the moment of area."""
return cv2.moments(self._array)
@property
def area(self) -> float:
- """ Get the region area. """
- return self.moments['m00']
+ """Get the region area."""
+ return self.moments["m00"]
@property
def arclength(self) -> float:
- """ Get the region arc length. """
+ """Get the region arc length."""
return cv2.arcLength(self._array, closed=True)
@property
def centroid(self) -> np.ndarray:
- """ Get the mass centers. """
- return np.array([
- self.moments['m10'] / (self.moments['m00'] + 1e-5),
- self.moments['m01'] / (self.moments['m00'] + 1e-5)
- ])
+ """Get the mass centers."""
+ return np.array(
+ [
+ self.moments["m10"] / (self.moments["m00"] + 1e-5),
+ self.moments["m01"] / (self.moments["m00"] + 1e-5),
+ ]
+ )
@property
- def boundingbox(self) -> np.ndarray:
- """ Get the bounding box. """
+ def boundingbox(self):
+ """Get the bounding box."""
from .boxes import Box
- bbox = cv2.boundingRect(self._array)
+
+ bbox = np.array(cv2.boundingRect(self._array), dtype="float32")
if not self.is_normalized:
- bbox = bbox - np.array([0, 0, 1, 1])
- return Box(bbox, 'xywh')
+ bbox = bbox - np.array([0, 0, 1, 1], dtype="float32")
+ return Box(bbox, "xywh", is_normalized=self.is_normalized)
@property
- def min_circle(self) -> Tuple[Tuple[int, int], int]:
- """ Get the min closed circle. """
+ def min_circle(self) -> tuple[Sequence[float], float]:
+ """Get the min closed circle."""
return cv2.minEnclosingCircle(self._array)
@property
- def min_box(self) -> Tuple[Tuple[int, int], Tuple[int, int], int]:
- """ Get the min area rectangle. """
+ def min_box(
+ self,
+ ) -> tuple[Sequence[float], Sequence[float], float]:
+ """Get the min area rectangle."""
return cv2.minAreaRect(self._array)
@property
def orientation(self) -> float:
- """ Get the min area rectangle. """
+ """Get the min area rectangle."""
_, _, angle = self.min_box
return angle
@property
- def min_box_wh(self) -> Tuple[float, float]:
- """ Get the min area rectangle. """
+ def min_box_wh(self) -> tuple[float, float]:
+ """Get the min area rectangle."""
_, (w, h), _ = self.min_box
return w, h
@property
def extent(self) -> float:
- """ Ratio of pixels in the region to pixels in the total bounding box. """
+ """Ratio of pixels in the region to pixels in the total bounding box."""
_, _, w, h = self.boundingbox
return self.area / (w * h)
@property
def solidity(self) -> float:
- """ Ratio of pixels in the region to pixels of the convex hull image. """
+ """Ratio of pixels in the region to pixels of the convex hull image."""
return self.area / (self.to_convexhull().area + 1e-5)
class Polygons:
-
def __init__(self, polygons: _Polygons, is_normalized: bool = False):
- if not isinstance(polygons, (list, np.ndarray)):
+ if isinstance(polygons, Polygons):
+ polygons = polygons._polygons
+ elif not isinstance(polygons, (list, tuple, np.ndarray)):
raise TypeError(
- f'Input type error: "{polygons}", must be list or np.ndarray type.')
+ f'Input type error: "{polygons}", must be list or np.ndarray type.'
+ )
self.is_normalized = is_normalized
+ self._polygons: list[Polygon]
self._polygons = [Polygon(p, is_normalized) for p in polygons]
def __repr__(self) -> str:
- return f"{self.__class__.__name__}({str(self._polygons)})"
+ return f"{self.__class__.__name__}({self._polygons!s})"
def __len__(self) -> int:
return len(self._polygons)
@@ -355,8 +366,17 @@ def __iter__(self) -> Any:
def __eq__(self, value: object) -> bool:
if not isinstance(value, self.__class__):
return False
- is_same = sum([t == x for t, x in zip(value, self)]) == len(self)
- return is_same
+ if len(value) != len(self):
+ return False
+ return all(t == x for t, x in zip(value, self, strict=True))
+
+ @overload
+ def __getitem__(self, item: int) -> Polygon: ...
+
+ @overload
+ def __getitem__(
+ self, item: list[int] | slice | np.ndarray
+ ) -> "Polygons": ...
def __getitem__(self, item) -> Union["Polygons", "Polygon"]:
"""
@@ -383,12 +403,13 @@ def __getitem__(self, item) -> Union["Polygons", "Polygon"]:
elif isinstance(item, slice):
output = Polygons(self._polygons[item])
elif isinstance(item, np.ndarray):
- if item.dtype == 'bool':
+ if item.dtype == "bool":
item = np.argwhere(item).flatten()
output = Polygons([self._polygons[i] for i in item])
else:
raise TypeError(
- 'Input item type error, expected to be int, list, ndarray or slice.')
+ "Input item type error, expected to be int, list, ndarray or slice."
+ )
return output
def is_empty(self, threshold: int = 3) -> np.ndarray:
@@ -400,102 +421,117 @@ def to_min_boxpoints(self) -> "Polygons":
def to_convexhull(self) -> "Polygons":
return Polygons([poly.to_convexhull() for poly in self._polygons])
- def to_boxes(self, box_mode: str = 'xyxy'):
+ def to_boxes(self, box_mode: str = "xyxy"):
from .boxes import Boxes
- return Boxes(self.boundingbox, 'xywh', self.is_normalized).convert(box_mode)
+
+ return Boxes(self.boundingbox, "xywh", self.is_normalized).convert(
+ box_mode
+ )
def drop_empty(self, threshold: int = 3) -> "Polygons":
- return Polygons([p for p in self._polygons if not p.is_empty(threshold)])
+ return Polygons(
+ [p for p in self._polygons if not p.is_empty(threshold)]
+ )
def copy(self):
return self.__class__(deepcopy(self._polygons))
def normalize(self, w: float, h: float) -> "Polygons":
if self.is_normalized:
- warn(f'Normalized polygons are forced to do normalization.')
+ warn(
+ "Normalized polygons are forced to do normalization.",
+ stacklevel=2,
+ )
_polygons = [x.normalize(w, h) for x in self._polygons]
polygons = self.__class__(_polygons, is_normalized=True)
return polygons
def denormalize(self, w: float, h: float) -> "Polygons":
if not self.is_normalized:
- warn(f'Non-normalized polygons are forced to do denormalization.')
+ warn(
+ "Non-normalized polygons are forced to do denormalization.",
+ stacklevel=2,
+ )
_polygons = [x.denormalize(w, h) for x in self._polygons]
polygons = self.__class__(_polygons, is_normalized=False)
return polygons
def clip(self, xmin: int, ymin: int, xmax: int, ymax: int) -> "Polygons":
- return Polygons([p.clip(xmin, ymin, xmax, ymax) for p in self._polygons])
+ return Polygons(
+ [p.clip(xmin, ymin, xmax, ymax) for p in self._polygons]
+ )
def shift(self, shift_x: float, shift_y: float) -> "Polygons":
return Polygons([p.shift(shift_x, shift_y) for p in self._polygons])
def scale(self, distance: int) -> "Polygons":
- return Polygons([p.scale(distance) for p in self._polygons]).drop_empty()
+ return Polygons(
+ [p.scale(distance) for p in self._polygons]
+ ).drop_empty()
def numpy(self, flatten: bool = False):
len_polys = np.array([len(p) for p in self._polygons], dtype=np.int32)
if (len_polys == len_polys.mean()).all():
- return np.array(self.to_list(flatten=flatten)).astype('float32')
+ return np.array(self.to_list(flatten=flatten)).astype("float32")
else:
return np.array(self._polygons, dtype=object)
def to_list(self, flatten: bool = False) -> list:
- """ Convert boxes to list.
+ """Convert boxes to list.
- Args:
+ Args:
- is_flatten (bool):
- True -> Output format (Nx(Mx2)):
+ is_flatten (bool):
+ True -> Output format (Nx(Mx2)):
+ [
+ [p11, p12, p13, p14],
+ [p21, p22, p23, p24],
+ ...,
+ ].
+ False -> Output format (NxMx2):
+ [
[
- [p11, p12, p13, p14],
- [p21, p22, p23, p24],
- ...,
- ].
- False -> Output format (NxMx2):
+ [p11],
+ [p12],
+ [p13],
+ [p14]
+ ],
[
- [
- [p11],
- [p12],
- [p13],
- [p14]
- ],
- [
- [p21],
- [p22],
- [p23],
- [p24]
- ],
- ...,
- ].
+ [p21],
+ [p22],
+ [p23],
+ [p24]
+ ],
+ ...,
+ ].
"""
return [p.to_list(flatten) for p in self._polygons]
def tolist(self, flatten: bool = False) -> list:
- """ Alias of `to_list` (numpy style) """
+ """Alias of `to_list` (numpy style)"""
return self.to_list(flatten=flatten)
- @ property
+ @property
def moments(self) -> list:
return [poly.moments for poly in self._polygons]
- @ property
+ @property
def min_circle(self) -> list:
return [poly.min_circle for poly in self._polygons]
- @ property
+ @property
def min_box(self) -> list:
return [poly.min_box for poly in self._polygons]
- @ property
+ @property
def area(self) -> np.ndarray:
return np.array([poly.area for poly in self._polygons])
- @ property
+ @property
def arclength(self) -> np.ndarray:
return np.array([poly.arclength for poly in self._polygons])
- @ property
+ @property
def centroid(self) -> np.ndarray:
return np.array([poly.centroid for poly in self._polygons])
@@ -503,15 +539,15 @@ def centroid(self) -> np.ndarray:
def boundingbox(self) -> np.ndarray:
return np.array([poly.boundingbox for poly in self._polygons])
- @ property
+ @property
def extent(self) -> np.ndarray:
return np.array([poly.extent for poly in self._polygons])
- @ property
+ @property
def solidity(self) -> np.ndarray:
return np.array([poly.solidity for poly in self._polygons])
- @ property
+ @property
def orientation(self) -> np.ndarray:
return np.array([poly.orientation for poly in self._polygons])
@@ -524,31 +560,34 @@ def from_image(
cls,
image: np.ndarray,
mode: int = cv2.RETR_EXTERNAL,
- method: int = cv2.CHAIN_APPROX_SIMPLE
+ method: int = cv2.CHAIN_APPROX_SIMPLE,
) -> "Polygons":
if not isinstance(image, np.ndarray):
- raise TypeError('Input image must be a np.ndarray.')
+ raise TypeError("Input image must be a np.ndarray.")
contours, _ = cv2.findContours(image, mode=mode, method=method)
if len(contours) > 0:
contours = [c for c in contours if c.shape[0] > 1]
return cls(list(contours))
@classmethod
- def cat(cls, polygons_list: List["Polygons"]) -> "Polygons":
+ def cat(cls, polygons_list: list["Polygons"]) -> "Polygons":
"""
Concatenates a list of Polygon into a single Polygons.
Returns:
Polygon: the concatenated Polygon
"""
if not isinstance(polygons_list, list):
- raise TypeError('Given polygon_list should be a list.')
+ raise TypeError("Given polygon_list should be a list.")
if len(polygons_list) == 0:
- raise ValueError('Given polygon_list is empty.')
+ raise ValueError("Given polygon_list is empty.")
- if not all(isinstance(polygons, Polygons) for polygons in polygons_list):
+ if not all(
+ isinstance(polygons, Polygons) for polygons in polygons_list
+ ):
raise TypeError(
- 'All type of elements in polygon_list must be Polygon.')
+ "All type of elements in polygon_list must be Polygon."
+ )
_polygons = []
for polys in polygons_list:
diff --git a/capybara/torchengine/__init__.py b/capybara/torchengine/__init__.py
new file mode 100644
index 0000000..49b4b95
--- /dev/null
+++ b/capybara/torchengine/__init__.py
@@ -0,0 +1,3 @@
+from .engine import TorchEngine, TorchEngineConfig
+
+__all__ = ["TorchEngine", "TorchEngineConfig"]
diff --git a/capybara/torchengine/engine.py b/capybara/torchengine/engine.py
new file mode 100644
index 0000000..44e65eb
--- /dev/null
+++ b/capybara/torchengine/engine.py
@@ -0,0 +1,242 @@
+from __future__ import annotations
+
+import time
+from collections.abc import Mapping, Sequence
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any
+
+import numpy as np
+
+__all__ = ["TorchEngine", "TorchEngineConfig"]
+
+
+def _lazy_import_torch():
+ try: # pragma: no cover - optional dependency
+ import torch # type: ignore
+ except Exception as exc: # pragma: no cover - surfaced in init
+ raise ImportError(
+ "PyTorch is required to use Runtime.pt. Please install torch>=2.0."
+ ) from exc
+ return torch
+
+
+def _validate_benchmark_params(*, repeat: Any, warmup: Any) -> tuple[int, int]:
+ repeat_i = int(repeat)
+ warmup_i = int(warmup)
+ if repeat_i < 1:
+ raise ValueError("repeat must be >= 1.")
+ if warmup_i < 0:
+ raise ValueError("warmup must be >= 0.")
+ return repeat_i, warmup_i
+
+
+@dataclass(slots=True)
+class TorchEngineConfig:
+ dtype: str | Any | None = None
+ cuda_sync: bool = True
+
+
+class TorchEngine:
+ """Thin wrapper around torch.jit.ScriptModule for AngiDetection."""
+
+ def __init__(
+ self,
+ model_path: str | Path,
+ *,
+ device: str | Any = "cuda",
+ output_names: Sequence[str] | None = None,
+ config: TorchEngineConfig | None = None,
+ ) -> None:
+ self.model_path = str(model_path)
+ self._torch = _lazy_import_torch()
+ self._cfg = config or TorchEngineConfig()
+ self.device = self._normalize_device(device)
+ self.output_names = tuple(output_names or ())
+ self.dtype = self._normalize_dtype(self._cfg.dtype)
+ self._model = self._load_model()
+
+ def __call__(self, **inputs: Any) -> dict[str, np.ndarray]:
+ if len(inputs) == 1 and isinstance(
+ next(iter(inputs.values())), Mapping
+ ):
+ feed_dict = dict(next(iter(inputs.values())))
+ else:
+ feed_dict = inputs
+ return self.run(feed_dict)
+
+ def run(self, feed: Mapping[str, Any]) -> dict[str, np.ndarray]:
+ prepared = self._prepare_feed(feed)
+ outputs = self._forward(prepared)
+ return self._format_outputs(outputs)
+
+ def benchmark(
+ self,
+ inputs: Mapping[str, Any],
+ *,
+ repeat: int = 100,
+ warmup: int = 10,
+ cuda_sync: bool | None = None,
+ ) -> dict[str, Any]:
+ repeat, warmup = _validate_benchmark_params(
+ repeat=repeat, warmup=warmup
+ )
+ prepared = self._prepare_feed(inputs)
+ sync = self._should_sync(cuda_sync)
+ with self._torch.inference_mode():
+ for _ in range(warmup):
+ self._forward(prepared)
+ if sync:
+ self._sync()
+
+ latencies: list[float] = []
+ t0 = time.perf_counter()
+ with self._torch.inference_mode():
+ for _ in range(repeat):
+ if sync:
+ self._sync()
+ start = time.perf_counter()
+ self._forward(prepared)
+ if sync:
+ self._sync()
+ latencies.append((time.perf_counter() - start) * 1e3)
+ total = time.perf_counter() - t0
+ arr = np.asarray(latencies, dtype=np.float64)
+ return {
+ "repeat": repeat,
+ "warmup": warmup,
+ "throughput_fps": repeat / total if total else None,
+ "latency_ms": {
+ "mean": float(arr.mean()) if arr.size else None,
+ "median": float(np.median(arr)) if arr.size else None,
+ "p90": float(np.percentile(arr, 90)) if arr.size else None,
+ "p95": float(np.percentile(arr, 95)) if arr.size else None,
+ "min": float(arr.min()) if arr.size else None,
+ "max": float(arr.max()) if arr.size else None,
+ },
+ }
+
+ def summary(self) -> dict[str, Any]:
+ return {
+ "model": self.model_path,
+ "device": str(self.device),
+ "dtype": str(self.dtype),
+ "outputs": list(self.output_names),
+ }
+
+ # Internal helpers -----------------------------------------------------
+ def _load_model(self):
+ torch = self._torch
+ model = torch.jit.load(self.model_path, map_location=self.device)
+ model.eval()
+ with torch.no_grad():
+ model.to(self.device)
+ if self.dtype == torch.float16:
+ model.half()
+ elif self.dtype == torch.float32:
+ model.float()
+ else:
+ model.to(dtype=self.dtype)
+ return model
+
+ def _prepare_feed(self, feed: Mapping[str, Any]) -> list[Any]:
+ if not isinstance(feed, Mapping):
+ raise TypeError("TorchEngine feed must be a mapping.")
+ prepared: list[Any] = []
+ for value in feed.values():
+ tensor = self._as_tensor(value)
+ tensor = tensor.to(self.device)
+ tensor = tensor.to(self.dtype)
+ prepared.append(tensor)
+ return prepared
+
+ def _forward(self, inputs: Sequence[Any]) -> Any:
+ if len(inputs) == 1:
+ return self._model(inputs[0])
+ return self._model(*inputs)
+
+ def _format_outputs(self, outputs: Any) -> dict[str, np.ndarray]:
+ torch = self._torch
+ if isinstance(outputs, Mapping):
+ return {
+ str(key): self._tensor_to_numpy(value)
+ for key, value in outputs.items()
+ }
+ if isinstance(outputs, (list, tuple)):
+ names = self._normalize_output_names(len(outputs))
+ return {
+ names[idx]: self._tensor_to_numpy(value)
+ for idx, value in enumerate(outputs)
+ }
+ if torch.is_tensor(outputs):
+ name = self.output_names[0] if self.output_names else "output"
+ return {name: self._tensor_to_numpy(outputs)}
+ raise TypeError(
+ "Unsupported TorchScript output. Expected tensor/dict/sequence."
+ )
+
+ def _normalize_output_names(self, count: int) -> tuple[str, ...]:
+ if self.output_names:
+ if len(self.output_names) != count:
+ raise ValueError(
+ f"output_names has {len(self.output_names)} entries but "
+ f"model produced {count} outputs."
+ )
+ return self.output_names
+ return tuple(f"output_{idx}" for idx in range(count))
+
+ def _tensor_to_numpy(self, tensor: Any) -> np.ndarray:
+ torch = self._torch
+ if not torch.is_tensor(tensor):
+ raise TypeError("Model outputs must be torch.Tensor instances.")
+ array = tensor.detach().to("cpu")
+ if array.dtype != torch.float32:
+ array = array.to(torch.float32)
+ return array.contiguous().numpy()
+
+ def _as_tensor(self, value: Any):
+ torch = self._torch
+ if torch.is_tensor(value):
+ return value
+ arr = np.asarray(value, dtype=np.float32)
+ return torch.from_numpy(arr)
+
+ def _normalize_device(self, device: Any):
+ torch = self._torch
+ torch_device_type = getattr(torch, "device", None)
+ if isinstance(torch_device_type, type) and isinstance(
+ device, torch_device_type
+ ):
+ return device
+ return torch.device(device)
+
+ def _normalize_dtype(self, dtype: Any | None):
+ torch = self._torch
+ if dtype is None or (
+ isinstance(dtype, str) and dtype.strip().lower() == "auto"
+ ):
+ token = Path(self.model_path).name.lower()
+ if "fp16" in token and self._device_is_cuda():
+ return torch.float16
+ return torch.float32
+ if isinstance(dtype, torch.dtype):
+ return dtype
+ if isinstance(dtype, str):
+ normalized = dtype.strip().lower()
+ if normalized in {"fp16", "float16", "half"}:
+ return torch.float16
+ if normalized in {"fp32", "float32"}:
+ return torch.float32
+ raise ValueError(f"Unsupported dtype specification '{dtype}'.")
+
+ def _device_is_cuda(self) -> bool:
+ return getattr(self.device, "type", "") == "cuda"
+
+ def _should_sync(self, override: bool | None) -> bool:
+ if override is not None:
+ return bool(override and self._device_is_cuda())
+ return bool(self._cfg.cuda_sync and self._device_is_cuda())
+
+ def _sync(self) -> None:
+ if self._device_is_cuda():
+ self._torch.cuda.synchronize(self.device)
diff --git a/capybara/typing.py b/capybara/typing.py
index 542d665..3bfdaaa 100644
--- a/capybara/typing.py
+++ b/capybara/typing.py
@@ -1,5 +1,3 @@
-from typing import Union
-
import numpy as np
-_Number = Union[np.number, int, float]
+_Number = np.number | int | float
diff --git a/capybara/utils/__init__.py b/capybara/utils/__init__.py
index 1a31110..1f9cdc0 100644
--- a/capybara/utils/__init__.py
+++ b/capybara/utils/__init__.py
@@ -1,7 +1,21 @@
-from .custom_path import *
-from .custom_tqdm import *
-from .files_utils import *
-from .powerdict import *
-from .system_info import *
-from .time import *
-from .utils import *
+from __future__ import annotations
+
+from pathlib import Path
+
+from .custom_path import copy_path, get_curdir, rm_path
+from .powerdict import PowerDict
+from .time import Timer, now
+from .utils import colorstr, download_from_google, make_batch
+
+__all__ = [
+ "Path",
+ "PowerDict",
+ "Timer",
+ "colorstr",
+ "copy_path",
+ "download_from_google",
+ "get_curdir",
+ "make_batch",
+ "now",
+ "rm_path",
+]
diff --git a/capybara/utils/custom_path.py b/capybara/utils/custom_path.py
index 0a32200..b4a2857 100644
--- a/capybara/utils/custom_path.py
+++ b/capybara/utils/custom_path.py
@@ -1,14 +1,10 @@
import shutil
from pathlib import Path
-from typing import Union
-__all__ = ['Path', 'get_curdir', 'rm_path']
+__all__ = ["Path", "get_curdir", "rm_path"]
-def get_curdir(
- path: Union[str, Path],
- absolute: bool = True
-) -> Path:
+def get_curdir(path: str | Path, absolute: bool = True) -> Path:
"""
Function to get the path of current workspace.
@@ -23,15 +19,15 @@ def get_curdir(
return path.parent.resolve() if absolute else path.parent
-def rm_path(path: Union[str, Path]):
+def rm_path(path: str | Path):
pth = Path(path)
- if pth.is_dir():
- pth.rmdir()
- else:
- pth.unlink()
+ if pth.is_dir() and not pth.is_symlink():
+ shutil.rmtree(pth)
+ return
+ pth.unlink()
-def copy_path(path_src: Union[str, Path], path_dst: Union[str, Path]):
+def copy_path(path_src: str | Path, path_dst: str | Path):
if not Path(path_src).is_file():
raise ValueError(f'Input path: "{path_src}" is invaild.')
shutil.copy(path_src, path_dst)
diff --git a/capybara/utils/custom_tqdm.py b/capybara/utils/custom_tqdm.py
index 747bedd..67e0e56 100644
--- a/capybara/utils/custom_tqdm.py
+++ b/capybara/utils/custom_tqdm.py
@@ -1,24 +1,23 @@
from collections.abc import Sized
+from typing import Any, cast
from tqdm import tqdm as _tqdm
-__all__ = ['Tqdm']
+__all__ = ["Tqdm"]
class Tqdm(_tqdm):
-
def __init__(self, iterable=None, desc=None, smoothing=0, **kwargs):
-
- if 'total' in kwargs:
- total = kwargs.pop('total', None)
+ if "total" in kwargs:
+ total = kwargs.pop("total", None)
else:
total = len(iterable) if isinstance(iterable, Sized) else None
super().__init__(
- iterable=iterable,
+ iterable=cast(Any, iterable),
desc=desc,
total=total,
smoothing=smoothing,
dynamic_ncols=True,
- **kwargs
+ **kwargs,
)
diff --git a/capybara/utils/files_utils.py b/capybara/utils/files_utils.py
index 47821ea..31a7e35 100644
--- a/capybara/utils/files_utils.py
+++ b/capybara/utils/files_utils.py
@@ -1,7 +1,7 @@
import errno
import hashlib
import os
-from typing import Any, List, Tuple, Union
+from typing import Any
import dill
import numpy as np
@@ -13,12 +13,19 @@
from .custom_tqdm import Tqdm
__all__ = [
- 'gen_md5', 'get_files', 'load_json', 'dump_json', 'load_pickle',
- 'dump_pickle', 'load_yaml', 'dump_yaml', 'img_to_md5',
+ "dump_json",
+ "dump_pickle",
+ "dump_yaml",
+ "gen_md5",
+ "get_files",
+ "img_to_md5",
+ "load_json",
+ "load_pickle",
+ "load_yaml",
]
-def gen_md5(file: Union[str, Path], block_size: int = 256 * 128) -> str:
+def gen_md5(file: str | Path, block_size: int = 256 * 128) -> str:
"""
This function is to gen md5 based on given file.
@@ -32,9 +39,9 @@ def gen_md5(file: Union[str, Path], block_size: int = 256 * 128) -> str:
Returns:
md5 (str)
"""
- with open(str(file), 'rb') as f:
+ with open(str(file), "rb") as f:
md5 = hashlib.md5()
- for chunk in iter(lambda: f.read(block_size), b''):
+ for chunk in iter(lambda: f.read(block_size), b""):
md5.update(chunk)
return str(md5.hexdigest())
@@ -47,7 +54,7 @@ def img_to_md5(img: np.ndarray) -> str:
return str(md5_hash.hexdigest())
-def load_json(path: Union[Path, str], **kwargs) -> dict:
+def load_json(path: Path | str, **kwargs) -> dict:
"""
Function to read ujson.
@@ -57,12 +64,12 @@ def load_json(path: Union[Path, str], **kwargs) -> dict:
Returns:
dict: ujson load to dictionary
"""
- with open(str(path), 'r') as f:
+ with open(str(path)) as f:
data = ujson.load(f, **kwargs)
return data
-def dump_json(obj: Any, path: Union[str, Path] = None, **kwargs) -> None:
+def dump_json(obj: Any, path: str | Path | None = None, **kwargs) -> None:
"""
Function to write obj to ujson
@@ -71,23 +78,23 @@ def dump_json(obj: Any, path: Union[str, Path] = None, **kwargs) -> None:
path (Union[str, Path]): ujson file's path
"""
dump_options = {
- 'sort_keys': False,
- 'indent': 2,
- 'ensure_ascii': False,
- 'escape_forward_slashes': False,
+ "sort_keys": False,
+ "indent": 2,
+ "ensure_ascii": False,
+ "escape_forward_slashes": False,
}
dump_options.update(kwargs)
if path is None:
- path = Path.cwd() / 'tmp.json'
+ path = Path.cwd() / "tmp.json"
- with open(str(path), 'w') as f:
+ with open(str(path), "w") as f:
ujson.dump(obj, f, **dump_options)
def get_files(
- folder: Union[str, Path],
- suffix: Union[str, List[str], Tuple[str]] = None,
+ folder: str | Path,
+ suffix: str | list[str] | tuple[str, ...] | None = None,
recursive: bool = True,
return_pathlib: bool = True,
sort_path: bool = True,
@@ -126,25 +133,28 @@ def get_files(
folder = Path(folder)
if not folder.is_dir():
raise FileNotFoundError(
- errno.ENOENT, os.strerror(errno.ENOENT), str(folder))
+ errno.ENOENT, os.strerror(errno.ENOENT), str(folder)
+ )
if not isinstance(suffix, (str, list, tuple)) and suffix is not None:
- raise TypeError('suffix must be a string, list or tuple.')
+ raise TypeError("suffix must be a string, list or tuple.")
# checking suffix
suffix = [suffix] if isinstance(suffix, str) else suffix
if suffix is not None and ignore_letter_case:
suffix = [s.lower() for s in suffix]
- if recursive:
- files_gen = folder.rglob('*')
- else:
- files_gen = folder.glob('*')
+ files_gen = folder.rglob("*") if recursive else folder.glob("*")
files = []
for f in Tqdm(files_gen, leave=False):
- if suffix is None or (ignore_letter_case and f.suffix.lower() in suffix) \
- or (not ignore_letter_case and f.suffix in suffix):
+ if not f.is_file():
+ continue
+ if (
+ suffix is None
+ or (ignore_letter_case and f.suffix.lower() in suffix)
+ or (not ignore_letter_case and f.suffix in suffix)
+ ):
files.append(f.absolute())
if not return_pathlib:
@@ -156,7 +166,7 @@ def get_files(
return files
-def load_pickle(path: Union[str, Path]):
+def load_pickle(path: str | Path):
"""
Function to load a pickle.
@@ -166,11 +176,11 @@ def load_pickle(path: Union[str, Path]):
Returns:
loaded_pickle (dict): loaded pickle.
"""
- with open(str(path), 'rb') as f:
+ with open(str(path), "rb") as f:
return dill.load(f)
-def dump_pickle(obj, path: Union[str, Path]):
+def dump_pickle(obj, path: str | Path):
"""
Function to dump an obj to a pickle file.
@@ -178,11 +188,11 @@ def dump_pickle(obj, path: Union[str, Path]):
obj: object to be dump.
path (Union[str, Path]): file path.
"""
- with open(str(path), 'wb') as f:
+ with open(str(path), "wb") as f:
dill.dump(obj, f)
-def load_yaml(path: Union[Path, str]) -> dict:
+def load_yaml(path: Path | str) -> dict:
"""
Function to read yaml.
@@ -192,12 +202,12 @@ def load_yaml(path: Union[Path, str]) -> dict:
Returns:
dict: yaml load to dictionary
"""
- with open(str(path), 'r') as f:
+ with open(str(path)) as f:
data = yaml.load(f, Loader=yaml.FullLoader)
return data
-def dump_yaml(obj, path: Union[str, Path] = None, **kwargs):
+def dump_yaml(obj, path: str | Path | None = None, **kwargs) -> None:
"""
Function to dump an obj to a yaml file.
@@ -205,14 +215,11 @@ def dump_yaml(obj, path: Union[str, Path] = None, **kwargs):
obj: object to be dump.
path (Union[str, Path]): file path.
"""
- dump_options = {
- 'indent': 2,
- 'sort_keys': True
- }
+ dump_options = {"indent": 2, "sort_keys": True}
dump_options.update(kwargs)
if path is None:
- path = Path.cwd() / 'tmp.yaml'
+ path = Path.cwd() / "tmp.yaml"
- with open(str(path), 'w') as f:
+ with open(str(path), "w") as f:
yaml.dump(obj, f, **dump_options)
diff --git a/capybara/utils/powerdict.py b/capybara/utils/powerdict.py
index c85a08c..a71b15f 100644
--- a/capybara/utils/powerdict.py
+++ b/capybara/utils/powerdict.py
@@ -1,14 +1,22 @@
from collections.abc import Mapping
from pprint import pprint
+from typing import Any
-from .files_utils import (dump_json, dump_pickle, dump_yaml, load_json,
- load_pickle, load_yaml)
+from .files_utils import (
+ dump_json,
+ dump_pickle,
+ dump_yaml,
+ load_json,
+ load_pickle,
+ load_yaml,
+)
-__all__ = ['PowerDict']
+__all__ = ["PowerDict"]
+_MISSING = object()
-class PowerDict(dict):
+class PowerDict(dict):
def __init__(self, d=None, **kwargs):
"""
This class is used to create a namespace dictionary with freeze and melt functions.
@@ -25,12 +33,23 @@ def __init__(self, d=None, **kwargs):
self._frozen = False
+ def __getattr__(self, key: str) -> Any:
+ try:
+ return self[key]
+ except KeyError as exc:
+ raise AttributeError(key) from exc
+
def __set(self, key, value):
if not self.is_frozen:
- if isinstance(value, Mapping) and not isinstance(value, self.__class__):
+ if isinstance(value, Mapping) and not isinstance(
+ value, self.__class__
+ ):
value = self.__class__(value)
if isinstance(value, (list, tuple)):
- value = [self.__class__(v) if isinstance(v, dict) else v for v in value]
+ value = [
+ self.__class__(v) if isinstance(v, dict) else v
+ for v in value
+ ]
super().__setattr__(key, value)
super().__setitem__(key, value)
else:
@@ -44,49 +63,57 @@ def __del(self, key):
raise ValueError(f"PowerDict is frozen. '{key}' cannot be del.")
def __setattr__(self, key, value):
- if key == '_frozen':
+ if key == "_frozen":
super().__setattr__(key, value)
else:
self.__set(key, value)
def __delattr__(self, key):
- if key == '_frozen':
+ if key == "_frozen":
raise KeyError("Can not del '_frozen'.")
else:
self.__del(key)
def __setitem__(self, key, value):
- if key == '_frozen':
+ if key == "_frozen":
raise KeyError("Can not set '_frozen' as an item.")
else:
self.__set(key, value)
def __delitem__(self, key):
- if key == '_frozen':
+ if key == "_frozen":
raise KeyError("There is not _frozen in items.")
else:
self.__del(key)
def update(self, e=None, **f):
if self.is_frozen:
- raise Warning(f'PowerDict is frozen and cannot be update.')
+ raise ValueError("PowerDict is frozen. Update is not allowed.")
+
+ if e is None:
+ d: dict = {}
else:
- d = e or dict()
- d.update(f)
- for k in d:
- setattr(self, k, d[k])
+ d = dict(e)
+ d.update(f)
+ for key, value in d.items():
+ setattr(self, key, value)
- def pop(self, key, d=None):
+ def pop(self, key, default=_MISSING):
if self.is_frozen:
- raise Warning(f"PowerDict is frozen and cannot be pop.")
- else:
- d = getattr(self, key, d)
- delattr(self, key)
- return d
+ raise ValueError("PowerDict is frozen. Pop is not allowed.")
+
+ if key in self:
+ value = self[key]
+ del self[key]
+ return value
+
+ if default is _MISSING:
+ raise KeyError(key)
+ return default
def freeze(self):
self._frozen = True
- for v in self.values():
+ for v in dict.values(self):
if isinstance(v, PowerDict):
v.freeze()
if isinstance(v, (list, tuple)):
@@ -96,7 +123,7 @@ def freeze(self):
def melt(self):
self._frozen = False
- for v in self.values():
+ for v in dict.values(self):
if isinstance(v, PowerDict):
v.melt()
if isinstance(v, (list, tuple)):
@@ -106,20 +133,22 @@ def melt(self):
@property
def is_frozen(self):
- return getattr(self, '_frozen', False)
+ return getattr(self, "_frozen", False)
def __deepcopy__(self, memo):
if self._frozen:
- raise Warning('PowerDict is frozen and cannot be copy.')
+ raise Warning("PowerDict is frozen and cannot be copy.")
return self.__class__(self)
def to_dict(self):
out = {}
- for k, v in self.items():
+ for k, v in dict.items(self):
if isinstance(v, PowerDict):
out[k] = v.to_dict()
elif isinstance(v, list):
- out[k] = [x.to_dict() if isinstance(x, PowerDict) else x for x in v]
+ out[k] = [
+ x.to_dict() if isinstance(x, PowerDict) else x for x in v
+ ]
else:
out[k] = v
@@ -135,7 +164,7 @@ def to_yaml(self, path):
def to_txt(self, path):
d = self.to_dict()
- with open(path, 'w') as f:
+ with open(path, "w") as f:
pprint(d, f)
def to_pickle(self, path):
diff --git a/capybara/utils/system_info.py b/capybara/utils/system_info.py
index 3840f8e..9e88002 100644
--- a/capybara/utils/system_info.py
+++ b/capybara/utils/system_info.py
@@ -1,13 +1,18 @@
import platform
import socket
import subprocess
+from importlib import import_module
+from typing import Any, cast
-import psutil
+import psutil # type: ignore
import requests
__all__ = [
- "get_package_versions", "get_gpu_cuda_versions", "get_system_info",
- "get_cpu_info", "get_external_ip"
+ "get_cpu_info",
+ "get_external_ip",
+ "get_gpu_cuda_versions",
+ "get_package_versions",
+ "get_system_info",
]
@@ -22,28 +27,32 @@ def get_package_versions():
# PyTorch
try:
- import torch
+ import torch # type: ignore
+
versions_info["PyTorch Version"] = torch.__version__
except Exception as e:
versions_info["PyTorch Error"] = str(e)
# PyTorch Lightning
try:
- import pytorch_lightning as pl
- versions_info["PyTorch Lightning Version"] = pl.__version__
+ import pytorch_lightning as pl # type: ignore
+
+ versions_info["PyTorch Lightning Version"] = str(
+ getattr(pl, "__version__", "unknown")
+ )
except Exception as e:
versions_info["PyTorch Lightning Error"] = str(e)
# TensorFlow
try:
- import tensorflow as tf
+ tf = cast(Any, import_module("tensorflow"))
versions_info["TensorFlow Version"] = tf.__version__
except Exception as e:
versions_info["TensorFlow Error"] = str(e)
# Keras
try:
- import keras
+ keras = cast(Any, import_module("keras"))
versions_info["Keras Version"] = keras.__version__
except Exception as e:
versions_info["Keras Error"] = str(e)
@@ -51,27 +60,31 @@ def get_package_versions():
# NumPy
try:
import numpy as np
+
versions_info["NumPy Version"] = np.__version__
except Exception as e:
versions_info["NumPy Error"] = str(e)
# Pandas
try:
- import pandas as pd
+ import pandas as pd # type: ignore
+
versions_info["Pandas Version"] = pd.__version__
except Exception as e:
versions_info["Pandas Error"] = str(e)
# Scikit-learn
try:
- import sklearn
+ import sklearn # type: ignore
+
versions_info["Scikit-learn Version"] = sklearn.__version__
except Exception as e:
versions_info["Scikit-learn Error"] = str(e)
# OpenCV
try:
- import cv2
+ import cv2 # type: ignore
+
versions_info["OpenCV Version"] = cv2.__version__
except Exception as e:
versions_info["OpenCV Error"] = str(e)
@@ -93,15 +106,16 @@ def get_gpu_cuda_versions():
# Attempt to retrieve CUDA version using PyTorch
try:
- import torch
- cuda_version = torch.version.cuda
+ import torch # type: ignore
+
+ cuda_version = getattr(getattr(torch, "version", None), "cuda", None)
except ImportError:
pass
# If not retrieved via PyTorch, try using TensorFlow
if not cuda_version:
try:
- import tensorflow as tf
+ tf = cast(Any, import_module("tensorflow"))
cuda_version = tf.version.COMPILER_VERSION
except ImportError:
pass
@@ -109,25 +123,33 @@ def get_gpu_cuda_versions():
# If still not retrieved, try using CuPy
if not cuda_version:
try:
- import cupy
+ cupy = cast(Any, import_module("cupy"))
cuda_version = cupy.cuda.runtime.runtimeGetVersion()
except ImportError:
- cuda_version = "Error: None of PyTorch, TensorFlow, or CuPy are installed."
+ cuda_version = (
+ "Error: None of PyTorch, TensorFlow, or CuPy are installed."
+ )
# Try to get Nvidia driver version using nvidia-smi command
try:
- smi_output = subprocess.check_output([
- 'nvidia-smi',
- '--query-gpu=driver_version',
- '--format=csv,noheader,nounits'
- ]).decode('utf-8').strip()
- nvidia_driver_version = smi_output.split('\n')[0]
+ smi_output = (
+ subprocess.check_output(
+ [
+ "nvidia-smi",
+ "--query-gpu=driver_version",
+ "--format=csv,noheader,nounits",
+ ]
+ )
+ .decode("utf-8")
+ .strip()
+ )
+ nvidia_driver_version = smi_output.split("\n")[0]
except Exception as e:
nvidia_driver_version = f"Error getting NVIDIA driver version: {e}"
return {
"CUDA Version": cuda_version,
- "NVIDIA Driver Version": nvidia_driver_version
+ "NVIDIA Driver Version": nvidia_driver_version,
}
@@ -147,15 +169,21 @@ def get_cpu_info():
elif platform.system() == "Linux":
# For Linux
command = "cat /proc/cpuinfo | grep 'model name' | uniq"
- return subprocess.check_output(command, shell=True).strip().decode().split(":")[1].strip()
+ return (
+ subprocess.check_output(command, shell=True)
+ .strip()
+ .decode()
+ .split(":")[1]
+ .strip()
+ )
else:
return "N/A"
def get_external_ip():
try:
- response = requests.get('https://httpbin.org/ip')
- return response.json()['origin']
+ response = requests.get("https://httpbin.org/ip")
+ return response.json()["origin"]
except Exception as e:
return f"Error obtaining IP: {e}"
@@ -171,19 +199,31 @@ def get_system_info():
"OS Version": platform.platform(),
"CPU Model": get_cpu_info(),
"Physical CPU Cores": psutil.cpu_count(logical=False),
- "Logical CPU Cores (incl. hyper-threading)": psutil.cpu_count(logical=True),
- "Total RAM (GB)": round(psutil.virtual_memory().total / (1024 ** 3), 2),
- "Available RAM (GB)": round(psutil.virtual_memory().available / (1024 ** 3), 2),
- "Disk Total (GB)": round(psutil.disk_usage('/').total / (1024 ** 3), 2),
- "Disk Used (GB)": round(psutil.disk_usage('/').used / (1024 ** 3), 2),
- "Disk Free (GB)": round(psutil.disk_usage('/').free / (1024 ** 3), 2)
+ "Logical CPU Cores (incl. hyper-threading)": psutil.cpu_count(
+ logical=True
+ ),
+ "Total RAM (GB)": round(psutil.virtual_memory().total / (1024**3), 2),
+ "Available RAM (GB)": round(
+ psutil.virtual_memory().available / (1024**3), 2
+ ),
+ "Disk Total (GB)": round(psutil.disk_usage("/").total / (1024**3), 2),
+ "Disk Used (GB)": round(psutil.disk_usage("/").used / (1024**3), 2),
+ "Disk Free (GB)": round(psutil.disk_usage("/").free / (1024**3), 2),
}
# Try to fetch GPU information using nvidia-smi command
try:
- gpu_info = subprocess.check_output(
- ['nvidia-smi', '--query-gpu=name', '--format=csv,noheader,nounits']
- ).decode('utf-8').strip()
+ gpu_info = (
+ subprocess.check_output(
+ [
+ "nvidia-smi",
+ "--query-gpu=name",
+ "--format=csv,noheader,nounits",
+ ]
+ )
+ .decode("utf-8")
+ .strip()
+ )
info["GPU Info"] = gpu_info
except Exception:
info["GPU Info"] = "N/A or Error"
@@ -191,22 +231,26 @@ def get_system_info():
# Get network information
addrs = psutil.net_if_addrs()
info["IPV4 Address"] = [
- addr.address for addr in addrs.get('enp5s0', []) if addr.family == socket.AF_INET
+ addr.address
+ for addr in addrs.get("enp5s0", [])
+ if addr.family == socket.AF_INET
]
info["IPV4 Address (External)"] = get_external_ip()
# Determine platform and choose correct address family for MAC
- if hasattr(socket, 'AF_LINK'):
- AF_LINK = socket.AF_LINK
- elif hasattr(psutil, 'AF_LINK'):
- AF_LINK = psutil.AF_LINK
- else:
- raise Exception(
- "Cannot determine the correct AF_LINK value for this platform.")
+ af_link = getattr(socket, "AF_LINK", None)
+ if af_link is None:
+ af_link = getattr(psutil, "AF_LINK", None)
+ if af_link is None:
+ raise RuntimeError(
+ "Cannot determine the correct AF_LINK value for this platform."
+ )
info["MAC Address"] = [
- addr.address for addr in addrs.get('enp5s0', []) if addr.family == AF_LINK
+ addr.address
+ for addr in addrs.get("enp5s0", [])
+ if addr.family == af_link
]
return info
diff --git a/capybara/utils/time.py b/capybara/utils/time.py
index 31209cf..2937264 100644
--- a/capybara/utils/time.py
+++ b/capybara/utils/time.py
@@ -1,27 +1,26 @@
import time
from datetime import datetime
from time import struct_time
-from typing import Union
import numpy as np
from .utils import colorstr
__all__ = [
- 'Timer',
- 'now',
- 'timestamp2datetime',
- 'timestamp2time',
- 'timestamp2str',
- 'time2datetime',
- 'time2timestamp',
- 'time2str',
- 'datetime2time',
- 'datetime2timestamp',
- 'datetime2str',
- 'str2time',
- 'str2datetime',
- 'str2timestamp',
+ "Timer",
+ "datetime2str",
+ "datetime2time",
+ "datetime2timestamp",
+ "now",
+ "str2datetime",
+ "str2time",
+ "str2timestamp",
+ "time2datetime",
+ "time2str",
+ "time2timestamp",
+ "timestamp2datetime",
+ "timestamp2str",
+ "timestamp2time",
]
__doc__ = """
@@ -32,18 +31,18 @@
==========|==========================================================
Directive | Meaning
==========|==========================================================
- %a | Locale’s abbreviated weekday name.
- %A | Locale’s full weekday name.
- %b | Locale’s abbreviated month name.
- %B | Locale’s full month name.
- %c | Locale’s appropriate date and time representation.
+ %a | Locale's abbreviated weekday name.
+ %A | Locale's full weekday name.
+ %b | Locale's abbreviated month name.
+ %B | Locale's full month name.
+ %c | Locale's appropriate date and time representation.
%d | Day of the month as a decimal number [01,31].
%H | Hour (24-hour clock) as a decimal number [00,23].
%I | Hour (12-hour clock) as a decimal number [01,12].
%j | Day of the year as a decimal number [001,366].
%m | Month as a decimal number [01,12].
%M | Minute as a decimal number [00,59].
- %p | Locale’s equivalent of either AM or PM.
+ %p | Locale's equivalent of either AM or PM.
%S | Second as a decimal number [00,61].
%U | Week number of the year (Sunday as the first day of the week)
| as a decimal number [00,53]. All days in a new year preceding
@@ -52,8 +51,8 @@
%W | Week number of the year (Monday as the first day of the week)
| as a decimal number [00,53]. All days in a new year preceding
| the first Monday are considered to be in week 0.
- %x | Locale’s appropriate date representation.
- %X | Locale’s appropriate time representation.
+ %x | Locale's appropriate date representation.
+ %X | Locale's appropriate time representation.
%y | Year without century as a decimal number [00,99].
%Y | Year with century as a decimal number.
%z | Time zone offset indicating a positive or negative time difference
@@ -95,27 +94,33 @@ def testing_function(*args, **kwargs):
do something...
"""
- def __init__(self, precision: int = 5, desc: str = None, verbose: bool = False):
+ def __init__(
+ self,
+ precision: int = 5,
+ desc: str | None = None,
+ verbose: bool = False,
+ ):
self.precision = precision
self.desc = desc
self.verbose = verbose
self.__record = []
def tic(self):
- """ start timer """
+ """start timer"""
if self.desc is not None and self.verbose:
- print(colorstr(self.desc, 'yellow'))
+ print(colorstr(self.desc, "yellow"))
self.time = time.perf_counter()
def toc(self, verbose=False):
- """ get time lag from start """
- if getattr(self, 'time', None) is None:
+ """get time lag from start"""
+ if getattr(self, "time", None) is None:
raise ValueError(
- f'The timer has not been started. Tic the timer first.')
+ "The timer has not been started. Tic the timer first."
+ )
total = round(time.perf_counter() - self.time, self.precision)
if verbose or self.verbose:
- print(colorstr(f'Cost: {total} sec', 'white'))
+ print(colorstr(f"Cost: {total} sec", "white"))
self.__record.append(total)
return total
@@ -126,10 +131,12 @@ def warp(*args, **kwargs):
result = fcn(*args, **kwargs)
self.toc()
return result
+
return warp
def __enter__(self):
self.tic()
+ return self
def __exit__(self, type, value, traceback):
self.dt = self.toc(True)
@@ -137,28 +144,28 @@ def __exit__(self, type, value, traceback):
def clear_record(self):
self.__record = []
- @ property
+ @property
def mean(self):
if len(self.__record):
return np.array(self.__record).mean().round(self.precision)
- @ property
+ @property
def max(self):
if len(self.__record):
return np.array(self.__record).max().round(self.precision)
- @ property
+ @property
def min(self):
if len(self.__record):
return np.array(self.__record).min().round(self.precision)
- @ property
+ @property
def std(self):
if len(self.__record):
return np.array(self.__record).std().round(self.precision)
-def now(typ: str = 'timestamp', fmt: str = None):
+def now(typ: str = "timestamp", fmt: str | None = None):
"""
Get now time. Specify the output type of time, or give the
formatted rule to get the time string, eg: now(fmt='%Y-%m-%d').
@@ -171,82 +178,82 @@ def now(typ: str = 'timestamp', fmt: str = None):
Raises:
ValueError: Unsupported type error.
"""
- if typ == 'timestamp':
+ if typ == "timestamp":
t = time.time()
- elif typ == 'datetime':
+ elif typ == "datetime":
t = datetime.now()
- elif typ == 'time':
+ elif typ == "time":
t = time.gmtime(time.time())
else:
- raise ValueError(f'Unsupported input {typ} type of time.')
+ raise ValueError(f"Unsupported input {typ} type of time.")
- if fmt != None:
+ if fmt is not None:
t = timestamp2str(time.time(), fmt=fmt)
return t
-def timestamp2datetime(ts: Union[int, float]):
+def timestamp2datetime(ts: int | float):
return datetime.fromtimestamp(ts)
-def timestamp2time(ts: Union[int, float]):
+def timestamp2time(ts: int | float):
return time.localtime(ts)
-def timestamp2str(ts: Union[int, float], fmt: str):
+def timestamp2str(ts: int | float, fmt: str):
return time2str(timestamp2time(ts), fmt)
def time2datetime(t: struct_time):
if not isinstance(t, struct_time):
- raise TypeError(f'Input type: {type(t)} error.')
+ raise TypeError(f"Input type: {type(t)} error.")
return datetime(*t[0:6])
def time2timestamp(t: struct_time):
if not isinstance(t, struct_time):
- raise TypeError(f'Input type: {type(t)} error.')
+ raise TypeError(f"Input type: {type(t)} error.")
return time.mktime(t)
def time2str(t: struct_time, fmt: str):
if not isinstance(t, struct_time):
- raise TypeError(f'Input type: {type(t)} error.')
+ raise TypeError(f"Input type: {type(t)} error.")
return time.strftime(fmt, t)
def datetime2time(dt: datetime):
if not isinstance(dt, datetime):
- raise TypeError(f'Input type: {type(dt)} error.')
+ raise TypeError(f"Input type: {type(dt)} error.")
return dt.timetuple()
def datetime2timestamp(dt: datetime):
if not isinstance(dt, datetime):
- raise TypeError(f'Input type: {type(dt)} error.')
+ raise TypeError(f"Input type: {type(dt)} error.")
return dt.timestamp()
def datetime2str(dt: datetime, fmt: str):
if not isinstance(dt, datetime):
- raise TypeError(f'Input type: {type(dt)} error.')
+ raise TypeError(f"Input type: {type(dt)} error.")
return dt.strftime(fmt)
def str2time(s: str, fmt: str):
if not isinstance(s, str):
- raise TypeError(f'Input type: {type(s)} error.')
+ raise TypeError(f"Input type: {type(s)} error.")
return time.strptime(s, fmt)
def str2datetime(s: str, fmt: str):
if not isinstance(s, str):
- raise TypeError(f'Input type: {type(s)} error.')
+ raise TypeError(f"Input type: {type(s)} error.")
return datetime.strptime(s, fmt)
def str2timestamp(s: str, fmt: str):
if not isinstance(s, str):
- raise TypeError(f'Input type: {type(s)} error.')
+ raise TypeError(f"Input type: {type(s)} error.")
return time2timestamp(str2time(s, fmt))
diff --git a/capybara/utils/utils.py b/capybara/utils/utils.py
index e2b3984..b2782a0 100644
--- a/capybara/utils/utils.py
+++ b/capybara/utils/utils.py
@@ -1,7 +1,8 @@
-import os
import re
+from collections.abc import Generator, Iterable
+from pathlib import Path
from pprint import pprint
-from typing import Any, Generator, Iterable, List, Union
+from typing import Any, cast
import requests
from bs4 import BeautifulSoup
@@ -10,15 +11,16 @@
from ..enums import COLORSTR, FORMATSTR
__all__ = [
- 'make_batch', 'colorstr', 'pprint',
- 'download_from_google',
+ "colorstr",
+ "download_from_google",
+ "make_batch",
+ "pprint",
]
def make_batch(
- data: Union[Iterable, Generator],
- batch_size: int
-) -> Generator[List, None, None]:
+ data: Iterable | Generator, batch_size: int
+) -> Generator[list, None, None]:
"""
This function is used to make data to batched data.
@@ -41,8 +43,8 @@ def make_batch(
def colorstr(
obj: Any,
- color: Union[COLORSTR, int, str] = COLORSTR.BLUE,
- fmt: Union[FORMATSTR, int, str] = FORMATSTR.BOLD
+ color: COLORSTR | int | str = COLORSTR.BLUE,
+ fmt: FORMATSTR | int | str = FORMATSTR.BOLD,
) -> str:
"""
This function is make colorful string for python.
@@ -66,11 +68,13 @@ def colorstr(
fmt = fmt.upper()
color_code = COLORSTR.obj_to_enum(color).value
format_code = FORMATSTR.obj_to_enum(fmt).value
- color_string = f'\033[{format_code};{color_code}m{obj}\033[0m'
+ color_string = f"\033[{format_code};{color_code}m{obj}\033[0m"
return color_string
-def download_from_google(file_id: str, file_name: str, target: str = "."):
+def download_from_google(
+ file_id: str, file_name: str, target: str | Path = "."
+) -> Path:
"""
Downloads a file from Google Drive, handling potential confirmation tokens for large files.
@@ -103,16 +107,13 @@ def download_from_google(file_id: str, file_name: str, target: str = "."):
target="./downloads"
)
"""
- # 第一次嘗試:docs.google.com/uc?export=download&id=檔案ID
+ # 第一次嘗試: docs.google.com/uc?export=download&id=檔案ID
base_url = "https://docs.google.com/uc"
session = requests.Session()
- params = {
- "export": "download",
- "id": file_id
- }
+ params = {"export": "download", "id": file_id}
response = session.get(base_url, params=params, stream=True)
- # 如果已經出現 Content-Disposition,代表直接拿到檔案
+ # 如果已經出現 Content-Disposition, 代表直接拿到檔案
if "content-disposition" not in response.headers:
# 先嘗試從 cookies 拿 token
token = None
@@ -121,37 +122,44 @@ def download_from_google(file_id: str, file_name: str, target: str = "."):
token = v
break
- # 如果 cookies 沒有,就從 HTML 解析
+ # 如果 cookies 沒有, 就從 HTML 解析
if not token:
soup = BeautifulSoup(response.text, "html.parser")
- # 常見情況:HTML 裡面有一個 form#download-form
+ # 常見情況: HTML 裡面有一個 form#download-form
download_form = soup.find("form", {"id": "download-form"})
- if download_form and download_form.get("action"):
- # 將 action 裡的網址抓出來,可能是 drive.usercontent.google.com/download
- download_url = download_form["action"]
+ download_form_tag = cast(Any, download_form)
+ if download_form_tag and download_form_tag.get("action"):
+ # 將 action 裡的網址抓出來, 可能是 drive.usercontent.google.com/download
+ download_url = str(download_form_tag["action"])
# 收集所有 hidden 欄位
- hidden_inputs = download_form.find_all(
- "input", {"type": "hidden"})
+ hidden_inputs = download_form_tag.find_all(
+ "input", {"type": "hidden"}
+ )
form_params = {}
for inp in hidden_inputs:
- if inp.get("name") and inp.get("value") is not None:
- form_params[inp["name"]] = inp["value"]
+ inp_tag = cast(Any, inp)
+ name = inp_tag.get("name")
+ value = inp_tag.get("value")
+ if name and value is not None:
+ form_params[str(name)] = str(value)
# 用這些參數去重新 GET
- # 注意:原本 action 可能只是相對路徑,這裡直接用完整網址
+ # 注意: 原本 action 可能只是相對路徑, 這裡直接用完整網址
response = session.get(
- download_url, params=form_params, stream=True)
+ download_url, params=form_params, stream=True
+ )
else:
# 或者有些情況是直接在 HTML 裡 search confirm=xxx
- match = re.search(r'confirm=([0-9A-Za-z-_]+)', response.text)
+ match = re.search(r"confirm=([0-9A-Za-z-_]+)", response.text)
if match:
token = match.group(1)
# 帶上 confirm token 再重新請求 docs.google.com
params["confirm"] = token
- response = session.get(
- base_url, params=params, stream=True)
+ response = session.get(base_url, params=params, stream=True)
else:
- raise Exception("無法在回應中找到下載連結或確認參數,下載失敗。")
+ raise Exception(
+ "無法在回應中找到下載連結或確認參數, 下載失敗。"
+ )
else:
# 直接帶上 cookies 抓到的 token 再打一次
@@ -159,25 +167,30 @@ def download_from_google(file_id: str, file_name: str, target: str = "."):
response = session.get(base_url, params=params, stream=True)
# 確保下載目錄存在
- os.makedirs(target, exist_ok=True)
- file_path = os.path.join(target, file_name)
+ target_path = Path(target)
+ target_path.mkdir(parents=True, exist_ok=True)
+ file_path = target_path / file_name
- # 開始把檔案 chunk 寫到本地,附帶進度條
+ # 開始把檔案 chunk 寫到本地, 附帶進度條
try:
- total_size = int(response.headers.get('content-length', 0))
- with open(file_path, "wb") as f, tqdm(
- desc=file_name,
- total=total_size,
- unit="B",
- unit_scale=True,
- unit_divisor=1024,
- ) as bar:
+ total_size = int(response.headers.get("content-length", 0))
+ with (
+ open(file_path, "wb") as f,
+ tqdm(
+ desc=file_name,
+ total=total_size,
+ unit="B",
+ unit_scale=True,
+ unit_divisor=1024,
+ ) as bar,
+ ):
for chunk in response.iter_content(chunk_size=32768):
if chunk:
f.write(chunk)
bar.update(len(chunk))
print(f"File successfully downloaded to: {file_path}")
+ return file_path
except Exception as e:
- raise Exception(f"File download failed: {e}")
+ raise RuntimeError(f"File download failed: {e}") from e
diff --git a/capybara/vision/__init__.py b/capybara/vision/__init__.py
index 79ee22d..be3a5ef 100644
--- a/capybara/vision/__init__.py
+++ b/capybara/vision/__init__.py
@@ -1,7 +1,17 @@
-from .functionals import *
-from .geometric import *
-from .improc import *
-from .ipcam import *
-from .morphology import *
-from .videotools import *
-from .visualization import *
+from __future__ import annotations
+
+from . import (
+ functionals,
+ geometric,
+ improc,
+ morphology,
+ videotools,
+)
+
+__all__ = [
+ "functionals",
+ "geometric",
+ "improc",
+ "morphology",
+ "videotools",
+]
diff --git a/capybara/vision/functionals.py b/capybara/vision/functionals.py
index cc40437..fbdb81e 100644
--- a/capybara/vision/functionals.py
+++ b/capybara/vision/functionals.py
@@ -1,5 +1,5 @@
from bisect import bisect_left
-from typing import List, Optional, Tuple, Union
+from typing import Literal, overload
import cv2
import numpy as np
@@ -9,26 +9,31 @@
from .geometric import imresize
__all__ = [
- 'meanblur', 'gaussianblur', 'medianblur', 'imcvtcolor', 'imadjust', 'pad',
- 'imcropbox', 'imcropboxes', 'imbinarize', 'centercrop', 'imresize_and_pad_if_need'
+ "centercrop",
+ "gaussianblur",
+ "imadjust",
+ "imbinarize",
+ "imcropbox",
+ "imcropboxes",
+ "imcvtcolor",
+ "imresize_and_pad_if_need",
+ "meanblur",
+ "medianblur",
+ "pad",
]
-_Ksize = Union[int, Tuple[int, int], np.ndarray]
+_Ksize = int | tuple[int, int] | np.ndarray
-def _check_ksize(ksize: _Ksize) -> Tuple[int, int]:
+def _check_ksize(ksize: _Ksize) -> tuple[int, int]:
if isinstance(ksize, int):
- ksize = (ksize, ksize)
- elif isinstance(ksize, tuple) and len(ksize) == 2 \
- and all(isinstance(val, int) for val in ksize):
- ksize = tuple(ksize)
- elif isinstance(ksize, np.ndarray) and ksize.ndim == 0:
- ksize = (int(ksize), int(ksize))
- else:
- raise TypeError(f'The input ksize = {ksize} is invalid.')
-
- ksize = tuple(int(val) for val in ksize)
- return ksize
+ return ksize, ksize
+ if isinstance(ksize, tuple) and len(ksize) == 2:
+ return int(ksize[0]), int(ksize[1])
+ if isinstance(ksize, np.ndarray) and ksize.ndim == 0:
+ value = int(ksize)
+ return value, value
+ raise TypeError(f"The input ksize = {ksize} is invalid.")
def meanblur(img: np.ndarray, ksize: _Ksize = 3, **kwargs) -> np.ndarray:
@@ -52,7 +57,9 @@ def meanblur(img: np.ndarray, ksize: _Ksize = 3, **kwargs) -> np.ndarray:
return cv2.blur(img, ksize=ksize, **kwargs)
-def gaussianblur(img: np.ndarray, ksize: _Ksize = 3, sigmaX: int = 0, **kwargs) -> np.ndarray:
+def gaussianblur(
+ img: np.ndarray, ksize: _Ksize = 3, sigma_x: int = 0, **kwargs
+) -> np.ndarray:
"""
Apply Gaussian blur to the input image.
@@ -65,7 +72,7 @@ def gaussianblur(img: np.ndarray, ksize: _Ksize = 3, sigmaX: int = 0, **kwargs)
size will be used. If a tuple (k_height, k_width) is provided, a
rectangular kernel of the specified size will be used.
Defaults to 3.
- sigmaX (int, optional):
+ sigma_x (int, optional):
The standard deviation in the X direction for Gaussian kernel.
Defaults to 0.
@@ -73,7 +80,8 @@ def gaussianblur(img: np.ndarray, ksize: _Ksize = 3, sigmaX: int = 0, **kwargs)
np.ndarray: The blurred image.
"""
ksize = _check_ksize(ksize)
- return cv2.GaussianBlur(img, ksize=ksize, sigmaX=sigmaX, **kwargs)
+ sigma_x = int(kwargs.pop("sigmaX", sigma_x))
+ return cv2.GaussianBlur(img, ksize=ksize, sigmaX=sigma_x, **kwargs)
def medianblur(img: np.ndarray, ksize: int = 3, **kwargs) -> np.ndarray:
@@ -94,7 +102,7 @@ def medianblur(img: np.ndarray, ksize: int = 3, **kwargs) -> np.ndarray:
return cv2.medianBlur(img, ksize=ksize, **kwargs)
-def imcvtcolor(img: np.ndarray, cvt_mode: Union[int, str]) -> np.ndarray:
+def imcvtcolor(img: np.ndarray, cvt_mode: int | str) -> np.ndarray:
"""
Convert the color space of the input image.
@@ -113,18 +121,24 @@ def imcvtcolor(img: np.ndarray, cvt_mode: Union[int, str]) -> np.ndarray:
Raises:
ValueError: If the input cvt_mode is invalid or not supported.
"""
- code = getattr(cv2, f'COLOR_{cvt_mode}', None)
- if code is None:
- raise ValueError(f'Input cvt_mode: "{cvt_mode}" is invaild.')
+ if isinstance(cvt_mode, (int, np.integer)):
+ code = int(cvt_mode)
+ else:
+ mode = str(cvt_mode).upper()
+ if mode.startswith("COLOR_"):
+ mode = mode[len("COLOR_") :]
+ code = getattr(cv2, f"COLOR_{mode}", None)
+ if code is None:
+ raise ValueError(f'Input cvt_mode: "{cvt_mode}" is invaild.')
img = cv2.cvtColor(img.copy(), code)
return img
def imadjust(
img: np.ndarray,
- rng_out: Tuple[int, int] = (0, 255),
+ rng_out: tuple[int, int] = (0, 255),
gamma: float = 1.0,
- color_base: str = 'BGR'
+ color_base: str = "BGR",
) -> np.ndarray:
"""
Adjust the intensity of an image.
@@ -161,7 +175,7 @@ def imadjust(
"""
is_trans_hsv = False
if img.ndim == 3:
- img_hsv = imcvtcolor(img, f'{color_base}2HSV')
+ img_hsv = imcvtcolor(img, f"{color_base}2HSV")
v = img_hsv[..., 2]
is_trans_hsv = True
else:
@@ -171,7 +185,7 @@ def imadjust(
total = v.size
low_bound, upp_bound = total * 0.01, total * 0.99
- hist, _ = np.histogram(v.ravel(), 256, [0, 256])
+ hist, _ = np.histogram(v.ravel(), 256, (0, 256))
cdf = hist.cumsum()
rng_in = [bisect_left(cdf, low_bound), bisect_left(cdf, upp_bound)]
if (rng_in[0] == rng_in[1]) or (rng_in[1] == 0):
@@ -182,21 +196,21 @@ def imadjust(
dist_out = rng_out[1] - rng_out[0]
dst = np.clip((np.clip(v, rng_in[0], None) - rng_in[0]) / dist_in, 0, 1)
- dst = (dst ** gamma) * dist_out + rng_out[0]
- dst = np.clip(dst, rng_out[0], rng_out[1]).astype('uint8')
+ dst = (dst**gamma) * dist_out + rng_out[0]
+ dst = np.clip(dst, rng_out[0], rng_out[1]).astype("uint8")
if is_trans_hsv:
img_hsv[..., 2] = dst
- dst = imcvtcolor(img_hsv, f'HSV2{color_base}')
+ dst = imcvtcolor(img_hsv, f"HSV2{color_base}")
return dst
def pad(
img: np.ndarray,
- pad_size: Union[int, Tuple[int, int], Tuple[int, int, int, int]],
- pad_value: Optional[Union[int, Tuple[int, int, int]]] = 0,
- pad_mode: Union[str, int, BORDER] = BORDER.CONSTANT
+ pad_size: int | tuple[int, int] | tuple[int, int, int, int],
+ pad_value: int | tuple[int, ...] | None = 0,
+ pad_mode: str | int | BORDER = BORDER.CONSTANT,
) -> np.ndarray:
"""
Pad the input image with specified padding size and mode.
@@ -227,29 +241,62 @@ def pad(
np.ndarray: The padded image.
"""
if isinstance(pad_size, int):
- left = right = top = bottom = pad_size
- elif len(pad_size) == 2:
- top = bottom = pad_size[0]
- left = right = pad_size[1]
- elif len(pad_size) == 4:
- top, bottom, left, right = pad_size
+ left = right = top = bottom = int(pad_size)
+ elif isinstance(pad_size, tuple) and len(pad_size) == 2:
+ top = bottom = int(pad_size[0])
+ left = right = int(pad_size[1])
+ elif isinstance(pad_size, tuple) and len(pad_size) == 4:
+ top, bottom, left, right = (int(v) for v in pad_size)
else:
raise ValueError(
- f'pad_size is not an int, a tuple with 2 ints, or a tuple with 4 ints.')
+ "pad_size is not an int, a tuple with 2 ints, or a tuple with 4 ints."
+ )
pad_mode = BORDER.obj_to_enum(pad_mode)
+
+ if pad_value is None:
+ pad_value = 0
+
if pad_mode == BORDER.CONSTANT:
- if img.ndim == 3 and isinstance(pad_value, int):
- pad_value = (pad_value, ) * img.shape[-1]
- cond1 = img.ndim == 3 and len(pad_value) == img.shape[-1]
- cond2 = img.ndim == 2 and isinstance(pad_value, int)
- if not (cond1 or cond2):
+ if img.ndim == 2:
+ channels = 1
+ elif img.ndim == 3:
+ channels = int(img.shape[-1])
+ else:
+ raise ValueError("img must be a 2D or 3D numpy image.")
+
+ if isinstance(pad_value, int):
+ if channels == 1:
+ pad_value = int(pad_value)
+ else:
+ pad_value = tuple(int(pad_value) for _ in range(channels))
+ elif isinstance(pad_value, tuple):
+ if channels == 1:
+ if len(pad_value) == 1:
+ pad_value = int(pad_value[0])
+ else:
+ raise ValueError(
+ "pad_value must be an int when padding a grayscale image."
+ )
+ elif len(pad_value) == channels:
+ pad_value = tuple(int(v) for v in pad_value)
+ else:
+ raise ValueError(
+ f"channel of image is {channels} but pad_value is {pad_value}."
+ )
+ else:
raise ValueError(
- f'channel of image is {img.shape[-1]} but length of fill is {len(pad_value)}.')
+ "pad_value must be an int or a tuple matching channel count."
+ )
img = cv2.copyMakeBorder(
- src=img, top=top, bottom=bottom, left=left, right=right,
- borderType=pad_mode, value=pad_value
+ src=img,
+ top=top,
+ bottom=bottom,
+ left=left,
+ right=right,
+ borderType=pad_mode.value,
+ value=pad_value,
)
return img
@@ -257,7 +304,7 @@ def pad(
def imcropbox(
img: np.ndarray,
- box: Union[Box, np.ndarray],
+ box: Box | np.ndarray,
use_pad: bool = False,
) -> np.ndarray:
"""
@@ -302,7 +349,8 @@ def imcropbox(
x1, y1, x2, y2 = box.astype(int)
else:
raise TypeError(
- f'Input box is not of type Box or NumPy array with 4 elements.')
+ "Input box is not of type Box or NumPy array with 4 elements."
+ )
im_h, im_w = img.shape[:2]
crop_x1 = max(0, x1)
@@ -325,9 +373,9 @@ def imcropbox(
def imcropboxes(
img: np.ndarray,
- boxes: Union[Boxes, np.ndarray],
+ boxes: Boxes | np.ndarray,
use_pad: bool = False,
-) -> List[np.ndarray]:
+) -> list[np.ndarray]:
"""
Crop the input image using multiple boxes.
"""
@@ -335,9 +383,7 @@ def imcropboxes(
def imbinarize(
- img: np.ndarray,
- threth: int = cv2.THRESH_BINARY,
- color_base: str = 'BGR'
+ img: np.ndarray, threth: int = cv2.THRESH_BINARY, color_base: str = "BGR"
) -> np.ndarray:
"""
Function for image binarize.
@@ -367,7 +413,7 @@ def imbinarize(
Binary image.
"""
if img.ndim == 3:
- img = imcvtcolor(img, f'{color_base}2GRAY')
+ img = imcvtcolor(img, f"{color_base}2GRAY")
_, dst = cv2.threshold(img, 0, 255, type=threth + cv2.THRESH_OTSU)
return dst
@@ -383,22 +429,49 @@ def centercrop(img: np.ndarray) -> np.ndarray:
Returns:
np.ndarray: The cropped image.
"""
- box = Box([0, 0, img.shape[1], img.shape[0]]).square()
+ box = Box((0, 0, int(img.shape[1]), int(img.shape[0]))).square()
return imcropbox(img, box)
+@overload
+def imresize_and_pad_if_need(
+ img: np.ndarray,
+ max_h: int,
+ max_w: int,
+ interpolation: str | int | INTER = INTER.BILINEAR,
+ pad_value: int | tuple[int, int, int] | None = 0,
+ pad_mode: str | int | BORDER = BORDER.CONSTANT,
+ return_scale: Literal[False] = False,
+) -> np.ndarray: ...
+
+
+@overload
def imresize_and_pad_if_need(
img: np.ndarray,
max_h: int,
max_w: int,
- interpolation: Union[str, int, INTER] = INTER.BILINEAR,
- pad_value: Optional[Union[int, Tuple[int, int, int]]] = 0,
- pad_mode: Union[str, int, BORDER] = BORDER.CONSTANT,
+ interpolation: str | int | INTER = INTER.BILINEAR,
+ pad_value: int | tuple[int, int, int] | None = 0,
+ pad_mode: str | int | BORDER = BORDER.CONSTANT,
+ return_scale: Literal[True] = True,
+) -> tuple[np.ndarray, float]: ...
+
+
+def imresize_and_pad_if_need(
+ img: np.ndarray,
+ max_h: int,
+ max_w: int,
+ interpolation: str | int | INTER = INTER.BILINEAR,
+ pad_value: int | tuple[int, int, int] | None = 0,
+ pad_mode: str | int | BORDER = BORDER.CONSTANT,
return_scale: bool = False,
-):
+) -> np.ndarray | tuple[np.ndarray, float]:
raw_h, raw_w = img.shape[:2]
scale = min(max_h / raw_h, max_w / raw_w)
- dst_h, dst_w = min(int(raw_h * scale), max_h), min(int(raw_w * scale), max_w)
+ dst_h, dst_w = (
+ min(int(raw_h * scale), max_h),
+ min(int(raw_w * scale), max_w),
+ )
img = imresize(
img,
(dst_h, dst_w),
diff --git a/capybara/vision/geometric.py b/capybara/vision/geometric.py
index 0cb072b..88f9bd5 100644
--- a/capybara/vision/geometric.py
+++ b/capybara/vision/geometric.py
@@ -1,4 +1,4 @@
-from typing import List, Tuple, Union
+from typing import Literal, overload
import cv2
import numpy as np
@@ -7,17 +7,38 @@
from ..structures import Polygon, Polygons, order_points_clockwise
__all__ = [
- 'imresize', 'imrotate90', 'imrotate', 'imwarp_quadrangle',
- 'imwarp_quadrangles'
+ "imresize",
+ "imrotate",
+ "imrotate90",
+ "imwarp_quadrangle",
+ "imwarp_quadrangles",
]
+@overload
def imresize(
img: np.ndarray,
- size: Tuple[int, int],
- interpolation: Union[str, int, INTER] = INTER.BILINEAR,
+ size: tuple[int | None, int | None],
+ interpolation: str | int | INTER = INTER.BILINEAR,
+ return_scale: Literal[False] = False,
+) -> np.ndarray: ...
+
+
+@overload
+def imresize(
+ img: np.ndarray,
+ size: tuple[int | None, int | None],
+ interpolation: str | int | INTER = INTER.BILINEAR,
+ return_scale: Literal[True] = True,
+) -> tuple[np.ndarray, float, float]: ...
+
+
+def imresize(
+ img: np.ndarray,
+ size: tuple[int | None, int | None],
+ interpolation: str | int | INTER = INTER.BILINEAR,
return_scale: bool = False,
-) -> Union[np.ndarray, Tuple[np.ndarray, float, float]]:
+) -> np.ndarray | tuple[np.ndarray, float, float]:
"""
This function is used to resize image.
@@ -50,10 +71,13 @@ def imresize(
scale = h / raw_h
w = int(raw_w * scale + 0.5) # round to nearest integer
+ if h is None or w is None:
+ raise ValueError("`size` must provide at least one dimension.")
+
resized_img = cv2.resize(img, (w, h), interpolation=interpolation.value)
if return_scale:
- if 'scale' not in locals(): # calculate scale if not already done
+ if "scale" not in locals(): # calculate scale if not already done
w_scale = w / raw_w
h_scale = h / raw_h
else:
@@ -81,13 +105,13 @@ def imrotate(
img: np.ndarray,
angle: float,
scale: float = 1,
- interpolation: Union[str, int, INTER] = INTER.BILINEAR,
- bordertype: Union[str, int, BORDER] = BORDER.CONSTANT,
- bordervalue: Union[int, Tuple[int, int, int]] = None,
+ interpolation: str | int | INTER = INTER.BILINEAR,
+ bordertype: str | int | BORDER = BORDER.CONSTANT,
+ bordervalue: int | tuple[int, ...] | None = None,
expand: bool = True,
- center: Tuple[int, int] = None,
+ center: tuple[int, int] | None = None,
) -> np.ndarray:
- '''
+ """
Rotate the image by angle.
Args:
@@ -113,49 +137,76 @@ def imrotate(
Returns:
rotated img: rotated img.
- '''
+ """
bordertype = BORDER.obj_to_enum(bordertype)
- bordervalue = (bordervalue,) * \
- 3 if isinstance(bordervalue, int) else bordervalue
interpolation = INTER.obj_to_enum(interpolation)
+ if img.ndim == 2:
+ channels = 1
+ elif img.ndim == 3:
+ channels = int(img.shape[-1])
+ else:
+ raise ValueError("img must be a 2D or 3D numpy image.")
+
+ if bordervalue is None:
+ bordervalue = 0
+ elif isinstance(bordervalue, int):
+ bordervalue = (
+ int(bordervalue)
+ if channels == 1
+ else tuple(int(bordervalue) for _ in range(channels))
+ )
+ elif isinstance(bordervalue, tuple):
+ if channels == 1 and len(bordervalue) == 1:
+ bordervalue = int(bordervalue[0])
+ elif len(bordervalue) == channels:
+ bordervalue = tuple(int(v) for v in bordervalue)
+ else:
+ raise ValueError(
+ f"channel of image is {channels} but bordervalue is {bordervalue}."
+ )
+
h, w = img.shape[:2]
center = center or (w / 2, h / 2)
- M = cv2.getRotationMatrix2D(center, angle=angle, scale=scale)
+ matrix = cv2.getRotationMatrix2D(center, angle=angle, scale=scale)
if expand:
- cos = np.abs(M[0, 0])
- sin = np.abs(M[0, 1])
+ cos = np.abs(matrix[0, 0])
+ sin = np.abs(matrix[0, 1])
# compute the new bounding dimensions of the image
- nW = int((h * sin) + (w * cos)) + 1
- nH = int((h * cos) + (w * sin)) + 1
+ new_w = int((h * sin) + (w * cos)) + 1
+ new_h = int((h * cos) + (w * sin)) + 1
# adjust the rotation matrix to take into account translation
- M[0, 2] += (nW / 2) - center[0]
- M[1, 2] += (nH / 2) - center[1]
+ matrix[0, 2] += (new_w / 2) - center[0]
+ matrix[1, 2] += (new_h / 2) - center[1]
# perform the actual rotation and return the image
dst = cv2.warpAffine(
- img, M, (nW, nH),
- flags=interpolation,
- borderMode=bordertype,
- borderValue=bordervalue
+ img,
+ matrix,
+ (new_w, new_h),
+ flags=interpolation.value,
+ borderMode=bordertype.value,
+ borderValue=bordervalue,
)
else:
dst = cv2.warpAffine(
- img, M, (w, h),
- flags=interpolation,
- borderMode=bordertype,
- borderValue=bordervalue
+ img,
+ matrix,
+ (w, h),
+ flags=interpolation.value,
+ borderMode=bordertype.value,
+ borderValue=bordervalue,
)
- return dst.astype('uint8')
+ return dst
def imwarp_quadrangle(
img: np.ndarray,
- polygon: Union[Polygon, np.ndarray],
- dst_size: Tuple[int, int] = None,
+ polygon: Polygon | np.ndarray,
+ dst_size: tuple[int, int] | None = None,
do_order_points: bool = True,
) -> np.ndarray:
"""
@@ -188,12 +239,12 @@ def imwarp_quadrangle(
polygon = Polygon(polygon)
if not isinstance(polygon, Polygon):
- raise TypeError(
- f'Input type of polygon {type(polygon)} not supported.')
+ raise TypeError(f"Input type of polygon {type(polygon)} not supported.")
if len(polygon) != 4:
raise ValueError(
- f'Input polygon, which is not contain 4 points is invalid.')
+ "Input polygon, which is not contain 4 points is invalid."
+ )
if dst_size is None:
width, height = polygon.min_box_wh
@@ -209,10 +260,9 @@ def imwarp_quadrangle(
if do_order_points:
src_pts = order_points_clockwise(src_pts)
- dst_pts = np.array([[0, 0],
- [width, 0],
- [width, height],
- [0, height]], dtype="float32")
+ dst_pts = np.array(
+ [[0, 0], [width, 0], [width, height], [0, height]], dtype="float32"
+ )
matrix = cv2.getPerspectiveTransform(src_pts, dst_pts)
return cv2.warpPerspective(img, matrix, (width, height))
@@ -221,9 +271,9 @@ def imwarp_quadrangle(
def imwarp_quadrangles(
img: np.ndarray,
polygons: Polygons,
- dst_size: Tuple[int, int] = None,
+ dst_size: tuple[int, int] | None = None,
do_order_points: bool = True,
-) -> List[np.ndarray]:
+) -> list[np.ndarray]:
"""
Apply a 4-point perspective transform to an image using a given polygons.
@@ -246,11 +296,11 @@ def imwarp_quadrangles(
"""
if not isinstance(polygons, Polygons):
raise TypeError(
- f'Input type of polygons {type(polygons)} not supported.')
+ f"Input type of polygons {type(polygons)} not supported."
+ )
return [
imwarp_quadrangle(
- img, poly,
- dst_size=dst_size,
- do_order_points=do_order_points
- ) for poly in polygons
+ img, poly, dst_size=dst_size, do_order_points=do_order_points
+ )
+ for poly in polygons
]
diff --git a/capybara/vision/improc.py b/capybara/vision/improc.py
index 39e59db..51e66c1 100644
--- a/capybara/vision/improc.py
+++ b/capybara/vision/improc.py
@@ -1,6 +1,9 @@
+import os
+import tempfile
import warnings
+from contextlib import suppress
from pathlib import Path
-from typing import Any, List, Union
+from typing import Any, cast
import cv2
import numpy as np
@@ -15,27 +18,27 @@
from .geometric import imrotate90
__all__ = [
- "imread",
- "imwrite",
- "imencode",
- "imdecode",
- "img_to_b64",
- "img_to_b64str",
"b64_to_img",
- "b64str_to_img",
"b64_to_npy",
+ "b64str_to_img",
"b64str_to_npy",
+ "get_orientation_code",
+ "imdecode",
+ "imencode",
+ "img_to_b64",
+ "img_to_b64str",
+ "imread",
+ "imwrite",
+ "is_numpy_img",
+ "jpgdecode",
+ "jpgencode",
+ "jpgread",
"npy_to_b64",
"npy_to_b64str",
"npyread",
"pdf2imgs",
- "jpgencode",
- "jpgdecode",
- "jpgread",
- "pngencode",
"pngdecode",
- "is_numpy_img",
- "get_orientation_code",
+ "pngencode",
]
jpeg = TurboJPEG()
@@ -45,47 +48,52 @@ def is_numpy_img(x: Any) -> bool:
"""
x == ndarray (H x W x C)
"""
- return isinstance(x, np.ndarray) and (x.ndim == 2 or (x.ndim == 3 and x.shape[-1] in [1, 3]))
+ return isinstance(x, np.ndarray) and (
+ x.ndim == 2 or (x.ndim == 3 and x.shape[-1] in [1, 3])
+ )
-def get_orientation_code(stream: Union[str, Path, bytes]):
- code = None
+def get_orientation_code(stream: str | Path | bytes):
try:
exif_dict = piexif.load(stream)
- if piexif.ImageIFD.Orientation in exif_dict["0th"]:
- orientation = exif_dict["0th"][piexif.ImageIFD.Orientation]
- if orientation == 3:
- code = ROTATE.ROTATE_180
- elif orientation == 6:
- code = ROTATE.ROTATE_90
- elif orientation == 8:
- code = ROTATE.ROTATE_270
- finally:
- return code
-
-
-def jpgencode(img: np.ndarray, quality: int = 90) -> Union[bytes, None]:
+ except Exception:
+ return None
+
+ orientation = exif_dict.get("0th", {}).get(piexif.ImageIFD.Orientation)
+ if orientation == 3:
+ return ROTATE.ROTATE_180
+ if orientation == 6:
+ return ROTATE.ROTATE_90
+ if orientation == 8:
+ return ROTATE.ROTATE_270
+ return None
+
+
+def jpgencode(img: np.ndarray, quality: int = 90) -> bytes | None:
byte_ = None
if is_numpy_img(img):
- try:
- byte_ = jpeg.encode(img, quality=quality)
- except Exception as _:
- pass
+ with suppress(Exception):
+ encoded = jpeg.encode(img, quality=quality)
+ if isinstance(encoded, tuple):
+ encoded = encoded[0]
+ byte_ = cast(bytes, encoded)
return byte_
-def jpgdecode(byte_: bytes) -> Union[np.ndarray, None]:
+def jpgdecode(byte_: bytes) -> np.ndarray | None:
try:
bgr_array = jpeg.decode(byte_)
code = get_orientation_code(byte_)
- bgr_array = imrotate90(bgr_array, code) if code is not None else bgr_array
+ bgr_array = (
+ imrotate90(bgr_array, code) if code is not None else bgr_array
+ )
except Exception as _:
bgr_array = None
return bgr_array
-def jpgread(img_file: Union[str, Path]) -> Union[np.ndarray, None]:
+def jpgread(img_file: str | Path) -> np.ndarray | None:
with open(str(img_file), "rb") as f:
binary_img = f.read()
bgr_array = jpgdecode(binary_img)
@@ -93,17 +101,19 @@ def jpgread(img_file: Union[str, Path]) -> Union[np.ndarray, None]:
return bgr_array
-def pngencode(img: np.ndarray, compression: int = 1) -> Union[bytes, None]:
+def pngencode(img: np.ndarray, compression: int = 1) -> bytes | None:
byte_ = None
if is_numpy_img(img):
- try:
- byte_ = cv2.imencode(".png", img, params=[int(cv2.IMWRITE_PNG_COMPRESSION), compression])[1].tobytes()
- except Exception as _:
- pass
+ with suppress(Exception):
+ byte_ = cv2.imencode(
+ ".png",
+ img,
+ params=[int(cv2.IMWRITE_PNG_COMPRESSION), compression],
+ )[1].tobytes()
return byte_
-def pngdecode(byte_: bytes) -> Union[np.ndarray, None]:
+def pngdecode(byte_: bytes) -> np.ndarray | None:
try:
enc = np.frombuffer(byte_, "uint8")
img = cv2.imdecode(enc, cv2.IMREAD_COLOR)
@@ -112,14 +122,25 @@ def pngdecode(byte_: bytes) -> Union[np.ndarray, None]:
return img
-def imencode(img: np.ndarray, IMGTYP: Union[str, int, IMGTYP] = IMGTYP.JPEG) -> Union[bytes, None]:
- IMGTYP = IMGTYP.obj_to_enum(IMGTYP)
- encode_fn = jpgencode if IMGTYP == IMGTYP.JPEG else pngencode
- byte_ = encode_fn(img)
- return byte_
-
-
-def imdecode(byte_: bytes) -> Union[np.ndarray, None]:
+def imencode(
+ img: np.ndarray,
+ imgtyp: str | int | IMGTYP = IMGTYP.JPEG,
+ **kwargs: object,
+) -> bytes | None:
+ if "IMGTYP" in kwargs:
+ if imgtyp != IMGTYP.JPEG:
+ raise TypeError("imgtyp and IMGTYP were both provided.")
+ imgtyp = cast(str | int | IMGTYP, kwargs.pop("IMGTYP"))
+ if kwargs:
+ unexpected = ", ".join(sorted(kwargs))
+ raise TypeError(f"Unexpected keyword arguments: {unexpected}")
+
+ imgtyp_enum = IMGTYP.obj_to_enum(imgtyp)
+ encode_fn = jpgencode if imgtyp_enum == IMGTYP.JPEG else pngencode
+ return encode_fn(img)
+
+
+def imdecode(byte_: bytes) -> np.ndarray | None:
try:
img = jpgdecode(byte_)
img = pngdecode(byte_) if img is None else img
@@ -128,11 +149,26 @@ def imdecode(byte_: bytes) -> Union[np.ndarray, None]:
return img
-def img_to_b64(img: np.ndarray, IMGTYP: Union[str, int, IMGTYP] = IMGTYP.JPEG) -> Union[bytes, None]:
- IMGTYP = IMGTYP.obj_to_enum(IMGTYP)
- encode_fn = jpgencode if IMGTYP == IMGTYP.JPEG else pngencode
+def img_to_b64(
+ img: np.ndarray,
+ imgtyp: str | int | IMGTYP = IMGTYP.JPEG,
+ **kwargs: object,
+) -> bytes | None:
+ if "IMGTYP" in kwargs:
+ if imgtyp != IMGTYP.JPEG:
+ raise TypeError("imgtyp and IMGTYP were both provided.")
+ imgtyp = cast(str | int | IMGTYP, kwargs.pop("IMGTYP"))
+ if kwargs:
+ unexpected = ", ".join(sorted(kwargs))
+ raise TypeError(f"Unexpected keyword arguments: {unexpected}")
+
+ imgtyp_enum = IMGTYP.obj_to_enum(imgtyp)
+ encode_fn = jpgencode if imgtyp_enum == IMGTYP.JPEG else pngencode
try:
- b64 = pybase64.b64encode(encode_fn(img))
+ encoded = encode_fn(img)
+ if encoded is None:
+ return None
+ b64 = pybase64.b64encode(encoded)
except Exception as _:
b64 = None
return b64
@@ -142,18 +178,23 @@ def npy_to_b64(x: np.ndarray, dtype="float32") -> bytes:
return pybase64.b64encode(x.astype(dtype).tobytes())
-def npy_to_b64str(x: np.ndarray, dtype="float32", string_encode: str = "utf-8") -> str:
+def npy_to_b64str(
+ x: np.ndarray, dtype="float32", string_encode: str = "utf-8"
+) -> str:
return pybase64.b64encode(x.astype(dtype).tobytes()).decode(string_encode)
def img_to_b64str(
- img: np.ndarray, IMGTYP: Union[str, int, IMGTYP] = IMGTYP.JPEG, string_encode: str = "utf-8"
-) -> Union[str, None]:
- b64 = img_to_b64(img, IMGTYP)
+ img: np.ndarray,
+ imgtyp: str | int | IMGTYP = IMGTYP.JPEG,
+ string_encode: str = "utf-8",
+ **kwargs: object,
+) -> str | None:
+ b64 = img_to_b64(img, imgtyp, **kwargs)
return b64.decode(string_encode) if isinstance(b64, bytes) else None
-def b64_to_img(b64: bytes) -> Union[np.ndarray, None]:
+def b64_to_img(b64: bytes) -> np.ndarray | None:
try:
img = imdecode(pybase64.b64decode(b64))
except Exception as _:
@@ -161,9 +202,11 @@ def b64_to_img(b64: bytes) -> Union[np.ndarray, None]:
return img
-def b64str_to_img(b64str: Union[str, None], string_encode: str = "utf-8") -> Union[np.ndarray, None]:
+def b64str_to_img(
+ b64str: str | None, string_encode: str = "utf-8"
+) -> np.ndarray | None:
if b64str is None:
- warnings.warn("b64str is None.")
+ warnings.warn("b64str is None.", stacklevel=2)
return None
if not isinstance(b64str, str):
@@ -176,11 +219,15 @@ def b64_to_npy(x: bytes, dtype="float32") -> np.ndarray:
return np.frombuffer(pybase64.b64decode(x), dtype=dtype)
-def b64str_to_npy(x: bytes, dtype="float32", string_encode: str = "utf-8") -> np.ndarray:
- return np.frombuffer(pybase64.b64decode(x.encode(string_encode)), dtype=dtype)
+def b64str_to_npy(
+ x: str, dtype="float32", string_encode: str = "utf-8"
+) -> np.ndarray:
+ return np.frombuffer(
+ pybase64.b64decode(x.encode(string_encode)), dtype=dtype
+ )
-def npyread(path: Union[str, Path]) -> Union[np.ndarray, None]:
+def npyread(path: str | Path) -> np.ndarray | None:
try:
with open(str(path), "rb") as f:
img = np.load(f)
@@ -189,7 +236,9 @@ def npyread(path: Union[str, Path]) -> Union[np.ndarray, None]:
return img
-def imread(path: Union[str, Path], color_base: str = "BGR", verbose: bool = False) -> Union[np.ndarray, None]:
+def imread(
+ path: str | Path, color_base: str = "BGR", verbose: bool = False
+) -> np.ndarray | None:
"""
This function reads an image from a given file path and converts its color
base if necessary.
@@ -215,8 +264,12 @@ def imread(path: Union[str, Path], color_base: str = "BGR", verbose: bool = Fals
if not Path(path).exists():
raise FileExistsError(f"{path} can not found.")
+ color_base = color_base.upper()
+
if Path(path).suffix.lower() == ".heic":
- heif_file = pillow_heif.open_heif(str(path), convert_hdr_to_8bit=True, bgr_mode=True)
+ heif_file = pillow_heif.open_heif(
+ str(path), convert_hdr_to_8bit=True, bgr_mode=True
+ )
img = np.asarray(heif_file)
else:
img = jpgread(path)
@@ -224,7 +277,7 @@ def imread(path: Union[str, Path], color_base: str = "BGR", verbose: bool = Fals
if img is None:
if verbose:
- warnings.warn("Got a None type image.")
+ warnings.warn("Got a None type image.", stacklevel=2)
return
if color_base != "BGR":
@@ -235,7 +288,7 @@ def imread(path: Union[str, Path], color_base: str = "BGR", verbose: bool = Fals
def imwrite(
img: np.ndarray,
- path: Union[str, Path] = None,
+ path: str | Path | None = None,
color_base: str = "BGR",
suffix: str = ".jpg",
) -> bool:
@@ -260,10 +313,15 @@ def imwrite(
color_base = color_base.upper()
if color_base != "BGR":
img = imcvtcolor(img, cvt_mode=f"{color_base}2BGR")
- return cv2.imwrite(str(path) if path else f"tmp{suffix}", img)
+ if path is None:
+ fd, target = tempfile.mkstemp(prefix="capybara_", suffix=suffix)
+ os.close(fd)
+ else:
+ target = str(path)
+ return bool(cv2.imwrite(target, img))
-def pdf2imgs(stream: Union[str, Path, bytes]) -> Union[List[np.ndarray], None]:
+def pdf2imgs(stream: str | Path | bytes) -> list[np.ndarray] | None:
"""
Function for converting a PDF document to numpy images.
@@ -278,6 +336,8 @@ def pdf2imgs(stream: Union[str, Path, bytes]) -> Union[List[np.ndarray], None]:
pil_imgs = convert_from_bytes(stream)
else:
pil_imgs = convert_from_path(stream)
- return [imcvtcolor(np.array(img), cvt_mode="RGB2BGR") for img in pil_imgs]
+ return [
+ imcvtcolor(np.array(img), cvt_mode="RGB2BGR") for img in pil_imgs
+ ]
except Exception as _:
return
diff --git a/capybara/vision/ipcam/__init__.py b/capybara/vision/ipcam/__init__.py
index 5bda3fc..e27b3bc 100644
--- a/capybara/vision/ipcam/__init__.py
+++ b/capybara/vision/ipcam/__init__.py
@@ -1,2 +1,5 @@
-from .app import *
-from .camera import *
+from __future__ import annotations
+
+from . import app, camera
+
+__all__ = ["app", "camera"]
diff --git a/capybara/vision/ipcam/app.py b/capybara/vision/ipcam/app.py
index b8f631a..ffbaff5 100644
--- a/capybara/vision/ipcam/app.py
+++ b/capybara/vision/ipcam/app.py
@@ -1,50 +1,60 @@
-from flask import Flask, Response, render_template_string
+from flask import Flask, Response, render_template_string # type: ignore
from ...utils import get_curdir
from ..improc import jpgencode
from .camera import IpcamCapture
-__all__ = ['WebDemo']
+__all__ = ["WebDemo"]
-def gen(cap, pipelines=[]):
+def gen(cap, pipelines=None):
+ if pipelines is None:
+ pipelines = []
while True:
frame = cap.get_frame()
for f in pipelines:
frame = f(frame)
frame_bytes = jpgencode(frame)
+ if frame_bytes is None:
+ continue
yield (
- b'--frame\r\n'
- b'Content-Type: image/jpeg\r\n\r\n' + frame_bytes + b'\r\n\r\n'
+ b"--frame\r\n"
+ b"Content-Type: image/jpeg\r\n\r\n" + frame_bytes + b"\r\n\r\n"
)
class WebDemo:
-
def __init__(
self,
camera_ip: str,
- color_base: str = 'BGR',
- route: str = '/',
- pipelines: list = [],
+ color_base: str = "BGR",
+ route: str = "/",
+ pipelines: list | None = None,
):
+ if pipelines is None:
+ pipelines = []
app = Flask(__name__)
- @ app.route(route)
+ @app.route(route)
def _index():
- with open(str(get_curdir(__file__) / 'video_streaming.html'), 'r', encoding='utf-8') as file:
+ with open(
+ str(get_curdir(__file__) / "video_streaming.html"),
+ encoding="utf-8",
+ ) as file:
html_content = file.read()
return render_template_string(html_content)
- @ app.route('/video_feed')
+ @app.route("/video_feed")
def _video_feed():
return Response(
- gen(cap=IpcamCapture(camera_ip, color_base), pipelines=pipelines),
- mimetype='multipart/x-mixed-replace; boundary=frame'
+ gen(
+ cap=IpcamCapture(camera_ip, color_base), pipelines=pipelines
+ ),
+ mimetype="multipart/x-mixed-replace; boundary=frame",
)
self.app = app
- def run(self, host='0.0.0.0', port=5001, debug=False, threaded=True):
+ def run(self, host="0.0.0.0", port=5001, debug=False, threaded=True):
self.app.run(host=host, port=port, debug=debug, threaded=threaded)
diff --git a/capybara/vision/ipcam/camera.py b/capybara/vision/ipcam/camera.py
index 4868325..1d1ff60 100644
--- a/capybara/vision/ipcam/camera.py
+++ b/capybara/vision/ipcam/camera.py
@@ -5,12 +5,11 @@
from ..functionals import imcvtcolor
-__all__ = ['IpcamCapture']
+__all__ = ["IpcamCapture"]
class IpcamCapture:
-
- def __init__(self, url=0, color_base='BGR'):
+ def __init__(self, url: int | str = 0, color_base: str = "BGR") -> None:
"""
Initializes the IpcamCapture class.
@@ -43,7 +42,7 @@ def __init__(self, url=0, color_base='BGR'):
self._lock = Lock()
if self._h == 0 or self._w == 0:
- raise ValueError(f'The image size is not supported.')
+ raise ValueError("The image size is not supported.")
Thread(target=self._queryframe, daemon=True).start()
@@ -52,8 +51,8 @@ def _queryframe(self):
ret, frame = self._capture.read()
if not ret:
break # Stop the loop if the video stream has ended or is unreadable
- if self.color_base != 'BGR':
- frame = imcvtcolor(frame, cvt_mode=f'BGR2{self.color_base}')
+ if self.color_base != "BGR":
+ frame = imcvtcolor(frame, cvt_mode=f"BGR2{self.color_base}")
with self._lock:
self._frame = frame
@@ -66,4 +65,7 @@ def get_frame(self):
return frame
def __iter__(self):
- yield self.get_frame()
+ return self
+
+ def __next__(self):
+ return self.get_frame()
diff --git a/capybara/vision/morphology.py b/capybara/vision/morphology.py
index f704235..3bd7b21 100644
--- a/capybara/vision/morphology.py
+++ b/capybara/vision/morphology.py
@@ -1,20 +1,23 @@
-from typing import Tuple, Union
-
import cv2
import numpy as np
from ..enums import MORPH
__all__ = [
- 'imerode', 'imdilate', 'imopen', 'imclose',
- 'imgradient', 'imtophat', 'imblackhat',
+ "imblackhat",
+ "imclose",
+ "imdilate",
+ "imerode",
+ "imgradient",
+ "imopen",
+ "imtophat",
]
def imerode(
img: np.ndarray,
- ksize: Union[int, Tuple[int, int]] = (3, 3),
- kstruct: Union[str, int, MORPH] = MORPH.RECT
+ ksize: int | tuple[int, int] = (3, 3),
+ kstruct: str | int | MORPH = MORPH.RECT,
) -> np.ndarray:
"""
Erosion:
@@ -33,9 +36,9 @@ def imerode(
MORPH.ELLIPSE}, Defaults to MORPH.RECT.
"""
if isinstance(ksize, int):
- ksize = (ksize, ) * 2
+ ksize = (ksize, ksize)
elif not isinstance(ksize, tuple) or len(ksize) != 2:
- raise TypeError(f'Got inappropriate type or shape of size. {ksize}.')
+ raise TypeError(f"Got inappropriate type or shape of size. {ksize}.")
kstruct = MORPH.obj_to_enum(kstruct)
@@ -44,8 +47,8 @@ def imerode(
def imdilate(
img: np.ndarray,
- ksize: Union[int, Tuple[int, int]] = (3, 3),
- kstruct: Union[str, int, MORPH] = MORPH.RECT
+ ksize: int | tuple[int, int] = (3, 3),
+ kstruct: str | int | MORPH = MORPH.RECT,
) -> np.ndarray:
"""
Dilation:
@@ -64,9 +67,9 @@ def imdilate(
MORPH.ELLIPSE}, Defaults to MORPH.RECT.
"""
if isinstance(ksize, int):
- ksize = (ksize, ) * 2
+ ksize = (ksize, ksize)
elif not isinstance(ksize, tuple) or len(ksize) != 2:
- raise TypeError(f'Got inappropriate type or shape of size. {ksize}.')
+ raise TypeError(f"Got inappropriate type or shape of size. {ksize}.")
kstruct = MORPH.obj_to_enum(kstruct)
@@ -75,8 +78,8 @@ def imdilate(
def imopen(
img: np.ndarray,
- ksize: Union[int, Tuple[int, int]] = (3, 3),
- kstruct: Union[str, int, MORPH] = MORPH.RECT
+ ksize: int | tuple[int, int] = (3, 3),
+ kstruct: str | int | MORPH = MORPH.RECT,
) -> np.ndarray:
"""
Opening:
@@ -93,19 +96,21 @@ def imopen(
MORPH.ELLIPSE}, Defaults to MORPH.RECT.
"""
if isinstance(ksize, int):
- ksize = (ksize, ) * 2
+ ksize = (ksize, ksize)
elif not isinstance(ksize, tuple) or len(ksize) != 2:
- raise TypeError(f'Got inappropriate type or shape of size. {ksize}.')
+ raise TypeError(f"Got inappropriate type or shape of size. {ksize}.")
kstruct = MORPH.obj_to_enum(kstruct)
- return cv2.morphologyEx(img, cv2.MORPH_OPEN, cv2.getStructuringElement(kstruct, ksize))
+ return cv2.morphologyEx(
+ img, cv2.MORPH_OPEN, cv2.getStructuringElement(kstruct, ksize)
+ )
def imclose(
img: np.ndarray,
- ksize: Union[int, Tuple[int, int]] = (3, 3),
- kstruct: Union[str, int, MORPH] = MORPH.RECT
+ ksize: int | tuple[int, int] = (3, 3),
+ kstruct: str | int | MORPH = MORPH.RECT,
) -> np.ndarray:
"""
Closing:
@@ -123,19 +128,21 @@ def imclose(
MORPH.ELLIPSE}, Defaults to MORPH.RECT.
"""
if isinstance(ksize, int):
- ksize = (ksize, ) * 2
+ ksize = (ksize, ksize)
elif not isinstance(ksize, tuple) or len(ksize) != 2:
- raise TypeError(f'Got inappropriate type or shape of size. {ksize}.')
+ raise TypeError(f"Got inappropriate type or shape of size. {ksize}.")
kstruct = MORPH.obj_to_enum(kstruct)
- return cv2.morphologyEx(img, cv2.MORPH_CLOSE, cv2.getStructuringElement(kstruct, ksize))
+ return cv2.morphologyEx(
+ img, cv2.MORPH_CLOSE, cv2.getStructuringElement(kstruct, ksize)
+ )
def imgradient(
img: np.ndarray,
- ksize: Union[int, Tuple[int, int]] = (3, 3),
- kstruct: Union[str, int, MORPH] = MORPH.RECT
+ ksize: int | tuple[int, int] = (3, 3),
+ kstruct: str | int | MORPH = MORPH.RECT,
) -> np.ndarray:
"""
Morphological Gradient:
@@ -151,19 +158,21 @@ def imgradient(
MORPH.ELLIPSE}, Defaults to MORPH.RECT.
"""
if isinstance(ksize, int):
- ksize = (ksize, ) * 2
+ ksize = (ksize, ksize)
elif not isinstance(ksize, tuple) or len(ksize) != 2:
- raise TypeError(f'Got inappropriate type or shape of size. {ksize}.')
+ raise TypeError(f"Got inappropriate type or shape of size. {ksize}.")
kstruct = MORPH.obj_to_enum(kstruct)
- return cv2.morphologyEx(img, cv2.MORPH_GRADIENT, cv2.getStructuringElement(kstruct, ksize))
+ return cv2.morphologyEx(
+ img, cv2.MORPH_GRADIENT, cv2.getStructuringElement(kstruct, ksize)
+ )
def imtophat(
img: np.ndarray,
- ksize: Union[int, Tuple[int, int]] = (3, 3),
- kstruct: Union[str, int, MORPH] = MORPH.RECT
+ ksize: int | tuple[int, int] = (3, 3),
+ kstruct: str | int | MORPH = MORPH.RECT,
) -> np.ndarray:
"""
Top Hat:
@@ -179,19 +188,21 @@ def imtophat(
MORPH.ELLIPSE}, Defaults to MORPH.RECT.
"""
if isinstance(ksize, int):
- ksize = (ksize, ) * 2
+ ksize = (ksize, ksize)
elif not isinstance(ksize, tuple) or len(ksize) != 2:
- raise TypeError(f'Got inappropriate type or shape of size. {ksize}.')
+ raise TypeError(f"Got inappropriate type or shape of size. {ksize}.")
kstruct = MORPH.obj_to_enum(kstruct)
- return cv2.morphologyEx(img, cv2.MORPH_TOPHAT, cv2.getStructuringElement(kstruct, ksize))
+ return cv2.morphologyEx(
+ img, cv2.MORPH_TOPHAT, cv2.getStructuringElement(kstruct, ksize)
+ )
def imblackhat(
img: np.ndarray,
- ksize: Union[int, Tuple[int, int]] = (3, 3),
- kstruct: Union[str, int, MORPH] = MORPH.RECT
+ ksize: int | tuple[int, int] = (3, 3),
+ kstruct: str | int | MORPH = MORPH.RECT,
):
"""
Black Hat:
@@ -207,10 +218,12 @@ def imblackhat(
MORPH.ELLIPSE}, Defaults to MORPH.RECT.
"""
if isinstance(ksize, int):
- ksize = (ksize, ) * 2
+ ksize = (ksize, ksize)
elif not isinstance(ksize, tuple) or len(ksize) != 2:
- raise TypeError(f'Got inappropriate type or shape of size. {ksize}.')
+ raise TypeError(f"Got inappropriate type or shape of size. {ksize}.")
kstruct = MORPH.obj_to_enum(kstruct)
- return cv2.morphologyEx(img, cv2.MORPH_BLACKHAT, cv2.getStructuringElement(kstruct, ksize))
+ return cv2.morphologyEx(
+ img, cv2.MORPH_BLACKHAT, cv2.getStructuringElement(kstruct, ksize)
+ )
diff --git a/capybara/vision/videotools/__init__.py b/capybara/vision/videotools/__init__.py
index 0872b94..42e9184 100644
--- a/capybara/vision/videotools/__init__.py
+++ b/capybara/vision/videotools/__init__.py
@@ -1,2 +1,6 @@
-from .video2frames import *
-from .video2frames_v2 import *
+from __future__ import annotations
+
+from .video2frames import video2frames
+from .video2frames_v2 import video2frames_v2
+
+__all__ = ["video2frames", "video2frames_v2"]
diff --git a/capybara/vision/videotools/video2frames.py b/capybara/vision/videotools/video2frames.py
index 9251464..cec8429 100644
--- a/capybara/vision/videotools/video2frames.py
+++ b/capybara/vision/videotools/video2frames.py
@@ -1,12 +1,13 @@
+import math
from pathlib import Path
-from typing import Any, List
+from typing import Any
import cv2
import numpy as np
-__all__ = ['video2frames', 'is_video_file']
+__all__ = ["is_video_file", "video2frames"]
-VIDEO_SUFFIX = ['.MOV', '.MP4', '.AVI', '.WEBM', '.3GP', '.MKV']
+VIDEO_SUFFIX = [".MOV", ".MP4", ".AVI", ".WEBM", ".3GP", ".MKV"]
def is_video_file(x: Any) -> bool:
@@ -17,9 +18,9 @@ def is_video_file(x: Any) -> bool:
def video2frames(
- video_path: str,
- frame_per_sec: int = None,
-) -> List[np.ndarray]:
+ video_path: str | Path,
+ frame_per_sec: int | None = None,
+) -> list[np.ndarray]:
"""
Extracts the frames from a video using ray
Inputs:
@@ -31,7 +32,7 @@ def video2frames(
frames (List[np.ndarray]): A list of frames.
"""
if not is_video_file(video_path):
- raise TypeError(f'The video_path {video_path} is inappropriate.')
+ raise TypeError(f"The video_path {video_path} is inappropriate.")
# get total_frames frames of video
cap = cv2.VideoCapture(str(video_path))
@@ -40,11 +41,19 @@ def video2frames(
return []
# Get the original FPS of the video
- original_fps = cap.get(cv2.CAP_PROP_FPS)
+ original_fps = float(cap.get(cv2.CAP_PROP_FPS) or 0.0)
# Calculate the interval for frame extraction
- interval = 1 if frame_per_sec is None \
- else int(original_fps / frame_per_sec)
+ if frame_per_sec is None:
+ interval = 1
+ else:
+ frame_per_sec_i = int(frame_per_sec)
+ if frame_per_sec_i <= 0:
+ raise ValueError("frame_per_sec must be > 0.")
+ if not math.isfinite(original_fps) or original_fps <= 0:
+ interval = 1
+ else:
+ interval = max(1, int(original_fps / frame_per_sec_i))
frames = []
index = 0
diff --git a/capybara/vision/videotools/video2frames_v2.py b/capybara/vision/videotools/video2frames_v2.py
index d50ae84..07a8c1f 100644
--- a/capybara/vision/videotools/video2frames_v2.py
+++ b/capybara/vision/videotools/video2frames_v2.py
@@ -1,6 +1,6 @@
from concurrent.futures import ThreadPoolExecutor, as_completed
from itertools import chain
-from typing import Any, List
+from typing import Any
import cv2
import numpy as np
@@ -15,7 +15,9 @@ def is_numpy_img(x: Any) -> bool:
"""
x == ndarray (H x W x C)
"""
- return isinstance(x, np.ndarray) and (x.ndim == 2 or (x.ndim == 3 and x.shape[-1] in [1, 3]))
+ return isinstance(x, np.ndarray) and (
+ x.ndim == 2 or (x.ndim == 3 and x.shape[-1] in [1, 3])
+ )
def flatten_list(xs: list) -> list:
@@ -38,18 +40,26 @@ def flatten_list(xs: list) -> list:
def get_step_inds(start: int, end: int, num: int):
if num > (end - start):
raise ValueError("num is larger than the number of total frames.")
- return np.around(np.linspace(start=start, stop=end, num=num, endpoint=False)).astype(int).tolist()
+ return (
+ np.around(np.linspace(start=start, stop=end, num=num, endpoint=False))
+ .astype(int)
+ .tolist()
+ )
def _extract_frames(
- inds: List[int], video_path: str, max_size: int = 1920, color_base: str = "BGR", global_ind: int = 0
+ inds: list[int],
+ video_path: str | Any,
+ max_size: int = 1920,
+ color_base: str = "BGR",
+ global_ind: int = 0,
):
# check video path
if not is_video_file(video_path):
raise TypeError(f"The video_path {video_path} is inappropriate.")
# open cap
- cap = cv2.VideoCapture(video_path)
+ cap = cv2.VideoCapture(str(video_path))
# if start or end isn't specified lets assume 0
start = inds[0]
end = inds[-1] + 1
@@ -62,7 +72,9 @@ def _process_frame(frame):
if scale < 1:
frame = imresize(frame, (dst_h, dst_w))
elif scale > 1:
- frame = imresize(frame, (dst_h, dst_w), interpolation=cv2.INTER_AREA)
+ frame = imresize(
+ frame, (dst_h, dst_w), interpolation=cv2.INTER_LINEAR
+ )
if color_base.upper() != "BGR":
frame = imcvtcolor(frame, cvt_mode=f"BGR2{color_base}")
@@ -87,14 +99,14 @@ def _pickup_frame():
def video2frames_v2(
- video_path: str,
- frame_per_sec: int = None,
+ video_path: str | Any,
+ frame_per_sec: int | None = None,
start_sec: float = 0,
- end_sec: float = None,
+ end_sec: float | None = None,
n_threads: int = 8,
max_size: int = 1920,
color_base: str = "BGR",
-) -> List[np.ndarray]:
+) -> list[np.ndarray]:
"""
Extracts the frames from a video using ray
Inputs:
@@ -129,7 +141,13 @@ def video2frames_v2(
if total_frames == 0 or fps == 0:
return []
- frame_per_sec = fps if frame_per_sec is None else frame_per_sec
+ n_threads = int(n_threads)
+ if n_threads < 1:
+ raise ValueError("n_threads must be >= 1.")
+
+ frame_per_sec = fps if frame_per_sec is None else int(frame_per_sec)
+ if frame_per_sec <= 0:
+ raise ValueError("frame_per_sec must be > 0.")
total_sec = total_frames / fps
# get frame inds
end_sec = total_sec if end_sec is None or end_sec > total_sec else end_sec
@@ -140,19 +158,29 @@ def video2frames_v2(
start_frame = round(start_sec * fps)
end_frame = round(end_sec * fps)
num = round(total_sec * frame_per_sec)
+ if num <= 0:
+ return []
frame_inds = get_step_inds(start_frame, end_frame, num)
+ if not frame_inds: # pragma: no cover
+ return []
out_frames = []
- with ThreadPoolExecutor(max_workers=n_threads) as executor:
+ worker_count = min(n_threads, len(frame_inds))
+ chunk_size = max(1, (len(frame_inds) + worker_count - 1) // worker_count)
+ frame_inds_list = [
+ frame_inds[i : i + chunk_size]
+ for i in range(0, len(frame_inds), chunk_size)
+ ]
+
+ with ThreadPoolExecutor(max_workers=worker_count) as executor:
## -----start process---- ##
- # split the frames into chunk lists
- divide_size = round(len(frame_inds) / n_threads) + 1
- frame_inds_list = [frame_inds[i * divide_size : (i + 1) * divide_size] for i in range(n_threads)]
future_to_frames = {
- executor.submit(_extract_frames, inds, video_path, max_size, color_base, i): inds
+ executor.submit(
+ _extract_frames, inds, video_path, max_size, color_base, i
+ ): inds
for i, inds in enumerate(frame_inds_list)
}
- out_frames = [[] for _ in range(n_threads)]
+ out_frames = [[] for _ in range(len(frame_inds_list))]
for future in as_completed(future_to_frames):
frames = future_to_frames[future]
diff --git a/capybara/vision/visualization/__init__.py b/capybara/vision/visualization/__init__.py
index c6df1f7..72918ad 100644
--- a/capybara/vision/visualization/__init__.py
+++ b/capybara/vision/visualization/__init__.py
@@ -1 +1,8 @@
-from .draw import *
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+__all__ = ["draw", "utils"]
+
+if TYPE_CHECKING: # pragma: no cover
+ from . import draw, utils
diff --git a/capybara/vision/visualization/draw.py b/capybara/vision/visualization/draw.py
index aa5b256..8840fe2 100644
--- a/capybara/vision/visualization/draw.py
+++ b/capybara/vision/visualization/draw.py
@@ -2,31 +2,58 @@
import functools
import hashlib
from pathlib import Path
-from typing import List, Tuple, Union
import cv2
import matplotlib
+import matplotlib.colors as mpl_colors
import numpy as np
from PIL import Image, ImageDraw, ImageFont
-from ...structures import Polygons
from ...structures.boxes import _Box, _Boxes
from ...structures.keypoints import _Keypoints, _KeypointsList
from ...structures.polygons import _Polygon, _Polygons
-from ...utils import download_from_google, get_curdir
+from ...utils import get_curdir
from ..geometric import imresize
-from .utils import (_Color, _Colors, _Point, _Points, _Scale, _Scales,
- _Thickness, _Thicknesses, prepare_box, prepare_boxes,
- prepare_color, prepare_colors, prepare_img,
- prepare_keypoints, prepare_keypoints_list, prepare_point,
- prepare_points, prepare_polygon, prepare_polygons,
- prepare_scale, prepare_scales, prepare_thickness,
- prepare_thicknesses)
+from .utils import (
+ _Color,
+ _Colors,
+ _Point,
+ _Points,
+ _Scale,
+ _Scales,
+ _Thickness,
+ _Thicknesses,
+ prepare_box,
+ prepare_boxes,
+ prepare_color,
+ prepare_colors,
+ prepare_img,
+ prepare_keypoints,
+ prepare_keypoints_list,
+ prepare_point,
+ prepare_points,
+ prepare_polygon,
+ prepare_polygons,
+ prepare_scale,
+ prepare_scales,
+ prepare_thickness,
+ prepare_thicknesses,
+)
__all__ = [
- 'draw_box', 'draw_boxes', 'draw_polygon', 'draw_polygons', 'draw_text',
- 'generate_colors', 'draw_mask', 'draw_point', 'draw_points', 'draw_keypoints',
- 'draw_keypoints_list', 'draw_detection', 'draw_detections',
+ "draw_box",
+ "draw_boxes",
+ "draw_detection",
+ "draw_detections",
+ "draw_keypoints",
+ "draw_keypoints_list",
+ "draw_mask",
+ "draw_point",
+ "draw_points",
+ "draw_polygon",
+ "draw_polygons",
+ "draw_text",
+ "generate_colors",
]
DIR = get_curdir(__file__)
@@ -34,16 +61,29 @@
DEFAULT_FONT_PATH = DIR / "NotoSansMonoCJKtc-VF.ttf"
-if not (font_path := DIR / DEFAULT_FONT_PATH).exists():
- file_id = "1m6jvsBGKgQsxzpIoe4iEp_EqFxYXe7T1"
- download_from_google(file_id, font_path.name, DIR)
+def _load_font(
+ font_path: str | Path | None,
+ *,
+ size: int,
+) -> ImageFont.FreeTypeFont | ImageFont.ImageFont:
+ candidates: list[Path] = []
+ if font_path is not None:
+ candidates.append(Path(font_path))
+ candidates.append(DEFAULT_FONT_PATH)
+
+ for candidate in candidates:
+ try:
+ return ImageFont.truetype(str(candidate), size=int(size))
+ except Exception:
+ continue
+ return ImageFont.load_default()
def draw_box(
img: np.ndarray,
box: _Box,
color: _Color = (0, 255, 0),
- thickness: _Thickness = 2
+ thickness: _Thickness = 2,
) -> np.ndarray:
"""
Draws a bounding box on the image.
@@ -70,14 +110,16 @@ def draw_box(
h, w = img.shape[:2]
box = box.denormalize(w, h)
x1, y1, x2, y2 = box.numpy().astype(int).tolist()
- return cv2.rectangle(img, (x1, y1), (x2, y2), color=color, thickness=thickness)
+ return cv2.rectangle(
+ img, (x1, y1), (x2, y2), color=color, thickness=thickness
+ )
def draw_boxes(
img: np.ndarray,
boxes: _Boxes,
colors: _Colors = (0, 255, 0),
- thicknesses: _Thicknesses = 2
+ thicknesses: _Thicknesses = 2,
) -> np.ndarray:
"""
Draws multiple bounding boxes on the image.
@@ -101,7 +143,7 @@ def draw_boxes(
boxes = prepare_boxes(boxes)
colors = prepare_colors(colors, len(boxes))
thicknesses = prepare_thicknesses(thicknesses, len(boxes))
- for box, c, t in zip(boxes, colors, thicknesses):
+ for box, c, t in zip(boxes, colors, thicknesses, strict=True):
draw_box(img, box, color=c, thickness=t)
return img
@@ -112,7 +154,7 @@ def draw_polygon(
color: _Color = (0, 255, 0),
thickness: _Thickness = 2,
fillup=False,
- **kwargs
+ **kwargs,
):
"""
Draw a polygon on the input image.
@@ -148,8 +190,14 @@ def draw_polygon(
if fillup:
img = cv2.fillPoly(img, [polygon], color=color, **kwargs)
else:
- img = cv2.polylines(img, [polygon], isClosed=True,
- color=color, thickness=thickness, **kwargs)
+ img = cv2.polylines(
+ img,
+ [polygon],
+ isClosed=True,
+ color=color,
+ thickness=thickness,
+ **kwargs,
+ )
return img
@@ -160,7 +208,7 @@ def draw_polygons(
colors: _Colors = (0, 255, 0),
thicknesses: _Thicknesses = 2,
fillup=False,
- **kwargs
+ **kwargs,
):
"""
Draw polygons on the input image.
@@ -196,9 +244,10 @@ def draw_polygons(
polygons = prepare_polygons(polygons)
colors = prepare_colors(colors, len(polygons))
thicknesses = prepare_thicknesses(thicknesses, len(polygons))
- for polygon, c, t in zip(polygons, colors, thicknesses):
- draw_polygon(img, polygon, color=c, thickness=t,
- fillup=fillup, **kwargs)
+ for polygon, c, t in zip(polygons, colors, thicknesses, strict=True):
+ draw_polygon(
+ img, polygon, color=c, thickness=t, fillup=fillup, **kwargs
+ )
return img
@@ -208,8 +257,8 @@ def draw_text(
location: _Point,
color: _Color = (0, 0, 0),
text_size: int = 12,
- font_path: Union[str, Path] = None,
- **kwargs
+ font_path: str | Path | None = None,
+ **kwargs,
) -> np.ndarray:
"""
Draw specified text on the given image at the provided location.
@@ -236,24 +285,24 @@ def draw_text(
np.ndarray: Image with the text drawn on it.
"""
img = prepare_img(img)
- img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
-
- draw = ImageDraw.Draw(img)
- font_path = DEFAULT_FONT_PATH if font_path is None else font_path
- font = ImageFont.truetype(str(font_path), size=text_size)
+ color = prepare_color(color)
+ pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
- _, top, _, bottom = font.getbbox(text)
- _, offset = font.getmask2(text)
- text_height = bottom - top
+ draw = ImageDraw.Draw(pil_img)
+ font = _load_font(font_path, size=text_size)
- offset_y = int(0.5 * (font.size - text_height) - offset[1])
+ offset_y = 0
+ try:
+ _left, top, _right, _bottom = font.getbbox(text)
+ offset_y = -int(top)
+ except Exception:
+ offset_y = 0
loc = prepare_point(location)
loc = (loc[0], loc[1] + offset_y)
- kwargs.update({'fill': (color[2], color[1], color[0])})
+ kwargs.update({"fill": (color[2], color[1], color[0])})
draw.text(loc, text, font=font, **kwargs)
- img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
-
- return img
+ out = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
+ return out
def draw_line(
@@ -262,11 +311,11 @@ def draw_line(
pt2: _Point,
color: _Color = (0, 255, 0),
thickness: _Thickness = 1,
- style: str = 'dotted',
+ style: str = "dotted",
gap: int = 20,
inplace: bool = False,
):
- '''
+ """
Draw a line on the image.
Args:
@@ -294,24 +343,39 @@ def draw_line(
Returns:
np.ndarray: Image with the drawn line.
- '''
+ """
img = img.copy() if not inplace else img
img = prepare_img(img)
pt1 = prepare_point(pt1)
pt2 = prepare_point(pt2)
- dist = ((pt1[0] - pt2[0])**2 + (pt1[1] - pt2[1])**2)**.5
+ color = prepare_color(color)
+ thickness = prepare_thickness(thickness)
+ gap = int(gap)
+ if gap <= 0:
+ raise ValueError("gap must be > 0.")
+ dist = ((pt1[0] - pt2[0]) ** 2 + (pt1[1] - pt2[1]) ** 2) ** 0.5
+ if dist == 0:
+ cv2.circle(
+ img,
+ pt1,
+ radius=max(1, abs(thickness)),
+ color=color,
+ thickness=-1,
+ lineType=cv2.LINE_AA,
+ )
+ return img
pts = []
for i in np.arange(0, dist, gap):
r = i / dist
- x = int((pt1[0] * (1 - r) + pt2[0] * r) + .5)
- y = int((pt1[1] * (1 - r) + pt2[1] * r) + .5)
+ x = int((pt1[0] * (1 - r) + pt2[0] * r) + 0.5)
+ y = int((pt1[1] * (1 - r) + pt2[1] * r) + 0.5)
p = (x, y)
pts.append(p)
- if style == 'dotted':
+ if style == "dotted":
for p in pts:
- cv2.circle(img, p, thickness, color, -1)
- elif style == 'line':
+ cv2.circle(img, p, radius=thickness, color=color, thickness=-1)
+ elif style == "line":
s = pts[0]
e = pts[0]
i = 0
@@ -319,7 +383,7 @@ def draw_line(
s = e
e = p
if i % 2 == 1:
- cv2.line(img, s, e, color, thickness)
+ cv2.line(img, s, e, color=color, thickness=thickness)
i += 1
else:
raise ValueError(f"Unknown style: {style}")
@@ -333,7 +397,7 @@ def draw_point(
color: _Color = (0, 255, 0),
thickness: _Thickness = -1,
) -> np.ndarray:
- '''
+ """
Draw a point on the image.
Args:
@@ -350,15 +414,22 @@ def draw_point(
Returns:
np.ndarray: Image with the drawn point.
- '''
+ """
is_gray_img = img.ndim == 2
img = prepare_img(img)
point = prepare_point(point)
color = prepare_color(color)
+ thickness = prepare_thickness(thickness)
h, w = img.shape[:2]
size = 1 + (np.sqrt(h * w) * 0.002 * scale).round().astype(int).item()
- img = cv2.circle(img, point, radius=size, color=color,
- lineType=cv2.LINE_AA, thickness=thickness)
+ img = cv2.circle(
+ img,
+ point,
+ radius=size,
+ color=color,
+ lineType=cv2.LINE_AA,
+ thickness=thickness,
+ )
img = img[..., 0] if is_gray_img else img
return img
@@ -366,11 +437,11 @@ def draw_point(
def draw_points(
img: np.ndarray,
points: _Points,
- scales: _Scales = 1.,
+ scales: _Scales = 1.0,
colors: _Colors = (0, 255, 0),
thicknesses: _Thicknesses = -1,
) -> np.ndarray:
- '''
+ """
Draw multiple points on the image.
Args:
@@ -387,14 +458,14 @@ def draw_points(
Returns:
np.ndarray: Image with the drawn points.
- '''
+ """
img = prepare_img(img).copy()
points = prepare_points(points)
colors = prepare_colors(colors, len(points))
thicknesses = prepare_thicknesses(thicknesses, len(points))
scales = prepare_scales(scales, len(points))
- for p, s, c, t in zip(points, scales, colors, thicknesses):
+ for p, s, c, t in zip(points, scales, colors, thicknesses, strict=True):
img = draw_point(img, p, s, c, t)
return img
@@ -403,10 +474,10 @@ def draw_points(
def draw_keypoints(
img: np.ndarray,
keypoints: _Keypoints,
- scale: _Scale = 1.,
- thickness: _Thickness = -1
+ scale: _Scale = 1.0,
+ thickness: _Thickness = -1,
) -> np.ndarray:
- '''
+ """
Draw keypoints on the image.
Args:
@@ -421,7 +492,7 @@ def draw_keypoints(
Returns:
np.ndarray: Image with the drawn keypoints.
- '''
+ """
img = prepare_img(img)
keypoints = prepare_keypoints(keypoints)
@@ -433,7 +504,7 @@ def draw_keypoints(
scale = prepare_scale(scale)
thickness = prepare_thickness(thickness)
points = keypoints.numpy()[..., :2]
- for p, c in zip(points, colors):
+ for p, c in zip(points, colors, strict=True):
img = draw_point(img, p, scale, c, thickness)
return img
@@ -441,10 +512,10 @@ def draw_keypoints(
def draw_keypoints_list(
img: np.ndarray,
keypoints_list: _KeypointsList,
- scales: _Scales = 1.,
- thicknesses: _Thicknesses = -1
+ scales: _Scales = 1.0,
+ thicknesses: _Thicknesses = -1,
) -> np.ndarray:
- '''
+ """
Draw keypoints list on the image.
Args:
@@ -459,45 +530,60 @@ def draw_keypoints_list(
Returns:
np.ndarray: Image with the drawn keypoints list.
- '''
+ """
img = prepare_img(img)
keypoints_list = prepare_keypoints_list(keypoints_list)
scales = prepare_scales(scales, len(keypoints_list))
thicknesses = prepare_thicknesses(thicknesses, len(keypoints_list))
- for ps, s, t in zip(keypoints_list, scales, thicknesses):
+ for ps, s, t in zip(keypoints_list, scales, thicknesses, strict=True):
img = draw_keypoints(img, ps, s, t)
return img
-def generate_colors_from_cmap(n: int, scheme: str) -> List[tuple]:
- cm = matplotlib.cm.get_cmap(scheme)
- return [cm(i/n)[:-1] for i in range(n)]
+def generate_colors_from_cmap(
+ n: int, scheme: str
+) -> list[tuple[float, float, float]]:
+ cm = matplotlib.colormaps.get_cmap(scheme)
+ rgb_colors = []
+ for i in range(n):
+ rgba = cm(i / n)
+ rgb_colors.append((float(rgba[0]), float(rgba[1]), float(rgba[2])))
+ return rgb_colors
-def generate_triadic_colors(n: int) -> List[tuple]:
+def generate_triadic_colors(n: int) -> list[tuple[float, float, float]]:
base_hue = np.random.rand()
- return [matplotlib.colors.hsv_to_rgb(((base_hue + i / 3.0) % 1, 1, 1)) for i in range(n)]
+ return [
+ tuple(mpl_colors.hsv_to_rgb(((base_hue + i / 3.0) % 1, 1, 1)))
+ for i in range(n)
+ ]
-def generate_analogous_colors(n: int) -> List[tuple]:
+def generate_analogous_colors(n: int) -> list[tuple[float, float, float]]:
base_hue = np.random.rand()
step = 0.05
- return [matplotlib.colors.hsv_to_rgb(((base_hue + i * step) % 1, 1, 1)) for i in range(n)]
+ return [
+ tuple(mpl_colors.hsv_to_rgb(((base_hue + i * step) % 1, 1, 1)))
+ for i in range(n)
+ ]
-def generate_square_colors(n: int) -> List[tuple]:
+def generate_square_colors(n: int) -> list[tuple[float, float, float]]:
base_hue = np.random.rand()
- return [matplotlib.colors.hsv_to_rgb(((base_hue + i / 4.0) % 1, 1, 1)) for i in range(n)]
+ return [
+ tuple(mpl_colors.hsv_to_rgb(((base_hue + i / 4.0) % 1, 1, 1)))
+ for i in range(n)
+ ]
-def generate_colors(n: int, scheme: str = 'hsv') -> List[tuple]:
+def generate_colors(n: int, scheme: str = "hsv") -> list[tuple[int, int, int]]:
"""
Generates n different colors based on the chosen color scheme.
"""
color_generators = {
- 'triadic': generate_triadic_colors,
- 'analogous': generate_analogous_colors,
- 'square': generate_square_colors
+ "triadic": generate_triadic_colors,
+ "analogous": generate_analogous_colors,
+ "square": generate_square_colors,
}
if scheme in color_generators:
@@ -507,19 +593,27 @@ def generate_colors(n: int, scheme: str = 'hsv') -> List[tuple]:
colors = generate_colors_from_cmap(n, scheme)
except ValueError:
print(
- f"Color scheme '{scheme}' not recognized. Returning empty list.")
+ f"Color scheme '{scheme}' not recognized. Returning empty list."
+ )
colors = []
- return [tuple(int(c * 255) for c in color) for color in colors]
+ return [
+ (
+ int(color[0] * 255),
+ int(color[1] * 255),
+ int(color[2] * 255),
+ )
+ for color in colors
+ ]
def draw_mask(
img: np.ndarray,
mask: np.ndarray,
colormap: int = cv2.COLORMAP_JET,
- weight: Tuple[float, float] = (0.5, 0.5),
+ weight: tuple[float, float] = (0.5, 0.5),
gamma: float = 0,
- min_max_normalize: bool = False
+ min_max_normalize: bool = False,
) -> np.ndarray:
"""
Draw the mask on the image.
@@ -543,15 +637,16 @@ def draw_mask(
"""
# Ensure the input image has 3 channels
- if img.ndim == 2:
- img = np.stack([img] * 3, axis=-1)
- else:
- img = img.copy() # Avoid modifying the original image
+ img = np.stack([img] * 3, axis=-1) if img.ndim == 2 else img.copy()
# Normalize mask if required
if min_max_normalize:
mask = mask.astype(np.float32)
- mask = (mask - mask.min()) / (mask.max() - mask.min())
+ denom = float(mask.max() - mask.min())
+ if denom > 0:
+ mask = (mask - mask.min()) / denom
+ else:
+ mask = np.zeros_like(mask, dtype=np.float32)
mask = (mask * 255).astype(np.uint8)
else:
mask = mask.astype(np.uint8) # Ensure mask is uint8 for color mapping
@@ -579,7 +674,7 @@ def _vdc(n: int, base: int = 2) -> float:
return vdc
-@functools.lru_cache(maxsize=None)
+@functools.cache
def distinct_color(idx: int) -> _Color:
"""Generate a perceptually distinct BGR color for class *idx*.
@@ -588,9 +683,9 @@ def distinct_color(idx: int) -> _Color:
between close indices.
2. Saturation / Value: cycle every 20 / 10 ids to avoid hue-only clashes.
"""
- hue = _vdc(idx + 1) # (0,1) ─ 長距離跳耀
+ hue = _vdc(idx + 1) # (0,1) ─ 長距離跳耀
sat_cycle = (0.65, 0.80, 0.50) # ┐週期性變化
- val_cycle = (1.00, 0.90, 0.80) # ┘增亮/加深
+ val_cycle = (1.00, 0.90, 0.80) # ┘增亮/加深
s = sat_cycle[(idx // 20) % len(sat_cycle)]
v = val_cycle[(idx // 10) % len(val_cycle)]
r, g, b = colorsys.hsv_to_rgb(hue, s, v)
@@ -609,11 +704,11 @@ def draw_detection(
img: np.ndarray,
box: _Box,
label: str,
- score: Union[float, None] = None,
+ score: float | None = None,
color: _Color | None = None,
thickness: _Thickness | None = None,
text_color: _Color = (255, 255, 255),
- font_path: Union[str, Path] | None = None,
+ font_path: str | Path | None = None,
text_size: int | None = None,
box_alpha: float = 1.0,
text_bg_alpha: float = 0.6,
@@ -652,36 +747,49 @@ def draw_detection(
draw_color = distinct_color(idx)
else:
draw_color = color
+ draw_color = prepare_color(draw_color)
if thickness is None:
# proportional to image diagonal
- diag = (canvas.shape[0]**2 + canvas.shape[1]**2)**0.5
+ diag = (canvas.shape[0] ** 2 + canvas.shape[1] ** 2) ** 0.5
line_thickness = max(1, int(diag * 0.002 + 0.5))
else:
line_thickness = thickness
+ line_thickness = prepare_thickness(line_thickness)
# 3. Draw box (with optional transparency)
if box_alpha >= 1.0:
- cv2.rectangle(canvas, (x1, y1), (x2, y2),
- draw_color, line_thickness, cv2.LINE_AA)
+ cv2.rectangle(
+ canvas,
+ (x1, y1),
+ (x2, y2),
+ color=draw_color,
+ thickness=line_thickness,
+ lineType=cv2.LINE_AA,
+ )
else:
overlay = canvas.copy()
- cv2.rectangle(overlay, (x1, y1), (x2, y2),
- draw_color, line_thickness, cv2.LINE_AA)
+ cv2.rectangle(
+ overlay,
+ (x1, y1),
+ (x2, y2),
+ color=draw_color,
+ thickness=line_thickness,
+ lineType=cv2.LINE_AA,
+ )
canvas = cv2.addWeighted(overlay, box_alpha, canvas, 1 - box_alpha, 0)
# 4. Prepare label text
- text = f"{label} {score*100:.1f}%" if score is not None else label
+ text = f"{label} {score * 100:.1f}%" if score is not None else label
# auto font size (~10% of box height, min 12)
if text_size is None:
text_size = max(12, int((y2 - y1) * 0.10))
# 5. Measure text size with PIL
- font_file = DEFAULT_FONT_PATH if font_path is None else Path(font_path)
pil_img = Image.fromarray(cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(pil_img, "RGBA")
- font = ImageFont.truetype(str(font_file), size=text_size)
+ font = _load_font(font_path, size=text_size)
bbox = draw.textbbox((0, 0), text, font=font)
text_w, text_h = bbox[2] - bbox[0], bbox[3] - bbox[1]
@@ -701,14 +809,14 @@ def draw_detection(
y0, y1_ = sorted([bg_y0, bg_y1])
# 7. Draw semi-transparent background
+ text_color = prepare_color(text_color)
draw.rectangle(
[(x0, y0), (x1_, y1_)],
fill=(*draw_color[::-1], int(text_bg_alpha * 255)),
)
# 8. Draw text (PIL expecting RGB)
- draw.text((x0 + pad, y0 + pad), text,
- font=font, fill=text_color[::-1])
+ draw.text((x0 + pad, y0 + pad), text, font=font, fill=text_color[::-1])
# 9. Convert back to BGR OpenCV
annotated = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
@@ -718,13 +826,13 @@ def draw_detection(
def draw_detections(
img: np.ndarray,
boxes: _Boxes,
- labels: List[str],
- scores: List[float] = None,
- colors: _Colors = None,
- thicknesses: _Thicknesses = None,
+ labels: list[str],
+ scores: list[float] | None = None,
+ colors: _Colors | None = None,
+ thicknesses: _Thicknesses | None = None,
text_colors: _Colors = (255, 255, 255),
- font_path: Union[str, Path] = None,
- text_sizes: List[int] = None,
+ font_path: str | Path | None = None,
+ text_sizes: list[int] | None = None,
box_alpha: float = 1.0,
text_bg_alpha: float = 0.6,
) -> np.ndarray:
@@ -756,18 +864,17 @@ def draw_detections(
if scores is not None and len(scores) != len(labels):
raise ValueError("Number of scores must match number of labels")
- if colors is not None:
- colors = prepare_colors(colors, len(boxes))
- else:
- colors = [None] * len(boxes)
-
- if thicknesses is not None:
- thicknesses = prepare_thicknesses(thicknesses, len(boxes))
- else:
- thicknesses = [None] * len(boxes)
-
- if text_colors is not None:
- text_colors = prepare_colors(text_colors, len(boxes))
+ colors_list = (
+ prepare_colors(colors, len(boxes))
+ if colors is not None
+ else [None] * len(boxes)
+ )
+ thicknesses_list = (
+ prepare_thicknesses(thicknesses, len(boxes))
+ if thicknesses is not None
+ else [None] * len(boxes)
+ )
+ text_colors_list = prepare_colors(text_colors, len(boxes))
if text_sizes is not None:
text_sizes = [int(size) for size in text_sizes]
@@ -775,10 +882,9 @@ def draw_detections(
for i, box in enumerate(boxes):
label = labels[i]
score = scores[i] if scores is not None else None
- color = colors[i] if colors is not None else None
- thickness = thicknesses[i] if thicknesses is not None else None
- text_color = text_colors[i] if text_colors is not None else (
- 255, 255, 255)
+ color = colors_list[i]
+ thickness = thicknesses_list[i]
+ text_color = text_colors_list[i]
text_size = text_sizes[i] if text_sizes is not None else None
canvas = draw_detection(
@@ -792,7 +898,7 @@ def draw_detections(
font_path=font_path,
text_size=text_size,
box_alpha=box_alpha,
- text_bg_alpha=text_bg_alpha
+ text_bg_alpha=text_bg_alpha,
)
return canvas
diff --git a/capybara/vision/visualization/utils.py b/capybara/vision/visualization/utils.py
index 378acb4..81beda7 100644
--- a/capybara/vision/visualization/utils.py
+++ b/capybara/vision/visualization/utils.py
@@ -1,51 +1,61 @@
-from typing import Any, List, Optional, Tuple, Union
+from collections.abc import Sequence
+from typing import Any, TypeAlias, cast
import numpy as np
-from ...structures import (Box, Boxes, BoxMode, Keypoints, KeypointsList,
- Polygon, Polygons)
+from ...structures import (
+ Box,
+ Boxes,
+ BoxMode,
+ Keypoints,
+ KeypointsList,
+ Polygon,
+ Polygons,
+)
from ...structures.boxes import _Box, _Boxes, _Number
from ...structures.keypoints import _Keypoints, _KeypointsList
from ...structures.polygons import _Polygon, _Polygons
-_Color = Union[int, List[int], Tuple[int, int, int], np.ndarray]
-_Colors = Union[_Color, List[_Color], np.ndarray]
-_Point = Union[List[int], Tuple[int, int], Tuple[int, int], np.ndarray]
-_Points = Union[List[_Point], np.ndarray]
-_Thickness = Union[_Number, np.ndarray]
-_Thicknesses = Union[List[_Thickness], _Thickness, np.ndarray]
-_Scale = Union[_Number, np.ndarray]
-_Scales = Union[List[_Scale], _Number, np.ndarray]
+_Color: TypeAlias = int | Sequence[int] | tuple[int, int, int] | np.ndarray
+_Colors: TypeAlias = _Color | Sequence[_Color] | np.ndarray
+_Point: TypeAlias = Sequence[int] | tuple[int, int] | np.ndarray
+_Points: TypeAlias = Sequence[_Point] | np.ndarray
+_Thickness: TypeAlias = _Number
+_Thicknesses: TypeAlias = _Thickness | Sequence[_Thickness] | np.ndarray
+_Scale: TypeAlias = _Number
+_Scales: TypeAlias = _Scale | Sequence[_Scale] | np.ndarray
__all__ = [
- 'is_numpy_img',
- 'prepare_color',
- 'prepare_colors',
- 'prepare_img',
- 'prepare_box',
- 'prepare_boxes',
- 'prepare_keypoints',
- 'prepare_keypoints_list',
- 'prepare_polygon',
- 'prepare_polygons',
- 'prepare_point',
- 'prepare_points',
- 'prepare_thickness',
- 'prepare_thicknesses',
- 'prepare_scale',
- 'prepare_scales',
+ "is_numpy_img",
+ "prepare_box",
+ "prepare_boxes",
+ "prepare_color",
+ "prepare_colors",
+ "prepare_img",
+ "prepare_keypoints",
+ "prepare_keypoints_list",
+ "prepare_point",
+ "prepare_points",
+ "prepare_polygon",
+ "prepare_polygons",
+ "prepare_scale",
+ "prepare_scales",
+ "prepare_thickness",
+ "prepare_thicknesses",
]
def is_numpy_img(x: Any) -> bool:
if not isinstance(x, np.ndarray):
return False
- return (x.ndim == 2 or (x.ndim == 3 and x.shape[-1] in [1, 3]))
+ return x.ndim == 2 or (x.ndim == 3 and x.shape[-1] in [1, 3])
-def prepare_color(color: _Color, ind: int = None) -> Tuple[int, int, int]:
- '''
+def prepare_color(
+ color: _Color, ind: int | None = None
+) -> tuple[int, int, int]:
+ """
This function prepares the color input for opencv.
Args:
@@ -57,23 +67,32 @@ def prepare_color(color: _Color, ind: int = None) -> Tuple[int, int, int]:
Returns:
Tuple[int, int, int]: a tuple of 3 integers.
- '''
+ """
cond1 = isinstance(color, int)
- cond2 = isinstance(color, (tuple, list)) and len(color) == 3 and \
- isinstance(color[0], int) and \
- isinstance(color[1], int) and \
- isinstance(color[2], int)
- cond3 = isinstance(color, np.ndarray) and color.ndim == 1 and len(color) == 3
+ cond2 = (
+ isinstance(color, (tuple, list))
+ and len(color) == 3
+ and isinstance(color[0], int)
+ and isinstance(color[1], int)
+ and isinstance(color[2], int)
+ )
+ cond3 = (
+ isinstance(color, np.ndarray) and color.ndim == 1 and len(color) == 3
+ )
if not (cond1 or cond2 or cond3):
- i = '' if ind is None else f's[{ind}]'
- raise TypeError(f'The input color{i} = {color} is invalid. Should be {_Color}')
- c = (color, ) * 3 if cond1 else color
+ i = "" if ind is None else f"s[{ind}]"
+ raise TypeError(
+ f"The input color{i} = {color} is invalid. Should be {_Color}"
+ )
+ c = (color,) * 3 if cond1 else color
c = tuple(np.array(c, dtype=int).tolist())
return c
-def prepare_colors(colors: _Colors, length: Union[int] = None) -> List[Tuple[int, int, int]]:
- '''
+def prepare_colors(
+ colors: _Colors, length: int | None = None
+) -> list[tuple[int, int, int]]:
+ """
This function prepares the colors input for opencv.
Args:
@@ -82,21 +101,23 @@ def prepare_colors(colors: _Colors, length: Union[int] = None) -> List[Tuple[int
Returns:
List[Tuple[int, int, int]]: a list of tuples of 3 integers.
- '''
- if isinstance(colors, (list, np.ndarray)) and not isinstance(colors[0], _Number):
+ """
+ try:
+ c = prepare_color(cast(Any, colors), 0)
+ except TypeError:
+ if not isinstance(colors, (list, tuple, np.ndarray)):
+ raise
if length is not None and len(colors) != length:
- raise ValueError(f'The length of colors = {len(colors)} is not equal to the length = {length}.')
- cs = []
- for i, color in enumerate(colors):
- cs.append(prepare_color(color, i))
- else:
- c = prepare_color(colors, 0)
- cs = [c] * length
- return cs
+ raise ValueError(
+ f"The length of colors = {len(colors)} is not equal to the length = {length}."
+ ) from None
+ return [prepare_color(color, i) for i, color in enumerate(colors)]
+ repeat = 1 if length is None else length
+ return [c] * repeat
-def prepare_img(img: np.ndarray, ind: Optional[int] = None) -> np.ndarray:
- '''
+def prepare_img(img: np.ndarray, ind: int | None = None) -> np.ndarray:
+ """
This function prepares the image input for opencv.
Args:
@@ -108,25 +129,25 @@ def prepare_img(img: np.ndarray, ind: Optional[int] = None) -> np.ndarray:
Returns:
np.ndarray: a valid numpy image for opencv.
- '''
+ """
if is_numpy_img(img):
if img.ndim == 2:
img = img[..., None].repeat(3, axis=-1)
elif img.ndim == 3 and img.shape[-1] == 1:
img = img.repeat(3, axis=-1)
else:
- i = '' if ind is None else f's[{ind}]'
- raise ValueError(f'The input image{i} is not invalid numpy image.')
+ i = "" if ind is None else f"s[{ind}]"
+ raise ValueError(f"The input image{i} is not invalid numpy image.")
return img
def prepare_box(
box: _Box,
- ind: Optional[int] = None,
- src_mode: Union[str, BoxMode] = BoxMode.XYXY,
- dst_mode: Union[str, BoxMode] = BoxMode.XYXY,
+ ind: int | None = None,
+ src_mode: str | BoxMode = BoxMode.XYXY,
+ dst_mode: str | BoxMode = BoxMode.XYXY,
) -> Box:
- '''
+ """
This function prepares the box input to XYXY format.
Args:
@@ -137,23 +158,27 @@ def prepare_box(
Returns:
Box: a valid Box instance.
- '''
+ """
try:
is_normalized = box.is_normalized if isinstance(box, Box) else False
src_mode = box.box_mode if isinstance(box, Box) else src_mode
- box = Box(box, box_mode=src_mode, is_normalized=is_normalized).convert(dst_mode)
- except:
- i = "" if ind is None else f'es[{ind}]'
- raise ValueError(f"The input box{i} is invalid value = {box}. Should be {_Box}")
+ box = Box(box, box_mode=src_mode, is_normalized=is_normalized).convert(
+ dst_mode
+ )
+ except Exception as exc:
+ i = "" if ind is None else f"es[{ind}]"
+ raise ValueError(
+ f"The input box{i} is invalid value = {box}. Should be {_Box}"
+ ) from exc
return box
def prepare_boxes(
boxes: _Boxes,
- src_mode: Union[str, BoxMode] = BoxMode.XYXY,
- dst_mode: Union[str, BoxMode] = BoxMode.XYXY,
+ src_mode: str | BoxMode = BoxMode.XYXY,
+ dst_mode: str | BoxMode = BoxMode.XYXY,
) -> Boxes:
- '''
+ """
This function prepares the boxes input to XYXY format.
Args:
@@ -163,16 +188,23 @@ def prepare_boxes(
Returns:
Boxes: a valid Boxes instance.
- '''
+ """
if isinstance(boxes, Boxes):
boxes = boxes.convert(dst_mode)
else:
- boxes = Boxes([prepare_box(box, i, src_mode, dst_mode) for i, box in enumerate(boxes)])
+ boxes = Boxes(
+ [
+ prepare_box(box, i, src_mode, dst_mode)
+ for i, box in enumerate(boxes)
+ ]
+ )
return boxes
-def prepare_keypoints(keypoints: _Keypoints, ind: Optional[int] = None) -> Keypoints:
- '''
+def prepare_keypoints(
+ keypoints: _Keypoints, ind: int | None = None
+) -> Keypoints:
+ """
This function prepares the keypoints input.
Args:
@@ -181,17 +213,21 @@ def prepare_keypoints(keypoints: _Keypoints, ind: Optional[int] = None) -> Keypo
Returns:
Keypoints: a valid Keypoints instance.
- '''
+ """
+ if isinstance(keypoints, Keypoints):
+ return keypoints
try:
keypoints = Keypoints(keypoints)
- except:
- i = "" if ind is None else f'_list[{ind}]'
- raise TypeError(f"The input keypoints{i} is invalid value = {keypoints}. Should be {_Keypoints}")
+ except Exception as exc:
+ i = "" if ind is None else f"_list[{ind}]"
+ raise TypeError(
+ f"The input keypoints{i} is invalid value = {keypoints}. Should be {_Keypoints}"
+ ) from exc
return keypoints
def prepare_keypoints_list(keypoints_list: _KeypointsList) -> KeypointsList:
- '''
+ """
This function prepares the keypoints list input.
Args:
@@ -199,78 +235,100 @@ def prepare_keypoints_list(keypoints_list: _KeypointsList) -> KeypointsList:
Returns:
KeypointsList: a valid KeypointsList instance.
- '''
- if not isinstance(keypoints_list, KeypointsList):
- keypoints_list = KeypointsList([prepare_keypoints(keypoints, i) for i, keypoints in enumerate(keypoints_list)])
+ """
+ if isinstance(keypoints_list, KeypointsList):
+ return keypoints_list
+ keypoints_list = KeypointsList(
+ [
+ prepare_keypoints(keypoints, i)
+ for i, keypoints in enumerate(keypoints_list)
+ ]
+ )
return keypoints_list
-def prepare_polygon(polygon: _Polygon, ind: Union[int] = None) -> Polygon:
+def prepare_polygon(polygon: _Polygon, ind: int | None = None) -> Polygon:
+ if isinstance(polygon, Polygon):
+ return polygon
try:
polygon = Polygon(polygon)
- except:
- i = "" if ind is None else f's[{ind}]'
- raise TypeError(f"The input polygon{i} is invalid value = {polygon}. Should be {_Polygon}")
+ except Exception as exc:
+ i = "" if ind is None else f"s[{ind}]"
+ raise TypeError(
+ f"The input polygon{i} is invalid value = {polygon}. Should be {_Polygon}"
+ ) from exc
return polygon
def prepare_polygons(polygons: _Polygons) -> Polygons:
- if not isinstance(polygons, Polygons):
- polygons = Polygons([prepare_polygon(polygon, i) for i, polygon in enumerate(polygons)])
+ if isinstance(polygons, Polygons):
+ return polygons
+ polygons = Polygons(
+ [prepare_polygon(polygon, i) for i, polygon in enumerate(polygons)]
+ )
return polygons
-def prepare_point(point: _Point, ind: Optional[int] = None) -> tuple:
- cond1 = isinstance(point, (tuple, list)) and len(point) == 2 and \
- isinstance(point[0], _Number) and \
- isinstance(point[1], _Number)
- cond2 = isinstance(point, np.ndarray) and point.ndim == 1 and len(point) == 2
+def prepare_point(point: _Point, ind: int | None = None) -> tuple:
+ cond1 = (
+ isinstance(point, (tuple, list))
+ and len(point) == 2
+ and isinstance(point[0], _Number)
+ and isinstance(point[1], _Number)
+ )
+ cond2 = (
+ isinstance(point, np.ndarray) and point.ndim == 1 and len(point) == 2
+ )
if not (cond1 or cond2):
- i = '' if ind is None else f's[{ind}]'
- raise TypeError(f'The input point{i} is invalid.')
+ i = "" if ind is None else f"s[{ind}]"
+ raise TypeError(f"The input point{i} is invalid.")
return tuple(np.array(point, dtype=int).tolist())
-def prepare_points(points: _Points) -> List[_Point]:
+def prepare_points(points: _Points) -> list[_Point]:
ps = []
for i, point in enumerate(points):
ps.append(prepare_point(point, i))
return ps
-def prepare_thickness(thickness: _Thickness, ind: int = None) -> int:
+def prepare_thickness(thickness: _Thickness, ind: int | None = None) -> int:
if not isinstance(thickness, _Number) or thickness < -1:
- i = '' if ind is None else f's[{ind}]'
- raise ValueError(f'The thickness[{i}] = {thickness} is not correct. \n')
- thickness = np.array(thickness, dtype='int').tolist()
- return thickness
+ i = "" if ind is None else f"s[{ind}]"
+ raise ValueError(f"The thickness[{i}] = {thickness} is not correct. \n")
+ value = np.array(thickness, dtype="int").tolist()
+ return int(value)
-def prepare_thicknesses(thicknesses: _Thicknesses, length: Optional[int] = None) -> List[int]:
- if isinstance(thicknesses, (list, np.ndarray)):
+def prepare_thicknesses(
+ thicknesses: _Thicknesses, length: int | None = None
+) -> list[int]:
+ if isinstance(thicknesses, _Number):
+ thickness = prepare_thickness(thicknesses, 0)
+ repeat = 1 if length is None else length
+ cs = [thickness] * repeat
+ else:
cs = []
for i, thickness in enumerate(thicknesses):
cs.append(prepare_thickness(thickness, i))
- else:
- thickness = prepare_thickness(thicknesses, 0)
- cs = [thickness] * length
return cs
-def prepare_scale(scale: _Scale, ind: int = None) -> float:
+def prepare_scale(scale: _Scale, ind: int | None = None) -> float:
if not isinstance(scale, _Number) or scale < -1:
- i = '' if ind is None else f's[{ind}]'
- raise ValueError(f'The scale[{i}] = {scale} is not correct. \n')
- scale = np.array(scale, dtype=float).tolist()
- return scale
+ i = "" if ind is None else f"s[{ind}]"
+ raise ValueError(f"The scale[{i}] = {scale} is not correct. \n")
+ value = np.array(scale, dtype=float).tolist()
+ return float(value)
-def prepare_scales(scales: _Scales, length: Optional[int] = None) -> List[float]:
- if isinstance(scales, (list, np.ndarray)):
+def prepare_scales(scales: _Scales, length: int | None = None) -> list[float]:
+ if isinstance(scales, _Number):
+ scale = prepare_scale(scales, 0)
+ repeat = 1 if length is None else length
+ cs = [scale] * repeat
+ else:
cs = []
for i, scale in enumerate(scales):
cs.append(prepare_scale(scale, i))
- else:
- scale = prepare_scale(scales, 0)
- cs = [scale] * length
return cs
diff --git a/docker/Dockerfile b/docker/Dockerfile
index 0ed3011..cea6e05 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -21,12 +21,12 @@ RUN ln -s /usr/bin/python3 /usr/bin/python && \
python -m pip install --no-cache-dir -U pip setuptools wheel
COPY . /usr/local/Capybara
-RUN cd /usr/local/Capybara && python setup.py bdist_wheel && \
- python -m pip install dist/*.whl && rm -rf /usr/local/Capybara
+RUN python -m pip install --no-cache-dir /usr/local/Capybara && \
+ rm -rf /usr/local/Capybara
# Preload data
RUN python -c "import capybara"
WORKDIR /code
-CMD ["bash"]
\ No newline at end of file
+CMD ["bash"]
diff --git a/pyproject.toml b/pyproject.toml
index 58ac44e..b196535 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,5 +1,5 @@
[build-system]
-requires = ["setuptools>=61", "wheel"]
+requires = ["setuptools>=68", "wheel"]
build-backend = "setuptools.build_meta"
[project]
@@ -16,16 +16,16 @@ classifiers = [
"Intended Audience :: Science/Research",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
+ "Programming Language :: Python :: 3.13",
+ "Programming Language :: Python :: 3.14",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules"
]
dependencies = [
"dacite",
- "psutil",
"requests",
- "onnx",
- "colored",
"numpy",
"pdf2image",
"ujson",
@@ -34,20 +34,52 @@ dependencies = [
"pybase64",
"PyTurboJPEG",
"dill",
- "networkx",
"natsort",
- "flask",
"shapely",
"piexif",
- "matplotlib",
"opencv-python>=4.12.0.88",
- "onnxslim",
"beautifulsoup4",
- "onnxruntime==1.22.0; platform_system == 'Darwin'",
- "onnxruntime_gpu==1.22.0; platform_system == 'Linux'",
"pillow-heif"
]
+[project.optional-dependencies]
+# Inference backends (opt-in; keep `pip install .` minimal)
+onnxruntime = [
+ "onnxruntime>=1.22.0,<2",
+ "onnx>=1.18.0",
+ "onnxslim>=0.1.0",
+]
+onnxruntime-gpu = [
+ "onnxruntime-gpu>=1.22.0,<2",
+ "onnx>=1.18.0",
+ "onnxslim>=0.1.0",
+]
+openvino = [
+ "openvino>=2024.0.0",
+]
+torchscript = [
+ "torch>=2.0",
+]
+all = [
+ "onnxruntime>=1.22.0,<2",
+ "onnx>=1.18.0",
+ "onnxslim>=0.1.0",
+ "openvino>=2024.0.0",
+ "torch>=2.0",
+]
+
+# Feature extras (non-core runtime)
+ipcam = [
+ "flask>=2.0",
+]
+system = [
+ "psutil",
+]
+visualization = [
+ "matplotlib",
+ "pillow",
+]
+
[project.urls]
Homepage = "https://docsaid.org/en/docs/capybara/"
Repository = "https://github.com/DocsaidLab/Capybara"
@@ -58,7 +90,68 @@ include-package-data = true
[tool.setuptools.packages.find]
include = ["capybara*"]
-exclude = ["demo", "tests", "docs", ".github", "docker", "wheelhouse"]
+exclude = [
+ "demo",
+ "tests",
+ "docs",
+ ".github",
+ "docker",
+ "wheelhouse"
+]
[tool.setuptools.dynamic]
-version = { attr = "capybara.__version__" }
\ No newline at end of file
+version = { attr = "capybara.__version__" }
+
+[tool.pytest.ini_options]
+pythonpath = [
+ ".",
+]
+testpaths = [
+ "tests",
+]
+addopts = "--ignore=tmp"
+
+[tool.ruff]
+target-version = "py310"
+line-length = 80
+extend-exclude = [
+ ".git",
+ ".mypy_cache",
+ ".pytest_cache",
+ "__pycache__",
+ ".venv",
+ "venv",
+ "build",
+ "dist",
+ ".tox",
+ ".eggs",
+ "*.egg-info",
+]
+
+[tool.ruff.lint]
+select = [
+ "E",
+ "F",
+ "W",
+ "B",
+ "UP",
+ "N",
+ "I",
+ "C4",
+ "SIM",
+ "RUF",
+]
+ignore = [
+ "E203",
+ "E501",
+]
+
+[tool.ruff.lint.isort]
+known-first-party = ["capybara", "tests"]
+
+[tool.ruff.format]
+quote-style = "double"
+indent-style = "space"
+skip-magic-trailing-comma = false
+line-ending = "auto"
+docstring-code-format = true
diff --git a/pyrightconfig.json b/pyrightconfig.json
new file mode 100644
index 0000000..cf282ef
--- /dev/null
+++ b/pyrightconfig.json
@@ -0,0 +1,7 @@
+{
+ "include": ["capybara", "tests"],
+ "exclude": ["tmp", "capybara/cpuinfo.py"],
+ "pythonVersion": "3.10",
+ "typeCheckingMode": "basic",
+ "reportMissingTypeStubs": false
+}
diff --git a/setup.py b/setup.py
deleted file mode 100644
index 6b40b52..0000000
--- a/setup.py
+++ /dev/null
@@ -1,4 +0,0 @@
-from setuptools import setup
-
-if __name__ == '__main__':
- setup()
diff --git a/tests/onnxengine/test_engine_stubbed.py b/tests/onnxengine/test_engine_stubbed.py
new file mode 100644
index 0000000..fae9953
--- /dev/null
+++ b/tests/onnxengine/test_engine_stubbed.py
@@ -0,0 +1,339 @@
+from __future__ import annotations
+
+import importlib
+import sys
+import types
+from enum import Enum
+from typing import Any
+
+import numpy as np
+import pytest
+
+
+@pytest.fixture()
+def onnx_engine_module(monkeypatch):
+ class GraphOptimizationLevel(Enum):
+ ORT_DISABLE_ALL = "disable"
+ ORT_ENABLE_BASIC = "basic"
+ ORT_ENABLE_EXTENDED = "extended"
+ ORT_ENABLE_ALL = "all"
+
+ class ExecutionMode(Enum):
+ ORT_SEQUENTIAL = "seq"
+ ORT_PARALLEL = "par"
+
+ modelmeta_holder: dict[str, dict[str, object] | None] = {"map": None}
+
+ class SessionOptions:
+ def __init__(self):
+ self.enable_profiling = False
+ self.graph_optimization_level = None
+ self.execution_mode = None
+ self.intra_op_num_threads = None
+ self.inter_op_num_threads = None
+ self.log_severity_level = None
+ self.config_entries: dict[str, str] = {}
+
+ def add_session_config_entry(self, key: str, value: str):
+ self.config_entries[key] = value
+
+ class RunOptions:
+ def __init__(self):
+ self.entries: dict[str, str] = {}
+
+ def add_run_config_entry(self, key: str, value: str):
+ self.entries[key] = value
+
+ class FakeIOBinding:
+ def __init__(self, session):
+ self.session = session
+ self.output_storage: dict[str, np.ndarray] = {}
+ self.bound_outputs: list[str] = []
+ self.bound_inputs: dict[str, np.ndarray] = {}
+
+ def clear_binding_inputs(self):
+ self.bound_inputs.clear()
+
+ def clear_binding_outputs(self):
+ self.bound_outputs.clear()
+ self.output_storage = {}
+
+ def bind_cpu_input(self, name: str, array: np.ndarray):
+ self.bound_inputs[name] = array
+
+ def bind_output(self, name: str):
+ self.bound_outputs.append(name)
+
+ def copy_outputs_to_cpu(self) -> list[np.ndarray]:
+ return [self.output_storage[name] for name in self.bound_outputs]
+
+ class FakeSession:
+ def __init__(
+ self, model_path, sess_options, providers, provider_options
+ ):
+ self.model_path = model_path
+ self.sess_options = sess_options
+ self._providers = providers
+ self._provider_options = provider_options
+ self._inputs = [
+ types.SimpleNamespace(
+ name="input", type="tensor(float)", shape=[1, "dyn", 4]
+ )
+ ]
+ self._outputs = [
+ types.SimpleNamespace(
+ name="output", type="tensor(float)", shape=[1, 4]
+ )
+ ]
+ self._binding = FakeIOBinding(self)
+ self.last_run_options = None
+
+ def get_providers(self):
+ return self._providers
+
+ def get_provider_options(self):
+ return self._provider_options
+
+ def get_modelmeta(self):
+ meta_map = modelmeta_holder["map"]
+ if meta_map is None:
+ raise RuntimeError("no metadata configured")
+ return types.SimpleNamespace(custom_metadata_map=meta_map)
+
+ def get_outputs(self):
+ return self._outputs
+
+ def get_inputs(self):
+ return self._inputs
+
+ def io_binding(self):
+ return self._binding
+
+ def run(self, output_names, feed, run_options=None):
+ self.last_run_options = run_options
+ return [np.ones((1, 4), dtype=np.float32) for _ in output_names]
+
+ def run_with_iobinding(self, binding, run_options=None):
+ self.last_run_options = run_options
+ for idx, name in enumerate(binding.bound_outputs):
+ binding.output_storage[name] = np.full(
+ (1, 4), float(idx), dtype=np.float32
+ )
+
+ fake_ort: Any = types.ModuleType("onnxruntime")
+ fake_ort.GraphOptimizationLevel = GraphOptimizationLevel
+ fake_ort.ExecutionMode = ExecutionMode
+ fake_ort.SessionOptions = SessionOptions
+ fake_ort.RunOptions = RunOptions
+ fake_ort.InferenceSession = FakeSession
+ fake_ort.get_available_providers = lambda: ["CPUExecutionProvider"]
+
+ monkeypatch.setitem(sys.modules, "onnxruntime", fake_ort)
+ module = importlib.reload(
+ importlib.import_module("capybara.onnxengine.engine")
+ )
+ module_any: Any = module
+ module_any._test_modelmeta_holder = modelmeta_holder
+ yield module
+
+
+def test_onnx_engine_with_stubbed_runtime(onnx_engine_module):
+ engine_config_cls = onnx_engine_module.EngineConfig
+ onnx_engine_cls = onnx_engine_module.ONNXEngine
+
+ config = engine_config_cls(
+ graph_optimization="basic",
+ execution_mode="parallel",
+ intra_op_num_threads=2,
+ inter_op_num_threads=1,
+ log_severity_level=0,
+ session_config_entries={"session.use": "1"},
+ provider_options={
+ "CUDAExecutionProvider": {"arena_extend_strategy": "manual"}
+ },
+ fallback_to_cpu=True,
+ enable_io_binding=True,
+ run_config_entries={"run.tag": "demo"},
+ enable_profiling=True,
+ )
+
+ engine = onnx_engine_cls(
+ model_path="model.onnx",
+ backend="cuda",
+ session_option={"custom": "value"},
+ provider_option={"do_copy_in_default_stream": False},
+ config=config,
+ )
+
+ feed = {"input": np.ones((1, 4), dtype=np.float64)}
+ outputs = engine(**feed)
+ assert set(outputs) == {"output"}
+ assert outputs["output"].dtype == np.float32
+
+ summary = engine.summary()
+ assert summary["model"] == "model.onnx"
+ assert summary["inputs"][0]["shape"][1] is None
+
+ stats = engine.benchmark(feed, repeat=3, warmup=1)
+ assert stats["repeat"] == 3
+ assert "mean" in stats["latency_ms"]
+
+
+def test_onnx_engine_benchmark_validates_repeat_and_warmup(onnx_engine_module):
+ onnx_engine_cls = onnx_engine_module.ONNXEngine
+ engine = onnx_engine_cls(model_path="model.onnx", backend="cpu")
+ feed = {"input": np.ones((1, 4), dtype=np.float32)}
+
+ with pytest.raises(ValueError, match="repeat must be >= 1"):
+ engine.benchmark(feed, repeat=0)
+
+ with pytest.raises(ValueError, match="warmup must be >= 0"):
+ engine.benchmark(feed, warmup=-1)
+
+
+def test_onnx_engine_accepts_wrapped_feed_dict(onnx_engine_module):
+ """__call__ supports passing a single mapping payload as a kwarg."""
+ onnx_engine_cls = onnx_engine_module.ONNXEngine
+ engine = onnx_engine_cls(model_path="model.onnx", backend="cpu")
+
+ feed = {"input": np.ones((1, 4), dtype=np.float32)}
+ outputs = engine(payload=feed)
+ assert outputs["output"].shape == (1, 4)
+
+
+def test_onnx_engine_run_method_uses_mapping(onnx_engine_module):
+ """run(feed) is a stable public API for inference."""
+ onnx_engine_cls = onnx_engine_module.ONNXEngine
+ engine = onnx_engine_cls(model_path="model.onnx", backend="cpu")
+
+ outputs = engine.run({"input": np.ones((1, 4), dtype=np.float32)})
+ assert outputs["output"].dtype == np.float32
+
+
+def test_onnx_engine_iobinding_without_run_options(onnx_engine_module):
+ """When run_config_entries is empty, IO binding runs without RunOptions."""
+ onnx_engine_cls = onnx_engine_module.ONNXEngine
+ engine_config_cls = onnx_engine_module.EngineConfig
+
+ config = engine_config_cls(enable_io_binding=True, run_config_entries=None)
+ engine = onnx_engine_cls(
+ model_path="model.onnx", backend="cpu", config=config
+ )
+
+ out = engine.run({"input": np.ones((1, 4), dtype=np.float32)})
+ assert engine._session.last_run_options is None
+ assert np.all(out["output"] == 0.0)
+
+
+def test_onnx_engine_session_option_overrides_attribute(onnx_engine_module):
+ """Known SessionOptions attributes are set directly instead of config entries."""
+ onnx_engine_cls = onnx_engine_module.ONNXEngine
+
+ engine = onnx_engine_cls(
+ model_path="model.onnx",
+ backend="cpu",
+ session_option={"enable_profiling": True},
+ )
+
+ assert engine._session.sess_options.enable_profiling is True
+
+
+def test_onnx_engine_accepts_enum_config_values(onnx_engine_module):
+ """Enum values pass through without string normalization."""
+ onnx_engine_cls = onnx_engine_module.ONNXEngine
+ engine_config_cls = onnx_engine_module.EngineConfig
+
+ config = engine_config_cls(
+ graph_optimization=onnx_engine_module.ort.GraphOptimizationLevel.ORT_ENABLE_BASIC,
+ execution_mode=onnx_engine_module.ort.ExecutionMode.ORT_PARALLEL,
+ )
+
+ engine = onnx_engine_cls(
+ model_path="model.onnx", backend="cpu", config=config
+ )
+ assert (
+ engine._session.sess_options.graph_optimization_level
+ == onnx_engine_module.ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
+ )
+ assert (
+ engine._session.sess_options.execution_mode
+ == onnx_engine_module.ort.ExecutionMode.ORT_PARALLEL
+ )
+
+
+def test_onnx_engine_extract_metadata_parses_json(onnx_engine_module):
+ """Custom metadata JSON payloads are parsed for easier downstream use."""
+ onnx_engine_cls = onnx_engine_module.ONNXEngine
+ onnx_engine_module._test_modelmeta_holder["map"] = {
+ "author": '{"name": "bob"}',
+ "note": "plain",
+ "count": 3,
+ }
+
+ engine = onnx_engine_cls(model_path="model.onnx", backend="cpu")
+
+ assert engine.metadata == {
+ "author": {"name": "bob"},
+ "note": "plain",
+ "count": 3,
+ }
+
+
+def test_onnx_engine_run_uses_run_options_without_iobinding(onnx_engine_module):
+ """run_config_entries should build RunOptions even when IO binding is disabled."""
+ onnx_engine_cls = onnx_engine_module.ONNXEngine
+ engine_config_cls = onnx_engine_module.EngineConfig
+
+ config = engine_config_cls(
+ enable_io_binding=False, run_config_entries={"k": "v"}
+ )
+ engine = onnx_engine_cls(
+ model_path="model.onnx",
+ backend="cpu",
+ config=config,
+ )
+
+ out = engine.run({"input": np.ones((1, 4), dtype=np.float32)})
+ assert out["output"].shape == (1, 4)
+ assert engine._session.last_run_options is not None
+
+
+def test_onnx_engine_converts_outputs_via_toarray(onnx_engine_module):
+ """Some onnxruntime outputs expose a toarray() helper instead of ndarray."""
+ onnx_engine_cls = onnx_engine_module.ONNXEngine
+ engine_config_cls = onnx_engine_module.EngineConfig
+
+ config = engine_config_cls(enable_io_binding=False)
+ engine = onnx_engine_cls(
+ model_path="model.onnx", backend="cpu", config=config
+ )
+
+ class _FakeOrtValue:
+ def __init__(self, value: float) -> None:
+ self.value = float(value)
+
+ def toarray(self) -> np.ndarray:
+ return np.full((1, 4), self.value, dtype=np.float32)
+
+ engine._session.run = lambda *_args, **_kwargs: [_FakeOrtValue(7.0)] # type: ignore[method-assign]
+
+ out = engine.run({"input": np.ones((1, 4), dtype=np.float32)})
+ assert np.all(out["output"] == 7.0)
+
+
+def test_onnx_engine_missing_required_input_raises_keyerror(onnx_engine_module):
+ onnx_engine_cls = onnx_engine_module.ONNXEngine
+ engine = onnx_engine_cls(model_path="model.onnx", backend="cpu")
+
+ with pytest.raises(KeyError, match="Missing required input"):
+ engine.run({})
+
+
+def test_onnx_engine_extract_metadata_returns_none_for_non_mapping(
+ onnx_engine_module,
+):
+ onnx_engine_cls = onnx_engine_module.ONNXEngine
+ onnx_engine_module._test_modelmeta_holder["map"] = 123
+
+ engine = onnx_engine_cls(model_path="model.onnx", backend="cpu")
+ assert engine.metadata is None
diff --git a/tests/onnxengine/test_utils_and_metadata.py b/tests/onnxengine/test_utils_and_metadata.py
new file mode 100644
index 0000000..f684d1a
--- /dev/null
+++ b/tests/onnxengine/test_utils_and_metadata.py
@@ -0,0 +1,246 @@
+from __future__ import annotations
+
+import types
+
+import onnx
+import pytest
+from onnx import TensorProto, helper
+
+from capybara.onnxengine import metadata, utils
+
+
+@pytest.fixture()
+def simple_onnx(tmp_path):
+ input_info = helper.make_tensor_value_info(
+ "input", TensorProto.FLOAT, [1, 3]
+ )
+ output_info = helper.make_tensor_value_info(
+ "output", TensorProto.FLOAT, [1, 3]
+ )
+ node = helper.make_node("Identity", ["input"], ["output"])
+ graph = helper.make_graph([node], "test-graph", [input_info], [output_info])
+ model = helper.make_model(graph)
+ path = tmp_path / "simple.onnx"
+ onnx.save(model, path)
+ return path
+
+
+def test_get_input_and_output_infos(simple_onnx):
+ inputs = utils.get_onnx_input_infos(simple_onnx)
+ outputs = utils.get_onnx_output_infos(simple_onnx)
+ assert inputs["input"]["shape"] == [1, 3]
+ assert outputs["output"]["shape"] == [1, 3]
+ assert str(inputs["input"]["dtype"]) == "float32"
+
+
+def test_get_input_and_output_infos_accept_model_proto(simple_onnx):
+ model = onnx.load(simple_onnx)
+ inputs = utils.get_onnx_input_infos(model)
+ outputs = utils.get_onnx_output_infos(model)
+ assert inputs["input"]["shape"] == [1, 3]
+ assert outputs["output"]["shape"] == [1, 3]
+
+
+def test_make_onnx_dynamic_axes_overrides_dims(
+ simple_onnx, monkeypatch, tmp_path
+):
+ monkeypatch.setattr(
+ utils,
+ "onnxslim",
+ types.SimpleNamespace(simplify=lambda model: (model, True)),
+ )
+ out_path = tmp_path / "dynamic.onnx"
+ utils.make_onnx_dynamic_axes(
+ model_fpath=simple_onnx,
+ output_fpath=out_path,
+ input_dims={"input": {0: "batch"}},
+ output_dims={"output": {0: "batch"}},
+ opset_version=18,
+ )
+ model = onnx.load(out_path)
+ assert (
+ model.graph.input[0].type.tensor_type.shape.dim[0].dim_param == "batch"
+ )
+ assert (
+ model.graph.output[0].type.tensor_type.shape.dim[0].dim_param == "batch"
+ )
+
+
+def test_make_onnx_dynamic_axes_adds_default_opset_when_missing(
+ monkeypatch, tmp_path
+):
+ input_info = helper.make_tensor_value_info(
+ "input", TensorProto.FLOAT, [1, 3]
+ )
+ output_info = helper.make_tensor_value_info(
+ "output", TensorProto.FLOAT, [1, 3]
+ )
+ node = helper.make_node("Identity", ["input"], ["output"])
+ graph = helper.make_graph(
+ [node], "no-default-opset", [input_info], [output_info]
+ )
+ model = helper.make_model(
+ graph,
+ opset_imports=[helper.make_opsetid(domain="ai.onnx.ml", version=1)],
+ )
+
+ src = tmp_path / "src.onnx"
+ out_path = tmp_path / "dynamic.onnx"
+ onnx.save(model, src)
+
+ monkeypatch.setattr(
+ utils,
+ "onnxslim",
+ types.SimpleNamespace(simplify=lambda model: (model, True)),
+ )
+ utils.make_onnx_dynamic_axes(
+ model_fpath=src,
+ output_fpath=out_path,
+ input_dims={"input": {0: "batch"}},
+ output_dims={"output": {0: "batch"}},
+ opset_version=18,
+ )
+ updated = onnx.load(out_path)
+ assert any(
+ opset.domain == "" and opset.version == 18
+ for opset in updated.opset_import
+ )
+
+
+def test_make_onnx_dynamic_axes_uses_current_opset_when_version_is_none(
+ monkeypatch, tmp_path
+):
+ input_info = helper.make_tensor_value_info(
+ "input", TensorProto.FLOAT, [1, 3]
+ )
+ output_info = helper.make_tensor_value_info(
+ "output", TensorProto.FLOAT, [1, 3]
+ )
+ node = helper.make_node("Identity", ["input"], ["output"])
+ graph = helper.make_graph(
+ [node], "no-default-opset", [input_info], [output_info]
+ )
+ model = helper.make_model(
+ graph,
+ opset_imports=[helper.make_opsetid(domain="ai.onnx.ml", version=1)],
+ )
+
+ src = tmp_path / "src.onnx"
+ out_path = tmp_path / "dynamic.onnx"
+ onnx.save(model, src)
+
+ monkeypatch.setattr(onnx.defs, "onnx_opset_version", lambda: 17)
+ monkeypatch.setattr(
+ utils,
+ "onnxslim",
+ types.SimpleNamespace(simplify=lambda model: (model, True)),
+ )
+ utils.make_onnx_dynamic_axes(
+ model_fpath=src,
+ output_fpath=out_path,
+ input_dims={"input": {0: "batch"}},
+ output_dims={"output": {0: "batch"}},
+ opset_version=None,
+ )
+ updated = onnx.load(out_path)
+ assert any(
+ opset.domain == "" and opset.version == 17
+ for opset in updated.opset_import
+ )
+
+
+def test_make_onnx_dynamic_axes_accepts_simplify_returning_model(
+ simple_onnx, monkeypatch, tmp_path
+):
+ monkeypatch.setattr(
+ utils,
+ "onnxslim",
+ types.SimpleNamespace(simplify=lambda model: model),
+ )
+ out_path = tmp_path / "dynamic.onnx"
+ utils.make_onnx_dynamic_axes(
+ model_fpath=simple_onnx,
+ output_fpath=out_path,
+ input_dims={"input": {0: "batch"}},
+ output_dims={"output": {0: "batch"}},
+ opset_version=18,
+ )
+ assert (tmp_path / "dynamic.onnx").exists()
+
+
+def test_make_onnx_dynamic_axes_rejects_reshape_nodes(tmp_path):
+ input_info = helper.make_tensor_value_info(
+ "input", TensorProto.FLOAT, [1, 3]
+ )
+ shape_init = helper.make_tensor(
+ "shape", TensorProto.INT64, dims=[2], vals=[1, 3]
+ )
+ output_info = helper.make_tensor_value_info(
+ "output", TensorProto.FLOAT, [1, 3]
+ )
+ node = helper.make_node("Reshape", ["input", "shape"], ["output"])
+ graph = helper.make_graph(
+ [node],
+ "reshape-graph",
+ [input_info],
+ [output_info],
+ initializer=[shape_init],
+ )
+ model = helper.make_model(graph)
+ src = tmp_path / "reshape.onnx"
+ onnx.save(model, src)
+
+ with pytest.raises(ValueError, match="Reshape cannot be trasformed"):
+ utils.make_onnx_dynamic_axes(
+ model_fpath=src,
+ output_fpath=tmp_path / "out.onnx",
+ input_dims={"input": {0: "batch"}},
+ output_dims={"output": {0: "batch"}},
+ opset_version=18,
+ )
+
+
+@pytest.fixture(autouse=True)
+def fake_ort(monkeypatch):
+ class FakeSession:
+ def __init__(self, path, providers=None):
+ self._path = path
+
+ def get_modelmeta(self):
+ model = onnx.load(self._path)
+ mapping = {p.key: p.value for p in model.metadata_props}
+ return types.SimpleNamespace(custom_metadata_map=mapping)
+
+ monkeypatch.setattr(
+ metadata, "ort", types.SimpleNamespace(InferenceSession=FakeSession)
+ )
+
+
+def test_metadata_roundtrip(simple_onnx, tmp_path):
+ out_path = tmp_path / "meta.onnx"
+ metadata.write_metadata_into_onnx(
+ simple_onnx, out_path, author={"name": "angizero"}
+ )
+ parsed = metadata.parse_metadata_from_onnx(out_path)
+ assert parsed["author"]["name"] == "angizero"
+
+ raw = metadata.get_onnx_metadata(out_path)
+ assert "author" in raw
+
+
+def test_parse_metadata_preserves_non_string_values(monkeypatch, simple_onnx):
+ class FakeSession:
+ def __init__(self, path, providers=None):
+ self._path = path
+
+ def get_modelmeta(self):
+ return types.SimpleNamespace(
+ custom_metadata_map={"raw": 123, "json": "1"}
+ )
+
+ monkeypatch.setattr(
+ metadata, "ort", types.SimpleNamespace(InferenceSession=FakeSession)
+ )
+ parsed = metadata.parse_metadata_from_onnx(simple_onnx)
+ assert parsed["raw"] == 123
+ assert parsed["json"] == 1
diff --git a/tests/onnxruntime/test_engine.py b/tests/onnxruntime/test_engine.py
deleted file mode 100644
index 67e114a..0000000
--- a/tests/onnxruntime/test_engine.py
+++ /dev/null
@@ -1,39 +0,0 @@
-import numpy as np
-import pytest
-
-from capybara import Backend, ONNXEngine, get_curdir, get_recommended_backend
-
-
-def test_ONNXEngine_CPU():
- model_path = get_curdir(__file__).parent / "resources/model_dynamic-axes.onnx"
- engine = ONNXEngine(model_path, backend="cpu")
- for i in range(5):
- xs = {"input": np.random.randn(32, 3, 224, 224).astype("float32")}
- outs = engine(**xs)
- if i:
- assert not np.allclose(outs["output"], prev_outs["output"])
- prev_outs = outs
-
-
-@pytest.mark.skipif(get_recommended_backend() != Backend.cuda, reason="Linux with GPU only")
-def test_ONNXEngine_CUDA():
- model_path = get_curdir(__file__).parent / "resources/model_dynamic-axes.onnx"
- engine = ONNXEngine(model_path, backend=get_recommended_backend())
- for i in range(5):
- xs = {"input": np.random.randn(32, 3, 224, 224).astype("float32")}
- outs = engine(**xs)
- if i:
- assert not np.allclose(outs["output"], prev_outs["output"])
- prev_outs = outs
-
-
-@pytest.mark.skipif(get_recommended_backend() != "Darwin", reason="Mac only")
-def test_ONNXEngine_COREML():
- model_path = get_curdir(__file__).parent / "resources/model_dynamic-axes.onnx"
- engine = ONNXEngine(model_path, backend="coreml")
- for i in range(5):
- xs = {"input": np.random.randn(32, 3, 224, 224).astype("float32")}
- outs = engine(**xs)
- if i:
- assert not np.allclose(outs["output"], prev_outs["output"])
- prev_outs = outs
diff --git a/tests/onnxruntime/test_engine_io_binding.py b/tests/onnxruntime/test_engine_io_binding.py
deleted file mode 100644
index 796d8a3..0000000
--- a/tests/onnxruntime/test_engine_io_binding.py
+++ /dev/null
@@ -1,17 +0,0 @@
-import numpy as np
-import pytest
-
-from capybara import Backend, ONNXEngineIOBinding, get_curdir, get_recommended_backend
-
-
-@pytest.mark.skipif(get_recommended_backend() != Backend.cuda, reason="Linux with GPU only")
-def test_ONNXEngineIOBinding_CUDAonly():
- model_path = get_curdir(__file__).parent / "resources/model_dynamic-axes.onnx"
- input_initializer = {"input": np.random.randn(32, 3, 448, 448).astype("float32")}
- engine = ONNXEngineIOBinding(model_path, input_initializer)
- for i in range(5):
- xs = {"input": np.random.randn(32, 3, 448, 448).astype("float32")}
- outs = engine(**xs)
- if i:
- assert not np.allclose(outs["output"], prev_outs["output"])
- prev_outs = outs
diff --git a/tests/onnxruntime/test_tools.py b/tests/onnxruntime/test_tools.py
deleted file mode 100644
index 23f6016..0000000
--- a/tests/onnxruntime/test_tools.py
+++ /dev/null
@@ -1,39 +0,0 @@
-import numpy as np
-
-from capybara import (
- ONNXEngine,
- get_onnx_input_infos,
- get_onnx_output_infos,
- make_onnx_dynamic_axes,
-)
-
-
-def test_get_onnx_input_infos():
- model_path = "tests/resources/model_shape=224x224.onnx"
- input_infos = get_onnx_input_infos(model_path)
- assert input_infos == {"input": {"shape": [1, 3, 224, 224], "dtype": "float32"}}
-
-
-def test_get_onnx_output_infos():
- model_path = "tests/resources/model_shape=224x224.onnx"
- output_infos = get_onnx_output_infos(model_path)
- assert output_infos == {"output": {"shape": [1, 64, 56, 56], "dtype": "float32"}}
-
-
-def test_make_onnx_dynamic_axes():
- model_path = "tests/resources/model_shape=224x224.onnx"
- input_infos = get_onnx_input_infos(model_path)
- output_infos = get_onnx_output_infos(model_path)
- input_dims = {k: {0: "b", 2: "h", 3: "w"} for k in input_infos.keys()}
- output_dims = {k: {0: "b", 2: "h", 3: "w"} for k in output_infos.keys()}
- new_model_path = "/tmp/model_dynamic-axes.onnx"
- make_onnx_dynamic_axes(
- model_path,
- new_model_path,
- input_dims=input_dims,
- output_dims=output_dims,
- )
- xs = {"input": np.random.randn(32, 3, 320, 320).astype("float32")}
- engine = ONNXEngine(new_model_path, session_option={"log_severity_level": 1}, backend="cpu")
- outs = engine(**xs)
- assert outs["output"].shape == (32, 64, 80, 80)
diff --git a/tests/openvinoengine/test_openvino_engine_stubbed.py b/tests/openvinoengine/test_openvino_engine_stubbed.py
new file mode 100644
index 0000000..2ed1d5b
--- /dev/null
+++ b/tests/openvinoengine/test_openvino_engine_stubbed.py
@@ -0,0 +1,601 @@
+from __future__ import annotations
+
+import importlib
+import queue
+import sys
+import threading
+import types
+import warnings
+from typing import Any
+
+import numpy as np
+import pytest
+
+
+@pytest.fixture()
+def openvino_engine_module(monkeypatch):
+ class FakeType:
+ f32 = object()
+ i64 = object()
+ boolean = object()
+
+ class FakeDim:
+ def __init__(self, value):
+ self.value = value
+ self.is_static = value is not None
+
+ def get_length(self):
+ if self.value is None:
+ raise ValueError("dynamic dim")
+ return self.value
+
+ class FakePort:
+ def __init__(self, name, element_type, shape):
+ self._name = name
+ self._type = element_type
+ self._shape = shape
+
+ def get_any_name(self):
+ return self._name
+
+ def get_element_type(self):
+ return self._type
+
+ def get_partial_shape(self):
+ dims = []
+ for value in self._shape:
+ if isinstance(value, int):
+ dims.append(FakeDim(value))
+ else:
+ dims.append(FakeDim(None))
+ return dims
+
+ class FakeTensor:
+ def __init__(self, value):
+ self.data = np.full((1, 2), value, dtype=np.float32)
+
+ class FakeInferRequest:
+ def __init__(self, outputs):
+ self._outputs = outputs
+
+ def infer(self, feed):
+ self._last_feed = feed
+
+ def get_tensor(self, port):
+ idx = self._outputs.index(port)
+ return FakeTensor(idx + 1)
+
+ class FakeCompiledModel:
+ def __init__(self, input_shapes=None):
+ shape = (1, None, 3)
+ if input_shapes:
+ shape = input_shapes.get("input", shape)
+ self.inputs = [
+ FakePort("input", FakeType.f32, shape),
+ ]
+ self.outputs = [
+ FakePort("output", FakeType.f32, (1, 2)),
+ ]
+
+ def create_infer_request(self):
+ return FakeInferRequest(self.outputs)
+
+ class FakeAsyncInferQueue:
+ def __init__(self, compiled_model, jobs):
+ self._compiled_model = compiled_model
+ self._jobs = int(jobs)
+ self._callback = None
+
+ def set_callback(self, callback):
+ self._callback = callback
+
+ def start_async(self, feed, userdata):
+ req = self._compiled_model.create_infer_request()
+ req.infer(feed)
+ if self._callback is not None:
+ self._callback(req, userdata)
+
+ def wait_all(self):
+ return None
+
+ class FakeCore:
+ def __init__(self):
+ self._properties = {}
+ self._last_reshape = None
+
+ def set_property(self, props):
+ self._properties.update(props)
+
+ def read_model(self, path):
+ class FakeModel:
+ def __init__(self, p):
+ self.path = p
+ self.reshape_map = {}
+
+ def reshape(self, mapping):
+ self.reshape_map = mapping
+
+ return FakeModel(path)
+
+ def compile_model(self, model, device, properties):
+ self._properties["device"] = device
+ self._properties.update(properties)
+ self._last_reshape = getattr(model, "reshape_map", {})
+ return FakeCompiledModel(self._last_reshape)
+
+ fake_runtime: Any = types.ModuleType("openvino.runtime")
+ fake_runtime.Type = FakeType
+ fake_runtime.Core = FakeCore
+ fake_runtime.AsyncInferQueue = FakeAsyncInferQueue
+ fake_pkg: Any = types.ModuleType("openvino")
+ fake_pkg.runtime = fake_runtime
+
+ monkeypatch.setitem(sys.modules, "openvino", fake_pkg)
+ monkeypatch.setitem(sys.modules, "openvino.runtime", fake_runtime)
+ module = importlib.reload(
+ importlib.import_module("capybara.openvinoengine.engine")
+ )
+ yield module
+
+
+def test_openvino_engine_runs_with_stub(openvino_engine_module, tmp_path):
+ config_cls = openvino_engine_module.OpenVINOConfig
+ engine_cls = openvino_engine_module.OpenVINOEngine
+ device_enum = openvino_engine_module.OpenVINODevice
+
+ cfg = config_cls(
+ compile_properties={"PERF_HINT": "THROUGHPUT"},
+ core_properties={"LOG_LEVEL": "ERROR"},
+ cache_dir=tmp_path / "cache",
+ num_streams=2,
+ num_threads=4,
+ )
+
+ engine = engine_cls(
+ model_path="model.xml",
+ device=device_enum.cpu,
+ config=cfg,
+ )
+
+ feed = {"input": np.ones((1, 3), dtype=np.float32)}
+ outputs = engine(**feed)
+ assert outputs["output"].shape == (1, 2)
+
+ summary = engine.summary()
+ assert summary["device"] == "CPU"
+ assert summary["inputs"][0]["shape"][1] is None
+
+
+def test_openvino_engine_accepts_input_shapes(openvino_engine_module):
+ engine_cls = openvino_engine_module.OpenVINOEngine
+ device_enum = openvino_engine_module.OpenVINODevice
+
+ engine = engine_cls(
+ model_path="model.xml",
+ device=device_enum.npu,
+ input_shapes={"input": (2, 3, 5)},
+ )
+
+ summary = engine.summary()
+ assert summary["inputs"][0]["shape"] == [2, 3, 5]
+
+
+def test_openvino_engine_async_queue(openvino_engine_module):
+ engine_cls = openvino_engine_module.OpenVINOEngine
+
+ engine = engine_cls(model_path="model.xml", device="CPU")
+ with engine.create_async_queue(num_requests=2) as q:
+ fut = q.submit({"input": np.ones((1, 3), dtype=np.float32)})
+ outputs = fut.result(timeout=1)
+ assert outputs["output"].shape == (1, 2)
+
+
+def test_openvino_engine_async_queue_auto_requests(openvino_engine_module):
+ engine_cls = openvino_engine_module.OpenVINOEngine
+
+ engine = engine_cls(model_path="model.xml", device="CPU")
+ with engine.create_async_queue(num_requests=0) as q:
+ fut = q.submit({"input": np.ones((1, 3), dtype=np.float32)})
+ outputs = fut.result(timeout=1)
+ assert outputs["output"].shape == (1, 2)
+
+
+def test_openvino_engine_benchmark_async(openvino_engine_module):
+ engine_cls = openvino_engine_module.OpenVINOEngine
+
+ engine = engine_cls(model_path="model.xml", device="CPU")
+ stats = engine.benchmark_async(
+ {"input": np.ones((1, 3), dtype=np.float32)},
+ repeat=10,
+ warmup=1,
+ num_requests=2,
+ )
+ assert stats["num_requests"] == 2
+
+
+def test_openvino_engine_async_queue_auto_request_id(openvino_engine_module):
+ engine_cls = openvino_engine_module.OpenVINOEngine
+
+ engine = engine_cls(model_path="model.xml", device="CPU")
+ completion = queue.Queue()
+ with engine.create_async_queue(num_requests=2) as q:
+ fut = q.submit(
+ {"input": np.ones((1, 3), dtype=np.float32)},
+ completion_queue=completion,
+ )
+ outputs = fut.result(timeout=1)
+ req_id, event_outputs = completion.get(timeout=1)
+
+ assert getattr(fut, "request_id", None) is not None
+ assert req_id == fut.request_id
+ assert event_outputs is not outputs
+
+
+def test_openvino_engine_async_queue_preserves_request_id(
+ openvino_engine_module,
+):
+ engine_cls = openvino_engine_module.OpenVINOEngine
+
+ engine = engine_cls(model_path="model.xml", device="CPU")
+ completion = queue.Queue()
+ with engine.create_async_queue(num_requests=2) as q:
+ fut = q.submit(
+ {"input": np.ones((1, 3), dtype=np.float32)},
+ request_id="req-123",
+ completion_queue=completion,
+ )
+ outputs = fut.result(timeout=1)
+ req_id, event_outputs = completion.get(timeout=1)
+
+ assert fut.request_id == "req-123"
+ assert req_id == "req-123"
+ assert event_outputs is not outputs
+
+
+def test_openvino_engine_async_queue_completion_queue(openvino_engine_module):
+ engine_cls = openvino_engine_module.OpenVINOEngine
+
+ engine = engine_cls(model_path="model.xml", device="CPU")
+ completion = queue.Queue()
+ with engine.create_async_queue(num_requests=2) as q:
+ fut = q.submit(
+ {"input": np.ones((1, 3), dtype=np.float32)},
+ request_id="req-1",
+ completion_queue=completion,
+ )
+ outputs = fut.result(timeout=1)
+ req_id, event_outputs = completion.get(timeout=1)
+
+ assert req_id == "req-1"
+ assert event_outputs is not outputs
+ outputs["mutated"] = np.zeros((1,), dtype=np.float32)
+ assert "mutated" not in event_outputs
+ assert outputs["output"].shape == (1, 2)
+ assert event_outputs["output"].shape == (1, 2)
+
+
+def test_openvino_engine_async_queue_completion_queue_full_does_not_block(
+ openvino_engine_module,
+):
+ engine_cls = openvino_engine_module.OpenVINOEngine
+
+ engine = engine_cls(model_path="model.xml", device="CPU")
+ completion = queue.Queue(maxsize=1)
+ completion.put(("sentinel", {}))
+
+ result: dict[str, object] = {}
+
+ def submit_request():
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", RuntimeWarning)
+ with engine.create_async_queue(num_requests=2) as q:
+ fut = q.submit(
+ {"input": np.ones((1, 3), dtype=np.float32)},
+ request_id="req-1",
+ completion_queue=completion,
+ )
+ result["outputs"] = fut.result(timeout=1)
+
+ thread = threading.Thread(target=submit_request, daemon=True)
+ thread.start()
+ thread.join(timeout=0.5)
+
+ if thread.is_alive():
+ completion.get_nowait()
+ thread.join(timeout=0.5)
+ pytest.fail("q.submit() blocked when completion_queue was full")
+
+ assert completion.qsize() == 1
+ assert completion.get_nowait()[0] == "sentinel"
+ assert isinstance(result.get("outputs"), dict)
+
+
+def test_openvino_engine_async_queue_completion_queue_full_emits_warning(
+ openvino_engine_module,
+):
+ engine_cls = openvino_engine_module.OpenVINOEngine
+
+ engine = engine_cls(model_path="model.xml", device="CPU")
+ completion = queue.Queue(maxsize=1)
+ completion.put(("sentinel", {}))
+
+ with engine.create_async_queue(num_requests=2) as q:
+ with pytest.warns(RuntimeWarning, match="completion_queue is full"):
+ fut = q.submit(
+ {"input": np.ones((1, 3), dtype=np.float32)},
+ request_id="req-1",
+ completion_queue=completion,
+ )
+ outputs = fut.result(timeout=1)
+
+ assert outputs["output"].shape == (1, 2)
+ assert completion.qsize() == 1
+ assert completion.get_nowait()[0] == "sentinel"
+
+
+def test_openvino_device_from_any_accepts_strings_and_rejects_unknown(
+ openvino_engine_module,
+):
+ device_enum = openvino_engine_module.OpenVINODevice
+
+ assert device_enum.from_any(device_enum.cpu) is device_enum.cpu
+ assert device_enum.from_any("cpu") is device_enum.cpu
+
+ with pytest.raises(ValueError, match="Unsupported OpenVINO device"):
+ device_enum.from_any("tpu")
+
+
+def test_openvino_engine_accepts_wrapped_feed_dict(openvino_engine_module):
+ """__call__ supports passing a single mapping payload as a kwarg."""
+ engine_cls = openvino_engine_module.OpenVINOEngine
+ engine = engine_cls(model_path="model.xml", device="CPU")
+
+ feed = {"input": np.ones((1, 3), dtype=np.float32)}
+ outputs = engine(payload=feed)
+ assert outputs["output"].shape == (1, 2)
+
+
+def test_openvino_engine_run_replaces_failed_request(openvino_engine_module):
+ """Infer failures should not leave a broken request in the pool."""
+ engine_cls = openvino_engine_module.OpenVINOEngine
+
+ engine = engine_cls(model_path="model.xml", device="CPU")
+ feed = {"input": np.ones((1, 3), dtype=np.float32)}
+
+ failing_request = engine._request_pool.get_nowait()
+
+ def boom(_prepared):
+ raise RuntimeError("infer boom")
+
+ failing_request.infer = boom # type: ignore[assignment]
+ engine._request_pool.put(failing_request)
+
+ with pytest.raises(RuntimeError, match="infer boom"):
+ engine.run(feed)
+
+ replacement = engine._request_pool.get_nowait()
+ assert replacement is not failing_request
+ engine._request_pool.put(replacement)
+
+ outputs = engine.run(feed)
+ assert outputs["output"].shape == (1, 2)
+
+
+def test_openvino_engine_benchmark_sync(openvino_engine_module):
+ engine_cls = openvino_engine_module.OpenVINOEngine
+
+ engine = engine_cls(model_path="model.xml", device="CPU")
+ stats = engine.benchmark(
+ {"input": np.ones((1, 3), dtype=np.float32)},
+ repeat=3,
+ warmup=1,
+ )
+
+ assert stats["repeat"] == 3
+ assert "latency_ms" in stats
+
+
+def test_openvino_engine_benchmark_validates_repeat_and_warmup(
+ openvino_engine_module,
+):
+ engine_cls = openvino_engine_module.OpenVINOEngine
+ engine = engine_cls(model_path="model.xml", device="CPU")
+ feed = {"input": np.ones((1, 3), dtype=np.float32)}
+
+ with pytest.raises(ValueError, match="repeat must be >= 1"):
+ engine.benchmark(feed, repeat=0)
+
+ with pytest.raises(ValueError, match="warmup must be >= 0"):
+ engine.benchmark(feed, warmup=-1)
+
+ with pytest.raises(ValueError, match="repeat must be >= 1"):
+ engine.benchmark_async(feed, repeat=0)
+
+ with pytest.raises(ValueError, match="warmup must be >= 0"):
+ engine.benchmark_async(feed, warmup=-1)
+
+
+def test_openvino_engine_num_requests_validation(openvino_engine_module):
+ engine_cls = openvino_engine_module.OpenVINOEngine
+ config_cls = openvino_engine_module.OpenVINOConfig
+
+ engine = engine_cls(
+ model_path="model.xml",
+ device="CPU",
+ config=config_cls(num_requests=0),
+ )
+ assert engine._request_pool.maxsize == 1
+
+ with pytest.raises(ValueError, match="num_requests must be >= 0"):
+ engine_cls(
+ model_path="model.xml",
+ device="CPU",
+ config=config_cls(num_requests=-1),
+ )
+
+ with pytest.raises(ValueError, match="num_requests must be >= 0"):
+ engine.create_async_queue(num_requests=-1)
+
+
+def test_openvino_engine_input_shapes_reject_none_dims(openvino_engine_module):
+ engine_cls = openvino_engine_module.OpenVINOEngine
+
+ with pytest.raises(ValueError, match="must use concrete dimensions"):
+ engine_cls(
+ model_path="model.xml",
+ device="CPU",
+ input_shapes={"input": (1, None, 3)},
+ )
+
+
+def test_openvino_engine_requires_model_reshape_when_input_shapes(
+ openvino_engine_module,
+):
+ engine_cls = openvino_engine_module.OpenVINOEngine
+
+ class NoReshapeCore:
+ def read_model(self, path):
+ return object()
+
+ def compile_model(self, model, device, properties): # pragma: no cover
+ raise AssertionError(
+ "compile_model should not run when reshape fails"
+ )
+
+ with pytest.raises(RuntimeError, match="does not support reshape"):
+ engine_cls(
+ model_path="model.xml",
+ device="CPU",
+ core=NoReshapeCore(),
+ input_shapes={"input": (1, 3, 5)},
+ )
+
+
+def test_openvino_engine_prepare_feed_validates_and_casts(
+ openvino_engine_module,
+):
+ engine_cls = openvino_engine_module.OpenVINOEngine
+ engine = engine_cls(model_path="model.xml", device="CPU")
+
+ with pytest.raises(KeyError, match="Missing required input"):
+ engine.run({"missing": np.ones((1, 3), dtype=np.float32)})
+
+ feed = {"input": np.ones((1, 3), dtype=np.float64)}
+ engine.run(feed)
+ req = engine._request_pool.get_nowait()
+ assert req._last_feed["input"].dtype == np.float32 # type: ignore[attr-defined]
+ engine._request_pool.put(req)
+
+
+def test_openvino_partial_shape_to_tuple_accepts_int_dims(
+ openvino_engine_module,
+):
+ engine_cls = openvino_engine_module.OpenVINOEngine
+ engine = engine_cls(model_path="model.xml", device="CPU")
+
+ assert engine._partial_shape_to_tuple([1, 2, "dyn"]) == (1, 2, None)
+
+
+def test_openvino_build_type_map_handles_missing_type(openvino_engine_module):
+ engine_cls = openvino_engine_module.OpenVINOEngine
+ engine = engine_cls(model_path="model.xml", device="CPU")
+
+ assert engine._build_type_map(types.SimpleNamespace()) == {}
+
+
+def test_openvino_async_queue_requires_asyncinferqueue(openvino_engine_module):
+ engine_cls = openvino_engine_module.OpenVINOEngine
+ engine = engine_cls(model_path="model.xml", device="CPU")
+ engine._ov = types.SimpleNamespace() # no AsyncInferQueue
+
+ with pytest.raises(RuntimeError, match="AsyncInferQueue"):
+ engine.create_async_queue(num_requests=2)
+
+
+def test_openvino_async_queue_submit_after_close_raises(openvino_engine_module):
+ engine_cls = openvino_engine_module.OpenVINOEngine
+ engine = engine_cls(model_path="model.xml", device="CPU")
+
+ q = engine.create_async_queue(num_requests=1)
+ q.close()
+ q.close() # idempotent
+ with pytest.raises(RuntimeError, match="Async queue is closed"):
+ q.submit({"input": np.ones((1, 3), dtype=np.float32)})
+
+
+def test_openvino_async_queue_completion_queue_put_fallback(
+ openvino_engine_module,
+):
+ """Fallback to completion_queue.put(block=False) when put_nowait is absent."""
+ engine_cls = openvino_engine_module.OpenVINOEngine
+ engine = engine_cls(model_path="model.xml", device="CPU")
+
+ class PutOnlyQueue:
+ def __init__(self):
+ self.items = []
+
+ def put(self, item, *, block=False):
+ self.items.append((item, block))
+
+ completion = PutOnlyQueue()
+ with engine.create_async_queue(num_requests=1) as q:
+ fut = q.submit(
+ {"input": np.ones((1, 3), dtype=np.float32)},
+ request_id="req-1",
+ completion_queue=completion,
+ )
+ fut.result(timeout=1)
+
+ assert completion.items[0][0][0] == "req-1"
+ assert completion.items[0][1] is False
+
+
+def test_openvino_async_queue_completion_queue_signature_mismatch_warns(
+ openvino_engine_module,
+):
+ """Completion queues without non-blocking put semantics should warn once."""
+ engine_cls = openvino_engine_module.OpenVINOEngine
+ engine = engine_cls(model_path="model.xml", device="CPU")
+
+ class BadQueue:
+ def put(self, item):
+ self.item = item
+
+ completion = BadQueue()
+ with engine.create_async_queue(num_requests=1) as q:
+ with pytest.warns(
+ RuntimeWarning,
+ match="does not support non-blocking put",
+ ):
+ fut = q.submit(
+ {"input": np.ones((1, 3), dtype=np.float32)},
+ request_id="req-1",
+ completion_queue=completion,
+ )
+ fut.result(timeout=1)
+
+
+def test_openvino_engine_request_pool_respects_num_requests(
+ openvino_engine_module,
+):
+ config_cls = openvino_engine_module.OpenVINOConfig
+ engine_cls = openvino_engine_module.OpenVINOEngine
+
+ cfg = config_cls(num_requests=3)
+ engine = engine_cls(model_path="model.xml", device="CPU", config=cfg)
+
+ assert engine._request_pool.maxsize == 3
+
+
+def test_openvino_engine_input_shapes_accepts_non_sequence_values(
+ openvino_engine_module,
+):
+ engine_cls = openvino_engine_module.OpenVINOEngine
+
+ engine = engine_cls(
+ model_path="model.xml",
+ device="CPU",
+ input_shapes={"input": {"dim": 3}},
+ )
+
+ assert engine._core._last_reshape["input"] == {"dim": 3}
diff --git a/tests/resources/make_onnx.py b/tests/resources/make_onnx.py
index 52e6dac..cb4a5a4 100644
--- a/tests/resources/make_onnx.py
+++ b/tests/resources/make_onnx.py
@@ -4,6 +4,7 @@
try:
import torch
+
model = torch.nn.Sequential(
torch.nn.Conv2d(3, 32, 3, 2, 1),
torch.nn.BatchNorm2d(32),
@@ -12,20 +13,24 @@
torch.nn.BatchNorm2d(64),
torch.nn.ReLU(),
)
+ dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
model,
- torch.randn(1, 3, 224, 224),
- cur_folder/'model_shape=224x224.onnx',
- input_names=['input'],
- output_names=['output'],
+ (dummy_input,),
+ cur_folder / "model_shape=224x224.onnx",
+ input_names=["input"],
+ output_names=["output"],
)
torch.onnx.export(
model,
- torch.randn(1, 3, 224, 224),
- cur_folder/'model_dynamic-axes.onnx',
- input_names=['input'],
- output_names=['output'],
- dynamic_axes={'input': {0: 'b', 2: 'h', 3: 'w'}, 'output': {0: 'b', 2: 'h', 3: 'w'}},
+ (dummy_input,),
+ cur_folder / "model_dynamic-axes.onnx",
+ input_names=["input"],
+ output_names=["output"],
+ dynamic_axes={
+ "input": {0: "b", 2: "h", 3: "w"},
+ "output": {0: "b", 2: "h", 3: "w"},
+ },
)
except ImportError:
print("PyTorch is not installed.")
diff --git a/tests/structures/test_boxes.py b/tests/structures/test_boxes.py
index 438cc2a..6c188af 100644
--- a/tests/structures/test_boxes.py
+++ b/tests/structures/test_boxes.py
@@ -1,35 +1,39 @@
+from typing import Any
+
import numpy as np
import pytest
from capybara import Box, Boxes, BoxMode
-def test_invalid_input_type():
+def test_box_invalid_input_type():
with pytest.raises(TypeError):
- Box("invalid_input")
+ invalid_input: Any = "invalid_input"
+ Box(invalid_input)
-def test_invalid_input_shape():
+def test_box_invalid_input_shape():
with pytest.raises(TypeError):
- Box([1, 2, 3, 4, 5]) # 長度為5而非4,不符合預期的box格式
+ Box([1, 2, 3, 4, 5]) # 長度為5而非4, 不符合預期的box格式
-def test_normalized_array():
+def test_box_accepts_is_normalized_flag():
array = np.array([0.1, 0.2, 0.3, 0.4])
box = Box(array, is_normalized=True)
assert box.is_normalized is True
-def test_invalid_box_mode():
+def test_box_invalid_box_mode():
with pytest.raises(KeyError):
array = np.array([1, 2, 3, 4])
Box(array, box_mode="invalid_mode")
-def test_array_conversion():
+def test_box_array_conversion():
array = [1, 2, 3, 4]
box = Box(array)
- assert np.allclose(box._array, np.array(array, dtype='float32'))
+ assert np.allclose(box._array, np.array(array, dtype="float32"))
+
# Test Box initialization
@@ -39,13 +43,17 @@ def test_box_init():
box = Box((50, 50, 100, 100), box_mode=BoxMode.XYXY)
assert isinstance(box, Box), "Initialization of Box failed."
+
# Test conversion of Box format
def test_box_convert():
box = Box((50, 50, 100, 100), box_mode=BoxMode.XYXY)
converted_box = box.convert(BoxMode.XYWH)
- assert np.allclose(converted_box.numpy(), np.array([50, 50, 50, 50])), "Box conversion failed."
+ assert np.allclose(converted_box.numpy(), np.array([50, 50, 50, 50])), (
+ "Box conversion failed."
+ )
+
# Test calculation of area of Box
@@ -54,13 +62,17 @@ def test_box_area():
box = Box((50, 50, 100, 100), box_mode=BoxMode.XYXY)
assert box.area == 2500, "Box area calculation failed."
+
# Test Box.copy() method
def test_box_copy():
box = Box((50, 50, 100, 100), box_mode=BoxMode.XYXY)
copied_box = box.copy()
- assert copied_box is not box and (copied_box._array == box._array).all(), "Box copy failed."
+ assert copied_box is not box and (copied_box._array == box._array).all(), (
+ "Box copy failed."
+ )
+
# Test Box conversion to numpy array
@@ -69,7 +81,9 @@ def test_box_numpy():
box = Box((50, 50, 100, 100), box_mode=BoxMode.XYXY)
arr = box.numpy()
assert isinstance(arr, np.ndarray) and np.allclose(
- arr, np.array([50, 50, 100, 100])), "Box to numpy conversion failed."
+ arr, np.array([50, 50, 100, 100])
+ ), "Box to numpy conversion failed."
+
# Test Box normalization
@@ -77,7 +91,10 @@ def test_box_numpy():
def test_box_normalize():
box = Box((50, 50, 100, 100), box_mode=BoxMode.XYXY)
normalized_box = box.normalize(200, 200)
- assert np.allclose(normalized_box.numpy(), np.array([0.25, 0.25, 0.5, 0.5])), "Box normalization failed."
+ assert np.allclose(
+ normalized_box.numpy(), np.array([0.25, 0.25, 0.5, 0.5])
+ ), "Box normalization failed."
+
# Test Box denormalization
@@ -85,7 +102,10 @@ def test_box_normalize():
def test_box_denormalize():
box = Box((0.25, 0.25, 0.5, 0.5), box_mode=BoxMode.XYXY, is_normalized=True)
denormalized_box = box.denormalize(200, 200)
- assert np.allclose(denormalized_box.numpy(), np.array([50, 50, 100, 100])), "Box denormalization failed."
+ assert np.allclose(
+ denormalized_box.numpy(), np.array([50, 50, 100, 100])
+ ), "Box denormalization failed."
+
# Test Box clipping
@@ -93,7 +113,23 @@ def test_box_denormalize():
def test_box_clip():
box = Box((50, 50, 100, 100), box_mode=BoxMode.XYXY)
clipped_box = box.clip(60, 60, 90, 90)
- assert np.allclose(clipped_box.numpy(), np.array([60, 60, 90, 90])), "Box clipping failed."
+ assert np.allclose(clipped_box.numpy(), np.array([60, 60, 90, 90])), (
+ "Box clipping failed."
+ )
+
+
+def test_box_clip_preserves_box_mode_for_xywh_inputs():
+ box = Box((10, 20, 5, 5), box_mode=BoxMode.XYWH)
+ clipped_box = box.clip(0, 0, 12, 30)
+ assert clipped_box.box_mode == BoxMode.XYWH
+ assert np.allclose(clipped_box.numpy(), np.array([10, 20, 2, 5]))
+
+
+def test_box_convert_preserves_is_normalized_flag():
+ box = Box((0.1, 0.2, 0.3, 0.4), box_mode=BoxMode.XYXY, is_normalized=True)
+ converted = box.convert(BoxMode.XYWH)
+ assert converted.is_normalized is True
+
# Test Box shifting
@@ -101,7 +137,10 @@ def test_box_clip():
def test_box_shift():
box = Box((50, 50, 100, 100), box_mode=BoxMode.XYXY)
shifted_box = box.shift(10, -10)
- assert np.allclose(shifted_box.numpy(), np.array([60, 40, 110, 90])), "Box shifting failed."
+ assert np.allclose(shifted_box.numpy(), np.array([60, 40, 110, 90])), (
+ "Box shifting failed."
+ )
+
# Test Box scaling
@@ -109,7 +148,10 @@ def test_box_shift():
def test_box_scale():
box = Box((50, 50, 100, 100), box_mode=BoxMode.XYXY)
scaled_box = box.scale(dsize=(20, 0))
- assert np.allclose(scaled_box.numpy(), np.array([40, 50, 110, 100])), "Box scaling failed."
+ assert np.allclose(scaled_box.numpy(), np.array([40, 50, 110, 100])), (
+ "Box scaling failed."
+ )
+
# Test Box to_list
@@ -118,127 +160,187 @@ def test_box_to_list():
box = Box((50, 50, 50, 50), box_mode=BoxMode.XYXY)
assert box.to_list() == [50, 50, 50, 50], "Boxes tolist failed."
+
# Test Box to_polygon
def test_box_to_polygon():
box = Box((50, 50, 100, 100), box_mode=BoxMode.XYXY)
polygon = box.to_polygon()
- assert np.allclose(polygon.numpy(), np.array(
- [[50, 50], [100, 50], [100, 100], [50, 100]])), "Box convert_to_polygon failed."
+ assert np.allclose(
+ polygon.numpy(), np.array([[50, 50], [100, 50], [100, 100], [50, 100]])
+ ), "Box convert_to_polygon failed."
+
# Test Boxes initialization
-def test_invalid_input_type():
+def test_boxes_invalid_input_type():
with pytest.raises(TypeError):
- Boxes("invalid_input")
+ invalid_input: Any = "invalid_input"
+ Boxes(invalid_input)
-def test_invalid_input_shape():
+def test_boxes_invalid_input_shape():
with pytest.raises(TypeError):
Boxes([[1, 2, 3, 4, 5]])
-def test_normalized_array():
+def test_boxes_accepts_is_normalized_flag():
array = np.array([0.1, 0.2, 0.3, 0.4])
box = Boxes([array], is_normalized=True)
assert box.is_normalized is True
-def test_invalid_box_mode():
+def test_boxes_invalid_box_mode():
with pytest.raises(KeyError):
array = np.array([1, 2, 3, 4])
- Box(array, box_mode="invalid_mode")
+ Boxes([array], box_mode="invalid_mode")
-def test_array_conversion():
+def test_boxes_array_conversion():
array = [[1, 2, 3, 4]]
box = Boxes(array)
- assert np.allclose(box._array, np.array(array, dtype='float32'))
+ assert np.allclose(box._array, np.array(array, dtype="float32"))
def test_boxes_init():
# Create boxes in XYXY format
- boxes = Boxes([(50, 50, 100, 100), [60, 60, 120, 120]], box_mode=BoxMode.XYXY)
+ boxes = Boxes(
+ [(50, 50, 100, 100), [60, 60, 120, 120]], box_mode=BoxMode.XYXY
+ )
assert isinstance(boxes, Boxes), "Initialization of Boxes failed."
+
# Test conversion of Boxes format
def test_boxes_convert():
- boxes = Boxes([(50, 50, 100, 100), (60, 60, 120, 120)], box_mode=BoxMode.XYXY)
+ boxes = Boxes(
+ [(50, 50, 100, 100), (60, 60, 120, 120)], box_mode=BoxMode.XYXY
+ )
converted_boxes = boxes.convert(BoxMode.XYWH)
- assert np.allclose(converted_boxes.numpy(), np.array(
- [[50, 50, 50, 50], [60, 60, 60, 60]])), "Boxes conversion failed."
+ assert np.allclose(
+ converted_boxes.numpy(), np.array([[50, 50, 50, 50], [60, 60, 60, 60]])
+ ), "Boxes conversion failed."
+
# Test calculation of area of Boxes
def test_boxes_area():
- boxes = Boxes([(50, 50, 100, 100), (60, 60, 120, 120)], box_mode=BoxMode.XYXY)
- assert np.allclose(boxes.area, np.array([2500, 3600])), "Boxes area calculation failed."
+ boxes = Boxes(
+ [(50, 50, 100, 100), (60, 60, 120, 120)], box_mode=BoxMode.XYXY
+ )
+ assert np.allclose(boxes.area, np.array([2500, 3600])), (
+ "Boxes area calculation failed."
+ )
+
# Test Boxes.copy() method
def test_boxes_copy():
- boxes = Boxes([(50, 50, 100, 100), (60, 60, 120, 120)], box_mode=BoxMode.XYXY)
+ boxes = Boxes(
+ [(50, 50, 100, 100), (60, 60, 120, 120)], box_mode=BoxMode.XYXY
+ )
copied_boxes = boxes.copy()
- assert copied_boxes is not boxes and (copied_boxes._array == boxes._array).all(), "Boxes copy failed."
+ assert (
+ copied_boxes is not boxes
+ and (copied_boxes._array == boxes._array).all()
+ ), "Boxes copy failed."
+
# Test Boxes conversion to numpy array
def test_boxes_numpy():
- boxes = Boxes([(50, 50, 100, 100), (60, 60, 120, 120)], box_mode=BoxMode.XYXY)
+ boxes = Boxes(
+ [(50, 50, 100, 100), (60, 60, 120, 120)], box_mode=BoxMode.XYXY
+ )
arr = boxes.numpy()
- assert isinstance(arr, np.ndarray) and np.allclose(arr, np.array(
- [(50, 50, 100, 100), (60, 60, 120, 120)])), "Boxes to numpy conversion failed."
+ assert isinstance(arr, np.ndarray) and np.allclose(
+ arr, np.array([(50, 50, 100, 100), (60, 60, 120, 120)])
+ ), "Boxes to numpy conversion failed."
+
# Test Boxes normalization
def test_boxes_normalize():
- boxes = Boxes([(50, 50, 100, 100), (60, 60, 120, 120)], box_mode=BoxMode.XYXY)
+ boxes = Boxes(
+ [(50, 50, 100, 100), (60, 60, 120, 120)], box_mode=BoxMode.XYXY
+ )
normalized_boxes = boxes.normalize(200, 200)
- assert np.allclose(normalized_boxes.numpy(), np.array(
- [[0.25, 0.25, 0.5, 0.5], [0.3, 0.3, 0.6, 0.6]])), "Boxes normalization failed."
+ assert np.allclose(
+ normalized_boxes.numpy(),
+ np.array([[0.25, 0.25, 0.5, 0.5], [0.3, 0.3, 0.6, 0.6]]),
+ ), "Boxes normalization failed."
+
# Test Boxes denormalization
def test_boxes_denormalize():
- boxes = Boxes([(0.25, 0.25, 0.5, 0.5), (0.3, 0.3, 0.6, 0.6)], box_mode=BoxMode.XYXY, is_normalized=True)
+ boxes = Boxes(
+ [(0.25, 0.25, 0.5, 0.5), (0.3, 0.3, 0.6, 0.6)],
+ box_mode=BoxMode.XYXY,
+ is_normalized=True,
+ )
denormalized_boxes = boxes.denormalize(200, 200)
- assert np.allclose(denormalized_boxes.numpy(), np.array(
- [(50, 50, 100, 100), (60, 60, 120, 120)])), "Boxes denormalization failed."
+ assert np.allclose(
+ denormalized_boxes.numpy(),
+ np.array([(50, 50, 100, 100), (60, 60, 120, 120)]),
+ ), "Boxes denormalization failed."
+
# Test Boxes clipping
def test_boxes_clip():
- boxes = Boxes([(50, 50, 100, 100), (60, 60, 120, 120)], box_mode=BoxMode.XYXY)
+ boxes = Boxes(
+ [(50, 50, 100, 100), (60, 60, 120, 120)], box_mode=BoxMode.XYXY
+ )
clipped_boxes = boxes.clip(60, 60, 90, 90)
- assert np.allclose(clipped_boxes.numpy(), np.array([(60, 60, 90, 90), (60, 60, 90, 90)])), "Boxes clipping failed."
+ assert np.allclose(
+ clipped_boxes.numpy(), np.array([(60, 60, 90, 90), (60, 60, 90, 90)])
+ ), "Boxes clipping failed."
+
# Test Boxes shifting
def test_boxes_shift():
- boxes = Boxes([(50, 50, 100, 100), (60, 60, 120, 120)], box_mode=BoxMode.XYXY)
+ boxes = Boxes(
+ [(50, 50, 100, 100), (60, 60, 120, 120)], box_mode=BoxMode.XYXY
+ )
shifted_boxes = boxes.shift(10, -10)
- assert np.allclose(shifted_boxes.numpy(), np.array(
- [(60, 40, 110, 90), (70, 50, 130, 110)])), "Boxes shifting failed."
+ assert np.allclose(
+ shifted_boxes.numpy(), np.array([(60, 40, 110, 90), (70, 50, 130, 110)])
+ ), "Boxes shifting failed."
+
# Test Boxes scaling
def test_boxes_scale():
- boxes = Boxes([(50, 50, 100, 100), (60, 60, 120, 120)], box_mode=BoxMode.XYXY)
+ boxes = Boxes(
+ [(50, 50, 100, 100), (60, 60, 120, 120)], box_mode=BoxMode.XYXY
+ )
scaled_boxes = boxes.scale(dsize=(20, 0))
- assert np.allclose(scaled_boxes.numpy(), np.array(
- [(40, 50, 110, 100), (50, 60, 130, 120)])), "Boxes scaling failed."
+ assert np.allclose(
+ scaled_boxes.numpy(), np.array([(40, 50, 110, 100), (50, 60, 130, 120)])
+ ), "Boxes scaling failed."
+
+
+def test_boxes_scale_supports_fy_with_single_box():
+ """Regression: fy scaling should not index the 4th row for small N."""
+ boxes = Boxes([(0, 0, 10, 10)], box_mode=BoxMode.XYWH)
+ scaled = boxes.scale(fy=2.0).convert(BoxMode.XYWH).numpy()
+ np.testing.assert_allclose(
+ scaled, np.array([[0, -5, 10, 20]], dtype=np.float32)
+ )
+
# Test Boxes get_empty_index
@@ -247,26 +349,183 @@ def test_boxes_get_empty_index():
boxes = Boxes([(50, 50, 50, 50), (60, 60, 120, 120)], box_mode=BoxMode.XYXY)
assert boxes.get_empty_index() == 0, "Boxes get_empty_index failed."
+
# Test Boxes drop_empty
def test_boxes_drop_empty():
boxes = Boxes([(50, 50, 50, 50), (60, 60, 120, 120)], box_mode=BoxMode.XYXY)
boxes = boxes.drop_empty()
- assert np.allclose(boxes.numpy(), np.array([(60, 60, 120, 120)])), "Boxes drop_empty failed."
+ assert np.allclose(boxes.numpy(), np.array([(60, 60, 120, 120)])), (
+ "Boxes drop_empty failed."
+ )
+
# Test Boxes tolist
def test_boxes_tolist():
boxes = Boxes([(50, 50, 50, 50), (60, 60, 120, 120)], box_mode=BoxMode.XYXY)
- assert boxes.tolist() == [[50, 50, 50, 50], [60, 60, 120, 120]], "Boxes tolist failed."
+ assert boxes.tolist() == [[50, 50, 50, 50], [60, 60, 120, 120]], (
+ "Boxes tolist failed."
+ )
+
# Test Boxes to_polygons
def test_boxes_to_polygons():
- boxes = Boxes([(50, 50, 100, 100), (60, 60, 120, 120)], box_mode=BoxMode.XYXY)
+ boxes = Boxes(
+ [(50, 50, 100, 100), (60, 60, 120, 120)], box_mode=BoxMode.XYXY
+ )
polygons = boxes.to_polygons()
- assert np.allclose(polygons.numpy(), np.array([[[50, 50], [100, 50], [100, 100], [50, 100]], [
- [60, 60], [120, 60], [120, 120], [60, 120]]])), "Boxes convert_to_polygons failed."
+ assert np.allclose(
+ polygons.numpy(),
+ np.array(
+ [
+ [[50, 50], [100, 50], [100, 100], [50, 100]],
+ [[60, 60], [120, 60], [120, 120], [60, 120]],
+ ]
+ ),
+ ), "Boxes convert_to_polygons failed."
+
+
+def test_boxmode_convert_supports_cxcywh_and_align_code_int_and_invalid():
+ xywh = np.array([10, 20, 30, 40], dtype=np.float32)
+
+ cxcywh = BoxMode.convert(xywh, 1, "cxcywh")
+ np.testing.assert_allclose(
+ cxcywh, np.array([25, 40, 30, 40], dtype=np.float32)
+ )
+
+ xyxy = BoxMode.convert(cxcywh, BoxMode.CXCYWH, BoxMode.XYXY)
+ np.testing.assert_allclose(
+ xyxy, np.array([10, 20, 40, 60], dtype=np.float32)
+ )
+
+ back_xywh = BoxMode.convert(xyxy, BoxMode.XYXY, BoxMode.XYWH)
+ np.testing.assert_allclose(back_xywh, xywh)
+
+ with pytest.raises(TypeError, match="not int, str, or BoxMode"):
+ invalid_box_mode: Any = object()
+ BoxMode.align_code(invalid_box_mode)
+
+
+def test_box_repr_getitem_slice_and_eq_non_box():
+ box = Box((1, 2, 3, 4), box_mode=BoxMode.XYXY)
+ assert "Box(" in repr(box)
+
+ np.testing.assert_allclose(box[:2], np.array([1, 2], dtype=np.float32))
+ assert (box == object()) is False
+
+
+def test_box_invalid_numpy_array_shape_raises():
+ with pytest.raises(TypeError, match="got shape"):
+ Box(np.zeros((2, 2), dtype=np.float32))
+
+
+def test_box_square_warns_on_normalize_denormalize_clip_nan_and_scale_fx_fy():
+ box_xywh = Box((0, 0, 10, 5), box_mode=BoxMode.XYWH)
+ square = box_xywh.square().convert(BoxMode.XYWH).numpy()
+ np.testing.assert_allclose(square, np.array([2.5, 0.0, 5.0, 5.0]))
+
+ with pytest.warns(UserWarning, match="forced to do normalization"):
+ _ = Box((0.1, 0.1, 0.2, 0.2), is_normalized=True).normalize(10, 10)
+
+ with pytest.warns(UserWarning, match="forced to do denormalization"):
+ _ = Box((1, 2, 3, 4), is_normalized=False).denormalize(10, 10)
+
+ with pytest.raises(ValueError, match="infinite or NaN"):
+ Box((np.nan, 0, 1, 1)).clip(0, 0, 10, 10)
+
+ scaled = Box((10, 10, 10, 20), box_mode=BoxMode.XYWH).scale(fx=2.0, fy=0.5)
+ np.testing.assert_allclose(
+ scaled.convert(BoxMode.XYWH).numpy(), np.array([5, 15, 20, 10])
+ )
+
+
+def test_box_to_polygon_rejects_non_positive_size_and_properties_work():
+ with pytest.raises(ValueError, match="invaild value"):
+ Box((0, 0, -1, 2), box_mode=BoxMode.XYWH).to_polygon()
+
+ box = Box((10, 20, 30, 40), box_mode=BoxMode.XYWH)
+ assert box.width == 30
+ assert box.height == 40
+ np.testing.assert_allclose(
+ box.left_top, np.array([10, 20], dtype=np.float32)
+ )
+ np.testing.assert_allclose(
+ box.right_bottom, np.array([40, 60], dtype=np.float32)
+ )
+ np.testing.assert_allclose(
+ box.left_bottom, np.array([10, 60], dtype=np.float32)
+ )
+ np.testing.assert_allclose(
+ box.right_top, np.array([40, 20], dtype=np.float32)
+ )
+ assert box.aspect_ratio == 30 / 40
+ np.testing.assert_allclose(box.center, np.array([25, 40], dtype=np.float32))
+
+
+def test_boxes_repr_indexing_eq_and_constructor_from_boxes():
+ boxes = Boxes([[0, 0, 10, 10], [20, 20, 30, 30]], box_mode=BoxMode.XYXY)
+ assert "Boxes(" in repr(boxes)
+
+ assert isinstance(boxes[[1]], Boxes)
+ assert len(boxes[[1]]) == 1
+ assert isinstance(boxes[:1], Boxes)
+ assert len(boxes[:1]) == 1
+
+ mask = np.array([True, False])
+ assert len(boxes[mask]) == 1
+
+ with pytest.raises(TypeError, match="Boxes indices"):
+ _ = boxes["0"] # type: ignore[index]
+
+ assert (boxes == object()) is False
+
+ converted = Boxes(boxes, box_mode=BoxMode.XYWH)
+ assert converted.box_mode == BoxMode.XYWH
+ assert len(converted) == len(boxes)
+
+
+def test_boxes_square_warns_clip_nan_scale_fx_and_to_polygons_invalid():
+ boxes_xywh = Boxes([[0, 0, 10, 5], [10, 10, 6, 8]], box_mode=BoxMode.XYWH)
+ squared = boxes_xywh.square().convert(BoxMode.XYWH).numpy()
+ assert np.allclose(squared[:, 2], squared[:, 3])
+
+ with pytest.warns(UserWarning, match="forced to do normalization"):
+ _ = Boxes([[0.1, 0.1, 0.2, 0.2]], is_normalized=True).normalize(10, 10)
+
+ with pytest.warns(UserWarning, match="forced to do denormalization"):
+ _ = Boxes([[1, 2, 3, 4]], is_normalized=False).denormalize(10, 10)
+
+ with pytest.raises(ValueError, match="infinite or NaN"):
+ Boxes([[np.nan, 0, 1, 1]], box_mode=BoxMode.XYXY).clip(0, 0, 10, 10)
+
+ scaled = Boxes([[10, 10, 10, 20]], box_mode=BoxMode.XYWH).scale(fx=2.0)
+ np.testing.assert_allclose(
+ scaled.convert(BoxMode.XYWH).numpy(), np.array([[5, 10, 20, 20]])
+ )
+
+ with pytest.raises(ValueError, match="invaild value"):
+ Boxes([[0, 0, -1, 2]], box_mode=BoxMode.XYWH).to_polygons()
+
+
+def test_boxes_properties_work():
+ boxes = Boxes([[10, 20, 30, 40], [0, 0, 2, 4]], box_mode=BoxMode.XYWH)
+ np.testing.assert_allclose(boxes.width, np.array([30, 2], dtype=np.float32))
+ np.testing.assert_allclose(
+ boxes.height, np.array([40, 4], dtype=np.float32)
+ )
+ np.testing.assert_allclose(
+ boxes.left_top, np.array([[10, 20], [0, 0]], dtype=np.float32)
+ )
+ np.testing.assert_allclose(
+ boxes.right_bottom,
+ np.array([[40, 60], [2, 4]], dtype=np.float32),
+ )
+ np.testing.assert_allclose(boxes.aspect_ratio, np.array([0.75, 0.5]))
+ np.testing.assert_allclose(
+ boxes.center, np.array([[25, 40], [1, 2]], dtype=np.float32)
+ )
diff --git a/tests/structures/test_functionals.py b/tests/structures/test_functionals.py
index 4f8816f..7d5d51b 100644
--- a/tests/structures/test_functionals.py
+++ b/tests/structures/test_functionals.py
@@ -1,67 +1,86 @@
import numpy as np
import pytest
-from capybara import (Boxes, Polygon, jaccard_index, pairwise_ioa,
- pairwise_iou, polygon_iou)
+from capybara import (
+ Box,
+ Boxes,
+ Keypoints,
+ Polygon,
+ jaccard_index,
+ pairwise_ioa,
+ pairwise_iou,
+ polygon_iou,
+)
+from capybara.structures.functionals import (
+ calc_angle,
+ is_inside_box,
+ pairwise_intersection,
+ poly_angle,
+)
test_functionals_error_param = [
(
pairwise_iou,
([(1, 2, 3, 4)], [(1, 2, 3, 4)]),
TypeError,
- 'Input type of boxes1 and boxes2 must be Boxes'
+ "Input type of boxes1 and boxes2 must be Boxes",
),
(
pairwise_iou,
[Boxes([(1, 1, 0, 2)], "XYWH"), Boxes([(1, 1, 0, 2)], "XYWH")],
ValueError,
- 'Some boxes in Boxes has invaild value'
+ "Some boxes in Boxes has invaild value",
),
(
pairwise_ioa,
([(1, 2, 3, 4)], [(1, 2, 3, 4)]),
TypeError,
- 'Input type of boxes1 and boxes2 must be Boxes'
+ "Input type of boxes1 and boxes2 must be Boxes",
),
(
pairwise_ioa,
[Boxes([(1, 1, 0, 2)], "XYWH"), Boxes([(1, 1, 0, 2)], "XYWH")],
ValueError,
- 'Some boxes in Boxes has invaild value'
+ "Some boxes in Boxes has invaild value",
),
]
-@pytest.mark.parametrize('fn, test_input, error, match', test_functionals_error_param)
+@pytest.mark.parametrize(
+ "fn, test_input, error, match", test_functionals_error_param
+)
def test_functionals_error(fn, test_input, error, match):
with pytest.raises(error, match=match):
fn(*test_input)
-test_pairwise_iou_param = [(
- Boxes(np.array([[10, 10, 20, 20], [15, 15, 25, 25]]), "XYXY"),
- Boxes(
- np.array([[10, 10, 20, 20], [15, 15, 25, 25], [25, 25, 10, 10]]), "XYWH"),
- np.array([
- [1 / 4, 1 / 28, 0],
- [1 / 4, 4 / 25, 0]
- ], dtype='float32')
-)]
+test_pairwise_iou_param = [
+ (
+ Boxes(np.array([[10, 10, 20, 20], [15, 15, 25, 25]]), "XYXY"),
+ Boxes(
+ np.array([[10, 10, 20, 20], [15, 15, 25, 25], [25, 25, 10, 10]]),
+ "XYWH",
+ ),
+ np.array([[1 / 4, 1 / 28, 0], [1 / 4, 4 / 25, 0]], dtype="float32"),
+ )
+]
-@pytest.mark.parametrize('boxes1, boxes2, expected', test_pairwise_iou_param)
+@pytest.mark.parametrize("boxes1, boxes2, expected", test_pairwise_iou_param)
def test_pairwise_iou(boxes1, boxes2, expected):
assert (pairwise_iou(boxes1, boxes2) == expected).all()
-test_pairwise_ioa_param = [(
- Boxes(np.array([[10, 10, 20, 20]]), "XYXY"),
- Boxes(np.array([[15, 15, 20, 20], [20, 20, 10, 10]]), "XYWH"),
- np.array([[1 / 16, 0]], dtype='float32')
-)]
+test_pairwise_ioa_param = [
+ (
+ Boxes(np.array([[10, 10, 20, 20]]), "XYXY"),
+ Boxes(np.array([[15, 15, 20, 20], [20, 20, 10, 10]]), "XYWH"),
+ np.array([[1 / 16, 0]], dtype="float32"),
+ )
+]
-@pytest.mark.parametrize('boxes1, boxes2, expected', test_pairwise_ioa_param)
+@pytest.mark.parametrize("boxes1, boxes2, expected", test_pairwise_ioa_param)
def test_pairwise_ioa(boxes1, boxes2, expected):
assert (pairwise_ioa(boxes1, boxes2) == expected).all()
@@ -70,12 +89,12 @@ def test_pairwise_ioa(boxes1, boxes2, expected):
(
Polygon(np.array([[0, 0], [0, 10], [10, 10], [10, 0]])),
Polygon(np.array([[5, 5], [5, 15], [15, 15], [15, 5]])),
- 25 / 175
+ 25 / 175,
)
]
-@pytest.mark.parametrize('poly1, poly2, expected', test_polygon_iou_param)
+@pytest.mark.parametrize("poly1, poly2, expected", test_polygon_iou_param)
def test_polygon_iou(poly1, poly2, expected):
assert polygon_iou(poly1, poly2) == expected
@@ -85,12 +104,14 @@ def test_polygon_iou(poly1, poly2, expected):
np.array([[0, 0], [0, 10], [10, 10], [10, 0]]),
np.array([[5, 5], [5, 15], [15, 15], [15, 5]]),
(100, 100),
- 25 / 175
+ 25 / 175,
)
]
-@pytest.mark.parametrize('pred_poly, gt_poly, img_size, expected', test_jaccard_index_param)
+@pytest.mark.parametrize(
+ "pred_poly, gt_poly, img_size, expected", test_jaccard_index_param
+)
def test_jaccard_index(pred_poly, gt_poly, img_size, expected):
assert jaccard_index(pred_poly, gt_poly, img_size) == expected
@@ -101,19 +122,154 @@ def test_jaccard_index(pred_poly, gt_poly, img_size, expected):
np.array([[5, 5], [5, 15], [15, 15], [15, 5]]),
(100, 100),
ValueError,
- 'Input polygon must be 4-point polygon.'
+ "Input polygon must be 4-point polygon.",
),
(
np.array([[0, 0], [0, 10], [10, 10], [10, 0]]),
np.array([[5, 5], [5, 15], [15, 15], [15, 5], [5, 5]]),
(100, 100),
ValueError,
- 'Input polygon must be 4-point polygon.'
+ "Input polygon must be 4-point polygon.",
),
]
-@pytest.mark.parametrize('pred_poly, gt_poly, img_size, error, match', test_jaccard_index_error_param)
+@pytest.mark.parametrize(
+ "pred_poly, gt_poly, img_size, error, match", test_jaccard_index_error_param
+)
def test_jaccard_index_error(pred_poly, gt_poly, img_size, error, match):
with pytest.raises(error, match=match):
jaccard_index(pred_poly, gt_poly, img_size)
+
+
+def test_pairwise_intersection_rejects_non_boxes():
+ with pytest.raises(TypeError, match="must be Boxes"):
+ pairwise_intersection([(0, 0, 1, 1)], [(0, 0, 1, 1)]) # type: ignore[arg-type]
+
+
+def test_jaccard_index_requires_image_size():
+ pred = np.zeros((4, 2), dtype=np.float32)
+ gt = np.zeros((4, 2), dtype=np.float32)
+ with pytest.raises(ValueError, match="image size"):
+ jaccard_index(pred, gt, None) # type: ignore[arg-type]
+
+
+def test_jaccard_index_returns_zero_when_shapely_raises(monkeypatch):
+ import capybara.structures.functionals as fn_mod
+
+ monkeypatch.setattr(
+ fn_mod,
+ "ShapelyPolygon",
+ lambda *_args, **_kwargs: (_ for _ in ()).throw(ValueError("boom")),
+ )
+ pred = np.zeros((4, 2), dtype=np.float32)
+ gt = np.zeros((4, 2), dtype=np.float32)
+ assert jaccard_index(pred, gt, (100, 100)) == 0
+
+
+def test_jaccard_index_clamps_intersection_area_close_to_min(monkeypatch):
+ import capybara.structures.functionals as fn_mod
+
+ monkeypatch.setattr(
+ fn_mod.cv2, "getPerspectiveTransform", lambda *_: np.eye(3)
+ )
+ monkeypatch.setattr(fn_mod.cv2, "perspectiveTransform", lambda pts, _m: pts)
+
+ class _FakePoly:
+ def __init__(
+ self, *, area: float, intersection_area: float | None = None
+ ) -> None:
+ self._area = area
+ self._intersection_area = intersection_area
+
+ @property
+ def area(self) -> float:
+ return self._area
+
+ def __and__(self, _other):
+ assert self._intersection_area is not None
+ return _FakePoly(area=self._intersection_area)
+
+ target = _FakePoly(area=1.0, intersection_area=1.00000000005)
+ pred = _FakePoly(area=1.0)
+ factory = iter([target, pred])
+ monkeypatch.setattr(
+ fn_mod, "ShapelyPolygon", lambda *_args, **_kwargs: next(factory)
+ )
+
+ pred_poly = np.zeros((4, 2), dtype=np.float32)
+ gt_poly = np.zeros((4, 2), dtype=np.float32)
+ assert jaccard_index(pred_poly, gt_poly, (10, 10)) == pytest.approx(1.0)
+
+
+def test_polygon_iou_rejects_non_polygon_and_returns_zero_on_errors(
+ monkeypatch,
+):
+ with pytest.raises(TypeError, match="must be Polygon"):
+ polygon_iou("bad", Polygon(np.zeros((4, 2), dtype=np.float32))) # type: ignore[arg-type]
+
+ import capybara.structures.functionals as fn_mod
+
+ class _FakePoly:
+ def __init__(
+ self,
+ *,
+ area: float,
+ intersection_area: float | None = None,
+ raise_intersection: bool = False,
+ ) -> None:
+ self._area = area
+ self._intersection_area = intersection_area
+ self._raise_intersection = raise_intersection
+
+ @property
+ def area(self) -> float:
+ return self._area
+
+ def intersection(self, _other):
+ if self._raise_intersection:
+ raise ValueError("boom")
+ assert self._intersection_area is not None
+ return _FakePoly(area=self._intersection_area)
+
+ poly1_shape = _FakePoly(area=2.0, intersection_area=1.00000000005)
+ poly2_shape = _FakePoly(area=1.0)
+ factory = iter([poly1_shape, poly2_shape])
+ monkeypatch.setattr(
+ fn_mod, "ShapelyPolygon", lambda *_args, **_kwargs: next(factory)
+ )
+
+ poly1 = Polygon(np.zeros((4, 2), dtype=np.float32))
+ poly2 = Polygon(np.zeros((4, 2), dtype=np.float32))
+ assert polygon_iou(poly1, poly2) == pytest.approx(0.5)
+
+ poly1_shape_err = _FakePoly(area=1.0, raise_intersection=True)
+ poly2_shape_err = _FakePoly(area=1.0)
+ factory_err = iter([poly1_shape_err, poly2_shape_err])
+ monkeypatch.setattr(
+ fn_mod, "ShapelyPolygon", lambda *_args, **_kwargs: next(factory_err)
+ )
+ assert polygon_iou(poly1, poly2) == 0
+
+
+def test_is_inside_box_calc_angle_and_poly_angle():
+ box = Box((0, 0, 10, 10), box_mode="XYXY")
+ assert is_inside_box(Keypoints([(1, 1), (9, 9)]), box)
+ assert not is_inside_box(Keypoints([(1, 1), (11, 9)]), box)
+
+ assert calc_angle(np.array([0, 1]), np.array([-1, 0])) == pytest.approx(
+ 90.0
+ )
+ assert calc_angle(np.array([0, 1]), np.array([1, 0])) == pytest.approx(
+ 270.0
+ )
+
+ poly1 = Polygon(
+ np.array([[0, 0], [0, 1], [1, 1], [1, 0]], dtype=np.float32)
+ )
+ assert poly_angle(poly1) == pytest.approx(90.0)
+
+ poly2 = Polygon(
+ np.array([[0, 0], [1, 0], [0, -1], [1, -1]], dtype=np.float32)
+ )
+ assert poly_angle(poly1, poly2) == pytest.approx(270.0)
diff --git a/tests/structures/test_keypoints.py b/tests/structures/test_keypoints.py
index 20509c9..d967ee0 100644
--- a/tests/structures/test_keypoints.py
+++ b/tests/structures/test_keypoints.py
@@ -1,23 +1,29 @@
+from typing import Any
+
import numpy as np
import pytest
-from capybara import Box, Boxes, BoxMode, Keypoints, KeypointsList
+from capybara import Keypoints, KeypointsList
def test_invalid_input_type():
with pytest.raises(TypeError):
- Keypoints("invalid_input")
+ invalid_input: Any = "invalid_input"
+ Keypoints(invalid_input)
def test_invalid_input_shape():
with pytest.raises(ValueError):
- Keypoints([(1, 2, 3, 4), (1, 2, 3, 4)])
+ invalid_input: Any = [(1, 2, 3, 4), (1, 2, 3, 4)]
+ Keypoints(invalid_input)
def test_keypoints_eat_itself():
keypoints1 = Keypoints([(1, 2), (3, 4)])
keypoints2 = Keypoints(keypoints1)
- assert np.allclose(keypoints1.numpy(), keypoints2.numpy()), "Keypoints eat itself failed."
+ assert np.allclose(keypoints1.numpy(), keypoints2.numpy()), (
+ "Keypoints eat itself failed."
+ )
def test_normalized_array():
@@ -29,91 +35,233 @@ def test_normalized_array():
def test_keypoints_numpy():
array = np.array([[1, 2], [3, 4]])
keypoints = Keypoints(array)
- assert np.allclose(keypoints.numpy(), array), "Keypoints numpy conversion failed."
+ assert np.allclose(keypoints.numpy(), array), (
+ "Keypoints numpy conversion failed."
+ )
def test_keypoints_copy():
keypoints = Keypoints([(1, 2), (3, 4)])
copied_keypoints = keypoints.copy()
- assert np.allclose(keypoints.numpy(), copied_keypoints.numpy()), "Keypoints copy failed."
+ assert np.allclose(keypoints.numpy(), copied_keypoints.numpy()), (
+ "Keypoints copy failed."
+ )
def test_keypoints_shift():
keypoints = Keypoints([(1, 2), (3, 4)])
shifted_keypoints = keypoints.shift(10, 10)
- assert np.allclose(shifted_keypoints.numpy(), np.array(
- [[11, 12], [13, 14]])), "Keypoints shift failed."
+ assert np.allclose(
+ shifted_keypoints.numpy(), np.array([[11, 12], [13, 14]])
+ ), "Keypoints shift failed."
def test_keypoints_scale():
keypoints = Keypoints([(1, 2), (3, 4)])
scaled_keypoints = keypoints.scale(10, 10)
- assert np.allclose(scaled_keypoints.numpy(), np.array(
- [[10, 20], [30, 40]])), "Keypoints scale failed."
+ assert np.allclose(
+ scaled_keypoints.numpy(), np.array([[10, 20], [30, 40]])
+ ), "Keypoints scale failed."
def test_keypoints_normalize():
keypoints = Keypoints([(1, 2), (3, 4)])
normalized_keypoints = keypoints.normalize(100, 100)
- assert np.allclose(normalized_keypoints.numpy(), np.array(
- [[0.01, 0.02], [0.03, 0.04]])), "Keypoints normalization failed."
+ assert np.allclose(
+ normalized_keypoints.numpy(), np.array([[0.01, 0.02], [0.03, 0.04]])
+ ), "Keypoints normalization failed."
def test_keypoints_denormalize():
keypoints = Keypoints([(0.01, 0.02), (0.03, 0.04)], is_normalized=True)
denormalized_keypoints = keypoints.denormalize(100, 100)
- assert np.allclose(denormalized_keypoints.numpy(), np.array(
- [[1, 2], [3, 4]])), "Keypoints denormalization failed."
+ assert np.allclose(
+ denormalized_keypoints.numpy(), np.array([[1, 2], [3, 4]])
+ ), "Keypoints denormalization failed."
def test_keypoints_list_empty_input():
keypoints_list = KeypointsList([])
- np.testing.assert_allclose(keypoints_list.numpy(), np.array([], dtype='float32'))
+ np.testing.assert_allclose(
+ keypoints_list.numpy(), np.array([], dtype="float32")
+ )
def test_keypoints_list_numpy():
array = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
keypoints_list = KeypointsList(array)
- assert np.allclose(keypoints_list.numpy(), array), "KeypointsList numpy conversion failed."
+ assert np.allclose(keypoints_list.numpy(), array), (
+ "KeypointsList numpy conversion failed."
+ )
def test_keypoints_list_copy():
keypoints_list = KeypointsList([[(1, 2), (3, 4)], [(5, 6), (7, 8)]])
copied_keypoints_list = keypoints_list.copy()
- assert np.allclose(keypoints_list.numpy(), copied_keypoints_list.numpy()), "KeypointsList copy failed."
+ assert np.allclose(keypoints_list.numpy(), copied_keypoints_list.numpy()), (
+ "KeypointsList copy failed."
+ )
def test_keypoints_list_shift():
keypoints_list = KeypointsList([[(1, 2), (3, 4)], [(5, 6), (7, 8)]])
shifted_keypoints_list = keypoints_list.shift(10, 10)
- assert np.allclose(shifted_keypoints_list.numpy(), np.array(
- [[[11, 12], [13, 14]], [[15, 16], [17, 18]]])), "KeypointsList shift failed."
+ assert np.allclose(
+ shifted_keypoints_list.numpy(),
+ np.array([[[11, 12], [13, 14]], [[15, 16], [17, 18]]]),
+ ), "KeypointsList shift failed."
def test_keypoints_list_scale():
keypoints_list = KeypointsList([[(1, 2), (3, 4)], [(5, 6), (7, 8)]])
scaled_keypoints_list = keypoints_list.scale(10, 10)
- assert np.allclose(scaled_keypoints_list.numpy(), np.array(
- [[[10, 20], [30, 40]], [[50, 60], [70, 80]]])), "KeypointsList scale failed."
+ assert np.allclose(
+ scaled_keypoints_list.numpy(),
+ np.array([[[10, 20], [30, 40]], [[50, 60], [70, 80]]]),
+ ), "KeypointsList scale failed."
def test_keypoints_list_normalize():
keypoints_list = KeypointsList([[(1, 2), (3, 4)], [(5, 6), (7, 8)]])
normalized_keypoints_list = keypoints_list.normalize(100, 100)
- assert np.allclose(normalized_keypoints_list.numpy(), np.array(
- [[[0.01, 0.02], [0.03, 0.04]], [[0.05, 0.06], [0.07, 0.08]]])), "KeypointsList normalization failed."
+ assert np.allclose(
+ normalized_keypoints_list.numpy(),
+ np.array([[[0.01, 0.02], [0.03, 0.04]], [[0.05, 0.06], [0.07, 0.08]]]),
+ ), "KeypointsList normalization failed."
def test_keypoints_list_denormalize():
- keypoints_list = KeypointsList([[(0.01, 0.02), (0.03, 0.04)], [(0.05, 0.06), (0.07, 0.08)]], is_normalized=True)
+ keypoints_list = KeypointsList(
+ [[(0.01, 0.02), (0.03, 0.04)], [(0.05, 0.06), (0.07, 0.08)]],
+ is_normalized=True,
+ )
denormalized_keypoints_list = keypoints_list.denormalize(100, 100)
- assert np.allclose(denormalized_keypoints_list.numpy(), np.array(
- [[[1, 2], [3, 4]], [[5, 6], [7, 8]]])), "KeypointsList denormalization failed."
+ assert np.allclose(
+ denormalized_keypoints_list.numpy(),
+ np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]),
+ ), "KeypointsList denormalization failed."
def test_keypoints_list_cat():
keypoints_list1 = KeypointsList([[(1, 2), (3, 4)]])
keypoints_list2 = KeypointsList([[(5, 6), (7, 8)]])
cat_keypoints_list = KeypointsList.cat([keypoints_list1, keypoints_list2])
- assert np.allclose(cat_keypoints_list.numpy(), np.array(
- [[[1, 2], [3, 4]], [[5, 6], [7, 8]]])), "KeypointsList concatenation failed."
+ assert np.allclose(
+ cat_keypoints_list.numpy(),
+ np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]),
+ ), "KeypointsList concatenation failed."
+
+
+def test_keypoints_repr_eq_and_validation_and_warnings():
+ kpts = Keypoints([(1, 2), (3, 4)])
+ assert "Keypoints(" in repr(kpts)
+ assert (kpts == object()) is False
+
+ with pytest.raises(ValueError, match="ndim"):
+ Keypoints(np.zeros((2, 2, 2), dtype=np.float32))
+
+ with pytest.raises(ValueError, match="labels"):
+ Keypoints(np.array([[1, 2, 3], [4, 5, -1]], dtype=np.float32))
+
+ with pytest.warns(UserWarning, match="forced to do normalization"):
+ Keypoints([(0.1, 0.2)], is_normalized=True).normalize(10, 10)
+
+ with pytest.warns(UserWarning, match="forced to do denormalization"):
+ Keypoints([(1, 2)], is_normalized=False).denormalize(10, 10)
+
+
+def test_keypoints_point_colors_can_be_updated():
+ kpts = Keypoints([(1, 2), (3, 4)])
+ colors = kpts.point_colors
+ assert len(colors) == 2
+ assert all(isinstance(c, tuple) and len(c) == 3 for c in colors)
+
+ kpts.point_colors = "viridis"
+ colors2 = kpts.point_colors
+ assert len(colors2) == 2
+
+
+def test_keypoints_list_getitem_setitem_repr_eq_and_point_colors():
+ kpts_list = KeypointsList([[(1, 2), (3, 4)], [(5, 6), (7, 8)]])
+ assert isinstance(kpts_list[0], Keypoints)
+ assert isinstance(kpts_list[:1], KeypointsList)
+
+ with pytest.raises(TypeError, match="not a keypoint"):
+ kpts_list[0] = "bad" # type: ignore[assignment]
+
+ kpts_list[0] = Keypoints([(9, 10), (11, 12)])
+ np.testing.assert_allclose(
+ kpts_list[0].numpy(), np.array([[9, 10], [11, 12]], dtype=np.float32)
+ )
+
+ assert "KeypointsList(" in repr(kpts_list)
+ assert (kpts_list == object()) is False
+
+ assert len(kpts_list.point_colors) == 2
+ kpts_list.point_colors = "viridis"
+ assert len(kpts_list.point_colors) == 2
+
+
+def test_keypoints_list_validation_errors_and_warnings_and_empty_point_colors():
+ kpts_list = KeypointsList([[(1, 2), (3, 4)]])
+ kpts_list2 = KeypointsList(kpts_list)
+ np.testing.assert_allclose(kpts_list2.numpy(), kpts_list.numpy())
+
+ with pytest.raises(ValueError, match="ndim"):
+ KeypointsList(np.zeros((2, 2), dtype=np.float32))
+
+ with pytest.raises(ValueError, match="shape\\[-1\\]"):
+ KeypointsList(np.zeros((1, 2, 4), dtype=np.float32))
+
+ with pytest.raises(ValueError, match="labels"):
+ KeypointsList(np.array([[[1, 2, 3], [4, 5, 9]]], dtype=np.float32))
+
+ with pytest.warns(UserWarning, match="keypoints_list"):
+ KeypointsList([[(0.1, 0.2), (0.3, 0.4)]], is_normalized=True).normalize(
+ 10, 10
+ )
+
+ with pytest.warns(UserWarning, match="forced to do denormalization"):
+ KeypointsList([[(1, 2), (3, 4)]], is_normalized=False).denormalize(
+ 10, 10
+ )
+
+ assert KeypointsList([]).point_colors == []
+
+
+def test_keypoints_list_cat_validation_errors():
+ with pytest.raises(TypeError, match="should be a list"):
+ KeypointsList.cat("not-a-list") # type: ignore[arg-type]
+
+ with pytest.raises(ValueError, match="is empty"):
+ KeypointsList.cat([])
+
+ with pytest.raises(TypeError, match="must be KeypointsList"):
+ KeypointsList.cat([KeypointsList([]), "bad"]) # type: ignore[list-item]
+
+
+def test_keypoints_colormap_falls_back_without_matplotlib(monkeypatch):
+ import builtins
+
+ real_import = builtins.__import__
+
+ def fake_import(name, *args, **kwargs):
+ if name == "matplotlib":
+ raise ModuleNotFoundError("boom")
+ return real_import(name, *args, **kwargs)
+
+ monkeypatch.setattr(builtins, "__import__", fake_import)
+
+ kpts = Keypoints([(0, 0), (1, 1)])
+ assert len(kpts.point_colors) == 2
+
+
+def test_keypoints_list_rejects_invalid_types():
+ with pytest.raises(TypeError, match="Input array is not"):
+ KeypointsList(123) # type: ignore[arg-type]
+
+
+def test_keypoints_list_set_point_colors_noops_when_empty():
+ keypoints_list = KeypointsList([])
+ keypoints_list.set_point_colors("rainbow")
+ assert keypoints_list.point_colors == []
diff --git a/tests/structures/test_polygons.py b/tests/structures/test_polygons.py
index c4ddb89..0f9b325 100644
--- a/tests/structures/test_polygons.py
+++ b/tests/structures/test_polygons.py
@@ -1,3 +1,5 @@
+from typing import Any, cast
+
import numpy as np
import pytest
@@ -13,45 +15,49 @@ def assert_almost_equal(actual, expected, tolerance=1e-5):
tl, bl, br, tr = (0, 0), (0, 1), (1, 1), (1, 0)
-POINTS_SET = np.array([(tl, bl, br, tr),
- (tl, bl, tr, br),
- (tl, br, bl, tr),
- (tl, br, tr, bl),
- (tl, tr, bl, br),
- (tl, tr, br, bl),
- (bl, tl, br, tr),
- (bl, tl, tr, br),
- (bl, br, tl, tr),
- (bl, br, tr, tl),
- (bl, tr, tl, br),
- (bl, tr, br, tl),
- (br, tl, bl, tr),
- (br, tl, tr, bl),
- (br, bl, tl, tr),
- (br, bl, tr, tl),
- (br, tr, tl, bl),
- (br, tr, bl, tl),
- (tr, tl, bl, br),
- (tr, tl, br, bl),
- (tr, bl, tl, br),
- (tr, bl, br, tl),
- (tr, br, tl, bl),
- (tr, br, bl, tl)])
+POINTS_SET = np.array(
+ [
+ (tl, bl, br, tr),
+ (tl, bl, tr, br),
+ (tl, br, bl, tr),
+ (tl, br, tr, bl),
+ (tl, tr, bl, br),
+ (tl, tr, br, bl),
+ (bl, tl, br, tr),
+ (bl, tl, tr, br),
+ (bl, br, tl, tr),
+ (bl, br, tr, tl),
+ (bl, tr, tl, br),
+ (bl, tr, br, tl),
+ (br, tl, bl, tr),
+ (br, tl, tr, bl),
+ (br, bl, tl, tr),
+ (br, bl, tr, tl),
+ (br, tr, tl, bl),
+ (br, tr, bl, tl),
+ (tr, tl, bl, br),
+ (tr, tl, br, bl),
+ (tr, bl, tl, br),
+ (tr, bl, br, tl),
+ (tr, br, tl, bl),
+ (tr, br, bl, tl),
+ ]
+)
CLOCKWISE_PTS = np.array([tl, tr, br, bl])
COUNTER_CLOCKWISE_PTS = np.array([tl, bl, br, tr])
-@pytest.mark.parametrize("pts, expected", [
- (pts, CLOCKWISE_PTS) for pts in POINTS_SET
-])
+@pytest.mark.parametrize(
+ "pts, expected", [(pts, CLOCKWISE_PTS) for pts in POINTS_SET]
+)
def test_order_points_clockwise(pts, expected):
ordered_pts = order_points_clockwise(pts)
np.testing.assert_allclose(ordered_pts, expected)
-@pytest.mark.parametrize("pts, expected", [
- (pts, COUNTER_CLOCKWISE_PTS) for pts in POINTS_SET
-])
+@pytest.mark.parametrize(
+ "pts, expected", [(pts, COUNTER_CLOCKWISE_PTS) for pts in POINTS_SET]
+)
def test_order_points_counter_clockwise(pts, expected):
ordered_pts = order_points_clockwise(pts, inverse=True)
np.testing.assert_allclose(ordered_pts, expected)
@@ -80,7 +86,7 @@ def test_polygon_repr():
# Test __repr__ method
array = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
poly = Polygon(array)
- assert repr(poly) == f"Polygon({str(array)})"
+ assert repr(poly) == f"Polygon({array!s})"
def test_polygon_len():
@@ -103,14 +109,14 @@ def test_polygon_normalized():
# Test initialization with is_normalized=True
array = np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]])
poly = Polygon(array, is_normalized=True)
- assert_array_equal(poly._array, array.astype('float32'))
+ assert_array_equal(poly._array, array.astype("float32"))
assert poly.is_normalized
def test_polygon_invalid_array():
# Test initialization with invalid arrays
with pytest.raises(TypeError):
- invalid_array = "invalid" # Invalid type
+ invalid_array: Any = "invalid" # Invalid type
Polygon(invalid_array)
with pytest.raises(TypeError):
@@ -118,15 +124,15 @@ def test_polygon_invalid_array():
Polygon(invalid_array)
with pytest.raises(TypeError):
- invalid_array = [1, 2, 3] # Invalid type
+ invalid_array: Any = [1, 2, 3] # Invalid type
Polygon(invalid_array)
with pytest.raises(TypeError):
- invalid_array = [1, 2, 3, 4] # Invalid type
+ invalid_array: Any = [1, 2, 3, 4] # Invalid type
Polygon(invalid_array)
with pytest.raises(TypeError):
- invalid_array = Polygon() # Invalid type (empty Polygon instance)
+ invalid_array = cast(Any, Polygon)() # Invalid type (missing args)
Polygon(invalid_array)
@@ -152,7 +158,8 @@ def test_polygon_normalize():
poly = Polygon(array)
normalized_poly = poly.normalize(100.0, 200.0)
expected_normalized_array = np.array(
- [[0.1, 0.1], [0.3, 0.2], [0.5, 0.3]]).astype('float32')
+ [[0.1, 0.1], [0.3, 0.2], [0.5, 0.3]]
+ ).astype("float32")
assert_array_equal(normalized_poly._array, expected_normalized_array)
assert normalized_poly.is_normalized
@@ -163,9 +170,11 @@ def test_polygon_denormalize():
poly = Polygon(normalized_array, is_normalized=True)
denormalized_poly = poly.denormalize(100.0, 200.0)
expected_denormalized_array = np.array(
- [[10.0, 20.0], [30.0, 40.0], [50.0, 60.0]]).astype('float32')
+ [[10.0, 20.0], [30.0, 40.0], [50.0, 60.0]]
+ ).astype("float32")
np.testing.assert_allclose(
- denormalized_poly._array, expected_denormalized_array)
+ denormalized_poly._array, expected_denormalized_array
+ )
assert not denormalized_poly.is_normalized
@@ -173,7 +182,8 @@ def test_polygon_denormalize_non_normalized():
# Test denormalize method for non-is_normalized Polygon
array = np.array([[10.0, 20.0], [30.0, 40.0], [50.0, 60.0]])
poly = Polygon(array)
- denormalized_poly = poly.denormalize(100.0, 200.0)
+ with pytest.warns(UserWarning, match="Non-normalized polygon"):
+ denormalized_poly = poly.denormalize(100.0, 200.0)
assert not np.array_equal(denormalized_poly._array, array)
assert not denormalized_poly.is_normalized
@@ -191,7 +201,8 @@ def test_polygon_clip():
# Test clipping outside the range
clipped_poly = poly.clip(10, 20, 30, 40)
expected_clipped_array = np.array(
- [[10.0, 20.0], [10.0, 20.0], [10.0, 20.0]])
+ [[10.0, 20.0], [10.0, 20.0], [10.0, 20.0]]
+ )
assert_array_equal(clipped_poly._array, expected_clipped_array)
@@ -237,7 +248,7 @@ def test_polygon_scale_empty():
with pytest.raises(ValueError):
array = np.array([])
poly = Polygon(array)
- scaled_poly = poly.scale(1)
+ poly.scale(1)
def test_polygon_to_convexhull():
@@ -268,7 +279,7 @@ def test_polygon_to_box():
poly = Polygon(array)
# Test bounding box of the polygon in "xyxy" format
- box = poly.to_box(box_mode='xyxy')
+ box = poly.to_box(box_mode="xyxy")
expected_box = [10, 7, 23, 23]
assert box.tolist() == expected_box
@@ -323,27 +334,22 @@ def test_polygon_is_empty_with_threshold():
assert non_empty_poly.is_empty(threshold=3) is True
-test_props_input = np.array([
- [5, 0],
- [10, 5],
- [5, 10],
- [0, 5]
-])
+test_props_input = np.array([[5, 0], [10, 5], [5, 10], [0, 5]])
test_props_params = [
(
"moments",
{
- 'm00': 50.0,
- 'm10': 250.0,
- 'm01': 250.0,
- 'm20': 1458.3333333333333,
- 'm11': 1250.0,
- 'm02': 1458.3333333333333,
- 'm30': 9375.0,
- 'm21': 7291.666666666667,
- 'm12': 7291.666666666667,
- 'm03': 9375.0,
+ "m00": 50.0,
+ "m10": 250.0,
+ "m01": 250.0,
+ "m20": 1458.3333333333333,
+ "m11": 1250.0,
+ "m02": 1458.3333333333333,
+ "m30": 9375.0,
+ "m21": 7291.666666666667,
+ "m12": 7291.666666666667,
+ "m03": 9375.0,
},
),
("area", 50),
@@ -359,7 +365,7 @@ def test_polygon_is_empty_with_threshold():
]
-@ pytest.mark.parametrize('prop, expected', test_props_params)
+@pytest.mark.parametrize("prop, expected", test_props_params)
def test_polygon_property(prop, expected):
value = getattr(Polygon(test_props_input), prop)
if isinstance(value, (int, float)):
@@ -368,7 +374,7 @@ def test_polygon_property(prop, expected):
for k, v in expected.items():
np.testing.assert_allclose(value[k], v, rtol=1e-4)
elif isinstance(value, (list, tuple)):
- for v, e in zip(value, expected):
+ for v, e in zip(value, expected, strict=True):
np.testing.assert_allclose(v, e, rtol=1e-4)
@@ -396,7 +402,7 @@ def test_polygons_init():
# Test initialization with invalid input
with pytest.raises(TypeError):
- invalid_input = "invalid"
+ invalid_input: Any = "invalid"
polygons = Polygons(invalid_input)
@@ -441,7 +447,7 @@ def test_polygons_getitem():
# Test invalid input
with pytest.raises(TypeError):
- invalid_input = 1.5
+ invalid_input: Any = 1.5
polygons[invalid_input]
@@ -465,10 +471,14 @@ def test_polygons_to_min_boxpoints():
polygons = Polygons(polygons_list)
min_boxpoints_polygons = polygons.to_min_boxpoints()
assert len(min_boxpoints_polygons) == 2
- assert_array_equal(min_boxpoints_polygons[0]._array, np.array(
- [[5, 0], [10, 5], [5, 10], [0, 5]]))
- assert_array_equal(min_boxpoints_polygons[1]._array, np.array(
- [[50, 0], [100, 50], [50, 100], [0, 50]]))
+ assert_array_equal(
+ min_boxpoints_polygons[0]._array,
+ np.array([[5, 0], [10, 5], [5, 10], [0, 5]]),
+ )
+ assert_array_equal(
+ min_boxpoints_polygons[1]._array,
+ np.array([[50, 0], [100, 50], [50, 100], [0, 50]]),
+ )
def test_polygons_to_convexhull():
@@ -480,10 +490,14 @@ def test_polygons_to_convexhull():
convexhull_polygons = polygons.to_convexhull()
assert len(convexhull_polygons) == 2
- assert_array_equal(convexhull_polygons[0]._array, np.array(
- [[5, 0], [10, 5], [5, 10], [0, 5]]))
- assert_array_equal(convexhull_polygons[1]._array, np.array(
- [[50, 0], [100, 50], [50, 100], [0, 50]]))
+ assert_array_equal(
+ convexhull_polygons[0]._array,
+ np.array([[5, 0], [10, 5], [5, 10], [0, 5]]),
+ )
+ assert_array_equal(
+ convexhull_polygons[1]._array,
+ np.array([[50, 0], [100, 50], [50, 100], [0, 50]]),
+ )
def test_polygons_to_boxes():
@@ -536,11 +550,17 @@ def test_polygons_normalize():
w, h = 10.0, 10.0
normalized_polygons = polygons.normalize(w, h)
assert len(normalized_polygons) == 2
- assert_array_equal(normalized_polygons[0]._array, np.array(
- [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]], dtype='float32'))
- assert_array_equal(normalized_polygons[1]._array, np.array(
- [[0.7, 0.8], [0.9, 1.0]], dtype='float32'))
- assert normalized_polygons.is_normalized # Check if the is_normalized flag is True
+ assert_array_equal(
+ normalized_polygons[0]._array,
+ np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]], dtype="float32"),
+ )
+ assert_array_equal(
+ normalized_polygons[1]._array,
+ np.array([[0.7, 0.8], [0.9, 1.0]], dtype="float32"),
+ )
+ assert (
+ normalized_polygons.is_normalized
+ ) # Check if the is_normalized flag is True
def test_polygons_denormalize():
@@ -553,10 +573,13 @@ def test_polygons_denormalize():
w, h = 10.0, 10.0
denormalized_polygons = polygons.denormalize(w, h)
assert len(denormalized_polygons) == 2
- assert_array_equal(denormalized_polygons[0]._array, np.array(
- [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]))
- assert_array_equal(denormalized_polygons[1]._array, np.array(
- [[7.0, 8.0], [9.0, 10.0]]))
+ assert_array_equal(
+ denormalized_polygons[0]._array,
+ np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]),
+ )
+ assert_array_equal(
+ denormalized_polygons[1]._array, np.array([[7.0, 8.0], [9.0, 10.0]])
+ )
# Check if the is_normalized flag is False
assert not denormalized_polygons.is_normalized
@@ -571,10 +594,14 @@ def test_polygons_scale():
# Test scaling with distance=1 and default join_style (mitre)
scaled_polygons = polygons.scale(1)
assert len(scaled_polygons) == 2
- assert_array_equal(scaled_polygons[0]._array, np.array(
- [[9, 9], [9, 21], [21, 21], [21, 9]]))
- assert_array_equal(scaled_polygons[1]._array, np.array(
- [[9, 9], [9, 21], [21, 21], [21, 9]]))
+ assert_array_equal(
+ scaled_polygons[0]._array,
+ np.array([[9, 9], [9, 21], [21, 21], [21, 9]]),
+ )
+ assert_array_equal(
+ scaled_polygons[1]._array,
+ np.array([[9, 9], [9, 21], [21, 21], [21, 9]]),
+ )
# Check if no empty polygons after scaling
assert not scaled_polygons.is_empty().any()
@@ -587,7 +614,6 @@ def test_polygons_numpy():
polygons = Polygons(polygons_list)
non_flattened_numpy_array = polygons.numpy(flatten=False)
- expected_non_flattened_array = np.array([array1, array2], dtype=object)
for i in range(len(polygons)):
assert_array_equal(non_flattened_numpy_array[i], polygons[i]._array)
@@ -601,7 +627,9 @@ def test_polygons_to_list():
flattened_list = polygons.to_list(flatten=True)
expected_flattened_list = [
- [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [7.0, 8.0, 9.0, 10.0]]
+ [1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
+ [7.0, 8.0, 9.0, 10.0],
+ ]
assert flattened_list == expected_flattened_list
non_flattened_list = polygons.to_list(flatten=False)
@@ -617,26 +645,32 @@ def test_polygons_to_list():
test_ploygons_props_params = [
(
- "moments", [
+ "moments",
+ [
{
- 'm00': 50.0,
- 'm10': 250.0,
- 'm01': 250.0,
+ "m00": 50.0,
+ "m10": 250.0,
+ "m01": 250.0,
},
{
- 'm00': 50.0,
- 'm10': 250.0,
- 'm01': 250.0,
+ "m00": 50.0,
+ "m10": 250.0,
+ "m01": 250.0,
},
- ]
+ ],
),
("area", np.array([50, 50])),
("arclength", np.array([28.28427, 28.28427])),
("centroid", np.array([(5, 5), (5, 5)])),
("boundingbox", np.array([(0, 0, 10, 10), (0, 0, 10, 10)])),
("min_circle", [((5.0, 5.0), 5.0), ((5.0, 5.0), 5.0)]),
- ("min_box", [((5.0, 5.0), (7.07106, 7.07106), 45.0),
- ((5.0, 5.0), (7.07106, 7.07106), 45.0)]),
+ (
+ "min_box",
+ [
+ ((5.0, 5.0), (7.07106, 7.07106), 45.0),
+ ((5.0, 5.0), (7.07106, 7.07106), 45.0),
+ ],
+ ),
("orientation", np.array([45, 45])),
("min_box_wh", np.array([(7.07106, 7.07106), (7.07106, 7.07106)])),
("extent", np.array([0.5, 0.5])),
@@ -644,19 +678,199 @@ def test_polygons_to_list():
]
-@ pytest.mark.parametrize('prop, expected', test_ploygons_props_params)
+@pytest.mark.parametrize("prop, expected", test_ploygons_props_params)
def test_polygons_property(prop, expected):
value = getattr(Polygons(test_ploygons_props_input), prop)
if isinstance(value, (int, float, np.ndarray)):
np.testing.assert_allclose(value, expected, rtol=1e-4)
elif isinstance(value, list):
- for v, e in zip(value, expected):
+ for v, e in zip(value, expected, strict=True):
if isinstance(v, (list, tuple)):
- for vv, ee in zip(v, e):
+ for vv, ee in zip(v, e, strict=True):
np.testing.assert_allclose(
- np.array(vv), np.array(ee), rtol=1e-4)
+ np.array(vv), np.array(ee), rtol=1e-4
+ )
elif isinstance(v, dict):
for key, val in e.items():
assert v[key] == val
else:
np.testing.assert_allclose(np.array(v), np.array(e), rtol=1e-4)
+
+
+def test_order_points_clockwise_rejects_invalid_shape():
+ with pytest.raises(ValueError, match=r"shape \(4, 2\)"):
+ order_points_clockwise(np.zeros((3, 2), dtype=np.float32))
+
+
+def test_polygon_accepts_nx1x2_contour_array_and_eq_non_polygon():
+ contour = np.array([[[1.0, 2.0]], [[3.0, 4.0]]], dtype=np.float32)
+ poly = Polygon(contour)
+ assert poly.numpy().shape == (2, 2)
+ assert (poly == object()) is False
+
+
+def test_polygon_warns_on_double_normalize_and_clip_rejects_nan():
+ poly = Polygon(
+ np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]),
+ is_normalized=True,
+ )
+ with pytest.warns(UserWarning, match="forced to do normalization"):
+ _ = poly.normalize(10.0, 10.0)
+
+ with pytest.raises(ValueError, match="infinite or NaN"):
+ Polygon(np.array([[np.nan, 0.0], [1.0, 1.0], [2.0, 2.0]])).clip(
+ 0, 0, 10, 10
+ )
+
+
+def test_polygon_is_empty_validates_threshold_type():
+ poly = Polygon(np.array([[0.0, 0.0], [1.0, 0.0], [1.0, 1.0]]))
+ with pytest.raises(TypeError, match='expected "int"'):
+ poly.is_empty(threshold="3") # type: ignore[arg-type]
+
+
+def test_polygons_init_from_polygons_repr_and_eq_edge_cases():
+ polygons = Polygons(
+ [
+ np.array([[0.0, 0.0], [1.0, 0.0], [1.0, 1.0]]),
+ np.array([[0.0, 0.0], [0.0, 2.0], [2.0, 2.0]]),
+ ]
+ )
+ polygons2 = Polygons(polygons)
+ assert polygons2 == polygons
+ assert "Polygons(" in repr(polygons2)
+
+ assert (polygons == object()) is False
+
+ polygons_small = Polygons([np.array([[0.0, 0.0], [1.0, 0.0], [1.0, 1.0]])])
+ assert (polygons_small == polygons) is False
+
+
+def test_polygons_warns_on_double_normalize_denormalize_and_supports_clip_shift_tolist():
+ poly = Polygon(
+ np.array([[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]),
+ is_normalized=True,
+ )
+ polygons = Polygons([poly], is_normalized=True)
+ with pytest.warns(UserWarning, match="forced to do normalization"):
+ _ = polygons.normalize(10.0, 10.0)
+
+ with pytest.warns(UserWarning, match="forced to do denormalization"):
+ _ = Polygons([poly], is_normalized=False).denormalize(10.0, 10.0)
+
+ clipped = polygons.clip(0, 0, 1, 1)
+ shifted = polygons.shift(1.0, -1.0)
+ assert isinstance(clipped, Polygons)
+ assert isinstance(shifted, Polygons)
+ assert clipped.tolist() == clipped.to_list()
+
+
+def test_polygons_from_image_validates_type_and_filters_short_contours(
+ monkeypatch,
+):
+ import capybara.structures.polygons as poly_mod
+
+ with pytest.raises(TypeError, match=r"np\.ndarray"):
+ Polygons.from_image("not-an-array") # type: ignore[arg-type]
+
+ def fake_find_contours(image, *, mode, method):
+ assert isinstance(image, np.ndarray)
+ assert isinstance(mode, int)
+ assert isinstance(method, int)
+ return (
+ [
+ np.zeros((1, 1, 2), dtype=np.int32),
+ np.array([[[1, 2]], [[3, 4]]], dtype=np.int32),
+ ],
+ None,
+ )
+
+ monkeypatch.setattr(poly_mod.cv2, "findContours", fake_find_contours)
+ polys = Polygons.from_image(np.zeros((10, 10), dtype=np.uint8))
+ assert len(polys) == 1
+ assert polys[0].numpy().shape == (2, 2)
+
+
+def test_polygons_cat_validation_and_happy_path():
+ with pytest.raises(TypeError, match="should be a list"):
+ Polygons.cat("bad") # type: ignore[arg-type]
+
+ with pytest.raises(ValueError, match="is empty"):
+ Polygons.cat([])
+
+ with pytest.raises(TypeError, match="must be Polygon"):
+ Polygons.cat([Polygons([]), "bad"]) # type: ignore[list-item]
+
+ polys1 = Polygons([np.array([[0.0, 0.0], [1.0, 0.0], [1.0, 1.0]])])
+ polys2 = Polygons([np.array([[0.0, 0.0], [0.0, 2.0], [2.0, 2.0]])])
+ cat = Polygons.cat([polys1, polys2])
+ assert len(cat) == 2
+
+
+def test_polygon_scale_handles_multipolygon_and_empty_exterior(monkeypatch):
+ import capybara.structures.polygons as poly_mod
+
+ class _FakeExterior:
+ def __init__(self, *, xy, is_empty: bool) -> None:
+ self.xy = xy
+ self.is_empty = is_empty
+
+ class _FakeMultiPolygon:
+ def __init__(self, geoms) -> None:
+ self.geoms = geoms
+
+ class _FakeShapelyPolygon:
+ def __init__(
+ self,
+ arr,
+ *,
+ area: float = 0.0,
+ exterior_empty: bool = False,
+ xy=None,
+ ) -> None:
+ self._area = area
+ self.exterior = _FakeExterior(
+ xy=xy or ([], []), is_empty=exterior_empty
+ )
+ self._arr = np.array(arr, dtype=np.float32)
+
+ @property
+ def area(self) -> float:
+ return self._area
+
+ def buffer(self, *_args, **_kwargs):
+ p1 = _FakeShapelyPolygon(
+ self._arr,
+ area=1.0,
+ exterior_empty=False,
+ xy=([0.0, 0.0, 1.0, 1.0], [0.0, 1.0, 1.0, 0.0]),
+ )
+ p2 = _FakeShapelyPolygon(
+ self._arr,
+ area=2.0,
+ exterior_empty=False,
+ xy=([0.0, 0.0, 2.0, 2.0], [0.0, 2.0, 2.0, 0.0]),
+ )
+ return _FakeMultiPolygon([p1, p2])
+
+ monkeypatch.setattr(poly_mod, "_Polygon_shapely", _FakeShapelyPolygon)
+ monkeypatch.setattr(poly_mod, "MultiPolygon", _FakeMultiPolygon)
+
+ poly = Polygon(np.array([[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0]]))
+ scaled = poly.scale(1)
+ assert scaled.numpy().shape == poly.numpy().shape
+
+ class _FakeShapelyPolygonEmptyExterior(_FakeShapelyPolygon):
+ def buffer(self, *_args, **_kwargs):
+ return _FakeShapelyPolygon(
+ self._arr,
+ area=1.0,
+ exterior_empty=True,
+ xy=([], []),
+ )
+
+ monkeypatch.setattr(
+ poly_mod, "_Polygon_shapely", _FakeShapelyPolygonEmptyExterior
+ )
+ empty_scaled = poly.scale(1)
+ assert empty_scaled.is_empty()
diff --git a/tests/test_hidden_bugs_regressions.py b/tests/test_hidden_bugs_regressions.py
new file mode 100644
index 0000000..0374b1a
--- /dev/null
+++ b/tests/test_hidden_bugs_regressions.py
@@ -0,0 +1,160 @@
+from __future__ import annotations
+
+import os
+import subprocess
+import sys
+from pathlib import Path
+
+import numpy as np
+import pytest
+
+
+def test_pyproject_onnxruntime_gpu_extra_uses_hyphenated_package_name():
+ text = (Path(__file__).resolve().parents[1] / "pyproject.toml").read_text(
+ encoding="utf-8"
+ )
+ assert "onnxruntime-gpu>=1.22.0,<2" in text
+ assert "onnxruntime_gpu>=1.22.0,<2" not in text
+
+
+def test_get_files_does_not_return_directories_when_suffix_none(tmp_path):
+ from capybara.utils.files_utils import get_files
+
+ (tmp_path / "a.txt").write_text("a", encoding="utf-8")
+ (tmp_path / "sub").mkdir()
+ (tmp_path / "sub" / "b.txt").write_text("b", encoding="utf-8")
+
+ paths = get_files(tmp_path, suffix=None, recursive=True, sort_path=False)
+ assert paths
+ assert all(Path(p).is_file() for p in paths)
+
+
+def test_get_onnx_infos_preserve_symbolic_dim_params(tmp_path):
+ import onnx
+ from onnx import TensorProto, helper
+
+ from capybara.onnxengine import utils
+
+ input_info = helper.make_tensor_value_info(
+ "input", TensorProto.FLOAT, ["batch", 3]
+ )
+ output_info = helper.make_tensor_value_info(
+ "output", TensorProto.FLOAT, ["batch", 3]
+ )
+ node = helper.make_node("Identity", ["input"], ["output"])
+ graph = helper.make_graph([node], "sym-graph", [input_info], [output_info])
+ model = helper.make_model(graph)
+ path = tmp_path / "sym.onnx"
+ onnx.save(model, path)
+
+ inputs = utils.get_onnx_input_infos(path)
+ outputs = utils.get_onnx_output_infos(path)
+
+ assert inputs["input"]["shape"][0] == "batch"
+ assert outputs["output"]["shape"][0] == "batch"
+
+
+def test_pad_constant_value_fills_all_channels_for_rgba():
+ from capybara.vision.functionals import pad
+
+ img = np.zeros((10, 10, 4), dtype=np.uint8)
+ out = pad(img, pad_size=1, pad_value=128)
+ assert out.shape == (12, 12, 4)
+ assert out.dtype == np.uint8
+ assert out[0, 0].tolist() == [128, 128, 128, 128]
+
+
+def test_imrotate_constant_border_fills_all_channels_for_rgba():
+ from capybara import BORDER
+ from capybara.vision.geometric import imrotate
+
+ img = np.zeros((20, 20, 4), dtype=np.uint8)
+ out = imrotate(
+ img,
+ angle=45,
+ expand=False,
+ bordertype=BORDER.CONSTANT,
+ bordervalue=128,
+ )
+ assert out.shape == img.shape
+ assert out[0, 0].tolist() == [128, 128, 128, 128]
+
+
+def test_imrotate_preserves_dtype_for_float32_inputs():
+ from capybara import BORDER
+ from capybara.vision.geometric import imrotate
+
+ img = np.zeros((20, 20, 3), dtype=np.float32)
+ out = imrotate(
+ img,
+ angle=10,
+ expand=False,
+ bordertype=BORDER.CONSTANT,
+ bordervalue=0,
+ )
+ assert out.dtype == np.float32
+
+
+def test_visualization_package_import_is_lazy():
+ repo_root = Path(__file__).resolve().parents[1]
+ code = """
+import sys
+sys.modules.pop("capybara.vision.visualization.draw", None)
+sys.modules.pop("capybara.vision.visualization.utils", None)
+sys.modules.pop("capybara.vision.visualization", None)
+import capybara.vision.visualization as vis
+assert "capybara.vision.visualization.draw" not in sys.modules
+assert "capybara.vision.visualization.utils" not in sys.modules
+"""
+ env = {**os.environ, "PYTHONPATH": str(repo_root)}
+ subprocess.run([sys.executable, "-c", code], env=env, check=True)
+
+
+def test_draw_text_falls_back_when_font_files_missing(monkeypatch, tmp_path):
+ import capybara.vision.visualization.draw as draw_mod
+
+ monkeypatch.setattr(draw_mod, "DEFAULT_FONT_PATH", tmp_path / "missing.ttf")
+
+ img = np.full((40, 160, 3), 255, dtype=np.uint8)
+ out = draw_mod.draw_text(
+ img.copy(),
+ "hello",
+ location=(5, 5),
+ color=(0, 0, 255),
+ text_size=18,
+ font_path=tmp_path / "also_missing.ttf",
+ )
+ assert out.shape == img.shape
+ assert not np.array_equal(out, img)
+
+
+def test_draw_line_handles_zero_length_and_validates_gap():
+ import capybara.vision.visualization.draw as draw_mod
+
+ img = np.zeros((40, 40, 3), dtype=np.uint8)
+ out = draw_mod.draw_line(
+ img.copy(),
+ pt1=(10, 10),
+ pt2=(10, 10),
+ color=(0, 255, 0),
+ thickness=2,
+ style="line",
+ gap=8,
+ inplace=False,
+ )
+ assert out.shape == img.shape
+ assert out.sum() > 0
+
+ with pytest.raises(ValueError, match="gap must be > 0"):
+ draw_mod.draw_line(img.copy(), (0, 0), (10, 10), gap=0)
+
+
+def test_draw_mask_minmax_normalize_constant_mask_is_safe():
+ import capybara.vision.visualization.draw as draw_mod
+
+ img = np.zeros((20, 30, 3), dtype=np.uint8)
+ mask = np.full((20, 30), 7, dtype=np.uint8)
+
+ with np.errstate(divide="raise", invalid="raise"):
+ out = draw_mod.draw_mask(img, mask, min_max_normalize=True)
+ assert out.shape == img.shape
diff --git a/tests/test_mixins.py b/tests/test_mixins.py
index 062055c..a89d1e1 100644
--- a/tests/test_mixins.py
+++ b/tests/test_mixins.py
@@ -1,82 +1,75 @@
from copy import deepcopy
-from dataclasses import dataclass, field
+from dataclasses import dataclass
from enum import Enum
-from typing import Any, Callable, Dict, List
+from typing import Any, ClassVar
import numpy as np
import pytest
import capybara as cb
-from capybara import DataclassCopyMixin, DataclassToJsonMixin, EnumCheckMixin, dict_to_jsonable
-
-MockImage = np.zeros((5, 5, 3), dtype="uint8")
-base64png_Image = "iVBORw0KGgoAAAANSUhEUgAAAAUAAAAFCAIAAAACDbGyAAAADElEQVQIHWNgoC4AAABQAAFhFZyBAAAAAElFTkSuQmCC"
-base64npy_Image = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
+from capybara import (
+ DataclassCopyMixin,
+ DataclassToJsonMixin,
+ EnumCheckMixin,
+ dict_to_jsonable,
+)
+
+MOCK_IMAGE = np.zeros((5, 5, 3), dtype="uint8")
+BASE64_PNG_IMAGE = "iVBORw0KGgoAAAANSUhEUgAAAAUAAAAFCAIAAAACDbGyAAAADElEQVQIHWNgoC4AAABQAAFhFZyBAAAAAElFTkSuQmCC"
+BASE64_NPY_IMAGE = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
data = [
(
- dict(
- box=cb.Box((0, 0, 1, 1)),
- boxes=cb.Boxes([(0, 0, 1, 1)]),
- keypoints=cb.Keypoints([(0, 1), (1, 0)]),
- keypoints_list=cb.KeypointsList([[(0, 1), (1, 0)], [(0, 1), (2, 0)]]),
- polygon=cb.Polygon([(0, 0), (1, 0), (1, 1)]),
- polygons=cb.Polygons([[(0, 0), (1, 0), (1, 1)]]),
- np_bool=np.bool_(True),
- np_float=np.float64(1),
- np_number=np.array(1),
- np_array=np.array([1, 2]),
- image=MockImage,
- dict=dict(box=cb.Box((0, 0, 1, 1))),
- str="test",
- int=1,
- float=0.6,
- tuple=(1, 1),
- pow=1e10,
- ),
- dict(
- image=lambda x: cb.img_to_b64str(x, cb.IMGTYP.PNG),
- ),
- dict(
- box=[0, 0, 1, 1],
- boxes=[[0, 0, 1, 1]],
- keypoints=[[0, 1], [1, 0]],
- keypoints_list=[[[0, 1], [1, 0]], [[0, 1], [2, 0]]],
- polygon=[[0, 0], [1, 0], [1, 1]],
- polygons=[[[0, 0], [1, 0], [1, 1]]],
- np_bool=True,
- np_float=1.0,
- np_number=1,
- np_array=[1, 2],
- image=base64png_Image,
- dict=dict(box=[0, 0, 1, 1]),
- str="test",
- int=1,
- float=0.6,
- tuple=[1, 1],
- pow=1e10,
- ),
+ {
+ "box": cb.Box((0, 0, 1, 1)),
+ "boxes": cb.Boxes([(0, 0, 1, 1)]),
+ "keypoints": cb.Keypoints([(0, 1), (1, 0)]),
+ "keypoints_list": cb.KeypointsList(
+ [[(0, 1), (1, 0)], [(0, 1), (2, 0)]]
+ ),
+ "polygon": cb.Polygon([(0, 0), (1, 0), (1, 1)]),
+ "polygons": cb.Polygons([[(0, 0), (1, 0), (1, 1)]]),
+ "np_bool": np.bool_(True),
+ "np_float": np.float64(1),
+ "np_number": np.array(1),
+ "np_array": np.array([1, 2]),
+ "image": MOCK_IMAGE,
+ "dict": {"box": cb.Box((0, 0, 1, 1))},
+ "str": "test",
+ "int": 1,
+ "float": 0.6,
+ "tuple": (1, 1),
+ "pow": 1e10,
+ },
+ {"image": lambda x: cb.img_to_b64str(x, cb.IMGTYP.PNG)},
+ {
+ "box": [0, 0, 1, 1],
+ "boxes": [[0, 0, 1, 1]],
+ "keypoints": [[0, 1], [1, 0]],
+ "keypoints_list": [[[0, 1], [1, 0]], [[0, 1], [2, 0]]],
+ "polygon": [[0, 0], [1, 0], [1, 1]],
+ "polygons": [[[0, 0], [1, 0], [1, 1]]],
+ "np_bool": True,
+ "np_float": 1.0,
+ "np_number": 1,
+ "np_array": [1, 2],
+ "image": BASE64_PNG_IMAGE,
+ "dict": {"box": [0, 0, 1, 1]},
+ "str": "test",
+ "int": 1,
+ "float": 0.6,
+ "tuple": [1, 1],
+ "pow": 1e10,
+ },
),
(
- dict(
- image=MockImage,
- ),
- dict(
- image=lambda x: cb.npy_to_b64str(x),
- ),
- dict(
- image=base64npy_Image,
- ),
+ {"image": MOCK_IMAGE},
+ {"image": lambda x: cb.npy_to_b64str(x)},
+ {"image": BASE64_NPY_IMAGE},
),
(
- dict(
- images=[dict(image=MockImage)],
- ),
- dict(
- image=lambda x: cb.npy_to_b64str(x),
- ),
- dict(
- images=[dict(image=base64npy_Image)],
- ),
+ {"images": [{"image": MOCK_IMAGE}]},
+ {"image": lambda x: cb.npy_to_b64str(x)},
+ {"images": [{"image": BASE64_NPY_IMAGE}]},
),
]
@@ -87,6 +80,8 @@ def test_dict_to_jsonable(x, jsonable_func, expected):
class TestEnum(EnumCheckMixin, Enum):
+ __test__ = False
+
FIRST = 1
SECOND = "two"
@@ -112,8 +107,10 @@ def test_obj_to_enum_with_invalid_int(self):
@dataclass
class TestDataclass(DataclassCopyMixin):
+ __test__ = False
+
int_field: int
- list_field: List[Any]
+ list_field: list[Any]
class TestDataclassCopyMixin:
@@ -131,12 +128,19 @@ def test_deep_copy(self, test_dataclass_instance):
deepcopy_instance = deepcopy(test_dataclass_instance)
assert deepcopy_instance is not test_dataclass_instance
assert deepcopy_instance.int_field == test_dataclass_instance.int_field
- assert deepcopy_instance.list_field is not test_dataclass_instance.list_field
- assert deepcopy_instance.list_field == test_dataclass_instance.list_field
+ assert (
+ deepcopy_instance.list_field
+ is not test_dataclass_instance.list_field
+ )
+ assert (
+ deepcopy_instance.list_field == test_dataclass_instance.list_field
+ )
@dataclass
class TestDataclass2(DataclassToJsonMixin):
+ __test__ = False
+
box: cb.Box
boxes: cb.Boxes
keypoints: cb.Keypoints
@@ -157,7 +161,7 @@ class TestDataclass2(DataclassToJsonMixin):
py_pow: float
# based on mixin
- jsonable_func = {
+ jsonable_func: ClassVar[dict[str, Any]] = {
"image": lambda x: cb.img_to_b64str(x, cb.IMGTYP.PNG),
"np_array_to_b64str": lambda x: cb.npy_to_b64str(x),
}
@@ -174,16 +178,18 @@ def test_dataclass_instance(self):
box=cb.Box((0, 0, 1, 1)),
boxes=cb.Boxes([(0, 0, 1, 1)]),
keypoints=cb.Keypoints([(0, 1), (1, 0)]),
- keypoints_list=cb.KeypointsList([[(0, 1), (1, 0)], [(0, 1), (2, 0)]]),
+ keypoints_list=cb.KeypointsList(
+ [[(0, 1), (1, 0)], [(0, 1), (2, 0)]]
+ ),
polygon=cb.Polygon([(0, 0), (1, 0), (1, 1)]),
polygons=cb.Polygons([[(0, 0), (1, 0), (1, 1)]]),
np_bool=np.bool_(True),
np_float=np.float64(1),
np_number=np.array(1),
np_array=np.array([1, 2]),
- image=MockImage,
+ image=MOCK_IMAGE,
np_array_to_b64str=np_array_to_b64str,
- py_dict=dict(box=cb.Box((0, 0, 1, 1))),
+ py_dict={"box": cb.Box((0, 0, 1, 1))},
py_str="test",
py_int=1,
py_float=0.6,
@@ -194,26 +200,49 @@ def test_dataclass_instance(self):
@pytest.fixture
def test_expected(self):
- return dict(
- box=[0, 0, 1, 1],
- boxes=[[0, 0, 1, 1]],
- keypoints=[[0, 1], [1, 0]],
- keypoints_list=[[[0, 1], [1, 0]], [[0, 1], [2, 0]]],
- polygon=[[0, 0], [1, 0], [1, 1]],
- polygons=[[[0, 0], [1, 0], [1, 1]]],
- np_bool=True,
- np_float=1.0,
- np_number=1,
- np_array=[1, 2],
- image=base64png_Image,
- np_array_to_b64str=np_array_to_b64str_b64,
- py_dict=dict(box=[0, 0, 1, 1]),
- py_str="test",
- py_int=1,
- py_float=0.6,
- py_tuple=[1, 1],
- py_pow=1e10,
- )
+ return {
+ "box": [0, 0, 1, 1],
+ "boxes": [[0, 0, 1, 1]],
+ "keypoints": [[0, 1], [1, 0]],
+ "keypoints_list": [[[0, 1], [1, 0]], [[0, 1], [2, 0]]],
+ "polygon": [[0, 0], [1, 0], [1, 1]],
+ "polygons": [[[0, 0], [1, 0], [1, 1]]],
+ "np_bool": True,
+ "np_float": 1.0,
+ "np_number": 1,
+ "np_array": [1, 2],
+ "image": BASE64_PNG_IMAGE,
+ "np_array_to_b64str": np_array_to_b64str_b64,
+ "py_dict": {"box": [0, 0, 1, 1]},
+ "py_str": "test",
+ "py_int": 1,
+ "py_float": 0.6,
+ "py_tuple": [1, 1],
+ "py_pow": 1e10,
+ }
def test_be_jsonable(self, test_dataclass_instance, test_expected):
assert test_dataclass_instance.be_jsonable() == test_expected
+
+
+def test_dict_to_jsonable_converts_enum_to_name():
+ class LocalEnum(Enum):
+ A = 1
+
+ out = dict_to_jsonable({"e": LocalEnum.A})
+ assert out["e"] == "A"
+
+
+def test_dict_to_jsonable_warns_when_output_is_not_jsonable():
+ with pytest.warns(UserWarning, match="not JSON serializable"):
+ out = dict_to_jsonable({"bad": {1, 2}})
+ assert out["bad"] == {1, 2}
+
+
+def test_dataclass_copy_mixin_requires_dataclass_instance():
+ class NotADataclass(DataclassCopyMixin):
+ def __init__(self) -> None:
+ self.x = 1
+
+ with pytest.raises(TypeError, match="not a dataclass"):
+ NotADataclass().__copy__()
diff --git a/tests/test_runtime.py b/tests/test_runtime.py
new file mode 100644
index 0000000..0bd5637
--- /dev/null
+++ b/tests/test_runtime.py
@@ -0,0 +1,306 @@
+import builtins
+import sys
+import types
+from typing import Any
+
+import pytest
+
+from capybara import runtime as runtime_module
+from capybara.runtime import Backend, Runtime
+
+
+def test_backend_from_any_respects_runtime_boundaries():
+ cuda = Backend.from_any("cuda", runtime="onnx")
+ assert cuda.name == "cuda"
+
+ with pytest.raises(ValueError):
+ Backend.from_any("cuda", runtime="openvino")
+
+
+def test_runtime_from_any_accepts_runtime_instances():
+ rt = Runtime.onnx
+ assert Runtime.from_any(rt) is rt
+
+
+def test_runtime_normalize_backend_defaults():
+ rt = Runtime.from_any("onnx")
+ backend = rt.normalize_backend(None)
+
+ assert backend.name == rt.default_backend_name
+ assert [b.name for b in rt.available_backends()] == list(rt.backend_names)
+
+
+def test_runtime_accepts_backend_instances():
+ rt = Runtime.from_any("openvino")
+ backend = rt.normalize_backend(Backend.ov_gpu)
+
+ assert backend.device == "GPU"
+ assert backend.runtime == rt.name
+
+
+def test_auto_backend_prefers_tensorrt(monkeypatch):
+ rt = Runtime.from_any("onnx")
+ monkeypatch.setattr(
+ "capybara.runtime._get_available_onnx_providers",
+ lambda: {
+ "TensorrtExecutionProvider",
+ "CUDAExecutionProvider",
+ "CPUExecutionProvider",
+ },
+ )
+
+ # When both TensorRT and CUDA providers are visible, prefer CUDA by default.
+ assert rt.auto_backend_name() == "cuda"
+
+
+def test_auto_backend_prefers_rtx(monkeypatch):
+ rt = Runtime.from_any("onnx")
+ monkeypatch.setattr(
+ "capybara.runtime._get_available_onnx_providers",
+ lambda: {"NvTensorRTRTXExecutionProvider"},
+ )
+
+ assert rt.auto_backend_name() == "tensorrt_rtx"
+
+
+def test_auto_backend_falls_back_to_cpu(monkeypatch):
+ rt = Runtime.from_any("onnx")
+ monkeypatch.setattr(
+ "capybara.runtime._get_available_onnx_providers",
+ lambda: {"CPUExecutionProvider"},
+ )
+
+ assert rt.auto_backend_name() == "cpu"
+
+
+def test_auto_backend_pt_prefers_cuda(monkeypatch):
+ rt = Runtime.from_any("pt")
+ monkeypatch.setattr(
+ runtime_module,
+ "_get_torch_capabilities",
+ lambda: (True, True),
+ )
+ assert rt.auto_backend_name() == "cuda"
+
+
+def test_auto_backend_pt_defaults_to_cpu(monkeypatch):
+ rt = Runtime.from_any("pt")
+ monkeypatch.setattr(
+ runtime_module,
+ "_get_torch_capabilities",
+ lambda: (False, False),
+ )
+ assert rt.auto_backend_name() == rt.default_backend_name
+
+
+def test_auto_backend_openvino_uses_default_when_no_devices(monkeypatch):
+ rt = Runtime.from_any("openvino")
+
+ monkeypatch.setattr(
+ runtime_module,
+ "_get_openvino_devices",
+ lambda: set(),
+ )
+
+ assert rt.auto_backend_name() == rt.default_backend_name
+
+
+def test_auto_backend_openvino_prefers_gpu(monkeypatch):
+ rt = Runtime.from_any("openvino")
+
+ monkeypatch.setattr(
+ runtime_module,
+ "_get_openvino_devices",
+ lambda: {"GPU.0", "CPU"},
+ )
+
+ assert rt.auto_backend_name() == "gpu"
+
+
+def test_auto_backend_openvino_prefers_npu(monkeypatch):
+ rt = Runtime.from_any("openvino")
+
+ monkeypatch.setattr(
+ runtime_module,
+ "_get_openvino_devices",
+ lambda: {"NPU", "CPU"},
+ )
+
+ assert rt.auto_backend_name() == "npu"
+
+
+def test_auto_backend_returns_default_for_unknown_runtime():
+ runtime_key = "custom_auto"
+ backend_name = "alpha"
+ Backend(name=backend_name, runtime_key=runtime_key)
+ try:
+ rt = Runtime(
+ name=runtime_key,
+ backend_names=(backend_name,),
+ default_backend_name=backend_name,
+ )
+ assert rt.auto_backend_name() == rt.default_backend_name
+ finally:
+ Runtime._REGISTRY.pop(runtime_key, None)
+ namespace = Backend._REGISTRY.get(runtime_key, {})
+ namespace.pop(backend_name, None)
+ if not namespace:
+ Backend._REGISTRY.pop(runtime_key, None)
+
+
+def test_backend_registration_rejects_duplicates():
+ runtime_key = "temp_runtime"
+ name = "temp_backend"
+ Backend(name=name, runtime_key=runtime_key)
+ try:
+ with pytest.raises(ValueError):
+ Backend(name=name, runtime_key=runtime_key)
+ finally:
+ namespace = Backend._REGISTRY.get(runtime_key, {})
+ namespace.pop(name, None)
+ if not namespace:
+ Backend._REGISTRY.pop(runtime_key, None)
+
+
+def test_backend_from_any_requires_runtime_when_many_registered():
+ with pytest.raises(ValueError):
+ Backend.from_any("cpu")
+
+
+def test_backend_instance_must_match_runtime():
+ cuda = Backend.from_any("cuda", runtime="onnx")
+ with pytest.raises(ValueError):
+ Backend.from_any(cuda, runtime="openvino")
+
+
+def test_runtime_duplicate_registration_rejected():
+ with pytest.raises(ValueError):
+ Runtime(
+ name="onnx",
+ backend_names=("cpu",),
+ default_backend_name="cpu",
+ )
+
+
+def test_runtime_unknown_backend_reference():
+ with pytest.raises(ValueError):
+ Runtime(
+ name="ghost",
+ backend_names=("missing",),
+ default_backend_name="missing",
+ )
+
+
+def test_runtime_default_backend_must_be_known():
+ runtime_key = "custom_runtime"
+ backend_name = "alpha"
+ Backend(name=backend_name, runtime_key=runtime_key)
+ try:
+ with pytest.raises(ValueError):
+ Runtime(
+ name=runtime_key,
+ backend_names=(backend_name,),
+ default_backend_name="beta",
+ )
+ finally:
+ namespace = Backend._REGISTRY.get(runtime_key, {})
+ namespace.pop(backend_name, None)
+ if not namespace:
+ Backend._REGISTRY.pop(runtime_key, None)
+
+
+def test_get_available_onnx_providers_handles_import_failure(monkeypatch):
+ real_import = builtins.__import__
+
+ def fake_import(name, *args, **kwargs):
+ if name == "onnxruntime":
+ raise ModuleNotFoundError("boom")
+ return real_import(name, *args, **kwargs)
+
+ monkeypatch.setattr(builtins, "__import__", fake_import)
+
+ assert runtime_module._get_available_onnx_providers() == set()
+
+
+def test_get_available_onnx_providers_reads_module(monkeypatch):
+ module = types.SimpleNamespace(
+ get_available_providers=lambda: [
+ "CUDAExecutionProvider",
+ "CPUExecutionProvider",
+ ]
+ )
+ monkeypatch.setitem(sys.modules, "onnxruntime", module)
+ try:
+ providers = runtime_module._get_available_onnx_providers()
+ assert providers == {"CUDAExecutionProvider", "CPUExecutionProvider"}
+ finally:
+ monkeypatch.delitem(sys.modules, "onnxruntime", raising=False)
+
+
+def test_get_torch_capabilities_handles_import_failure(monkeypatch):
+ real_import = builtins.__import__
+
+ def fake_import(name, *args, **kwargs):
+ if name == "torch":
+ raise ModuleNotFoundError("boom")
+ return real_import(name, *args, **kwargs)
+
+ monkeypatch.setattr(builtins, "__import__", fake_import)
+
+ assert runtime_module._get_torch_capabilities() == (False, False)
+
+
+def test_get_torch_capabilities_reads_cuda_available(monkeypatch):
+ fake_torch: Any = types.ModuleType("torch")
+
+ class FakeCuda:
+ @staticmethod
+ def is_available() -> bool:
+ return True
+
+ fake_torch.cuda = FakeCuda
+ monkeypatch.setitem(sys.modules, "torch", fake_torch)
+ try:
+ assert runtime_module._get_torch_capabilities() == (True, True)
+ finally:
+ monkeypatch.delitem(sys.modules, "torch", raising=False)
+
+
+def test_get_openvino_devices_handles_import_failure(monkeypatch):
+ real_import = builtins.__import__
+
+ def fake_import(name, *args, **kwargs):
+ if name.startswith("openvino"):
+ raise ModuleNotFoundError("boom")
+ return real_import(name, *args, **kwargs)
+
+ monkeypatch.setattr(builtins, "__import__", fake_import)
+
+ assert runtime_module._get_openvino_devices() == set()
+
+
+def test_get_openvino_devices_reads_core_devices(monkeypatch):
+ fake_ov_runtime: Any = types.ModuleType("openvino.runtime")
+
+ class FakeCore:
+ available_devices = ("GPU.0", "CPU")
+
+ fake_ov_runtime.Core = FakeCore
+
+ fake_ov: Any = types.ModuleType("openvino")
+ fake_ov.runtime = fake_ov_runtime
+
+ monkeypatch.setitem(sys.modules, "openvino", fake_ov)
+ monkeypatch.setitem(sys.modules, "openvino.runtime", fake_ov_runtime)
+ try:
+ assert runtime_module._get_openvino_devices() == {"GPU.0", "CPU"}
+ finally:
+ monkeypatch.delitem(sys.modules, "openvino.runtime", raising=False)
+ monkeypatch.delitem(sys.modules, "openvino", raising=False)
+
+
+def test_backend_from_any_infers_runtime_when_single(monkeypatch):
+ monkeypatch.setattr(Backend, "_REGISTRY", {})
+ Backend(name="alpha", runtime_key="solo")
+ backend = Backend.from_any("alpha")
+ assert backend.runtime == "solo"
diff --git a/tests/test_torchengine.py b/tests/test_torchengine.py
new file mode 100644
index 0000000..8472684
--- /dev/null
+++ b/tests/test_torchengine.py
@@ -0,0 +1,444 @@
+from __future__ import annotations
+
+import sys
+import types
+from types import SimpleNamespace
+from typing import cast
+
+import numpy as np
+import pytest
+
+from capybara.torchengine import TorchEngine, TorchEngineConfig
+from capybara.torchengine import engine as engine_module
+
+
+class _DummyContext:
+ def __enter__(self):
+ return None
+
+ def __exit__(self, exc_type, exc, tb):
+ return False
+
+
+class _FakeDevice:
+ def __init__(self, spec: str | _FakeDevice = "cpu"):
+ if isinstance(spec, _FakeDevice):
+ self.type = spec.type
+ self._spec = spec._spec
+ else:
+ spec_str = str(spec)
+ self._spec = spec_str
+ self.type = "cuda" if spec_str.startswith("cuda") else "cpu"
+
+ def __str__(self) -> str:
+ return self._spec
+
+
+class _FakeTensor:
+ def __init__(self, array: np.ndarray, dtype, device: _FakeDevice):
+ self._array = np.asarray(array, dtype=np.float32)
+ self.dtype = dtype
+ self.device = device
+
+ def to(self, target):
+ if isinstance(target, _FakeDevice):
+ self.device = target
+ elif isinstance(target, str):
+ self.device = _FakeDevice(target)
+ else:
+ self.dtype = target
+ return self
+
+ def detach(self):
+ return self
+
+ def contiguous(self):
+ return self
+
+ def numpy(self) -> np.ndarray:
+ return np.array(self._array, copy=True)
+
+
+class _FakeTorchModel:
+ def __init__(self, torch_ref):
+ self._torch = torch_ref
+ self.dtype = torch_ref.float32
+ self.device = _FakeDevice("cpu")
+ self.eval_called = False
+ self.calls = 0
+
+ def eval(self):
+ self.eval_called = True
+ return self
+
+ def to(self, device=None, dtype=None):
+ if isinstance(device, _FakeDevice):
+ self.device = device
+ if dtype in (self._torch.float16, self._torch.float32):
+ self.dtype = dtype
+ return self
+
+ def half(self):
+ self.dtype = self._torch.float16
+ return self
+
+ def float(self):
+ self.dtype = self._torch.float32
+ return self
+
+ def __call__(self, *tensors):
+ self.calls += 1
+ outputs = []
+ for idx in range(2):
+ arr = np.full(
+ (1, 2, 2, 2), fill_value=self.calls + idx, dtype=np.float32
+ )
+ outputs.append(_FakeTensor(arr, self.dtype, self.device))
+ return tuple(outputs)
+
+
+class _FakeTorchModule:
+ def __init__(self):
+ class _FakeDType:
+ pass
+
+ self.dtype = _FakeDType
+ self.float16 = _FakeDType()
+ self.float32 = _FakeDType()
+ self.cuda = SimpleNamespace(
+ synchronize=lambda device: self._sync_calls.append(str(device)),
+ is_available=lambda: True,
+ )
+ self._sync_calls: list[str] = []
+ self._loaded_paths: list[str] = []
+ self.jit = SimpleNamespace(load=self._load)
+
+ def _load(self, path, map_location=None):
+ self._loaded_paths.append(str(path))
+ model = _FakeTorchModel(self)
+ model.to(map_location)
+ return model
+
+ def device(self, spec):
+ return _FakeDevice(spec)
+
+ def no_grad(self):
+ return _DummyContext()
+
+ def inference_mode(self):
+ return _DummyContext()
+
+ def from_numpy(self, array):
+ return _FakeTensor(
+ np.asarray(array, dtype=np.float32),
+ self.float32,
+ _FakeDevice("cpu"),
+ )
+
+ def is_tensor(self, obj):
+ return isinstance(obj, _FakeTensor)
+
+
+@pytest.fixture
+def fake_torch(monkeypatch):
+ stub = _FakeTorchModule()
+ monkeypatch.setattr(engine_module, "_lazy_import_torch", lambda: stub)
+ return stub
+
+
+def test_lazy_import_torch_prefers_sysmodules(monkeypatch):
+ fake = types.ModuleType("torch")
+ monkeypatch.setitem(sys.modules, "torch", fake)
+ assert engine_module._lazy_import_torch() is fake
+
+
+def test_torch_engine_formats_outputs(fake_torch, tmp_path):
+ model_path = tmp_path / "fake.pt"
+ model_path.write_bytes(b"torchscript")
+ engine = TorchEngine(
+ model_path,
+ device="cuda:1",
+ output_names=("feat_s8", "feat_s16"),
+ )
+ inputs = {
+ "image": np.zeros((1, 3, 4, 4), dtype=np.float32),
+ }
+ outputs = engine.run(inputs)
+
+ assert set(outputs.keys()) == {"feat_s8", "feat_s16"}
+ for value in outputs.values():
+ assert value.dtype == np.float32
+ assert value.shape == (1, 2, 2, 2)
+
+ # Ensure TorchEngine selected the CUDA device and casted dtype.
+ assert engine.device.type == "cuda"
+ assert engine.dtype == fake_torch.float32
+ assert fake_torch._loaded_paths == [str(model_path)]
+
+
+def test_torch_engine_call_accepts_kwargs_feed(fake_torch, tmp_path):
+ model_path = tmp_path / "demo.pt"
+ model_path.write_bytes(b"torchscript")
+
+ engine = TorchEngine(
+ model_path,
+ device="cpu",
+ output_names=("feat_s8", "feat_s16"),
+ )
+ outputs = engine(image=np.zeros((1, 3, 4, 4), dtype=np.float32))
+
+ assert set(outputs.keys()) == {"feat_s8", "feat_s16"}
+
+
+def test_torch_engine_auto_dtype_selects_fp16_when_name_and_cuda(
+ fake_torch, tmp_path
+):
+ """Auto dtype picks fp16 for '*fp16*' model names when running on CUDA."""
+ model_path = tmp_path / "demo_fp16.pt"
+ model_path.write_bytes(b"torchscript")
+
+ engine = TorchEngine(model_path, device="cuda:0")
+
+ assert engine.device.type == "cuda"
+ assert engine.dtype == fake_torch.float16
+ assert engine._model.dtype == fake_torch.float16
+
+
+def test_torch_engine_explicit_fp32_dtype(fake_torch, tmp_path):
+ model_path = tmp_path / "demo.pt"
+ model_path.write_bytes(b"torchscript")
+
+ config = TorchEngineConfig(dtype="fp32")
+ engine = TorchEngine(model_path, device="cpu", config=config)
+
+ assert engine.dtype == fake_torch.float32
+ assert engine._model.dtype == fake_torch.float32
+
+
+def test_torch_engine_prepare_feed_requires_mapping(fake_torch, tmp_path):
+ """Run rejects non-mapping feeds to avoid accidental positional ordering bugs."""
+ model_path = tmp_path / "demo.pt"
+ model_path.write_bytes(b"torchscript")
+ engine = TorchEngine(model_path, device="cpu")
+
+ with pytest.raises(TypeError, match="feed must be a mapping"):
+ engine.run(["not", "a", "mapping"]) # type: ignore[arg-type]
+
+
+def test_torch_engine_output_names_mismatch_raises(fake_torch, tmp_path):
+ """Output key schema must match model outputs to prevent silent mislabeling."""
+ model_path = tmp_path / "demo.pt"
+ model_path.write_bytes(b"torchscript")
+ engine = TorchEngine(model_path, device="cpu", output_names=("only_one",))
+
+ with pytest.raises(ValueError, match="model produced 2 outputs"):
+ engine.run({"image": np.zeros((1, 3, 4, 4), dtype=np.float32)})
+
+
+def test_torch_engine_formats_mapping_and_tensor_outputs(fake_torch, tmp_path):
+ """TorchEngine supports dict outputs and single tensor outputs."""
+ model_path = tmp_path / "demo.pt"
+ model_path.write_bytes(b"torchscript")
+ engine = TorchEngine(model_path, device="cuda:0")
+
+ def _mapping_model(*_tensors):
+ return {
+ "feat": _FakeTensor(
+ np.ones((1, 2, 2, 2)),
+ fake_torch.float16,
+ cast(_FakeDevice, engine.device),
+ )
+ }
+
+ engine._model = _mapping_model # type: ignore[assignment]
+ out = engine.run({"image": np.zeros((1, 3, 4, 4), dtype=np.float32)})
+ assert set(out) == {"feat"}
+ assert out["feat"].dtype == np.float32
+
+ def _tensor_model(*_tensors):
+ return _FakeTensor(
+ np.ones((1, 2, 2, 2)),
+ fake_torch.float16,
+ cast(_FakeDevice, engine.device),
+ )
+
+ engine._model = _tensor_model # type: ignore[assignment]
+ out = engine.run({"image": np.zeros((1, 3, 4, 4), dtype=np.float32)})
+ assert set(out) == {"output"}
+
+
+def test_torch_engine_benchmark_honors_cuda_sync_override(fake_torch, tmp_path):
+ """Benchmark uses synchronize() only when CUDA + cuda_sync is enabled."""
+ model_path = tmp_path / "demo.pt"
+ model_path.write_bytes(b"torchscript")
+ engine = TorchEngine(model_path, device="cuda:1")
+
+ stats = engine.benchmark(
+ {"image": np.zeros((1, 3, 4, 4), dtype=np.float32)},
+ repeat=2,
+ warmup=1,
+ cuda_sync=True,
+ )
+
+ assert stats["repeat"] == 2
+ assert stats["warmup"] == 1
+ assert "latency_ms" in stats
+ assert len(fake_torch._sync_calls) == 5
+
+
+def test_torch_engine_benchmark_validates_repeat_and_warmup(
+ fake_torch, tmp_path
+):
+ model_path = tmp_path / "demo.pt"
+ model_path.write_bytes(b"torchscript")
+ engine = TorchEngine(model_path, device="cpu")
+ feed = {"image": np.zeros((1, 3, 4, 4), dtype=np.float32)}
+
+ with pytest.raises(ValueError, match="repeat must be >= 1"):
+ engine.benchmark(feed, repeat=0)
+
+ with pytest.raises(ValueError, match="warmup must be >= 0"):
+ engine.benchmark(feed, warmup=-1)
+
+
+def test_torch_engine_call_accepts_wrapped_mapping_and_generates_names(
+ fake_torch, tmp_path
+):
+ """__call__ supports passing a mapping payload and auto-generating output names."""
+ model_path = tmp_path / "demo.pt"
+ model_path.write_bytes(b"torchscript")
+ engine = TorchEngine(model_path, device="cpu")
+
+ outputs = engine(
+ payload={"image": np.zeros((1, 3, 4, 4), dtype=np.float32)}
+ )
+ assert set(outputs) == {"output_0", "output_1"}
+
+
+def test_torch_engine_multiple_inputs_use_positional_forward(
+ fake_torch, tmp_path
+):
+ """Multiple inputs are forwarded positionally into the TorchScript model."""
+ model_path = tmp_path / "demo.pt"
+ model_path.write_bytes(b"torchscript")
+ engine = TorchEngine(model_path, device="cpu")
+
+ outputs = engine.run(
+ {
+ "a": np.zeros((1, 3, 4, 4), dtype=np.float32),
+ "b": np.zeros((1, 3, 4, 4), dtype=np.float32),
+ }
+ )
+ assert set(outputs) == {"output_0", "output_1"}
+
+
+def test_torch_engine_summary_reports_core_fields(fake_torch, tmp_path):
+ model_path = tmp_path / "demo.pt"
+ model_path.write_bytes(b"torchscript")
+ engine = TorchEngine(model_path, device="cpu")
+
+ summary = engine.summary()
+ assert summary["model"] == str(model_path)
+ assert summary["device"] == "cpu"
+
+
+def test_torch_engine_benchmark_respects_config_cuda_sync_default(
+ fake_torch, tmp_path
+):
+ """cuda_sync defaults to config when override is omitted."""
+ model_path = tmp_path / "demo.pt"
+ model_path.write_bytes(b"torchscript")
+ engine = TorchEngine(
+ model_path,
+ device="cuda:0",
+ config=TorchEngineConfig(cuda_sync=False),
+ )
+
+ engine.benchmark(
+ {"image": np.zeros((1, 3, 4, 4), dtype=np.float32)},
+ repeat=1,
+ warmup=0,
+ )
+ assert fake_torch._sync_calls == []
+
+
+def test_torch_engine_accepts_preconstructed_tensors(
+ fake_torch, monkeypatch, tmp_path
+):
+ """Existing torch tensors should pass through without from_numpy conversion."""
+ model_path = tmp_path / "demo.pt"
+ model_path.write_bytes(b"torchscript")
+ engine = TorchEngine(model_path, device="cpu")
+
+ tensor = fake_torch.from_numpy(np.zeros((1, 3, 4, 4), dtype=np.float32))
+ monkeypatch.setattr(
+ fake_torch,
+ "from_numpy",
+ lambda *_args, **_kwargs: pytest.fail(
+ "from_numpy should not be called"
+ ),
+ )
+
+ outputs = engine.run({"image": tensor})
+ assert set(outputs) == {"output_0", "output_1"}
+
+
+def test_torch_engine_device_instance_short_circuits_normalization(
+ fake_torch, monkeypatch, tmp_path
+):
+ """Passing an existing torch.device object should be preserved."""
+ model_path = tmp_path / "demo.pt"
+ model_path.write_bytes(b"torchscript")
+ monkeypatch.setattr(fake_torch, "device", _FakeDevice)
+
+ device = _FakeDevice("cuda:0")
+ engine = TorchEngine(model_path, device=device)
+ assert engine.device is device
+
+
+def test_torch_engine_dtype_string_and_custom_dtype_handling(
+ fake_torch, monkeypatch, tmp_path
+):
+ """dtype supports explicit strings, custom torch.dtype instances, and errors."""
+ model_path = tmp_path / "demo.pt"
+ model_path.write_bytes(b"torchscript")
+
+ fake_dtype_type = type("dtype", (), {})
+ monkeypatch.setattr(fake_torch, "dtype", fake_dtype_type, raising=False)
+
+ engine = TorchEngine(
+ model_path,
+ device="cpu",
+ config=TorchEngineConfig(dtype="fp16"),
+ )
+ assert engine.dtype == fake_torch.float16
+
+ custom_dtype = fake_dtype_type()
+ engine = TorchEngine(
+ model_path,
+ device="cpu",
+ config=TorchEngineConfig(dtype=custom_dtype),
+ )
+ assert engine.dtype is custom_dtype
+
+ with pytest.raises(ValueError, match="Unsupported dtype specification"):
+ TorchEngine(
+ model_path,
+ device="cpu",
+ config=TorchEngineConfig(dtype="weird"),
+ )
+
+
+def test_torch_engine_rejects_unsupported_outputs(fake_torch, tmp_path):
+ """Unexpected model outputs should raise a clear error."""
+ model_path = tmp_path / "demo.pt"
+ model_path.write_bytes(b"torchscript")
+ engine = TorchEngine(model_path, device="cpu")
+
+ engine._model = lambda *_args: 123 # type: ignore[assignment]
+ with pytest.raises(TypeError, match="Unsupported TorchScript output"):
+ engine.run({"image": np.zeros((1, 3, 4, 4), dtype=np.float32)})
+
+ engine._model = lambda *_args: {"feat": "bad"} # type: ignore[assignment]
+ with pytest.raises(TypeError, match=r"Model outputs must be torch\.Tensor"):
+ engine.run({"image": np.zeros((1, 3, 4, 4), dtype=np.float32)})
diff --git a/tests/utils/test_custom_path.py b/tests/utils/test_custom_path.py
new file mode 100644
index 0000000..1ed8212
--- /dev/null
+++ b/tests/utils/test_custom_path.py
@@ -0,0 +1,20 @@
+from __future__ import annotations
+
+from capybara.utils.custom_path import rm_path
+
+
+def test_rm_path_removes_non_empty_directories(tmp_path):
+ folder = tmp_path / "nested"
+ folder.mkdir()
+ (folder / "file.txt").write_text("hello", encoding="utf-8")
+
+ rm_path(folder)
+ assert not folder.exists()
+
+
+def test_rm_path_removes_files(tmp_path):
+ file_path = tmp_path / "file.txt"
+ file_path.write_text("hello", encoding="utf-8")
+
+ rm_path(file_path)
+ assert not file_path.exists()
diff --git a/tests/utils/test_custom_tqdm.py b/tests/utils/test_custom_tqdm.py
new file mode 100644
index 0000000..ea30b9c
--- /dev/null
+++ b/tests/utils/test_custom_tqdm.py
@@ -0,0 +1,15 @@
+from capybara.utils.custom_tqdm import Tqdm
+
+
+def test_custom_tqdm_respects_explicit_total_and_infers_total():
+ bar = Tqdm(range(3), total=10, disable=True)
+ try:
+ assert bar.total == 10
+ finally:
+ bar.close()
+
+ inferred = Tqdm([1, 2, 3], disable=True)
+ try:
+ assert inferred.total == 3
+ finally:
+ inferred.close()
diff --git a/tests/utils/test_download_from_google.py b/tests/utils/test_download_from_google.py
new file mode 100644
index 0000000..7b424c8
--- /dev/null
+++ b/tests/utils/test_download_from_google.py
@@ -0,0 +1,183 @@
+from __future__ import annotations
+
+from collections.abc import Iterator
+from pathlib import Path
+from typing import Any
+
+import pytest
+
+import capybara.utils.utils as utils_mod
+
+
+class _FakeResponse:
+ def __init__(
+ self,
+ *,
+ headers: dict[str, str] | None = None,
+ cookies: dict[str, str] | None = None,
+ text: str = "",
+ chunks: list[bytes] | None = None,
+ iter_raises: Exception | None = None,
+ ) -> None:
+ self.headers = headers or {}
+ self.cookies = cookies or {}
+ self.text = text
+ self._chunks = chunks or []
+ self._iter_raises = iter_raises
+
+ def iter_content(self, *, chunk_size: int) -> Iterator[bytes]:
+ if self._iter_raises is not None:
+ raise self._iter_raises
+ yield from self._chunks
+
+
+class _FakeSession:
+ def __init__(self, responses: list[_FakeResponse]) -> None:
+ self._responses = list(responses)
+ self.calls: list[tuple[str, dict[str, Any] | None]] = []
+
+ def get(self, url: str, params: dict[str, Any] | None = None, stream=True):
+ self.calls.append((url, params))
+ if not self._responses:
+ raise AssertionError("No more fake responses configured")
+ return self._responses.pop(0)
+
+
+class _DummyTqdm:
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
+ self.total = kwargs.get("total", 0)
+ self.updated = 0
+
+ def update(self, n: int) -> None:
+ self.updated += n
+
+ def __enter__(self) -> _DummyTqdm:
+ return self
+
+ def __exit__(self, exc_type, exc, tb) -> bool:
+ return False
+
+
+def _install_fakes(monkeypatch, session: _FakeSession) -> None:
+ monkeypatch.setattr(utils_mod.requests, "Session", lambda: session)
+ monkeypatch.setattr(utils_mod, "tqdm", _DummyTqdm)
+
+
+def test_download_from_google_direct_content_disposition(tmp_path, monkeypatch):
+ session = _FakeSession(
+ [
+ _FakeResponse(
+ headers={
+ "content-disposition": "attachment; filename=x.bin",
+ "content-length": "3",
+ },
+ chunks=[b"abc"],
+ )
+ ]
+ )
+ _install_fakes(monkeypatch, session)
+
+ out = utils_mod.download_from_google(
+ file_id="id",
+ file_name="x.bin",
+ target=tmp_path,
+ )
+ assert out == Path(tmp_path) / "x.bin"
+ assert out.read_bytes() == b"abc"
+ assert session.calls[0][0] == "https://docs.google.com/uc"
+
+
+def test_download_from_google_uses_cookie_confirm_token(tmp_path, monkeypatch):
+ session = _FakeSession(
+ [
+ _FakeResponse(cookies={"download_warning_foo": "TOKEN"}, text=""),
+ _FakeResponse(
+ headers={
+ "content-disposition": "attachment",
+ "content-length": "1",
+ },
+ chunks=[b"x"],
+ ),
+ ]
+ )
+ _install_fakes(monkeypatch, session)
+
+ out = utils_mod.download_from_google("id", "y.bin", target=tmp_path)
+ assert out.read_bytes() == b"x"
+ assert session.calls[1][1] is not None
+ assert session.calls[1][1]["confirm"] == "TOKEN"
+
+
+def test_download_from_google_parses_html_download_form(tmp_path, monkeypatch):
+ html = """
+
+
+
+ """
+ session = _FakeSession(
+ [
+ _FakeResponse(text=html),
+ _FakeResponse(
+ headers={
+ "content-disposition": "attachment",
+ "content-length": "2",
+ },
+ chunks=[b"hi"],
+ ),
+ ]
+ )
+ _install_fakes(monkeypatch, session)
+
+ out = utils_mod.download_from_google("id", "z.bin", target=tmp_path)
+ assert out.read_bytes() == b"hi"
+ assert session.calls[1][0] == "https://example.com/download"
+
+
+def test_download_from_google_parses_confirm_param(tmp_path, monkeypatch):
+ session = _FakeSession(
+ [
+ _FakeResponse(text="... confirm=ABC123 ..."),
+ _FakeResponse(
+ headers={
+ "content-disposition": "attachment",
+ "content-length": "1",
+ },
+ chunks=[b"1"],
+ ),
+ ]
+ )
+ _install_fakes(monkeypatch, session)
+
+ out = utils_mod.download_from_google("id", "c.bin", target=tmp_path)
+ assert out.read_bytes() == b"1"
+ assert session.calls[1][1] is not None
+ assert session.calls[1][1]["confirm"] == "ABC123"
+
+
+def test_download_from_google_raises_when_no_link_found(tmp_path, monkeypatch):
+ session = _FakeSession([_FakeResponse(text="")])
+ _install_fakes(monkeypatch, session)
+
+ with pytest.raises(Exception, match="無法在回應中找到下載連結"):
+ utils_mod.download_from_google("id", "x.bin", target=tmp_path)
+
+
+def test_download_from_google_wraps_streaming_errors(tmp_path, monkeypatch):
+ session = _FakeSession(
+ [
+ _FakeResponse(
+ headers={
+ "content-disposition": "attachment",
+ "content-length": "1",
+ },
+ iter_raises=ValueError("boom"),
+ )
+ ]
+ )
+ _install_fakes(monkeypatch, session)
+
+ with pytest.raises(RuntimeError, match="File download failed"):
+ utils_mod.download_from_google("id", "x.bin", target=tmp_path)
diff --git a/tests/utils/test_files_utils.py b/tests/utils/test_files_utils.py
new file mode 100644
index 0000000..dfaf0c6
--- /dev/null
+++ b/tests/utils/test_files_utils.py
@@ -0,0 +1,103 @@
+from __future__ import annotations
+
+import hashlib
+from pathlib import Path
+
+import numpy as np
+import pytest
+
+from capybara.utils import custom_path
+from capybara.utils import files_utils as futils
+
+
+def test_rm_path_removes_file_and_directory(tmp_path: Path):
+ f = tmp_path / "a.txt"
+ f.write_text("x", encoding="utf-8")
+ assert f.exists()
+ custom_path.rm_path(f)
+ assert not f.exists()
+
+ d = tmp_path / "empty_dir"
+ d.mkdir()
+ assert d.exists()
+ custom_path.rm_path(d)
+ assert not d.exists()
+
+
+def test_copy_path_copies_and_validates_source(tmp_path: Path):
+ src = tmp_path / "src.txt"
+ src.write_text("hello", encoding="utf-8")
+ dst = tmp_path / "dst.txt"
+ custom_path.copy_path(src, dst)
+ assert dst.read_text(encoding="utf-8") == "hello"
+
+ with pytest.raises(ValueError, match="invaild"):
+ custom_path.copy_path(tmp_path / "missing.txt", dst)
+
+
+def test_gen_md5_and_img_to_md5(tmp_path: Path):
+ p = tmp_path / "blob.bin"
+ payload = b"capybara"
+ p.write_bytes(payload)
+
+ assert futils.gen_md5(p) == hashlib.md5(payload).hexdigest()
+
+ img = np.arange(12, dtype=np.uint8).reshape(3, 4)
+ assert futils.img_to_md5(img) == hashlib.md5(img.tobytes()).hexdigest()
+
+ with pytest.raises(TypeError, match="numpy array"):
+ futils.img_to_md5("not-an-array") # type: ignore[arg-type]
+
+
+def test_dump_and_load_json_yaml_and_pickle(tmp_path: Path, monkeypatch):
+ obj = {"a": 1, "b": [1, 2, 3]}
+
+ json_path = tmp_path / "x.json"
+ futils.dump_json(obj, json_path)
+ assert futils.load_json(json_path) == obj
+
+ yaml_path = tmp_path / "x.yaml"
+ futils.dump_yaml(obj, yaml_path)
+ assert futils.load_yaml(yaml_path) == obj
+
+ pkl_path = tmp_path / "x.pkl"
+ futils.dump_pickle(obj, pkl_path)
+ assert futils.load_pickle(pkl_path) == obj
+
+ # Default-path behavior should not pollute repo root.
+ monkeypatch.chdir(tmp_path)
+ futils.dump_json(obj, path=None)
+ assert (tmp_path / "tmp.json").exists()
+ futils.dump_yaml(obj, path=None)
+ assert (tmp_path / "tmp.yaml").exists()
+
+
+def test_get_files_filters_suffix_and_options(tmp_path: Path):
+ # Create a few files with varying cases.
+ (tmp_path / "a.txt").write_text("a", encoding="utf-8")
+ (tmp_path / "b.TXT").write_text("b", encoding="utf-8")
+ (tmp_path / "c.jpg").write_text("c", encoding="utf-8")
+ sub = tmp_path / "sub"
+ sub.mkdir()
+ (sub / "d.txt").write_text("d", encoding="utf-8")
+
+ with pytest.raises(FileNotFoundError):
+ futils.get_files(tmp_path / "missing")
+
+ with pytest.raises(TypeError, match="suffix must be"):
+ futils.get_files(tmp_path, suffix=123) # type: ignore[arg-type]
+
+ files = futils.get_files(tmp_path, suffix=".txt", recursive=True)
+ assert all(Path(p).suffix.lower() == ".txt" for p in files)
+ assert len(files) == 3
+
+ files_no_case = futils.get_files(
+ tmp_path,
+ suffix=[".txt"],
+ recursive=False,
+ ignore_letter_case=False,
+ sort_path=False,
+ return_pathlib=False,
+ )
+ assert all(isinstance(p, str) for p in files_no_case)
+ assert len(files_no_case) == 1
diff --git a/tests/utils/test_powerdict.py b/tests/utils/test_powerdict.py
index ae3679d..b27f55a 100644
--- a/tests/utils/test_powerdict.py
+++ b/tests/utils/test_powerdict.py
@@ -1,161 +1,172 @@
+from typing import Any, cast
+
import pytest
from capybara import PowerDict
-power_dict_attr_data = [
- (
- {
- 'key': 'value',
- 'name': 'mock_name'
- },
- ['key', 'name']
- )
-]
+def test_powerdict_init_accepts_none_and_kwargs():
+ assert PowerDict() == {}
+ assert PowerDict(None) == {}
+ pd = PowerDict(None, new="value")
+ assert pd == {"new": "value"}
+ assert pd.new == "value"
-@pytest.mark.parametrize('x, match', power_dict_attr_data)
-def test_power_dict_attr(x, match):
- test_dict = PowerDict(x)
- for attr in match:
- assert hasattr(test_dict, attr)
+ pd = PowerDict({"a": 1}, b=2)
+ assert pd == {"a": 1, "b": 2}
+ assert pd.a == 1
+ assert pd.b == 2
-power_dict_freeze_melt_data = [
- (
- {
- 'key': 'value',
- },
- ['key']
- ),
- (
- {
- 'PowerDict': PowerDict({'A': 1}),
- },
- ['PowerDict']
- ),
- (
- {
- 'list': [1, 2, 3],
- 'tuple': (1, 2, 3)
- },
- ['list', 'tuple']
- ),
- (
+def test_powerdict_attribute_and_item_access_are_kept_in_sync():
+ pd = PowerDict()
+ pd.alpha = 1
+ assert pd["alpha"] == 1
+
+ pd["beta"] = 2
+ assert pd.beta == 2
+
+ del pd["alpha"]
+ assert "alpha" not in pd
+ assert not hasattr(pd, "alpha")
+
+ pd.gamma = 3
+ del pd.gamma
+ assert "gamma" not in pd
+
+
+def test_powerdict_recursively_wraps_nested_mappings_and_sequences():
+ pd = PowerDict(
{
- 'PowerDict_in_list': [PowerDict({'A': 1}), PowerDict({'B': 2})],
- 'PowerDict_in_tuple': (PowerDict({'A': 1}), PowerDict({'B': 2}))
- },
- ['PowerDict_in_list', 'PowerDict_in_tuple']
+ "cfg": {"x": 1},
+ "items": [{"y": 2}, {"z": 3}],
+ "numbers_tuple": (1, 2, 3),
+ }
)
-]
+ assert isinstance(pd.cfg, PowerDict)
+ assert pd.cfg.x == 1
+
+ assert isinstance(pd.items, list)
+ assert [type(x) for x in pd.items] == [PowerDict, PowerDict]
+ assert pd.items[0].y == 2
+ assert pd.items[1].z == 3
+
+ # Tuples are normalized to lists to keep internal mutation rules simple.
+ assert pd.numbers_tuple == [1, 2, 3]
+
+
+def test_powerdict_freeze_blocks_mutation_and_melt_restores():
+ pd = PowerDict({"a": 1, "nested": {"b": 2}})
+ pd.freeze()
+
+ with pytest.raises(ValueError, match="PowerDict is frozen"):
+ pd.a = 10
+ with pytest.raises(ValueError, match="PowerDict is frozen"):
+ pd["a"] = 10
+ with pytest.raises(ValueError, match="PowerDict is frozen"):
+ del pd.a
+ with pytest.raises(ValueError, match="PowerDict is frozen"):
+ del pd["a"]
+ with pytest.raises(ValueError, match="PowerDict is frozen"):
+ pd.update({"c": 3})
+ with pytest.raises(ValueError, match="PowerDict is frozen"):
+ pd.pop("a")
+
+ # Nested PowerDict is also frozen.
+ with pytest.raises(ValueError, match="PowerDict is frozen"):
+ pd.nested.b = 20
+
+ pd.melt()
+ pd.a = 10
+ pd.update({"c": 3})
+ assert pd == {"a": 10, "nested": {"b": 2}, "c": 3}
+
+
+def test_powerdict_pop_behaves_like_dict_pop():
+ pd = PowerDict({"a": 1})
+ assert pd.pop("a") == 1
+ assert "a" not in pd
-@pytest.mark.parametrize('x, match', power_dict_freeze_melt_data)
-def test_power_dict_freeze_melt(x, match):
- test_dict = PowerDict(x)
- test_dict.freeze()
- for attr in match:
- try:
- test_dict[attr] = None
- assert False
- except ValueError:
- pass
+ assert pd.pop("missing", None) is None
- test_dict.melt()
- for attr in match:
- test_dict[attr] = None
+ with pytest.raises(KeyError):
+ pd.pop("missing")
-power_dict_init_data = [
- {
- 'd': None,
- 'kwargs': None,
- 'match': {}
- },
- {
- 'd': None,
- 'kwargs': {'new': 'update'},
- 'match': {'kwargs': {'new': 'update'}}
- }
-]
+def test_powerdict_reserved_frozen_key_behaviors():
+ pd = PowerDict()
+ with pytest.raises(KeyError, match="_frozen"):
+ pd["_frozen"] = True
+ with pytest.raises(KeyError, match="_frozen"):
+ del pd["_frozen"]
+ with pytest.raises(KeyError, match="_frozen"):
+ del pd._frozen
-@pytest.mark.parametrize('test_data', power_dict_init_data)
-def test_dict_init(test_data: dict):
- if test_data['kwargs'] is not None:
- assert PowerDict(d=test_data['d'], kwargs=test_data['kwargs']) == test_data['match']
- else:
- assert PowerDict(d=test_data['d']) == test_data['match']
+def test_powerdict_missing_attr_raises_attribute_error():
+ pd = PowerDict()
+ with pytest.raises(AttributeError):
+ _ = pd.missing
-power_dict_set_data = [
- ({'int': 1}, 'int', 2, {'int': 2}),
- ({'list': [1, 2, 3]}, 'list', [4, 5, 6], {'list': [4, 5, 6]}),
- ({'tuple': (1, 2, 3)}, 'tuple', (7, 8, 9), {'tuple': [7, 8, 9]})
-]
+def test_powerdict_serialization_helpers(tmp_path):
+ pd = PowerDict({"a": 1, "nested": {"b": 2}})
+ json_path = tmp_path / "x.json"
+ yaml_path = tmp_path / "x.yaml"
+ pkl_path = tmp_path / "x.pkl"
+ txt_path = tmp_path / "x.txt"
-@pytest.mark.parametrize('x, new_key, new_value, match', power_dict_set_data)
-def test_power_dict_set(x, new_key, new_value, match):
- test_dict = PowerDict(x)
- test_dict[new_key] = new_value
- assert test_dict == match
+ assert pd.to_json(json_path) is None
+ assert json_path.exists()
+ assert PowerDict.load_json(json_path) == pd
+ assert pd.to_yaml(yaml_path) is None
+ assert yaml_path.exists()
+ assert PowerDict.load_yaml(yaml_path) == pd
-power_dict_set_raises_data = [
- ({'int': 1}, 'int', 2, ValueError, "PowerDict is frozen. 'int' cannot be set."),
- ({'list': [1, 2, 3]}, 'list', [3, 4], ValueError, "PowerDict is frozen. 'list' cannot be set."),
- ({'tuple': (1, 2, 3)}, 'tuple', (3, 4), ValueError, "PowerDict is frozen. 'tuple' cannot be set.")
-]
+ assert pd.to_pickle(pkl_path) is None
+ assert pkl_path.exists()
+ assert PowerDict.load_pickle(pkl_path) == pd
+ pd.to_txt(txt_path)
+ assert "nested" in txt_path.read_text(encoding="utf-8")
-@pytest.mark.parametrize('x, new_key, new_value, error, match', power_dict_set_raises_data)
-def test_power_dict_set_raises(x, new_key, new_value, error, match):
- test_dict = PowerDict(x)
- test_dict.freeze()
- with pytest.raises(error, match=match):
- test_dict[new_key] = new_value
+def test_powerdict_deepcopy_is_blocked_when_frozen():
+ from copy import deepcopy
-power_dict_update_data = [
- ({'a': 1, 'b': 2}, {'b': 4, 'c': 3}, {'a': 1, 'b': 4, 'c': 3}),
- ({'a': 1, 'b': 2}, {'c': [1, 2]}, {'a': 1, 'b': 2, 'c': [1, 2]}),
- ({'a': 1, 'b': 2}, {'c': (1, 2)}, {'a': 1, 'b': 2, 'c': [1, 2]})
-]
+ pd = PowerDict({"a": 1})
+ pd.freeze()
+ with pytest.raises(Warning, match="cannot be copy"):
+ _ = deepcopy(pd)
-@pytest.mark.parametrize('x, e, match', power_dict_update_data)
-def test_power_dict_update(x, e, match):
- test_dict = PowerDict(x)
- test_dict.update(e)
- assert test_dict == match
+def test_powerdict_update_without_args_and_deepcopy_when_unfrozen():
+ from copy import deepcopy
+ pd = PowerDict({"a": 1})
+ pd.update()
+ assert pd == {"a": 1}
-power_dict_pop_data = [
- ({'a': 1, 'b': 2, 'c': 3}, 'b', {'a': 1, 'c': 3}),
- ({'a': 1, 'b': 2, 'c': [1, 2]}, 'b', {'a': 1, 'c': [1, 2]}),
- ({'a': 1, 'b': 2, 'c': (1, 2)}, 'b', {'a': 1, 'c': [1, 2]}),
-]
+ clone = deepcopy(pd)
+ assert clone == pd
+ assert clone is not pd
-@pytest.mark.parametrize('x, key, match', power_dict_pop_data)
-def test_power_dict_pop(x, key, match):
- test_dict = PowerDict(x)
- test_dict.pop(key)
- assert test_dict == match
+def test_powerdict_freeze_and_to_dict_handle_powerdict_inside_lists():
+ pd = PowerDict({"items": [{"x": 1}, {"y": 2}]})
+ pd.freeze()
+ with pytest.raises(ValueError, match="PowerDict is frozen"):
+ items = cast(list[Any], pd["items"])
+ cast(Any, items[0]).x = 10
+ pd.melt()
+ items = cast(list[Any], pd["items"])
+ cast(Any, items[0]).x = 10
+ assert cast(Any, items[0]).x == 10
-test_power_dict_del_raises = [
- ({'a': 1, 'b': 2, 'c': 3}, 'b', ValueError, "PowerDict is frozen. 'b' cannot be del."),
- ({'a': 1, 'b': 2, 'c': [1, 2]}, 'b', ValueError, "PowerDict is frozen. 'b' cannot be del."),
- ({'a': 1, 'b': 2, 'c': (1, 2)}, 'b', ValueError, "PowerDict is frozen. 'b' cannot be del."),
-]
-
-
-@pytest.mark.parametrize('x, key, error, match', test_power_dict_del_raises)
-def test_power_dict_del_raises(x, key, error, match):
- test_dict = PowerDict(x)
- test_dict.freeze()
- with pytest.raises(error, match=match):
- del test_dict[key]
+ out = pd.to_dict()
+ assert out == {"items": [{"x": 10}, {"y": 2}]}
diff --git a/tests/utils/test_time.py b/tests/utils/test_time.py
new file mode 100644
index 0000000..71a5047
--- /dev/null
+++ b/tests/utils/test_time.py
@@ -0,0 +1,134 @@
+from __future__ import annotations
+
+import time as py_time
+from datetime import datetime, timezone
+
+import numpy as np
+import pytest
+
+import capybara.utils.time as time_mod
+
+
+def test_timer_requires_tic_before_toc():
+ timer = time_mod.Timer()
+ with pytest.raises(ValueError, match="has not been started"):
+ timer.toc()
+
+
+def test_timer_records_and_stats(monkeypatch, capsys):
+ perf = iter([10.0, 10.5, 20.0, 20.25])
+
+ monkeypatch.setattr(time_mod.time, "perf_counter", lambda: next(perf))
+ timer = time_mod.Timer(precision=2, desc="bench", verbose=True)
+
+ timer.tic()
+ dt1 = timer.toc(verbose=True)
+ assert dt1 == 0.5
+
+ timer.tic()
+ dt2 = timer.toc()
+ assert dt2 == 0.25
+
+ assert timer.mean == np.array([0.5, 0.25]).mean().round(2)
+ assert timer.max == 0.5
+ assert timer.min == 0.25
+ assert timer.std == np.array([0.5, 0.25]).std().round(2)
+
+ out = capsys.readouterr().out
+ assert "bench" in out
+ assert "Cost:" in out
+
+ timer.clear_record()
+ assert timer.mean is None
+
+
+def test_timer_as_context_manager(monkeypatch):
+ perf = iter([1.0, 1.2])
+ monkeypatch.setattr(time_mod.time, "perf_counter", lambda: next(perf))
+ timer = time_mod.Timer(precision=3)
+ with timer:
+ pass
+
+ assert timer.dt == 0.2
+
+
+def test_timer_context_manager_returns_self(monkeypatch):
+ perf = iter([1.0, 1.1])
+ monkeypatch.setattr(time_mod.time, "perf_counter", lambda: next(perf))
+
+ with time_mod.Timer() as timer:
+ assert isinstance(timer, time_mod.Timer)
+
+
+def test_timer_as_decorator(monkeypatch):
+ perf = iter([5.0, 5.1])
+ monkeypatch.setattr(time_mod.time, "perf_counter", lambda: next(perf))
+
+ timer = time_mod.Timer()
+
+ @timer
+ def add1(x: int) -> int:
+ return x + 1
+
+ assert add1(2) == 3
+ assert timer.mean == 0.1
+
+
+def test_now_supports_timestamp_datetime_time_and_fmt(monkeypatch):
+ monkeypatch.setattr(time_mod.time, "time", lambda: 123.0)
+ assert time_mod.now("timestamp") == 123.0
+
+ assert isinstance(time_mod.now("datetime"), datetime)
+ assert isinstance(time_mod.now("time"), py_time.struct_time)
+
+ # fmt takes precedence and returns a string.
+ monkeypatch.setattr(time_mod.time, "localtime", py_time.gmtime)
+ assert time_mod.now(fmt="%Y-%m-%d") == "1970-01-01"
+
+ with pytest.raises(ValueError, match="Unsupported input"):
+ time_mod.now("invalid")
+
+
+def test_time_and_datetime_converters_validate_types(monkeypatch):
+ fixed = py_time.gmtime(0)
+
+ assert time_mod.time2datetime(fixed).year == 1970
+ assert isinstance(time_mod.timestamp2datetime(0), datetime)
+
+ with pytest.raises(TypeError):
+ time_mod.time2datetime("not-a-time") # type: ignore[arg-type]
+
+ monkeypatch.setattr(time_mod.time, "mktime", lambda _: 42.0)
+ assert time_mod.time2timestamp(fixed) == 42.0
+ with pytest.raises(TypeError):
+ time_mod.time2timestamp("not-a-time") # type: ignore[arg-type]
+
+ assert time_mod.time2str(fixed, "%Y") == "1970"
+ with pytest.raises(TypeError):
+ time_mod.time2str("not-a-time", "%Y") # type: ignore[arg-type]
+
+ dt = datetime(2000, 1, 1, tzinfo=timezone.utc)
+ assert isinstance(time_mod.datetime2time(dt), py_time.struct_time)
+ assert time_mod.datetime2timestamp(dt) == 946684800.0
+ assert time_mod.datetime2str(dt, "%Y") == "2000"
+
+ with pytest.raises(TypeError):
+ time_mod.datetime2time("not-a-dt") # type: ignore[arg-type]
+ with pytest.raises(TypeError):
+ time_mod.datetime2timestamp("not-a-dt") # type: ignore[arg-type]
+ with pytest.raises(TypeError):
+ time_mod.datetime2str("not-a-dt", "%Y") # type: ignore[arg-type]
+
+
+def test_str_converters_validate_types(monkeypatch):
+ monkeypatch.setattr(time_mod.time, "mktime", lambda _: 99.0)
+ assert time_mod.str2time("1970-01-01", "%Y-%m-%d").tm_year == 1970
+ assert time_mod.str2datetime("1970-01-01", "%Y-%m-%d").year == 1970
+ assert time_mod.str2timestamp("1970-01-01", "%Y-%m-%d") == 99.0
+
+ with pytest.raises(TypeError):
+ time_mod.str2time(123, "%Y") # type: ignore[arg-type]
+ with pytest.raises(TypeError):
+ time_mod.str2datetime(123, "%Y") # type: ignore[arg-type]
+ with pytest.raises(TypeError):
+ time_mod.str2timestamp(123, "%Y") # type: ignore[arg-type]
diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py
index 88039ac..c13783d 100644
--- a/tests/utils/test_utils.py
+++ b/tests/utils/test_utils.py
@@ -1,34 +1,20 @@
from capybara import COLORSTR, FORMATSTR, colorstr, make_batch
-def number_generator(start, end):
- for i in range(start, end + 1):
- yield i
+def number_generator(start: int, end: int):
+ yield from range(start, end + 1)
def test_make_batch():
- # Create a number generator from 1 to 10
data_generator = number_generator(1, 10)
-
- # Batch size of 3
batch_size = 3
-
- # Generate batched data
batched_data_generator = make_batch(data_generator, batch_size)
-
- # Check the first batch
batch = next(batched_data_generator)
assert batch == [1, 2, 3]
-
- # Check the second batch
batch = next(batched_data_generator)
assert batch == [4, 5, 6]
-
- # Check the third batch
batch = next(batched_data_generator)
assert batch == [7, 8, 9]
-
- # Check the fourth batch (last batch with remaining data)
batch = next(batched_data_generator)
assert batch == [10]
@@ -36,19 +22,22 @@ def test_make_batch():
def test_colorstr_blue_bold():
obj = "Hello, colorful world!"
expected_output = "\033[1;34mHello, colorful world!\033[0m"
- assert colorstr(obj, color=COLORSTR.BLUE, fmt=FORMATSTR.BOLD) == expected_output
+ assert (
+ colorstr(obj, color=COLORSTR.BLUE, fmt=FORMATSTR.BOLD)
+ == expected_output
+ )
def test_colorstr_red():
obj = "Error: Something went wrong!"
expected_output = "\033[1;31mError: Something went wrong!\033[0m"
- assert colorstr(obj, color='red') == expected_output
+ assert colorstr(obj, color="red") == expected_output
def test_colorstr_underline_green():
obj = "Important message"
expected_output = "\033[4;32mImportant message\033[0m"
- assert colorstr(obj, color=32, fmt='underline') == expected_output
+ assert colorstr(obj, color=32, fmt="underline") == expected_output
def test_colorstr_default():
diff --git a/tests/vision/ipcam/test_camera.py b/tests/vision/ipcam/test_camera.py
new file mode 100644
index 0000000..b6c834e
--- /dev/null
+++ b/tests/vision/ipcam/test_camera.py
@@ -0,0 +1,119 @@
+from __future__ import annotations
+
+import itertools
+
+import numpy as np
+
+
+def test_ipcam_capture_is_an_iterator_and_yields_frames(monkeypatch):
+ import capybara.vision.ipcam.camera as cam_mod
+
+ class _FakeCapture:
+ def get(self, prop):
+ # 4: height, 3: width (as used by the implementation)
+ if prop == 4:
+ return 10
+ if prop == 3:
+ return 20
+ return 0
+
+ def read(self):
+ return False, None
+
+ class _FakeThread:
+ def __init__(self, target, daemon=True):
+ self._target = target
+
+ def start(self):
+ self._target()
+
+ monkeypatch.setattr(
+ cam_mod.cv2, "VideoCapture", lambda *_args, **_kwargs: _FakeCapture()
+ )
+ monkeypatch.setattr(cam_mod, "Thread", _FakeThread)
+
+ cap = cam_mod.IpcamCapture(url="fake", color_base="BGR")
+ assert iter(cap) is cap
+
+ frames = list(itertools.islice(cap, 3))
+ assert len(frames) == 3
+ assert all(isinstance(frame, np.ndarray) for frame in frames)
+
+
+def test_ipcam_capture_rejects_unsupported_image_size(monkeypatch):
+ import capybara.vision.ipcam.camera as cam_mod
+
+ class _FakeCapture:
+ def get(self, prop):
+ if prop == 4:
+ return 0
+ if prop == 3:
+ return 0
+ return 0
+
+ def read(self):
+ return False, None
+
+ class _FakeThread:
+ def __init__(self, target, daemon=True):
+ self._target = target
+
+ def start(self):
+ self._target()
+
+ monkeypatch.setattr(
+ cam_mod.cv2, "VideoCapture", lambda *_args, **_kwargs: _FakeCapture()
+ )
+ monkeypatch.setattr(cam_mod, "Thread", _FakeThread)
+
+ with np.testing.assert_raises(ValueError):
+ cam_mod.IpcamCapture(url="fake", color_base="BGR")
+
+
+def test_ipcam_capture_converts_color_and_returns_frame_copy(monkeypatch):
+ import capybara.vision.ipcam.camera as cam_mod
+
+ calls: list[str] = []
+
+ def fake_imcvtcolor(frame: np.ndarray, *, cvt_mode: str) -> np.ndarray:
+ calls.append(cvt_mode)
+ return frame + 1
+
+ class _FakeCapture:
+ def __init__(self):
+ self._calls = 0
+
+ def get(self, prop):
+ if prop == 4:
+ return 10
+ if prop == 3:
+ return 20
+ return 0
+
+ def read(self):
+ if self._calls == 0:
+ self._calls += 1
+ return True, np.zeros((10, 20, 3), dtype=np.uint8)
+ return False, None
+
+ class _FakeThread:
+ def __init__(self, target, daemon=True):
+ self._target = target
+
+ def start(self):
+ self._target()
+
+ monkeypatch.setattr(
+ cam_mod.cv2, "VideoCapture", lambda *_args, **_kwargs: _FakeCapture()
+ )
+ monkeypatch.setattr(cam_mod, "Thread", _FakeThread)
+ monkeypatch.setattr(cam_mod, "imcvtcolor", fake_imcvtcolor)
+
+ cap = cam_mod.IpcamCapture(url="fake", color_base="RGB")
+ frame1 = cap.get_frame()
+ assert calls == ["BGR2RGB"]
+ assert frame1.sum() > 0
+
+ frame1[0, 0, 0] = 123
+ frame2 = cap.get_frame()
+ assert frame2[0, 0, 0] != 123
diff --git a/tests/vision/test_functional.py b/tests/vision/test_functional.py
index 9077514..c2107c0 100644
--- a/tests/vision/test_functional.py
+++ b/tests/vision/test_functional.py
@@ -1,10 +1,24 @@
+from typing import Any
+
import cv2
import numpy as np
import pytest
-from capybara import (BORDER, Box, Boxes, gaussianblur, imbinarize, imcropbox,
- imcropboxes, imcvtcolor, imresize_and_pad_if_need,
- meanblur, medianblur, pad)
+from capybara import (
+ BORDER,
+ Box,
+ Boxes,
+ gaussianblur,
+ imbinarize,
+ imcropbox,
+ imcropboxes,
+ imcvtcolor,
+ imresize_and_pad_if_need,
+ meanblur,
+ medianblur,
+ pad,
+)
+from capybara.vision.functionals import centercrop, imadjust
def test_meanblur():
@@ -31,8 +45,8 @@ def test_gaussianblur():
# 測試指定ksize和sigmaX
ksize = (7, 7)
- sigmaX = 1
- blurred_img_custom = gaussianblur(img, ksize=ksize, sigmaX=sigmaX)
+ sigma_x = 1
+ blurred_img_custom = gaussianblur(img, ksize=ksize, sigma_x=sigma_x)
assert blurred_img_custom.shape == img.shape
@@ -55,16 +69,24 @@ def test_imcvtcolor():
img = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8)
# 測試RGB轉灰階
- gray_img = imcvtcolor(img, 'RGB2GRAY')
+ gray_img = imcvtcolor(img, "RGB2GRAY")
assert gray_img.shape == (100, 100)
+ # 支援 OpenCV 的 COLOR_ 前綴寫法
+ gray_img_prefixed = imcvtcolor(img, "COLOR_RGB2GRAY")
+ assert gray_img_prefixed.shape == (100, 100)
+
+ # 支援直接傳入 OpenCV 的 conversion code
+ gray_img2 = imcvtcolor(img, cv2.COLOR_RGB2GRAY)
+ assert gray_img2.shape == (100, 100)
+
# 測試RGB轉BGR
- bgr_img = imcvtcolor(img, 'RGB2BGR')
+ bgr_img = imcvtcolor(img, "RGB2BGR")
assert bgr_img.shape == img.shape
# 測試轉換為不支援的色彩空間
with pytest.raises(ValueError):
- imcvtcolor(img, 'RGB2WWW') # XYZ為不支援的色彩空間
+ imcvtcolor(img, "RGB2WWW") # XYZ為不支援的色彩空間
def test_pad_constant_gray():
@@ -76,13 +98,22 @@ def test_pad_constant_gray():
pad_value = 128
padded_img = pad(img, pad_size=pad_size, pad_value=pad_value)
assert padded_img.shape == (
- img.shape[0] + 2 * pad_size, img.shape[1] + 2 * pad_size)
+ img.shape[0] + 2 * pad_size,
+ img.shape[1] + 2 * pad_size,
+ )
assert np.all(padded_img[:pad_size, :] == pad_value)
assert np.all(padded_img[-pad_size:, :] == pad_value)
assert np.all(padded_img[:, :pad_size] == pad_value)
assert np.all(padded_img[:, -pad_size:] == pad_value)
+def test_pad_constant_gray_accepts_singleton_tuple_pad_value():
+ img = np.random.randint(0, 256, size=(10, 10), dtype=np.uint8)
+ padded_img = pad(img, pad_size=1, pad_value=(123,))
+ assert padded_img.shape == (12, 12)
+ assert int(padded_img[0, 0]) == 123
+
+
def test_pad_constant_color():
# 測試用的彩色圖片
img = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8)
@@ -92,13 +123,26 @@ def test_pad_constant_color():
pad_value = (255, 0, 0) # 紅色
padded_img = pad(img, pad_size=pad_size, pad_value=pad_value)
assert padded_img.shape == (
- img.shape[0] + 2 * pad_size, img.shape[1] + 2 * pad_size, img.shape[2])
+ img.shape[0] + 2 * pad_size,
+ img.shape[1] + 2 * pad_size,
+ img.shape[2],
+ )
assert np.all(padded_img[:pad_size, :, :] == pad_value)
assert np.all(padded_img[-pad_size:, :, :] == pad_value)
assert np.all(padded_img[:, :pad_size, :] == pad_value)
assert np.all(padded_img[:, -pad_size:, :] == pad_value)
+def test_pad_rejects_invalid_pad_value_type_and_invalid_image_ndim():
+ img = np.random.randint(0, 256, size=(10, 10, 3), dtype=np.uint8)
+ with pytest.raises(ValueError, match="pad_value must be"):
+ pad(img, pad_size=1, pad_value=[1, 2, 3]) # type: ignore[arg-type]
+
+ bad = np.zeros((1, 2, 3, 4), dtype=np.uint8)
+ with pytest.raises(ValueError, match="img must be a 2D or 3D"):
+ pad(bad, pad_size=1, pad_value=0)
+
+
def test_pad_replicate():
# 測試用的圖片
img = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8)
@@ -107,7 +151,10 @@ def test_pad_replicate():
pad_size = (5, 10)
padded_img = pad(img, pad_size=pad_size, pad_mode=BORDER.REPLICATE)
assert padded_img.shape == (
- img.shape[0] + 2 * pad_size[0], img.shape[1] + 2 * pad_size[1], img.shape[2])
+ img.shape[0] + 2 * pad_size[0],
+ img.shape[1] + 2 * pad_size[1],
+ img.shape[2],
+ )
def test_pad_reflect():
@@ -120,7 +167,7 @@ def test_pad_reflect():
assert padded_img.shape == (
img.shape[0] + pad_size[0] + pad_size[1],
img.shape[1] + pad_size[2] + pad_size[3],
- img.shape[2]
+ img.shape[2],
)
@@ -128,12 +175,13 @@ def test_pad_invalid_input():
# 測試不支援的填充模式
img = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8)
with pytest.raises(ValueError):
- pad(img, pad_size=5, pad_mode='invalid_mode')
+ pad(img, pad_size=5, pad_mode="invalid_mode")
# 測試不合法的填充大小
img = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8)
with pytest.raises(ValueError):
- pad(img, pad_size=(10, 20, 30))
+ pad_size: Any = (10, 20, 30)
+ pad(img, pad_size=pad_size)
def test_imcropbox_custom_box():
@@ -170,7 +218,8 @@ def test_imcropbox_invalid_input():
# 測試不支援的裁剪區域
img = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8)
with pytest.raises(TypeError):
- imcropbox(img, "invalid_box")
+ invalid_box: Any = "invalid_box"
+ imcropbox(img, invalid_box)
# 測試不合法的裁剪區域
img = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8)
@@ -240,7 +289,8 @@ def test_imcropboxes_invalid_input():
# 測試不支援的裁剪區域
img = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8)
with pytest.raises(TypeError):
- imcropboxes(img, "invalid_boxes")
+ invalid_boxes: Any = "invalid_boxes"
+ imcropboxes(img, invalid_boxes)
# 測試不合法的裁剪區域
img = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8)
@@ -286,18 +336,84 @@ def test_imbinarize_invalid_input():
def test_imresize_and_pad_if_need():
- img = np.ones((150, 120, 3), dtype='uint8')
+ img = np.ones((150, 120, 3), dtype="uint8")
processed = imresize_and_pad_if_need(img, 150, 150)
- np.testing.assert_allclose(processed[:, 120:], np.zeros((150, 30, 3), dtype='uint8'))
-
- img = np.ones((151, 119, 3), dtype='uint8')
+ np.testing.assert_allclose(
+ processed[:, 120:], np.zeros((150, 30, 3), dtype="uint8")
+ )
+
+ img = np.ones((151, 119, 3), dtype="uint8")
processed = imresize_and_pad_if_need(img, 150, 150)
- np.testing.assert_allclose(processed[:, 120:], np.zeros((150, 30, 3), dtype='uint8'))
+ np.testing.assert_allclose(
+ processed[:, 120:], np.zeros((150, 30, 3), dtype="uint8")
+ )
- img = np.ones((200, 100, 3), dtype='uint8')
+ img = np.ones((200, 100, 3), dtype="uint8")
processed = imresize_and_pad_if_need(img, 100, 100)
- np.testing.assert_allclose(processed[:, 50:], np.zeros((100, 50, 3), dtype='uint8'))
+ np.testing.assert_allclose(
+ processed[:, 50:], np.zeros((100, 50, 3), dtype="uint8")
+ )
- img = np.ones((20, 20, 3), dtype='uint8')
+ img = np.ones((20, 20, 3), dtype="uint8")
processed = imresize_and_pad_if_need(img, 100, 100)
- np.testing.assert_allclose(processed, np.ones((100, 100, 3), dtype='uint8'))
+ np.testing.assert_allclose(processed, np.ones((100, 100, 3), dtype="uint8"))
+
+
+def test_blur_accepts_numpy_scalar_ksize_and_rejects_invalid_ksize():
+ img = np.random.randint(0, 256, size=(20, 20, 3), dtype=np.uint8)
+ out = meanblur(img, ksize=np.array(3))
+ assert out.shape == img.shape
+ out2 = gaussianblur(img, ksize=np.array(5))
+ assert out2.shape == img.shape
+
+ with pytest.raises(TypeError, match="ksize"):
+ meanblur(img, ksize=(1, 2, 3)) # type: ignore[arg-type]
+
+
+def test_pad_accepts_none_pad_value_and_validates_pad_value_types():
+ img = np.zeros((3, 3, 3), dtype=np.uint8)
+ padded = pad(img, pad_size=1, pad_value=None)
+ assert padded.shape == (5, 5, 3)
+ assert np.all(padded[0] == 0)
+
+ with pytest.raises(ValueError, match="pad_value"):
+ invalid_pad_value: Any = (1, 2)
+ pad(img, pad_size=1, pad_value=invalid_pad_value)
+
+ gray = np.zeros((3, 3), dtype=np.uint8)
+ with pytest.raises(ValueError, match="must be an int"):
+ pad(gray, pad_size=1, pad_value=(1, 2, 3))
+
+
+def test_centercrop_produces_square_crop():
+ img = np.zeros((10, 20, 3), dtype=np.uint8)
+ cropped = centercrop(img)
+ assert cropped.shape == (10, 10, 3)
+
+
+def test_imadjust_returns_input_when_bounds_degenerate():
+ img = np.zeros((20, 20), dtype=np.uint8)
+ out = imadjust(img)
+ assert np.array_equal(out, img)
+
+
+def test_imadjust_stretches_grayscale_and_color_images():
+ gray = np.tile(np.arange(256, dtype=np.uint8), (4, 1))
+ out = imadjust(gray)
+ assert out.shape == gray.shape
+ assert out.dtype == np.uint8
+ assert out.min() == 0
+ assert out.max() == 255
+ assert not np.array_equal(out, gray)
+
+ bgr = np.stack([gray] * 3, axis=-1)
+ out2 = imadjust(bgr)
+ assert out2.shape == bgr.shape
+ assert out2.dtype == np.uint8
+
+
+def test_imresize_and_pad_if_need_can_return_scale():
+ img = np.ones((200, 100, 3), dtype=np.uint8)
+ out, scale = imresize_and_pad_if_need(img, 100, 100, return_scale=True)
+ assert out.shape == (100, 100, 3)
+ assert scale == pytest.approx(0.5)
diff --git a/tests/vision/test_geometric.py b/tests/vision/test_geometric.py
index 9131110..e8fad30 100644
--- a/tests/vision/test_geometric.py
+++ b/tests/vision/test_geometric.py
@@ -1,9 +1,20 @@
+from typing import Any
+
import numpy as np
import pytest
-from capybara import (BORDER, INTER, ROTATE, Polygon, Polygons, imresize,
- imrotate, imrotate90, imwarp_quadrangle,
- imwarp_quadrangles)
+from capybara import (
+ BORDER,
+ INTER,
+ ROTATE,
+ Polygon,
+ Polygons,
+ imresize,
+ imrotate,
+ imrotate90,
+ imwarp_quadrangle,
+ imwarp_quadrangles,
+)
@pytest.fixture
@@ -15,14 +26,10 @@ def random_img():
@pytest.fixture
def small_gray_img():
"""
- 建立 3x3 的灰階小影像(數值小,方便檢查旋轉結果)。
- 為了測試方便,這裡只有一個通道。
+ 建立 3x3 的灰階小影像 (數值小, 方便檢查旋轉結果)。
+ 為了測試方便, 這裡只有一個通道。
"""
- return np.array([
- [1, 2, 3],
- [4, 5, 6],
- [7, 8, 9]
- ], dtype=np.uint8)
+ return np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.uint8)
def test_imresize_both_dims(random_img):
@@ -32,7 +39,7 @@ def test_imresize_both_dims(random_img):
def test_imresize_single_dim(random_img):
- """只給定單一維度時,檢查是否能等比例縮放。"""
+ """只給定單一維度時, 檢查是否能等比例縮放。"""
orig_h, orig_w = random_img.shape[:2]
# 測試只給定寬度
@@ -48,6 +55,22 @@ def test_imresize_single_dim(random_img):
assert resized_img.shape[1] == expected_w
+def test_imresize_rejects_missing_dimensions(random_img):
+ with pytest.raises(ValueError, match="at least one dimension"):
+ imresize(random_img, (None, None), INTER.BILINEAR)
+
+
+def test_imresize_return_scale_when_only_one_dim_provided(random_img):
+ _orig_h, orig_w = random_img.shape[:2]
+ resized_img, w_scale, h_scale = imresize(
+ random_img, (None, 50), INTER.BILINEAR, return_scale=True
+ )
+
+ assert resized_img.shape[1] == 50
+ assert w_scale == pytest.approx(50 / orig_w)
+ assert h_scale == pytest.approx(50 / orig_w)
+
+
def test_imresize_return_scale(random_img):
"""測試回傳縮放比例。"""
orig_h, orig_w = random_img.shape[:2]
@@ -78,29 +101,17 @@ def test_imresize_different_interpolation(random_img):
[
(
ROTATE.ROTATE_90,
- np.array([
- [7, 4, 1],
- [8, 5, 2],
- [9, 6, 3]
- ], dtype=np.uint8)
+ np.array([[7, 4, 1], [8, 5, 2], [9, 6, 3]], dtype=np.uint8),
),
(
ROTATE.ROTATE_180,
- np.array([
- [9, 8, 7],
- [6, 5, 4],
- [3, 2, 1]
- ], dtype=np.uint8)
+ np.array([[9, 8, 7], [6, 5, 4], [3, 2, 1]], dtype=np.uint8),
),
(
ROTATE.ROTATE_270,
- np.array([
- [3, 6, 9],
- [2, 5, 8],
- [1, 4, 7]
- ], dtype=np.uint8)
- )
- ]
+ np.array([[3, 6, 9], [2, 5, 8], [1, 4, 7]], dtype=np.uint8),
+ ),
+ ],
)
def test_imrotate90(small_gray_img, rotate_code, expected):
"""測試 90 度基底旋轉。"""
@@ -110,14 +121,14 @@ def test_imrotate90(small_gray_img, rotate_code, expected):
@pytest.mark.parametrize("angle, expand", [(90, False), (45, False)])
def test_imrotate_no_expand(random_img, angle, expand):
- """測試不擴展的旋轉,輸出大小應與原圖相同。"""
+ """測試不擴展的旋轉, 輸出大小應與原圖相同。"""
rotated_img = imrotate(random_img, angle=angle, expand=expand)
assert rotated_img.shape == random_img.shape
@pytest.mark.parametrize("angle, expand", [(90, True), (45, True)])
def test_imrotate_expand(random_img, angle, expand):
- """測試擴展旋轉,輸出大小應大於或等於原圖。"""
+ """測試擴展旋轉, 輸出大小應大於或等於原圖。"""
h, w = random_img.shape[:2]
rotated_img = imrotate(random_img, angle=angle, expand=expand)
assert rotated_img.shape[0] >= h
@@ -125,7 +136,7 @@ def test_imrotate_expand(random_img, angle, expand):
def test_imrotate_with_center(random_img):
- """指定旋轉中心,檢查旋轉結果維度是否符合預期。"""
+ """指定旋轉中心, 檢查旋轉結果維度是否符合預期。"""
h, w = random_img.shape[:2]
center = (w // 4, h // 4)
angle = 30
@@ -140,8 +151,13 @@ def test_imrotate_scale_border(random_img):
測試帶有 scale 與 bordervalue 的旋轉。
例如提供 (255, 0, 0) 以便檢查邊界是否填上紅色。
"""
- scaled_img = imrotate(random_img, angle=45, scale=1.5,
- bordertype=BORDER.REFLECT, bordervalue=(255, 0, 0))
+ scaled_img = imrotate(
+ random_img,
+ angle=45,
+ scale=1.5,
+ bordertype=BORDER.REFLECT,
+ bordervalue=(255, 0, 0),
+ )
# 只要確定沒有拋出錯誤並且輸出維度放大即可
assert scaled_img.shape[0] > random_img.shape[0]
assert scaled_img.shape[1] > random_img.shape[1]
@@ -156,6 +172,31 @@ def test_imrotate_invalid_input(random_img):
imrotate(random_img, angle=90, interpolation="invalid_interpolation")
+def test_imrotate_validates_image_ndim_and_border_value_tuple_sizes(
+ random_img, small_gray_img
+):
+ rotated_gray = imrotate(
+ small_gray_img,
+ angle=15,
+ expand=False,
+ bordertype=BORDER.CONSTANT,
+ bordervalue=(7,),
+ )
+ assert rotated_gray.shape == small_gray_img.shape
+
+ with pytest.raises(ValueError, match="2D or 3D"):
+ imrotate(np.zeros((1, 2, 3, 4), dtype=np.uint8), angle=0)
+
+ with pytest.raises(ValueError, match="bordervalue"):
+ imrotate(
+ random_img,
+ angle=10,
+ expand=False,
+ bordertype=BORDER.CONSTANT,
+ bordervalue=(1, 2),
+ )
+
+
@pytest.fixture
def default_polygon():
"""產生含有 4 個點的基本 Polygon。"""
@@ -164,7 +205,7 @@ def default_polygon():
def test_imwarp_quadrangle_default(random_img, default_polygon):
- """不指定 dst_size 時,自動使用 min_area_rectangle 所得長寬。"""
+ """不指定 dst_size 時, 自動使用 min_area_rectangle 所得長寬。"""
warped = imwarp_quadrangle(random_img, default_polygon)
# 大致上應該會有 80x80 以上的圖
assert warped.shape[0] >= 80
@@ -172,18 +213,19 @@ def test_imwarp_quadrangle_default(random_img, default_polygon):
def test_imwarp_quadrangle_with_dstsize(random_img, default_polygon):
- """指定 dst_size,檢查變換後的輸出是否符合設定大小。"""
+ """指定 dst_size, 檢查變換後的輸出是否符合設定大小。"""
warped = imwarp_quadrangle(random_img, default_polygon, dst_size=(100, 50))
assert warped.shape == (50, 100, 3)
def test_imwarp_quadrangle_no_order_points(random_img, default_polygon):
"""
- 當 do_order_points = False 時,檢查程式能否正常執行。
- 有些情況下,使用者已確保點位順序正確,就可以省略排序。
+ 當 do_order_points = False 時, 檢查程式能否正常執行。
+ 有些情況下, 使用者已確保點位順序正確, 就可以省略排序。
"""
warped = imwarp_quadrangle(
- random_img, default_polygon, do_order_points=False)
+ random_img, default_polygon, do_order_points=False
+ )
assert warped.shape[0] >= 80
assert warped.shape[1] >= 80
@@ -192,7 +234,8 @@ def test_imwarp_quadrangle_invalid_polygon(random_img):
"""檢查多種不合法 polygon 之行為。"""
# 傳入不支援的 polygon 類型
with pytest.raises(TypeError):
- imwarp_quadrangle(random_img, "invalid_polygon")
+ invalid_polygon: Any = "invalid_polygon"
+ imwarp_quadrangle(random_img, invalid_polygon)
# 傳入長度不是 4 的 polygon
with pytest.raises(ValueError):
@@ -200,17 +243,23 @@ def test_imwarp_quadrangle_invalid_polygon(random_img):
imwarp_quadrangle(random_img, bad_pts)
+def test_imwarp_quadrangle_swaps_width_height_when_needed(random_img):
+ pts = np.array([[10, 10], [30, 10], [30, 90], [10, 90]], dtype=np.float32)
+ polygon = Polygon(pts)
+ warped = imwarp_quadrangle(random_img, polygon)
+ assert warped.shape[0] < warped.shape[1]
+
+
@pytest.fixture
def polygons_list():
- """產生 2 個四邊形,合併成 Polygons。"""
+ """產生 2 個四邊形, 合併成 Polygons。"""
src_pts_1 = np.array(
- [[10, 10], [90, 10], [90, 90], [10, 90]], dtype=np.float32)
+ [[10, 10], [90, 10], [90, 90], [10, 90]], dtype=np.float32
+ )
src_pts_2 = np.array(
- [[20, 20], [80, 20], [80, 80], [20, 80]], dtype=np.float32)
- return Polygons([
- Polygon(src_pts_1),
- Polygon(src_pts_2)
- ])
+ [[20, 20], [80, 20], [80, 80], [20, 80]], dtype=np.float32
+ )
+ return Polygons([Polygon(src_pts_1), Polygon(src_pts_2)])
def test_imwarp_quadrangles(random_img, polygons_list):
@@ -227,12 +276,16 @@ def test_imwarp_quadrangles(random_img, polygons_list):
def test_imwarp_quadrangles_invalid_type(random_img):
"""檢查不合法的 polygons 輸入。"""
with pytest.raises(TypeError):
- imwarp_quadrangles(random_img, "invalid_polygons")
+ invalid_polygons: Any = "invalid_polygons"
+ imwarp_quadrangles(random_img, invalid_polygons)
with pytest.raises(TypeError):
- invalid_polygons = [
+ invalid_polygons: Any = [
Polygon(
- np.array([[10, 10], [90, 10], [90, 90], [10, 90]], dtype=np.float32)),
- "invalid_polygon"
+ np.array(
+ [[10, 10], [90, 10], [90, 90], [10, 90]], dtype=np.float32
+ )
+ ),
+ "invalid_polygon",
]
imwarp_quadrangles(random_img, invalid_polygons)
diff --git a/tests/vision/test_improc.py b/tests/vision/test_improc.py
index 4ed3581..ca98eef 100644
--- a/tests/vision/test_improc.py
+++ b/tests/vision/test_improc.py
@@ -22,6 +22,11 @@ def test_imread():
assert isinstance(img_gray, np.ndarray)
assert len(img_gray.shape) == 2 # 灰階圖片的channel數為1
+ # color_base should be case-insensitive
+ img_gray2 = imread(image_path, color_base="gray")
+ assert isinstance(img_gray2, np.ndarray)
+ assert len(img_gray2.shape) == 2
+
# 測試heif格式的圖片讀取
img_heif = imread(DIR.parent / "resources" / "lena.heic", color_base="BGR")
assert isinstance(img_heif, np.ndarray)
@@ -32,14 +37,24 @@ def test_imread():
imread("non_existent_image.jpg")
-def test_imwrite():
+def test_imwrite(tmp_path):
# 測試用的圖片
img = np.zeros((100, 100, 3), dtype=np.uint8) # 建立一個全黑的BGR圖片
# 測試BGR格式的圖片寫入
- temp_file_path = DIR / "temp_image.jpg"
+ temp_file_path = tmp_path / "temp_image.jpg"
assert imwrite(img, path=temp_file_path, color_base="BGR")
assert Path(temp_file_path).exists()
- # 測試不指定路徑時的圖片寫入
- assert imwrite(img, color_base="BGR") # 將會寫入一個暫時的檔案
+ # 測試不指定路徑時的圖片寫入 (不應污染 repo working directory)
+ assert imwrite(img, color_base="BGR")
+
+
+def test_imwrite_without_path_does_not_create_tmp_file_in_cwd(
+ tmp_path, monkeypatch
+):
+ """Regression: historical implementation wrote `tmp{suffix}` into CWD."""
+ monkeypatch.chdir(tmp_path)
+ img = np.zeros((10, 10, 3), dtype=np.uint8)
+ assert imwrite(img, color_base="BGR", suffix=".jpg") is True
+ assert not (tmp_path / "tmp.jpg").exists()
diff --git a/tests/vision/test_improc_extra.py b/tests/vision/test_improc_extra.py
new file mode 100644
index 0000000..d0346d8
--- /dev/null
+++ b/tests/vision/test_improc_extra.py
@@ -0,0 +1,294 @@
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Any
+
+import numpy as np
+import pytest
+
+import capybara.vision.improc as improc
+from capybara import IMGTYP, ROTATE
+
+
+def test_is_numpy_img_accepts_2d_and_3d_and_rejects_other_shapes():
+ assert improc.is_numpy_img(np.zeros((10, 10), dtype=np.uint8))
+ assert improc.is_numpy_img(np.zeros((10, 10, 1), dtype=np.uint8))
+ assert improc.is_numpy_img(np.zeros((10, 10, 3), dtype=np.uint8))
+ assert not improc.is_numpy_img(np.zeros((10, 10, 4), dtype=np.uint8))
+ assert not improc.is_numpy_img("not-an-array") # type: ignore[arg-type]
+
+
+def test_get_orientation_code_maps_known_values(monkeypatch):
+ def fake_load(_: Any):
+ return {"0th": {improc.piexif.ImageIFD.Orientation: 6}}
+
+ monkeypatch.setattr(improc.piexif, "load", fake_load)
+ assert improc.get_orientation_code(b"") == ROTATE.ROTATE_90
+
+ def fake_load_3(_: Any):
+ return {"0th": {improc.piexif.ImageIFD.Orientation: 3}}
+
+ monkeypatch.setattr(improc.piexif, "load", fake_load_3)
+ assert improc.get_orientation_code(b"") == ROTATE.ROTATE_180
+
+ def fake_load_8(_: Any):
+ return {"0th": {improc.piexif.ImageIFD.Orientation: 8}}
+
+ monkeypatch.setattr(improc.piexif, "load", fake_load_8)
+ assert improc.get_orientation_code(b"") == ROTATE.ROTATE_270
+
+ def fake_load_other(_: Any):
+ return {"0th": {improc.piexif.ImageIFD.Orientation: 1}}
+
+ monkeypatch.setattr(improc.piexif, "load", fake_load_other)
+ assert improc.get_orientation_code(b"") is None
+
+ monkeypatch.setattr(
+ improc.piexif,
+ "load",
+ lambda _: (_ for _ in ()).throw(Exception("boom")),
+ )
+ assert improc.get_orientation_code(b"") is None
+
+
+def test_jpgencode_handles_tuple_return_and_failures(monkeypatch):
+ class FakeJPEG:
+ def __init__(self, *, raise_encode: bool = False) -> None:
+ self.raise_encode = raise_encode
+
+ def encode(self, img: np.ndarray, quality: int = 90):
+ if self.raise_encode:
+ raise RuntimeError("boom")
+ return (b"xx", b"ignored")
+
+ monkeypatch.setattr(improc, "jpeg", FakeJPEG())
+ img = np.zeros((8, 8, 3), dtype=np.uint8)
+ assert improc.jpgencode(img) == b"xx"
+
+ monkeypatch.setattr(improc, "jpeg", FakeJPEG(raise_encode=True))
+ assert improc.jpgencode(img) is None
+
+ assert improc.jpgencode(np.zeros((8, 8, 4), dtype=np.uint8)) is None
+
+
+def test_jpgdecode_rotates_based_on_orientation(monkeypatch):
+ img = np.zeros((5, 5, 3), dtype=np.uint8)
+
+ class FakeJPEG:
+ def decode(self, _: bytes) -> np.ndarray:
+ return img
+
+ monkeypatch.setattr(improc, "jpeg", FakeJPEG())
+ monkeypatch.setattr(
+ improc, "get_orientation_code", lambda _: ROTATE.ROTATE_90
+ )
+ monkeypatch.setattr(improc, "imrotate90", lambda arr, code: arr + 1)
+ out = improc.jpgdecode(b"bytes")
+ assert isinstance(out, np.ndarray)
+ assert out.sum() > 0
+
+ class BadJPEG:
+ def decode(self, _: bytes) -> np.ndarray:
+ raise RuntimeError("bad")
+
+ monkeypatch.setattr(improc, "jpeg", BadJPEG())
+ assert improc.jpgdecode(b"bytes") is None
+
+
+def test_pngencode_pngdecode_roundtrip_and_invalid_inputs():
+ img = np.zeros((10, 10, 3), dtype=np.uint8)
+ img[2:4, 2:4] = 255
+
+ enc = improc.pngencode(img)
+ assert isinstance(enc, bytes)
+ dec = improc.pngdecode(enc)
+ assert isinstance(dec, np.ndarray)
+ assert dec.shape == img.shape
+
+ assert improc.pngencode(np.zeros((10, 10, 4), dtype=np.uint8)) is None
+ assert improc.pngdecode(b"not-a-real-image") is None
+
+
+def test_imencode_selects_format_and_validates_kwargs(monkeypatch):
+ monkeypatch.setattr(improc, "jpgencode", lambda _: b"jpg")
+ monkeypatch.setattr(improc, "pngencode", lambda _: b"png")
+
+ dummy = np.zeros((2, 2, 3), dtype=np.uint8)
+ assert improc.imencode(dummy, IMGTYP.JPEG) == b"jpg"
+ assert improc.imencode(dummy, IMGTYP.PNG) == b"png"
+ assert improc.imencode(dummy, IMGTYP="PNG") == b"png"
+
+ with pytest.raises(TypeError, match="both provided"):
+ improc.imencode(dummy, "png", IMGTYP="jpeg")
+
+ with pytest.raises(TypeError, match="Unexpected keyword"):
+ improc.imencode(dummy, unexpected=1) # type: ignore[arg-type]
+
+
+def test_imdecode_falls_back_to_png_when_jpeg_decode_fails(monkeypatch):
+ monkeypatch.setattr(improc, "jpgdecode", lambda _: None)
+ monkeypatch.setattr(
+ improc, "pngdecode", lambda _: np.zeros((1, 1, 3), dtype=np.uint8)
+ )
+ out = improc.imdecode(b"blob")
+ assert isinstance(out, np.ndarray)
+
+ monkeypatch.setattr(
+ improc, "jpgdecode", lambda _: (_ for _ in ()).throw(RuntimeError("x"))
+ )
+ assert improc.imdecode(b"blob") is None
+
+
+def test_img_to_b64_and_back(monkeypatch):
+ dummy = np.zeros((2, 2, 3), dtype=np.uint8)
+
+ monkeypatch.setattr(improc, "jpgencode", lambda _: b"\x00\x01")
+ b64 = improc.img_to_b64(dummy, IMGTYP.JPEG)
+ assert isinstance(b64, bytes)
+
+ b64str = improc.img_to_b64str(dummy, IMGTYP.JPEG)
+ assert isinstance(b64str, str)
+
+ monkeypatch.setattr(
+ improc, "imdecode", lambda _: np.ones((1, 1, 3), dtype=np.uint8)
+ )
+ out = improc.b64_to_img(b64)
+ assert isinstance(out, np.ndarray)
+
+
+def test_img_to_b64_accepts_imgtyp_via_imgtyp_kwarg(monkeypatch):
+ dummy = np.zeros((2, 2, 3), dtype=np.uint8)
+
+ monkeypatch.setattr(improc, "pngencode", lambda _: b"\x00\x01")
+ b64 = improc.img_to_b64(dummy, IMGTYP="PNG")
+ assert isinstance(b64, bytes)
+
+
+def test_b64str_to_img_validates_and_warns(monkeypatch):
+ monkeypatch.setattr(
+ improc, "b64_to_img", lambda _: np.zeros((1, 1, 3), dtype=np.uint8)
+ )
+ assert improc.b64str_to_img("AA==") is not None
+
+ with pytest.warns(UserWarning, match="b64str is None"):
+ assert improc.b64str_to_img(None) is None
+
+ with pytest.raises(ValueError, match="not a string"):
+ improc.b64str_to_img(123) # type: ignore[arg-type]
+
+
+def test_npy_b64_roundtrip_and_npyread(tmp_path: Path):
+ arr = np.array([1.0, 2.0], dtype=np.float32)
+ b64 = improc.npy_to_b64(arr)
+ out = improc.b64_to_npy(b64)
+ np.testing.assert_allclose(out, arr)
+
+ b64s = improc.npy_to_b64str(arr)
+ out2 = improc.b64str_to_npy(b64s)
+ np.testing.assert_allclose(out2, arr)
+
+ file = tmp_path / "x.npy"
+ np.save(file, arr)
+ loaded = improc.npyread(file)
+ assert loaded is not None
+ np.testing.assert_allclose(loaded, arr)
+
+ assert improc.npyread(tmp_path / "missing.npy") is None
+
+
+def test_pdf2imgs_handles_bytes_and_paths_and_failures(
+ monkeypatch, tmp_path: Path
+):
+ from PIL import Image
+
+ pil = Image.fromarray(np.zeros((2, 3, 3), dtype=np.uint8))
+
+ monkeypatch.setattr(improc, "convert_from_bytes", lambda _: [pil])
+ monkeypatch.setattr(improc, "convert_from_path", lambda _: [pil])
+ monkeypatch.setattr(improc, "imcvtcolor", lambda arr, cvt_mode: arr)
+
+ out = improc.pdf2imgs(b"%PDF-1.0")
+ assert out is not None
+ assert len(out) == 1
+ assert isinstance(out[0], np.ndarray)
+
+ pdf_path = tmp_path / "a.pdf"
+ pdf_path.write_bytes(b"%PDF-1.0")
+ out2 = improc.pdf2imgs(pdf_path)
+ assert out2 is not None
+ assert len(out2) == 1
+
+ monkeypatch.setattr(
+ improc,
+ "convert_from_path",
+ lambda _: (_ for _ in ()).throw(RuntimeError("boom")),
+ )
+ assert improc.pdf2imgs(pdf_path) is None
+
+
+def test_pngdecode_returns_none_when_cv2_imdecode_raises(monkeypatch):
+ monkeypatch.setattr(
+ improc.cv2,
+ "imdecode",
+ lambda *_args, **_kwargs: (_ for _ in ()).throw(ValueError("x")),
+ )
+ assert improc.pngdecode(b"bytes") is None
+
+
+def test_img_to_b64_validates_kwargs_and_handles_encode_none_and_exceptions(
+ monkeypatch,
+):
+ dummy = np.zeros((2, 2, 3), dtype=np.uint8)
+
+ with pytest.raises(TypeError, match="both provided"):
+ improc.img_to_b64(dummy, "png", IMGTYP="jpeg")
+
+ with pytest.raises(TypeError, match="Unexpected keyword"):
+ improc.img_to_b64(dummy, unexpected=1) # type: ignore[arg-type]
+
+ monkeypatch.setattr(improc, "jpgencode", lambda *_args, **_kwargs: None)
+ assert improc.img_to_b64(dummy, IMGTYP.JPEG) is None
+
+ monkeypatch.setattr(improc, "jpgencode", lambda *_args, **_kwargs: b"x")
+ monkeypatch.setattr(
+ improc.pybase64,
+ "b64encode",
+ lambda *_args, **_kwargs: (_ for _ in ()).throw(ValueError("boom")),
+ )
+ assert improc.img_to_b64(dummy, IMGTYP.JPEG) is None
+
+
+def test_b64_to_img_returns_none_when_b64decode_raises(monkeypatch):
+ monkeypatch.setattr(
+ improc.pybase64,
+ "b64decode",
+ lambda *_args, **_kwargs: (_ for _ in ()).throw(ValueError("boom")),
+ )
+ assert improc.b64_to_img(b"@@@") is None
+
+
+def test_imread_warns_when_image_is_none(monkeypatch, tmp_path: Path):
+ path = tmp_path / "x.jpg"
+ path.write_bytes(b"not-a-real-jpg")
+
+ monkeypatch.setattr(improc, "jpgread", lambda *_args, **_kwargs: None)
+ monkeypatch.setattr(improc.cv2, "imread", lambda *_args, **_kwargs: None)
+
+ with pytest.warns(UserWarning, match="None type image"):
+ assert improc.imread(path, verbose=True) is None
+
+
+def test_imwrite_converts_color_base(monkeypatch, tmp_path: Path):
+ calls: list[str] = []
+
+ def fake_imcvtcolor(img: np.ndarray, *, cvt_mode: str) -> np.ndarray:
+ calls.append(cvt_mode)
+ return img
+
+ monkeypatch.setattr(improc, "imcvtcolor", fake_imcvtcolor)
+ monkeypatch.setattr(improc.cv2, "imwrite", lambda *_args, **_kwargs: True)
+
+ img = np.zeros((2, 2, 3), dtype=np.uint8)
+ out_path = tmp_path / "out.jpg"
+ assert improc.imwrite(img, path=out_path, color_base="RGB")
+ assert calls == ["RGB2BGR"]
diff --git a/tests/vision/test_morphology.py b/tests/vision/test_morphology.py
new file mode 100644
index 0000000..9f17452
--- /dev/null
+++ b/tests/vision/test_morphology.py
@@ -0,0 +1,56 @@
+import numpy as np
+import pytest
+
+from capybara.enums import MORPH
+from capybara.vision import morphology as morph
+
+
+@pytest.mark.parametrize(
+ "fn",
+ [
+ morph.imerode,
+ morph.imdilate,
+ morph.imopen,
+ morph.imclose,
+ morph.imgradient,
+ morph.imtophat,
+ morph.imblackhat,
+ ],
+)
+def test_morphology_ops_preserve_shape_and_dtype(fn):
+ img = np.zeros((20, 20), dtype=np.uint8)
+ img[8:12, 8:12] = 255
+
+ out = fn(img, ksize=3, kstruct=MORPH.RECT)
+ assert out.shape == img.shape
+ assert out.dtype == img.dtype
+
+ out2 = fn(img, ksize=(5, 3), kstruct="CROSS")
+ assert out2.shape == img.shape
+
+ out3 = fn(img, ksize=3, kstruct=MORPH.RECT.value)
+ assert out3.shape == img.shape
+
+
+@pytest.mark.parametrize(
+ "fn",
+ [
+ morph.imerode,
+ morph.imdilate,
+ morph.imopen,
+ morph.imclose,
+ morph.imgradient,
+ morph.imtophat,
+ morph.imblackhat,
+ ],
+)
+def test_morphology_invalid_ksize_raises(fn):
+ img = np.zeros((5, 5), dtype=np.uint8)
+ with pytest.raises(TypeError):
+ fn(img, ksize=(1, 2, 3)) # type: ignore[arg-type]
+
+
+def test_morphology_invalid_kstruct_raises():
+ img = np.zeros((5, 5), dtype=np.uint8)
+ with pytest.raises(ValueError):
+ morph.imerode(img, kstruct="cross")
diff --git a/tests/vision/videotools/test_video2frames.py b/tests/vision/videotools/test_video2frames.py
index 6e84414..6ae3651 100644
--- a/tests/vision/videotools/test_video2frames.py
+++ b/tests/vision/videotools/test_video2frames.py
@@ -21,6 +21,18 @@ def test_video2frames_with_fps():
assert len(frames) == 18
+def test_video2frames_with_fps_greater_than_video_fps():
+ # When requested FPS exceeds the video FPS, fall back to extracting all frames.
+ frames_all = video2frames(video_path)
+ frames = video2frames(video_path, frame_per_sec=60)
+ assert len(frames) == len(frames_all)
+
+
+def test_video2frames_rejects_non_positive_fps():
+ with pytest.raises(ValueError, match="frame_per_sec must be > 0"):
+ video2frames(video_path, frame_per_sec=0)
+
+
def test_video2frames_invalid_input():
# 測試不支援的影片類型
with pytest.raises(TypeError):
@@ -29,3 +41,49 @@ def test_video2frames_invalid_input():
# 測試不存在的影片路徑
with pytest.raises(TypeError):
video2frames("non_existent_video.mp4")
+
+
+def test_video2frames_returns_empty_list_when_capture_cannot_open(tmp_path):
+ # A file with a valid suffix but invalid contents should fail to open.
+ bad = tmp_path / "bad.mp4"
+ bad.write_bytes(b"")
+ assert video2frames(bad) == []
+
+
+def test_video2frames_falls_back_to_all_frames_when_fps_is_invalid(
+ monkeypatch, tmp_path
+):
+ import importlib
+
+ v_mod = importlib.import_module("capybara.vision.videotools.video2frames")
+
+ dummy = tmp_path / "x.mp4"
+ dummy.write_bytes(b"0")
+
+ class _FakeCapture:
+ def __init__(self) -> None:
+ self._idx = 0
+
+ def isOpened(self) -> bool: # noqa: N802
+ return True
+
+ def get(self, prop):
+ if prop == v_mod.cv2.CAP_PROP_FPS:
+ return 0
+ return 0
+
+ def read(self):
+ if self._idx >= 3:
+ return False, None
+ self._idx += 1
+ return True, np.zeros((2, 2, 3), dtype=np.uint8)
+
+ def release(self) -> None:
+ return None
+
+ monkeypatch.setattr(
+ v_mod.cv2, "VideoCapture", lambda *_args, **_kwargs: _FakeCapture()
+ )
+
+ frames = v_mod.video2frames(dummy, frame_per_sec=2)
+ assert len(frames) == 3
diff --git a/tests/vision/videotools/test_video2frames_v2.py b/tests/vision/videotools/test_video2frames_v2.py
index cbbfc58..bc91951 100644
--- a/tests/vision/videotools/test_video2frames_v2.py
+++ b/tests/vision/videotools/test_video2frames_v2.py
@@ -17,7 +17,9 @@ def test_video2frames_v2():
def test_video2frames_v2_with_fps():
# 測試指定提取幀的速度
- frames = video2frames_v2(video_path, frame_per_sec=2, start_sec=0, end_sec=2, n_threads=2)
+ frames = video2frames_v2(
+ video_path, frame_per_sec=2, start_sec=0, end_sec=2, n_threads=2
+ )
assert len(frames) == 4
@@ -29,3 +31,262 @@ def test_video2frames_v2_invalid_input():
# 測試不存在的影片路徑
with pytest.raises(TypeError):
video2frames_v2("non_existent_video.mp4")
+
+
+def test_video2frames_v2_helpers_and_internal_branches(monkeypatch, tmp_path):
+ import importlib
+
+ v2_mod = importlib.import_module(
+ "capybara.vision.videotools.video2frames_v2"
+ )
+
+ assert v2_mod.is_numpy_img(np.zeros((10, 10), dtype=np.uint8))
+ assert v2_mod.flatten_list([[1], [2, 3]]) == [1, 2, 3]
+ assert v2_mod.flatten_list([[[1], [2]]]) == [1, 2]
+
+ with pytest.raises(ValueError, match="larger than"):
+ v2_mod.get_step_inds(0, 1, 5)
+
+ with pytest.raises(TypeError, match="inappropriate"):
+ v2_mod._extract_frames([0, 1], "missing.mp4")
+
+ dummy_video = tmp_path / "x.mp4"
+ dummy_video.write_bytes(b"")
+
+ resized_calls: list[dict[str, object]] = []
+ cvt_calls: list[str] = []
+
+ def fake_imresize(frame: np.ndarray, dsize, **kwargs):
+ resized_calls.append({"dsize": dsize, **kwargs})
+ h, w = dsize
+ return np.zeros((h, w, frame.shape[-1]), dtype=frame.dtype)
+
+ def fake_imcvtcolor(frame: np.ndarray, *, cvt_mode: str) -> np.ndarray:
+ cvt_calls.append(cvt_mode)
+ return frame
+
+ monkeypatch.setattr(v2_mod, "imresize", fake_imresize)
+ monkeypatch.setattr(v2_mod, "imcvtcolor", fake_imcvtcolor)
+
+ class _FakeCapture:
+ def __init__(self, frames) -> None:
+ self._frames = list(frames)
+ self._idx = 0
+
+ def set(self, *_args, **_kwargs) -> None:
+ return None
+
+ def read(self):
+ if self._idx >= len(self._frames):
+ return False, None
+ frame = self._frames[self._idx]
+ self._idx += 1
+ return True, frame
+
+ def release(self) -> None:
+ return None
+
+ # scale < 1 branch + color conversion + skip None frame
+ monkeypatch.setattr(
+ v2_mod.cv2,
+ "VideoCapture",
+ lambda *_args, **_kwargs: _FakeCapture(
+ [None, np.zeros((20, 20, 3), dtype=np.uint8)]
+ ),
+ )
+ frames, _ = v2_mod._extract_frames(
+ inds=[0, 1],
+ video_path=dummy_video,
+ max_size=10,
+ color_base="RGB",
+ )
+ assert len(frames) == 1
+ assert resized_calls
+ assert cvt_calls == ["BGR2RGB"]
+
+ resized_calls.clear()
+ cvt_calls.clear()
+
+ # scale > 1 branch
+ monkeypatch.setattr(
+ v2_mod.cv2,
+ "VideoCapture",
+ lambda *_args, **_kwargs: _FakeCapture(
+ [np.zeros((5, 5, 3), dtype=np.uint8)]
+ ),
+ )
+ frames2, _ = v2_mod._extract_frames(
+ inds=[0],
+ video_path=dummy_video,
+ max_size=10,
+ color_base="BGR",
+ )
+ assert len(frames2) == 1
+ assert any(call.get("interpolation") is not None for call in resized_calls)
+
+
+def test_video2frames_v2_returns_empty_for_zero_frames_or_fps(
+ monkeypatch, tmp_path
+):
+ import importlib
+
+ v2_mod = importlib.import_module(
+ "capybara.vision.videotools.video2frames_v2"
+ )
+
+ dummy_video = tmp_path / "x.mp4"
+ dummy_video.write_bytes(b"")
+
+ class _FakeCapture:
+ def get(self, *_args, **_kwargs):
+ return 0
+
+ def release(self) -> None:
+ return None
+
+ monkeypatch.setattr(
+ v2_mod.cv2, "VideoCapture", lambda *_args, **_kwargs: _FakeCapture()
+ )
+ assert video2frames_v2(dummy_video) == []
+
+
+def test_video2frames_v2_validates_start_end_and_handles_worker_exceptions(
+ monkeypatch, tmp_path
+):
+ import importlib
+
+ v2_mod = importlib.import_module(
+ "capybara.vision.videotools.video2frames_v2"
+ )
+
+ dummy_video = tmp_path / "x.mp4"
+ dummy_video.write_bytes(b"")
+
+ class _FakeCapture:
+ def __init__(self) -> None:
+ self.calls = 0
+
+ def get(self, prop):
+ if prop == v2_mod.cv2.CAP_PROP_FRAME_COUNT:
+ return 10
+ if prop == v2_mod.cv2.CAP_PROP_FPS:
+ return 10
+ return 0
+
+ def release(self) -> None:
+ return None
+
+ monkeypatch.setattr(
+ v2_mod.cv2, "VideoCapture", lambda *_args, **_kwargs: _FakeCapture()
+ )
+
+ with pytest.raises(ValueError, match="start_sec"):
+ video2frames_v2(dummy_video, start_sec=2.0, end_sec=1.0)
+
+ monkeypatch.setattr(
+ v2_mod,
+ "_extract_frames",
+ lambda *_args, **_kwargs: (_ for _ in ()).throw(RuntimeError("boom")),
+ )
+ assert (
+ video2frames_v2(dummy_video, start_sec=0.0, end_sec=1.0, n_threads=1)
+ == []
+ )
+
+
+def test_video2frames_v2_skips_empty_worker_chunks(monkeypatch, tmp_path):
+ import importlib
+
+ v2_mod = importlib.import_module(
+ "capybara.vision.videotools.video2frames_v2"
+ )
+
+ dummy_video = tmp_path / "x.mp4"
+ dummy_video.write_bytes(b"")
+
+ class _FakeCapture:
+ def get(self, prop):
+ if prop == v2_mod.cv2.CAP_PROP_FRAME_COUNT:
+ return 10
+ if prop == v2_mod.cv2.CAP_PROP_FPS:
+ return 10
+ return 0
+
+ def release(self) -> None:
+ return None
+
+ monkeypatch.setattr(
+ v2_mod.cv2, "VideoCapture", lambda *_args, **_kwargs: _FakeCapture()
+ )
+
+ calls: list[list[int]] = []
+
+ def fake_extract_frames(
+ inds,
+ video_path,
+ max_size=1920,
+ color_base="BGR",
+ global_ind=0,
+ ):
+ assert inds, "Should never schedule empty index chunks."
+ calls.append(list(inds))
+ return [], global_ind
+
+ monkeypatch.setattr(v2_mod, "_extract_frames", fake_extract_frames)
+
+ assert (
+ v2_mod.video2frames_v2(
+ dummy_video,
+ frame_per_sec=10,
+ start_sec=0.0,
+ end_sec=0.1,
+ n_threads=8,
+ )
+ == []
+ )
+ assert calls == [[0]]
+
+
+def test_video2frames_v2_validates_threads_fps_and_zero_duration(
+ monkeypatch, tmp_path
+):
+ import importlib
+
+ v2_mod = importlib.import_module(
+ "capybara.vision.videotools.video2frames_v2"
+ )
+
+ dummy_video = tmp_path / "x.mp4"
+ dummy_video.write_bytes(b"")
+
+ class _FakeCapture:
+ def get(self, prop):
+ if prop == v2_mod.cv2.CAP_PROP_FRAME_COUNT:
+ return 10
+ if prop == v2_mod.cv2.CAP_PROP_FPS:
+ return 10
+ return 0
+
+ def release(self) -> None:
+ return None
+
+ monkeypatch.setattr(
+ v2_mod.cv2, "VideoCapture", lambda *_args, **_kwargs: _FakeCapture()
+ )
+
+ with pytest.raises(ValueError, match="n_threads must be >= 1"):
+ v2_mod.video2frames_v2(dummy_video, n_threads=0)
+
+ with pytest.raises(ValueError, match="frame_per_sec must be > 0"):
+ v2_mod.video2frames_v2(dummy_video, frame_per_sec=0, n_threads=1)
+
+ assert (
+ v2_mod.video2frames_v2(
+ dummy_video,
+ frame_per_sec=10,
+ start_sec=1.0,
+ end_sec=1.0,
+ n_threads=1,
+ )
+ == []
+ )
diff --git a/tests/vision/visualization/test_draw.py b/tests/vision/visualization/test_draw.py
new file mode 100644
index 0000000..4a0b71d
--- /dev/null
+++ b/tests/vision/visualization/test_draw.py
@@ -0,0 +1,247 @@
+import numpy as np
+import pytest
+
+from capybara import Box, Boxes, Keypoints, KeypointsList, Polygon, Polygons
+from capybara.vision.visualization import draw as draw_mod
+
+
+def test_draw_box_denormalizes_normalized_box():
+ img = np.zeros((60, 80, 3), dtype=np.uint8)
+ box = Box([0.1, 0.2, 0.9, 0.8], box_mode="XYXY", is_normalized=True)
+ out = draw_mod.draw_box(img.copy(), box, color=(0, 255, 0), thickness=1)
+ assert out.shape == img.shape
+ assert out.dtype == img.dtype
+ # (x1, y1) should be denormalized to a non-zero coordinate.
+ assert out[12, 8].sum() > 0
+
+
+def test_draw_boxes_supports_per_box_colors_and_thicknesses():
+ img = np.zeros((40, 40, 3), dtype=np.uint8)
+ boxes = Boxes([[0, 0, 10, 10], [20, 20, 35, 35]], box_mode="XYXY")
+ out = draw_mod.draw_boxes(
+ img.copy(),
+ boxes,
+ colors=[(255, 0, 0), (0, 255, 0)],
+ thicknesses=[1, 2],
+ )
+ assert out.shape == img.shape
+ assert out.sum() > 0
+
+
+def test_draw_polygon_and_polygons_support_fillup_and_normalized_coords():
+ img = np.zeros((50, 50, 3), dtype=np.uint8)
+ poly = Polygon(
+ [(0.1, 0.1), (0.9, 0.1), (0.9, 0.9), (0.1, 0.9)],
+ is_normalized=True,
+ )
+ out_edges = draw_mod.draw_polygon(img.copy(), poly, fillup=False)
+ out_fill = draw_mod.draw_polygon(img.copy(), poly, fillup=True)
+ assert out_edges.shape == img.shape
+ assert out_fill.shape == img.shape
+ assert out_edges.sum() > 0
+ assert out_fill.sum() > 0
+
+ polys = Polygons([poly, poly.shift(0.0, -0.1)], is_normalized=True)
+ out_multi = draw_mod.draw_polygons(
+ img.copy(),
+ polys,
+ colors=[(0, 0, 255), (255, 255, 0)],
+ thicknesses=[1, 1],
+ fillup=False,
+ )
+ assert out_multi.shape == img.shape
+ assert out_multi.sum() > 0
+
+
+def test_draw_text_draws_pixels():
+ img = np.full((60, 200, 3), 255, dtype=np.uint8)
+ out = draw_mod.draw_text(
+ img.copy(),
+ "hello",
+ location=(5, 5),
+ color=(0, 0, 255),
+ text_size=18,
+ )
+ assert out.shape == img.shape
+ assert out.dtype == img.dtype
+ assert not np.array_equal(out, img)
+
+
+def test_draw_text_handles_fonts_without_getbbox(monkeypatch):
+ from PIL import ImageFont
+
+ class _NoBBoxFont:
+ def __init__(self):
+ self._font = ImageFont.load_default()
+
+ def getbbox(self, _text):
+ raise RuntimeError("boom")
+
+ def __getattr__(self, name: str):
+ return getattr(self._font, name)
+
+ monkeypatch.setattr(
+ draw_mod, "_load_font", lambda *_args, **_kwargs: _NoBBoxFont()
+ )
+
+ img = np.full((40, 160, 3), 255, dtype=np.uint8)
+ out = draw_mod.draw_text(img.copy(), "hello", location=(5, 5))
+ assert out.shape == img.shape
+
+
+@pytest.mark.parametrize("style", ["dotted", "line"])
+def test_draw_line_supports_styles(style: str):
+ img = np.zeros((40, 40, 3), dtype=np.uint8)
+ out = draw_mod.draw_line(
+ img,
+ pt1=(0, 0),
+ pt2=(39, 39),
+ color=(0, 255, 0),
+ thickness=2,
+ style=style,
+ gap=8,
+ inplace=False,
+ )
+ assert out.shape == img.shape
+ assert not np.array_equal(out, img)
+
+
+def test_draw_line_inplace_and_invalid_style():
+ img = np.zeros((20, 20, 3), dtype=np.uint8)
+ out = draw_mod.draw_line(
+ img,
+ pt1=(0, 0),
+ pt2=(19, 0),
+ color=(255, 0, 0),
+ thickness=1,
+ style="line",
+ gap=5,
+ inplace=True,
+ )
+ assert out is img
+ assert out.sum() > 0
+
+ with pytest.raises(ValueError, match="Unknown style"):
+ draw_mod.draw_line(img, (0, 0), (10, 10), style="invalid")
+
+
+def test_draw_point_and_draw_points_preserve_grayscale_shape():
+ gray = np.zeros((30, 30), dtype=np.uint8)
+ out = draw_mod.draw_point(
+ gray.copy(),
+ (15, 15),
+ scale=1.0,
+ color=(255, 0, 0),
+ thickness=-1,
+ )
+ assert out.shape == gray.shape
+ assert out.ndim == 2
+ assert out.sum() > 0
+
+ out2 = draw_mod.draw_points(
+ gray.copy(),
+ points=[(5, 5), (25, 25)],
+ scales=[1.0, 2.0],
+ colors=[(0, 255, 0), (0, 0, 255)],
+ thicknesses=[-1, -1],
+ )
+ assert out2.shape == (*gray.shape, 3)
+ assert out2.ndim == 3
+ assert out2.sum() > 0
+
+
+def test_draw_keypoints_and_list_support_normalized_inputs():
+ img = np.zeros((60, 60, 3), dtype=np.uint8)
+ kpts = Keypoints([(0.25, 0.25), (0.75, 0.75)], is_normalized=True)
+ out = draw_mod.draw_keypoints(img.copy(), kpts, scale=1.0, thickness=-1)
+ assert out.shape == img.shape
+ assert out.sum() > 0
+
+ kpts_list = KeypointsList([kpts, kpts.shift(0.0, -0.1)], is_normalized=True)
+ out2 = draw_mod.draw_keypoints_list(
+ img.copy(), kpts_list, scales=[1.0, 1.5], thicknesses=[-1, -1]
+ )
+ assert out2.shape == img.shape
+ assert out2.sum() > 0
+
+
+def test_generate_colors_and_distinct_color():
+ np.random.seed(0)
+ tri = draw_mod.generate_colors(3, scheme="triadic")
+ assert len(tri) == 3
+ assert all(isinstance(c, tuple) and len(c) == 3 for c in tri)
+
+ hsv = draw_mod.generate_colors(3, scheme="hsv")
+ assert len(hsv) == 3
+
+ analogous = draw_mod.generate_colors(3, scheme="analogous")
+ assert len(analogous) == 3
+
+ square = draw_mod.generate_colors(3, scheme="square")
+ assert len(square) == 3
+
+ unknown = draw_mod.generate_colors(2, scheme="not-a-scheme")
+ assert unknown == []
+
+ assert 0 <= draw_mod._vdc(1, base=2) < 1
+ c0 = draw_mod.distinct_color(0)
+ c1 = draw_mod.distinct_color(1)
+ assert isinstance(c0, tuple) and len(c0) == 3
+ assert c0 != c1
+
+ assert draw_mod._label_to_index("123") == 123
+ assert draw_mod._label_to_index("cat") == draw_mod._label_to_index("cat")
+
+
+def test_draw_mask_normalization_and_shape_checks():
+ img = np.zeros((20, 30, 3), dtype=np.uint8)
+ mask = np.arange(20 * 30, dtype=np.uint8).reshape(20, 30)
+ out = draw_mod.draw_mask(img, mask, min_max_normalize=True)
+ assert out.shape == img.shape
+
+ mask_bgr = np.stack([mask] * 3, axis=-1)
+ out2 = draw_mod.draw_mask(img, mask_bgr, min_max_normalize=False)
+ assert out2.shape == img.shape
+
+ bad_mask = np.zeros((20, 30, 2), dtype=np.uint8)
+ with pytest.raises(ValueError, match="Mask must be either 2D"):
+ draw_mod.draw_mask(img, bad_mask)
+
+
+def test_draw_detection_and_draw_detections_end_to_end():
+ img = np.zeros((80, 120, 3), dtype=np.uint8)
+ box = Box([0.05, 0.0, 0.5, 0.2], box_mode="XYXY", is_normalized=True)
+
+ out = draw_mod.draw_detection(
+ img.copy(),
+ box,
+ label="cat",
+ score=0.9,
+ color=None,
+ thickness=None,
+ box_alpha=0.5,
+ text_bg_alpha=0.5,
+ )
+ assert out.shape == img.shape
+ assert out.sum() > 0
+
+ boxes = Boxes([[0, 0, 20, 20], [30, 30, 60, 60]], box_mode="XYXY")
+ with pytest.raises(ValueError, match="Number of boxes must match"):
+ draw_mod.draw_detections(img, boxes, labels=["only-one"])
+ with pytest.raises(ValueError, match="Number of scores must match"):
+ draw_mod.draw_detections(img, boxes, labels=["a", "b"], scores=[0.1])
+
+ out2 = draw_mod.draw_detections(
+ img.copy(),
+ boxes,
+ labels=["a", "b"],
+ scores=[0.1, 0.2],
+ colors=[(0, 255, 0), (255, 0, 0)],
+ thicknesses=[1, 2],
+ text_colors=[(255, 255, 255), (0, 0, 0)],
+ text_sizes=[12, 13],
+ box_alpha=1.0,
+ text_bg_alpha=0.5,
+ )
+ assert out2.shape == img.shape
+ assert out2.sum() > 0
diff --git a/tests/vision/visualization/test_vis_utils.py b/tests/vision/visualization/test_vis_utils.py
index 81b9af4..8524737 100644
--- a/tests/vision/visualization/test_vis_utils.py
+++ b/tests/vision/visualization/test_vis_utils.py
@@ -1,23 +1,47 @@
+from typing import Any
+
import numpy as np
import pytest
-from capybara.structures import (Box, Boxes, Keypoints, KeypointsList, Polygon,
- Polygons)
-from capybara.vision.visualization.utils import *
+from capybara.structures import (
+ Box,
+ Boxes,
+ Keypoints,
+ KeypointsList,
+ Polygon,
+ Polygons,
+)
+from capybara.vision.visualization.utils import (
+ is_numpy_img,
+ prepare_box,
+ prepare_boxes,
+ prepare_color,
+ prepare_colors,
+ prepare_img,
+ prepare_keypoints,
+ prepare_keypoints_list,
+ prepare_point,
+ prepare_polygon,
+ prepare_polygons,
+ prepare_scale,
+ prepare_scales,
+ prepare_thickness,
+ prepare_thicknesses,
+)
def test_is_numpy_img():
img = np.random.random((100, 100, 3))
- assert is_numpy_img(img) == True
+ assert is_numpy_img(img) is True
img = np.random.random((100, 100))
- assert is_numpy_img(img) == True
+ assert is_numpy_img(img) is True
- img = np.random.random((100, ))
- assert is_numpy_img(img) == False
+ img = np.random.random((100,))
+ assert is_numpy_img(img) is False
img = "not an image"
- assert is_numpy_img(img) == False
+ assert is_numpy_img(img) is False
def test_prepare_color():
@@ -33,12 +57,14 @@ def test_prepare_color():
color = 0
assert prepare_color(color) == (0, 0, 0)
- color = 'black'
+ color: Any = "black"
with pytest.raises(TypeError):
prepare_color(color)
- color = (0.1, 0.1, 0.1)
- with pytest.raises(TypeError, match=r'[0-9a-zA-Z=,. ]+colors\[2\][0-9a-zA-Z=,. ()[\]]+'):
+ color: Any = (0.1, 0.1, 0.1)
+ with pytest.raises(
+ TypeError, match=r"[0-9a-zA-Z=,. ]+colors\[2\][0-9a-zA-Z=,. ()[\]]+"
+ ):
prepare_color(color, 2)
@@ -59,29 +85,38 @@ def test_prepare_colors():
length = 3
assert prepare_colors(colors, length) == [(0, 0, 0), (0, 0, 0), (0, 0, 0)]
- colors = 'black'
+ colors: Any = "black"
length = 3
with pytest.raises(TypeError):
prepare_colors(colors, length)
colors = [(0, 0, 0), (1, 1, 1), (2, 2, 2)]
length = 2
- with pytest.raises(ValueError, match=r'The length of colors = 3 is not equal to the length = 2.'):
+ with pytest.raises(
+ ValueError,
+ match=r"The length of colors = 3 is not equal to the length = 2.",
+ ):
prepare_colors(colors, length)
colors = [(0, 0, 0), (1.1, 1.1, 1.1), (2, 2, 2)]
length = 3
- with pytest.raises(TypeError, match=r'[0-9a-zA-Z = , . ]+colors\[1\][0-9a-zA-Z = , . ()[\]]+'):
+ with pytest.raises(
+ TypeError,
+ match=r"[0-9a-zA-Z = , . ]+colors\[1\][0-9a-zA-Z = , . ()[\]]+",
+ ):
prepare_colors(colors, length)
- colors = 0.1
+ colors: Any = 0.1
length = 3
- with pytest.raises(TypeError, match=r'[0-9a-zA-Z = , . ]+colors\[0\][0-9a-zA-Z = , . ()[\]]+'):
+ with pytest.raises(
+ TypeError,
+ match=r"[0-9a-zA-Z = , . ]+colors\[0\][0-9a-zA-Z = , . ()[\]]+",
+ ):
prepare_colors(colors, length)
def test_prepare_img():
- tgt_img1 = np.random.randint(0, 255, (100, 100, 3), dtype='uint8')
+ tgt_img1 = np.random.randint(0, 255, (100, 100, 3), dtype="uint8")
img = tgt_img1.copy()
np.testing.assert_allclose(prepare_img(img), tgt_img1)
@@ -93,7 +128,12 @@ def test_prepare_img():
with pytest.raises(ValueError):
prepare_img(img)
- img = "not an image"
+ img = np.random.randint(0, 255, (10, 20, 1), dtype="uint8")
+ out = prepare_img(img)
+ assert out.shape == (10, 20, 3)
+ np.testing.assert_allclose(out[..., 0], img[..., 0])
+
+ img: Any = "not an image"
with pytest.raises(ValueError):
prepare_img(img)
@@ -111,16 +151,23 @@ def test_prepare_box():
box = np.array([0, 0, 100, 100])
assert prepare_box(box) == tgt_box
- box = 0
- with pytest.raises(ValueError, match=r"[0-9a-zA-Z=,. ]+0[0-9a-zA-Z=,. ()[\]]+"):
+ box: Any = 0
+ with pytest.raises(
+ ValueError, match=r"[0-9a-zA-Z=,. ]+0[0-9a-zA-Z=,. ()[\]]+"
+ ):
prepare_box(box)
box = (0, 0, 100)
- with pytest.raises(ValueError, match=r"[0-9a-zA-Z=,. ]+\(0, 0, 100\)[0-9a-zA-Z=,. ()[\]']+"):
+ with pytest.raises(
+ ValueError, match=r"[0-9a-zA-Z=,. ]+\(0, 0, 100\)[0-9a-zA-Z=,. ()[\]']+"
+ ):
prepare_box(box)
box = (0, 0, 100, 100, 100)
- with pytest.raises(ValueError, match=r"[0-9a-zA-Z=,. ]+\(0, 0, 100, 100, 100\)[0-9a-zA-Z=,. ()[\]']+"):
+ with pytest.raises(
+ ValueError,
+ match=r"[0-9a-zA-Z=,. ]+\(0, 0, 100, 100, 100\)[0-9a-zA-Z=,. ()[\]']+",
+ ):
prepare_box(box)
@@ -134,8 +181,13 @@ def test_prepare_boxes():
np_boxes = np.array(boxes_list)
assert prepare_boxes(np_boxes) == boxes
- boxes = [(0, 1), ]
- with pytest.raises(ValueError, match=r"[0-9a-zA-Z=,. ]+boxes\[0\][0-9a-zA-Z=,. ]+\(0, 1\)[0-9a-zA-Z=,. ()[\]']+"):
+ boxes = [
+ (0, 1),
+ ]
+ with pytest.raises(
+ ValueError,
+ match=r"[0-9a-zA-Z=,. ]+boxes\[0\][0-9a-zA-Z=,. ]+\(0, 1\)[0-9a-zA-Z=,. ()[\]']+",
+ ):
prepare_boxes(boxes)
@@ -150,8 +202,11 @@ def test_prepare_keypoints():
assert prepare_keypoints(tgt_keypoints) == tgt_keypoints
- keypoints = [[0, 1], [2, 1], [3, 1]]
- with pytest.raises(TypeError, match=r"[0-9a-zA-Z=,. ]+\[\[0, 1\], \[2, 1\], \[3, 1\]\][0-9a-zA-Z=,. ()[\]]+"):
+ keypoints: Any = [[0, 1], [2, 1], [3, 1]]
+ with pytest.raises(
+ TypeError,
+ match=r"[0-9a-zA-Z=,. ]+\[\[0, 1\], \[2, 1\], \[3, 1\]\][0-9a-zA-Z=,. ()[\]]+",
+ ):
prepare_keypoints(keypoints)
@@ -177,7 +232,10 @@ def test_prepare_keypoints_list():
[[0, 1], [1, 2], [3, 4]],
[(0, 1), (1, 2), (3, 3)],
]
- with pytest.raises(TypeError, match=r"[0-9a-zA-Z=,. ]+keypoints_list\[0\][0-9a-zA-Z=,. ()[\]]+"):
+ with pytest.raises(
+ TypeError,
+ match=r"[0-9a-zA-Z=,. ]+keypoints_list\[0\][0-9a-zA-Z=,. ()[\]]+",
+ ):
prepare_keypoints_list(keypoints_list)
@@ -191,7 +249,10 @@ def test_prepare_polygon():
assert prepare_polygon(np_polygon) == tgt_polygon
polygon = [[0, 1], [1, 2, 3]]
- with pytest.raises(TypeError, match=r"[0-9a-zA-Z=,. ]+\[\[0, 1\], \[1, 2, 3\]\][0-9a-zA-Z=,. ()[\]]+"):
+ with pytest.raises(
+ TypeError,
+ match=r"[0-9a-zA-Z=,. ]+\[\[0, 1\], \[1, 2, 3\]\][0-9a-zA-Z=,. ()[\]]+",
+ ):
prepare_polygon(polygon)
@@ -199,7 +260,14 @@ def test_prepare_polygons():
tgt_polygons = Polygons(
[
[[0, 1], [1, 2], [3, 4]],
- [[1, 2,], [3, 3], [5, 5]],
+ [
+ [
+ 1,
+ 2,
+ ],
+ [3, 3],
+ [5, 5],
+ ],
]
)
polygons = [
@@ -215,7 +283,9 @@ def test_prepare_polygons():
[[0, 1], [1, 2], [3, 4]],
[[1, 2], [3, 3], [5, 5, 3]],
]
- with pytest.raises(TypeError, match=r"[0-9a-zA-Z=,. ]+polygons\[1\][0-9a-zA-Z=,. ()[\]]+"):
+ with pytest.raises(
+ TypeError, match=r"[0-9a-zA-Z=,. ]+polygons\[1\][0-9a-zA-Z=,. ()[\]]+"
+ ):
prepare_polygons(polygons)
@@ -263,3 +333,15 @@ def test_prepare_scales():
scales = 1.2
assert prepare_scales(scales, 2) == [1.2, 1.2]
+
+
+def test_prepare_point_and_numeric_validations_cover_error_paths():
+ point: Any = (1,)
+ with pytest.raises(TypeError, match=r"points\[0\]"):
+ prepare_point(point, ind=0)
+
+ with pytest.raises(ValueError, match=r"thickness\[s\[1\]\]"):
+ prepare_thickness(-2, ind=1)
+
+ with pytest.raises(ValueError, match=r"scale\[s\[1\]\]"):
+ prepare_scale(-2, ind=1)