diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..b0d4dbf --- /dev/null +++ b/.coveragerc @@ -0,0 +1,15 @@ +[run] +omit = + */config.py + */config-*.py + capybara/cpuinfo.py + capybara/utils/system_info.py + capybara/vision/ipcam/app.py + +[report] +omit = + */config.py + */config-*.py + capybara/cpuinfo.py + capybara/utils/system_info.py + capybara/vision/ipcam/app.py diff --git a/.github/scripts/coverage_gate.py b/.github/scripts/coverage_gate.py new file mode 100644 index 0000000..3c2d32e --- /dev/null +++ b/.github/scripts/coverage_gate.py @@ -0,0 +1,173 @@ +from __future__ import annotations + +import argparse +import sys +from pathlib import Path +from xml.etree import ElementTree + + +def _as_float(value: str | None, default: float | None = None) -> float | None: + if value is None: + return default + text = value.strip() + if not text: + return default + try: + return float(text) + except ValueError: + return default + + +def _as_bool(value: str | None) -> bool: + if value is None: + return False + return str(value).strip().lower() in {"1", "true", "yes", "on", "y"} + + +def _percent(value: float | None) -> str: + if value is None or value < 0: + return "N/A" + pct = value * 100.0 + if pct.is_integer(): + return f"{int(pct)}%" + return f"{pct:.2f}%" + + +def _load_coverage_root(path: Path) -> ElementTree.Element | None: + if not path.exists(): + return None + try: + return ElementTree.parse(path).getroot() + except ElementTree.ParseError: + return None + + +def evaluate_coverage( + coverage_path: Path, + min_line: float | None, + min_branch: float | None, +) -> tuple[bool, list[str], float | None, float | None]: + """Return a tuple describing coverage gate result.""" + root = _load_coverage_root(coverage_path) + if root is None: + return False, [f"coverage XML '{coverage_path}' 無法讀取."], None, None + + line_rate = _as_float(root.attrib.get("line-rate")) + branch_rate = _as_float(root.attrib.get("branch-rate")) + + messages: list[str] = [] + passed = True + + if min_line is not None: + if line_rate is None: + passed = False + messages.append("行覆蓋率資料不存在.") + elif line_rate + 1e-9 < min_line: + passed = False + messages.append( + f"行覆蓋率 {_percent(line_rate)} 低於門檻 {_percent(min_line)}." + ) + else: + messages.append( + f"行覆蓋率 {_percent(line_rate)} >= 門檻 {_percent(min_line)}." + ) + + if min_branch is not None: + if branch_rate is None: + passed = False + messages.append("分支覆蓋率資料不存在.") + elif branch_rate + 1e-9 < min_branch: + passed = False + messages.append( + f"分支覆蓋率 {_percent(branch_rate)} 低於門檻 {_percent(min_branch)}." + ) + else: + messages.append( + f"分支覆蓋率 {_percent(branch_rate)} >= 門檻 {_percent(min_branch)}." + ) + + if min_line is None and min_branch is None: + messages.append("未設定覆蓋率門檻, 跳過檢查.") + + return passed, messages, line_rate, branch_rate + + +def parse_args(argv: list[str]) -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Coverage gate checker") + parser.add_argument( + "--file", + required=True, + help="Path to coverage XML report generated by coverage.py", + ) + parser.add_argument( + "--min-line", + dest="min_line", + default=None, + help="Minimum line coverage required (0.0 - 1.0)", + ) + parser.add_argument( + "--min-branch", + dest="min_branch", + default=None, + help="Minimum branch coverage required (0.0 - 1.0)", + ) + parser.add_argument( + "--enforce", + dest="enforce", + default="0", + help="Set to 1/true to enforce the gate (exit with non-zero on failure)", + ) + return parser.parse_args(argv) + + +def main(argv: list[str] | None = None) -> int: + args = parse_args(argv or sys.argv[1:]) + + coverage_path = Path(args.file) + min_line = _as_float(args.min_line) + min_branch = _as_float(args.min_branch) + enforce_gate = _as_bool(args.enforce) + + if not coverage_path.exists(): + message = f"找不到覆蓋率報告: {coverage_path}" + if enforce_gate: + print(f"::error::{message}") + print("::error::Coverage Gate Result: FAIL (N/A)") + return 2 + print(f"::warning::{message}") + print("::warning::Coverage Gate Result: SKIP (N/A)") + return 0 + + passed, messages, line_rate, branch_rate = evaluate_coverage( + coverage_path, min_line, min_branch + ) + prefix = "::notice::" if passed else "::error::" + for line in messages: + print(f"{prefix}{line}") + + summary_parts: list[str] = [] + if line_rate is not None: + summary_parts.append(f"行 {_percent(line_rate)}") + if branch_rate is not None: + summary_parts.append(f"分支 {_percent(branch_rate)}") + summary_text = ", ".join(summary_parts) if summary_parts else "N/A" + summary_prefix = ( + "::notice::" + if passed + else ("::warning::" if not enforce_gate else "::error::") + ) + result_text = "PASS" if passed else "FAIL" + print( + f"{summary_prefix}Coverage Gate Result: {result_text} ({summary_text})" + ) + + if passed or not enforce_gate: + if not passed: + print("::warning::覆蓋率未達門檻, 但目前未強制執行 (enforce=0).") + return 0 + + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/.github/scripts/generate_pytest_summary.py b/.github/scripts/generate_pytest_summary.py new file mode 100644 index 0000000..1a377da --- /dev/null +++ b/.github/scripts/generate_pytest_summary.py @@ -0,0 +1,548 @@ +import os +import re +import textwrap +from contextlib import suppress +from dataclasses import dataclass +from pathlib import Path +from xml.etree import ElementTree + +MARKER = "" +ARTIFACT_PLACEHOLDER = "" +DEFAULT_LOG_TAIL_LINES = 200 +DEFAULT_LOG_CHAR_LIMIT = 6000 +SNIPPET_CHAR_LIMIT = 4000 +TOP_SLOW_N = 20 + + +@dataclass +class CoverageSummary: + line_rate: float | None = None + branch_rate: float | None = None + line_pct: str | None = None + branch_pct: str | None = None + short_text: str = "覆蓋率資料不可用" + + +REPORT_DIR = Path(".ci-reports/pytest") +RAW_REPORT_DIR = Path(os.getenv("PYTEST_REPORT_DIR", ".ci-reports/pytest/raw")) +LEGACY_REPORT_DIR = Path(".pytest-reports") +COVERAGE_REPORT_FILE = "coverage_report.txt" + + +def _first_existing(paths: list[Path]) -> Path | None: + for candidate in paths: + if candidate.exists(): + return candidate + return paths[0] if paths else None + + +def _as_float(value: str | float | None) -> float | None: + if value is None: + return None + try: + return float(value) + except (TypeError, ValueError): + return None + + +def _threshold_value(value: str | None) -> float | None: + if value is None: + return None + try: + return float(str(value)) + except (TypeError, ValueError): + return None + + +def fmt_outcome(value: str) -> str: + mapping = { + "success": "✅ 通過", + "failure": "❌ 失敗", + "cancelled": "⚪ 取消", + "skipped": "⚪ 未執行", + } + value = (value or "").strip() + return mapping.get(value, f"⚪ {value or '未知'}") + + +def fmt_percent(raw: str | float | int | None) -> str | None: + if raw is None: + return None + if isinstance(raw, int | float): + value = float(raw) + else: + try: + value = float(raw) + except (TypeError, ValueError): + return None + pct = value * 100.0 + if pct.is_integer(): + return f"{int(pct)}%" + return f"{pct:.2f}%" + + +def cleanse_snippet(snippet: str, limit: int = SNIPPET_CHAR_LIMIT) -> str: + snippet = textwrap.dedent(snippet).strip() + if not snippet: + return "(無訊息)" + snippet = "\n".join(line.rstrip() for line in snippet.splitlines()) + if len(snippet) > limit: + snippet = f"{snippet[:limit]}\n... (截斷)" + return snippet + + +def load_log_tail( + path: Path, + *, + limit: int = DEFAULT_LOG_CHAR_LIMIT, + tail_lines: int = DEFAULT_LOG_TAIL_LINES, +) -> str | None: + if not path.exists(): + return None + try: + raw = path.read_text(encoding="utf-8", errors="ignore").strip() + except OSError: + return None + if not raw: + return None + lines = raw.splitlines() + tail = "\n".join(lines[-tail_lines:]) + if len(tail) > limit: + tail = tail[-limit:] + return tail + + +def _read_text_if_exists(path: Path) -> str | None: + if not path.exists(): + return None + try: + return path.read_text(encoding="utf-8", errors="ignore") + except OSError: + return None + + +def append_coverage_report( + lines: list[str], report_path: Path, max_lines: int = 80 +) -> None: + raw = _read_text_if_exists(report_path) + if raw is None: + return + content = [line.rstrip() for line in raw.splitlines() if line.strip()] + if not content: + return + lines.append("") + lines.append("### 覆蓋率缺口") + lines.append("") + lines.append("
未覆蓋的檔案列表") + lines.append("") + lines.append("```") + for line in content[:max_lines]: + lines.append(line) + if len(content) > max_lines: + lines.append("... (截斷)") + lines.append("```") + lines.append("
") + + +def _iter_testcases(root: ElementTree.Element) -> list[ElementTree.Element]: + return list(root.findall(".//testcase")) + + +def _parse_root(xml_path: Path) -> ElementTree.Element | None: + if not xml_path.exists(): + return None + try: + return ElementTree.parse(xml_path).getroot() + except ElementTree.ParseError: + return None + + +def append_pytest_summary(lines: list[str], junit_path: Path) -> None: + root = _parse_root(junit_path) + if root is None: + return + + attrs = root.attrib + tests = attrs.get("tests") + failures = attrs.get("failures") or attrs.get("failed") + errors = attrs.get("errors") + skipped = attrs.get("skipped") or attrs.get("disabled") + time_spent = attrs.get("time") + + def _safe_int(x: str | int | float | None, default: int | str = 0) -> int: + if x is None: + return int(default) + try: + return int(float(x)) + except (TypeError, ValueError): + try: + return int(x) + except (TypeError, ValueError): + return int(default) + + if tests is None: + tests = 0 + failures = 0 + errors = 0 + skipped = 0 + total_time = 0.0 + for suite in root.findall(".//testsuite"): + a = suite.attrib + tests += _safe_int(a.get("tests"), 0) + failures += _safe_int(a.get("failures") or a.get("failed"), 0) + errors += _safe_int(a.get("errors"), 0) + skipped += _safe_int(a.get("skipped") or a.get("disabled"), 0) + with suppress(Exception): + total_time += float(a.get("time") or 0.0) + time_spent = f"{total_time:.3f}" + else: + tests = _safe_int(tests, 0) + failures = _safe_int(failures, 0) + errors = _safe_int(errors, 0) + skipped = _safe_int(skipped, 0) + + passed = max(tests - failures - errors - skipped, 0) + + lines.append( + f"- 測試統計: 共 {tests}; ✅ 通過 {passed}; ❌ 失敗 {failures}; ⚠️ 錯誤 {errors}; ⏭️ 跳過 {skipped}; 耗時 {time_spent or 'N/A'} 秒" + ) + + failures_block: list[str] = [] + for case in _iter_testcases(root): + failure = case.find("failure") + error = case.find("error") + target = failure or error + if target is None: + continue + status = "failure" if failure is not None else "error" + classname = case.attrib.get("classname") or "" + name = case.attrib.get("name") or "" + elapsed = case.attrib.get("time") or "" + header = f"- `{classname}.{name}` ({status}" + if elapsed: + header += f", {elapsed} 秒" + header += ")" + snippet = target.attrib.get("message", "") or "" + text_body = target.text or "" + combined = "\n".join(part for part in (snippet, text_body) if part) + failures_block.append(header) + formatted = cleanse_snippet(combined) + indented = "\n".join(f" {line}" for line in formatted.splitlines()) + failures_block.append(" ```") + failures_block.append(indented or " (無訊息)") + failures_block.append(" ```") + + if failures_block: + lines.append("") + lines.append("#### 失敗與錯誤測試") + lines.append("") + lines.extend(failures_block) + + +def append_top_slowest( + lines: list[str], junit_path: Path, top_n: int = TOP_SLOW_N +) -> None: + root = _parse_root(junit_path) + if root is None: + return + records: list[tuple[float, str]] = [] + for case in _iter_testcases(root): + try: + t = float(case.attrib.get("time") or 0.0) + except Exception: + t = 0.0 + if t <= 0.0: + continue + classname = case.attrib.get("classname") or "" + name = case.attrib.get("name") or "" + fqname = f"{classname}.{name}" if classname else name + records.append((t, fqname)) + if not records: + return + records.sort(key=lambda x: x[0], reverse=True) + lines.append("") + lines.append(f"#### 最慢測試 Top {min(top_n, len(records))}") + lines.append("") + lines.append("| 測試 | 耗時 (秒) |") + lines.append("| --- | ---: |") + for t, fqname in records[:top_n]: + lines.append(f"| `{fqname}` | {t:.3f} |") + + +def append_warnings_summary( + lines: list[str], candidate_logs: list[Path] +) -> None: + text: str | None = None + for p in candidate_logs: + text = _read_text_if_exists(p) + if text: + break + if not text: + return + + start = None + end = None + all_lines = text.splitlines() + for i, line in enumerate(all_lines): + if re.search(r"=+\s*warnings summary\s*=+", line, re.IGNORECASE): + start = i + break + if start is None: + return + for j in range(start + 1, len(all_lines)): + if re.match(r"=+\s", all_lines[j]): + end = j + break + block = ( + "\n".join(all_lines[start:end]).strip() + if end is not None + else "\n".join(all_lines[start:]).strip() + ) + block = cleanse_snippet(block, limit=SNIPPET_CHAR_LIMIT) + + if block: + lines.append("") + lines.append("#### Warnings 摘要") + lines.append("") + lines.append("```") + lines.append(block) + lines.append("```") + + +def append_session_meta(lines: list[str], candidate_logs: list[Path]) -> None: + text: str | None = None + for p in candidate_logs: + text = _read_text_if_exists(p) + if text: + break + if not text: + return + platform_line = re.search(r"^platform .+$", text, re.MULTILINE) + rootdir_line = re.search(r"^rootdir: .+$", text, re.MULTILINE) + plugins_line = re.search(r"^plugins: .+$", text, re.MULTILINE) + + entries = [] + if platform_line: + entries.append(f"- {platform_line.group(0)}") + if rootdir_line: + entries.append(f"- {rootdir_line.group(0)}") + if plugins_line: + entries.append(f"- {plugins_line.group(0)}") + + if entries: + lines.append("") + lines.append("### 測試環境") + lines.append("") + lines.extend(entries) + + +def append_coverage_summary( + lines: list[str], coverage_path: Path +) -> CoverageSummary: + summary = CoverageSummary() + + lines.append("") + lines.append("### 覆蓋率") + lines.append("") + + root = _parse_root(coverage_path) + if root is None: + lines.append("- 未產生覆蓋率報告 (可能因測試失敗或覆蓋率未啟用)。") + summary.short_text = "覆蓋率資料不可用" + return summary + + line_rate = _as_float(root.attrib.get("line-rate")) + branch_rate = _as_float(root.attrib.get("branch-rate")) + overall_line = fmt_percent(line_rate) + overall_branch = fmt_percent(branch_rate) + summary.line_rate = line_rate + summary.branch_rate = branch_rate + summary.line_pct = overall_line + summary.branch_pct = overall_branch + + lines_valid = root.attrib.get("lines-valid") + lines_covered = root.attrib.get("lines-covered") + branches_valid = root.attrib.get("branches-valid") + branches_covered = root.attrib.get("branches-covered") + overall_parts: list[str] = [] + if overall_line: + detail = overall_line + if lines_covered and lines_valid: + detail += f" ({lines_covered}/{lines_valid})" + overall_parts.append(f"行 {detail}") + if overall_branch: + detail = overall_branch + if branches_covered and branches_valid: + detail += f" ({branches_covered}/{branches_valid})" + overall_parts.append(f"分支 {detail}") + if overall_parts: + lines.append(f"- 總覆蓋率: {', '.join(overall_parts)}") + else: + lines.append("- 總覆蓋率資料不可用。") + + records: list[tuple[float, str, str, str]] = [] + for cls in root.findall(".//class"): + filename = cls.attrib.get("filename") or cls.attrib.get("name") + if not filename: + continue + line_rate = cls.attrib.get("line-rate") + branch_rate = cls.attrib.get("branch-rate") + try: + line_value = float(line_rate) if line_rate is not None else 1.0 + except ValueError: + line_value = 1.0 + records.append( + ( + line_value, + filename, + fmt_percent(line_rate) or "N/A", + fmt_percent(branch_rate) or "N/A", + ) + ) + + if records: + records.sort(key=lambda item: item[0]) + lines.append("") + lines.append("#### 覆蓋率最低的檔案 (前 10 名)") + lines.append("") + lines.append("| 檔案 | 行覆蓋 | 分支覆蓋 |") + lines.append("| --- | --- | --- |") + for _, filename, line_pct, branch_pct in records[:10]: + lines.append(f"| `{filename}` | {line_pct} | {branch_pct} |") + + summary_parts: list[str] = [] + if overall_line: + summary_parts.append(f"行 {overall_line}") + if overall_branch: + summary_parts.append(f"分支 {overall_branch}") + summary.short_text = ( + " / ".join(summary_parts) if summary_parts else "覆蓋率資料不可用" + ) + + return summary + + +def build_coverage_kpi( + info: CoverageSummary, + min_line: float | None, + min_branch: float | None, +) -> str: + line_pct = info.line_pct + branch_pct = info.branch_pct + line_rate = info.line_rate + branch_rate = info.branch_rate + + actual_parts: list[str] = [] + if line_pct: + actual_parts.append(f"行 {line_pct}") + if branch_pct: + actual_parts.append(f"分支 {branch_pct}") + actual_text = ( + " / ".join(actual_parts) if actual_parts else "覆蓋率資料不可用" + ) + + threshold_parts: list[str] = [] + if min_line is not None: + min_line_pct = fmt_percent(min_line) or f"{min_line * 100:.2f}%" + threshold_parts.append(f"行 {min_line_pct}") + if min_branch is not None: + min_branch_pct = fmt_percent(min_branch) or f"{min_branch * 100:.2f}%" + threshold_parts.append(f"分支 {min_branch_pct}") + threshold_text = "" + if threshold_parts: + threshold_text = f"(門檻 {' / '.join(threshold_parts)})" + + thresholds_provided = bool(threshold_parts) + status_ok = True + if min_line is not None and ( + line_rate is None or line_rate + 1e-9 < min_line + ): + status_ok = False + if min_branch is not None and ( + branch_rate is None or branch_rate + 1e-9 < min_branch + ): + status_ok = False + + if info.short_text == "覆蓋率資料不可用": + status_icon = "⚪" + elif thresholds_provided: + status_icon = "✅" if status_ok else "❌" + else: + status_icon = "✅" + + return f"{actual_text}{threshold_text} → {status_icon}" + + +def main() -> None: + report_dir = REPORT_DIR + report_dir.mkdir(parents=True, exist_ok=True) + summary_path = report_dir / "summary.md" + + outcome = os.getenv("PYTEST_OUTCOME", "") + exit_code = os.getenv("PYTEST_EXIT_CODE", "") + + lines: list[str] = [MARKER, "### Pytest 結果"] + status_line = f"- 狀態: {fmt_outcome(outcome)}" + if exit_code: + status_line += f" (exit={exit_code})" + lines.append(status_line) + + junit_candidates = [ + RAW_REPORT_DIR / "pytest.xml", + report_dir / "pytest.xml", + LEGACY_REPORT_DIR / "pytest.xml", + ] + coverage_candidates = [ + RAW_REPORT_DIR / "coverage.xml", + report_dir / "coverage.xml", + LEGACY_REPORT_DIR / "coverage.xml", + ] + junit_path = _first_existing(junit_candidates) or junit_candidates[0] + coverage_xml = ( + _first_existing(coverage_candidates) or coverage_candidates[0] + ) + log_candidates = [ + report_dir / "pytest.log", + report_dir / "run.log", + RAW_REPORT_DIR / "pytest.log", + LEGACY_REPORT_DIR / "pytest.log", + ] + + append_session_meta(lines, log_candidates) + append_pytest_summary(lines, junit_path) + append_top_slowest(lines, junit_path, top_n=TOP_SLOW_N) + append_warnings_summary(lines, log_candidates) + coverage_info = append_coverage_summary(lines, coverage_xml) + append_coverage_report(lines, report_dir / COVERAGE_REPORT_FILE) + + min_line = _threshold_value(os.getenv("COVERAGE_MIN_LINE")) + min_branch = _threshold_value(os.getenv("COVERAGE_MIN_BRANCH")) + coverage_kpi = build_coverage_kpi(coverage_info, min_line, min_branch) + + lines.append("") + lines.append(f"- 產物: {ARTIFACT_PLACEHOLDER}") + + if outcome.strip() == "failure": + tail = None + for p in log_candidates: + tail = load_log_tail(p) + if tail: + break + if tail: + lines.append("") + lines.append("
") + lines.append("Pytest 輸出 (最後 200 行)") + lines.append("") + lines.append("```") + lines.append(tail) + lines.append("```") + lines.append("
") + + summary_path.write_text("\n".join(lines) + "\n", encoding="utf-8") + + coverage_txt_path = report_dir / "coverage.txt" + coverage_txt_path.write_text(f"{coverage_kpi}\n", encoding="utf-8") + + +if __name__ == "__main__": + main() diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..131b08c --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,428 @@ +name: Capybara CI + +on: + pull_request: + branches: [main] + types: [opened, synchronize, reopened, ready_for_review] + push: + branches: [main] + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + qa: + name: QA (ruff + pyright + pytest) + runs-on: [self-hosted, unicorn] + timeout-minutes: 45 + permissions: + actions: read + contents: read + packages: read + pull-requests: write + checks: write + env: + PYTHON_VERSION: "3.10" + REPORT_DIR: .ci-reports + PYTEST_REPORT_DIR: .ci-reports/pytest/raw + RUFF_REPORT_DIR: .ci-reports/ruff + PYRIGHT_REPORT_DIR: .ci-reports/pyright + COVERAGE_MIN_LINE: "0.99" + COVERAGE_MIN_BRANCH: "0.00" + COVERAGE_ENFORCE: "1" + ARTIFACT_RETENTION_DAYS: 14 + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ env.PYTHON_VERSION }} + cache-dependency-path: | + pyproject.toml + **/requirements*.txt + + - name: Install dependencies + run: | + PYTORCH_INDEX_URL="https://download.pytorch.org/whl/cpu" + python -m pip install --upgrade pip wheel + python -m pip install torch --index-url "$PYTORCH_INDEX_URL" + python -m pip install -e . + python -m pip install pytest pytest-cov pytest-html coverage ruff pyright onnx onnxslim + + - name: Prepare report directories + run: | + mkdir -p "$REPORT_DIR" + mkdir -p "$PYTEST_REPORT_DIR" + mkdir -p "$RUFF_REPORT_DIR" + mkdir -p "$PYRIGHT_REPORT_DIR" + + - name: Ruff - Lint + id: ruff_lint + continue-on-error: true + run: | + set -o pipefail + set +e + python -m ruff version + python -m ruff check capybara tests | tee "$RUFF_REPORT_DIR/lint.log" + status=${PIPESTATUS[0]} + set -e + echo "exit_code=$status" >> "$GITHUB_OUTPUT" + exit $status + + - name: Ruff - Format Check + id: ruff_format + continue-on-error: true + run: | + set -o pipefail + set +e + python -m ruff format --check capybara tests | tee "$RUFF_REPORT_DIR/format.log" + status=${PIPESTATUS[0]} + set -e + echo "exit_code=$status" >> "$GITHUB_OUTPUT" + exit $status + + - name: Pyright + id: pyright + continue-on-error: true + run: | + set -o pipefail + set +e + python -m pyright --version + python -m pyright | tee "$PYRIGHT_REPORT_DIR/pyright.log" + status=${PIPESTATUS[0]} + set -e + echo "exit_code=$status" >> "$GITHUB_OUTPUT" + exit $status + + - name: Pytest (XML + HTML + Coverage) + id: pytest + continue-on-error: true + env: + PYTEST_ADDOPTS: "-vv -ra" + run: | + set -o pipefail + raw_dir="$PYTEST_REPORT_DIR" + final_dir="$REPORT_DIR/pytest" + mkdir -p "$final_dir" "$raw_dir" + set +e + python -m pytest \ + --junitxml="$raw_dir/pytest.xml" \ + --cov=capybara \ + --cov-report=term-missing:skip-covered \ + --cov-report=xml:"$raw_dir/coverage.xml" \ + --cov-report=html:"$raw_dir/htmlcov" \ + --html="$raw_dir/pytest.html" \ + --self-contained-html \ + --durations=25 \ + tests 2>&1 | tee "$raw_dir/pytest.log" + status=${PIPESTATUS[0]} + set -e + if [[ -f .coverage ]]; then + python -m coverage report -m --skip-empty --skip-covered > "$raw_dir/coverage_report.txt" || true + else + : > "$raw_dir/coverage_report.txt" + fi + for artifact in pytest.log pytest.xml coverage.xml pytest.html coverage_report.txt; do + if [[ -f "$raw_dir/$artifact" ]]; then + cp "$raw_dir/$artifact" "$final_dir/$artifact" + fi + done + if [[ -f "$raw_dir/pytest.log" ]]; then + cp "$raw_dir/pytest.log" "$final_dir/run.log" + fi + if [[ -d "$raw_dir/htmlcov" ]]; then + rm -rf "$final_dir/htmlcov" + cp -R "$raw_dir/htmlcov" "$final_dir/htmlcov" + fi + echo "exit_code=$status" >> "$GITHUB_OUTPUT" + exit $status + + - name: Generate Pytest summary + if: always() + env: + PYTEST_OUTCOME: ${{ steps.pytest.outcome }} + PYTEST_EXIT_CODE: ${{ steps.pytest.outputs.exit_code }} + run: | + python .github/scripts/generate_pytest_summary.py + summary_path="$REPORT_DIR/pytest/summary.md" + if [[ ! -s "$summary_path" ]]; then + echo "::error::Pytest summary was not generated" + exit 1 + fi + { + echo "PYTEST_SUMMARY_BODY<> "$GITHUB_ENV" + + - name: Coverage Gate (optional) + id: coverage_gate + if: always() && hashFiles('.ci-reports/pytest/raw/coverage.xml') != '' + continue-on-error: true + run: | + set -o pipefail + set +e + python .github/scripts/coverage_gate.py \ + --file "${{ github.workspace }}/${{ env.PYTEST_REPORT_DIR }}/coverage.xml" \ + --min-line "$COVERAGE_MIN_LINE" \ + --min-branch "$COVERAGE_MIN_BRANCH" \ + --enforce "$COVERAGE_ENFORCE" | tee "$REPORT_DIR/pytest/coverage_gate.log" + status=${PIPESTATUS[0]} + set -e + echo "exit_code=$status" >> "$GITHUB_OUTPUT" + exit $status + + - name: Upload CI report artifact + if: always() + uses: actions/upload-artifact@v4 + with: + name: ci-report-${{ github.run_id }} + if-no-files-found: warn + retention-days: ${{ env.ARTIFACT_RETENTION_DAYS }} + compression-level: 6 + path: .ci-reports/ + + - name: Publish CI summary (artifact + comment) + if: always() + uses: actions/github-script@v7 + with: + script: | + const fs = require('fs'); + const path = require('path'); + const marker = ''; + const baseDir = path.join(process.cwd(), '.ci-reports'); + const pytestDir = path.join(baseDir, 'pytest'); + const ruffDir = path.join(baseDir, 'ruff'); + const artifactName = `ci-report-${context.runId}`; + fs.mkdirSync(baseDir, { recursive: true }); + + const fmtOutcome = (value) => { + const mapping = { + success: '✅ 通過', + failure: '❌ 失敗', + cancelled: '⚪ 取消', + skipped: '⚪ 未執行', + }; + return mapping[value] || `⚪ ${value || '未知'}`; + }; + + const readIfExists = (file) => { + try { + return fs.readFileSync(file, 'utf8').trim(); + } catch { + return null; + } + }; + + const readTail = (file) => { + try { + const raw = fs.readFileSync(file, 'utf8').trim(); + if (!raw) return null; + const lines = raw.split(/\r?\n/); + const tail = lines.slice(-200).join('\n'); + return tail.length > 6000 ? tail.slice(-6000) : tail; + } catch { + return null; + } + }; + + const artifactPlaceholder = ''; + const coverageSummary = readIfExists(path.join(pytestDir, 'coverage.txt')) || 'N/A'; + const pytestSummaryRaw = readIfExists(path.join(pytestDir, 'summary.md')); + const pytestSummary = (() => { + if (!pytestSummaryRaw) return null; + const marker = ''; + if (pytestSummaryRaw.includes(marker)) { + return pytestSummaryRaw.split(marker).pop().trim(); + } + return pytestSummaryRaw; + })(); + const lintTail = readTail(path.join(ruffDir, 'lint.log')); + const formatTail = readTail(path.join(ruffDir, 'format.log')); + + const ruffLintOutcome = '${{ steps.ruff_lint.outcome }}'; + const ruffFormatOutcome = '${{ steps.ruff_format.outcome }}'; + const pytestOutcome = '${{ steps.pytest.outcome }}'; + const pytestExitCode = '${{ steps.pytest.outputs.exit_code || '' }}'; + const coverageGateOutcome = '${{ steps.coverage_gate.outcome || '' }}'; + const coverageGateExit = '${{ steps.coverage_gate.outputs.exit_code || '' }}'; + const coverageGateEnforce = '${{ env.COVERAGE_ENFORCE }}'; + + const coverageGateStatus = (() => { + if (!coverageGateOutcome) { + return '⚪ 未執行'; + } + const suffix = coverageGateExit && coverageGateExit !== '0' && coverageGateEnforce === '0' + ? '(未強制)' + : ''; + return `${fmtOutcome(coverageGateOutcome)}${suffix}`; + })(); + + const pytestStatus = (() => { + let status = fmtOutcome(pytestOutcome); + if (pytestExitCode) { + status += ` (exit=${pytestExitCode})`; + } + return status; + })(); + + const kpiTable = [ + '| 項目 | 狀態 |', + '| --- | --- |', + `| Ruff Lint | ${fmtOutcome(ruffLintOutcome)} |`, + `| Ruff Format | ${fmtOutcome(ruffFormatOutcome)} |`, + `| Pytest | ${pytestStatus} |`, + `| Coverage | ${coverageSummary} |`, + `| Coverage Gate | ${coverageGateStatus} |`, + ].join('\n'); + + let artifactLink = null; + try { + const { owner, repo } = context.repo; + const { data } = await github.rest.actions.listWorkflowRunArtifacts({ + owner, + repo, + run_id: context.runId, + per_page: 100, + }); + const target = (data.artifacts || []).find((item) => item.name === artifactName); + if (target) { + artifactLink = `[${target.name}](${target.archive_download_url})`; + } + } catch (error) { + core.warning(`無法取得 artifacts: ${error.message}`); + } + + const quickLinks = []; + const runSummaryUrl = `${process.env.GITHUB_SERVER_URL}/${process.env.GITHUB_REPOSITORY}/actions/runs/${context.runId}`; + quickLinks.push(`- [Workflow 執行](${runSummaryUrl})`); + if (artifactLink) { + quickLinks.push(`- 測試 HTML:${artifactLink}(內含 \`pytest/pytest.html\`)`); + quickLinks.push(`- JUnit XML:${artifactLink}(內含 \`pytest/pytest.xml\`)`); + quickLinks.push(`- Coverage HTML:${artifactLink}(內含 \`pytest/htmlcov/index.html\`)`); + } else { + quickLinks.push('- 測試 HTML:不可用'); + quickLinks.push('- JUnit XML:不可用'); + quickLinks.push('- Coverage HTML:不可用'); + } + + const lines = [ + marker, + '## CI 概覽', + '', + kpiTable, + '', + '### 快速連結', + '', + ...quickLinks, + '', + ]; + + if (pytestSummary) { + const artifactText = artifactLink || '不可用'; + const processedSummary = pytestSummary.replace(new RegExp(artifactPlaceholder, 'g'), artifactText); + lines.push(processedSummary); + lines.push(''); + } else { + lines.push('_(Pytest 摘要不可用)_', ''); + } + + if (ruffLintOutcome === 'failure' && lintTail) { + lines.push('### Ruff Lint(最後 200 行)', '', '
展開', '', '```'); + lines.push(lintTail); + lines.push('```', '
', ''); + } + + if (ruffFormatOutcome === 'failure' && formatTail) { + lines.push('### Ruff Format(最後 200 行)', '', '
展開', '', '```'); + lines.push(formatTail); + lines.push('```', '
', ''); + } + + const body = lines.join('\n'); + const indexPath = path.join(baseDir, 'index.md'); + fs.writeFileSync(indexPath, `${body}\n`, { encoding: 'utf8' }); + + if (context.eventName !== 'pull_request') { + return; + } + + const { owner, repo } = context.repo; + const issue_number = context.payload.pull_request.number; + const { data: comments } = await github.rest.issues.listComments({ + owner, + repo, + issue_number, + per_page: 100, + }); + + const existing = comments.find((comment) => comment.body && comment.body.includes(marker)); + if (existing) { + await github.rest.issues.updateComment({ + owner, + repo, + comment_id: existing.id, + body, + }); + } else { + await github.rest.issues.createComment({ + owner, + repo, + issue_number, + body, + }); + } + + - name: Append CI summary to Job Summary + if: always() + run: | + if [[ -f "$REPORT_DIR/index.md" ]]; then + cat "$REPORT_DIR/index.md" >> "$GITHUB_STEP_SUMMARY" + else + echo "## CI 概覽" >> "$GITHUB_STEP_SUMMARY" + echo "" >> "$GITHUB_STEP_SUMMARY" + echo "_(summary 未產生)_" >> "$GITHUB_STEP_SUMMARY" + fi + + - name: Finalize CI status + if: always() + run: | + failed=0 + ruff_lint_exit="${{ steps.ruff_lint.outputs.exit_code || '0' }}" + ruff_format_exit="${{ steps.ruff_format.outputs.exit_code || '0' }}" + pyright_exit="${{ steps.pyright.outputs.exit_code || '0' }}" + pytest_exit="${{ steps.pytest.outputs.exit_code || '0' }}" + coverage_gate_exit="${{ steps.coverage_gate.outputs.exit_code || '0' }}" + + if [[ "$ruff_lint_exit" -ne 0 ]]; then + echo "::error::Ruff Lint failed" + failed=1 + fi + if [[ "$ruff_format_exit" -ne 0 ]]; then + echo "::error::Ruff Format failed" + failed=1 + fi + if [[ "$pyright_exit" -ne 0 ]]; then + echo "::error::Pyright failed" + failed=1 + fi + if [[ "$pytest_exit" -ne 0 ]]; then + echo "::error::Pytest failed" + failed=1 + fi + if [[ "$coverage_gate_exit" -ne 0 ]]; then + if [[ "$COVERAGE_ENFORCE" != "0" ]]; then + echo "::error::Coverage gate failed" + failed=1 + else + echo "::warning::Coverage gate reported failure but enforcement is disabled (COVERAGE_ENFORCE=$COVERAGE_ENFORCE)." + fi + fi + + if [[ $failed -ne 0 ]]; then + exit 1 + fi diff --git a/.github/workflows/cpu-ci.yml b/.github/workflows/deprecated/cpu-ci.yml similarity index 100% rename from .github/workflows/cpu-ci.yml rename to .github/workflows/deprecated/cpu-ci.yml diff --git a/.github/workflows/cuda-ci.yml b/.github/workflows/deprecated/cuda-ci.yml similarity index 100% rename from .github/workflows/cuda-ci.yml rename to .github/workflows/deprecated/cuda-ci.yml diff --git a/.gitignore b/.gitignore index 79f632a..b651d68 100644 --- a/.gitignore +++ b/.gitignore @@ -166,3 +166,8 @@ temp_image.jpg .DS_Store .python-version tmp.jpg +tmp* +.venv_min +.vscode +.ruff_cache +.benchmarks \ No newline at end of file diff --git a/README.md b/README.md index ceb36d1..fc7ef05 100644 --- a/README.md +++ b/README.md @@ -1,170 +1,472 @@ -[**English**](./README.md) | [中文](./README_tw.md) +**[English](./README.md)** | [Chinese](./README_tw.md) # Capybara -**An Integrated Python Package for Image Processing and Deep Learning.** -

- - + +

-## Introduction - ![title](https://raw.githubusercontent.com/DocsaidLab/Capybara/refs/heads/main/docs/title.webp) -This project is an image processing and deep learning toolkit, mainly consisting of the following parts: +--- -- **Vision**: Provides functionalities related to computer vision, such as image and video processing. -- **Structures**: Modules for handling structured data, such as BoundingBox and Polygon. -- **ONNXEngine**: Provides ONNX inference functionalities, supporting ONNX format models. -- **Utils**: Contains utility functions that do not belong to other modules. -- **Tests**: Includes test code for various functions to verify their correctness. +## Introduction -## Technical Documentation +Capybara is designed with three goals: -For more detailed information on installation and usage, please refer to the [**Capybara Documents**](https://docsaid.org/en/docs/capybara). +1. **Lightweight default install**: `pip install capybara-docsaid` installs only the core `utils/structures/vision` modules, without forcing heavy inference dependencies. +2. **Inference backends as opt-in extras**: install ONNX Runtime / OpenVINO / TorchScript only when you need them via extras. +3. **Lower risk**: enforce quality gates with ruff/pyright/pytest and target **90%** line coverage for the core codebase. -The document provides a detailed explanation of this project and answers to frequently asked questions. +What you get: -## Prerequisites +- **Image tools** (`capybara.vision`): I/O, color conversion, resize/rotate/pad/crop, and video frame extraction. +- **Geometry structures** (`capybara.structures`): `Box/Boxes`, `Polygon/Polygons`, `Keypoints`, plus helper functions like IoU. +- **Inference wrappers (optional)**: `capybara.onnxengine` / `capybara.openvinoengine` / `capybara.torchengine`. +- **Feature extras (optional)**: `visualization` (drawing tools), `ipcam` (simple web demo), `system` (system info tools). +- **Utilities** (`capybara.utils`): `PowerDict`, `Timer`, `make_batch`, `download_from_google`, and other common helpers. -Before the installation of Capybara, ensure that your system meets the following requirements: +## Quick Start -### Python Version +### Install and verify -3.10+ +```bash +pip install capybara-docsaid +python -c "import capybara; print(capybara.__version__)" +``` -### Dependency Packages +## Documentation -Please install the necessary system packages according to your operating system: +To learn more about installation and usage, see [**Capybara Documents**](https://docsaid.org/docs/capybara). -#### Ubuntu +The documentation includes detailed guides and common FAQs for this project. + +## Installation + +### Core install (lightweight) ```bash -sudo apt install libturbojpeg exiftool ffmpeg libheif-dev poppler-utils +pip install capybara-docsaid ``` -##### GPU Dependencies +### Enable inference backends (optional) + +```bash +# ONNX Runtime (CPU) +pip install "capybara-docsaid[onnxruntime]" + +# ONNX Runtime (GPU) +pip install "capybara-docsaid[onnxruntime-gpu]" + +# OpenVINO runtime +pip install "capybara-docsaid[openvino]" -To use ONNX Runtime with GPU acceleration, ensure that you install a compatible version, which can be found on the official ONNX Runtime CUDA Execution Provider requirements page. +# TorchScript runtime +pip install "capybara-docsaid[torchscript]" + +# Install everything +pip install "capybara-docsaid[all]" +``` -Here's an example to install cuda-12.8: +### Feature extras (optional) ```bash -wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb -sudo dpkg -i cuda-keyring_1.1-1_all.deb -sudo apt-get update -sudo apt-get -y install cuda-toolkit-12-8 -# Post installation, add cuda path to .bashrc or .zshrc -export shellrc="~/.zshrc" -echo 'export PATH=/usr/local/cuda-12.8/bin${PATH:+:${PATH}}' >> $shellrc -echo 'export LD_LIBRARY_PATH=/usr/local/cuda-12.8/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}' >> $shellrc +# Visualization (matplotlib/pillow) +pip install "capybara-docsaid[visualization]" + +# IPCam app (flask) +pip install "capybara-docsaid[ipcam]" + +# System info (psutil) +pip install "capybara-docsaid[system]" ``` -For more details, please see [Nvidia CUDA](https://developer.nvidia.com/cuda-toolkit). +### Combine multiple extras -#### MacOS +If you want OpenVINO inference and the IPCam features, install: ```bash -brew install jpeg-turbo exiftool ffmpeg libheif poppler +# OpenVINO + IPCam +pip install "capybara-docsaid[openvino,ipcam]" ``` -## Installation +### Install from Git + +```bash +pip install git+https://github.com/DocsaidLab/Capybara.git +``` + +## System Dependencies (Install as needed) -### PyPI +Some features require OS-level codecs / image I/O / PDF tools (install as needed): + +- `PyTurboJPEG` (faster JPEG I/O): requires the TurboJPEG library. +- `pillow-heif` (HEIC/HEIF support): requires libheif. +- `pdf2image` (PDF to images): requires Poppler. +- Video frame extraction: installing `ffmpeg` is recommended (more stable OpenCV video decoding). + +### Ubuntu ```bash -pip install capybara-docsaid +sudo apt install ffmpeg libturbojpeg libheif-dev poppler-utils ``` -### Git +### macOS ```bash -pip install git+https://github.com/DocsaidLab/Capybara.git +brew install jpeg-turbo ffmpeg libheif poppler +``` + +### GPU Notes (ONNX Runtime CUDA) + +If you're using `onnxruntime-gpu`, install the compatible CUDA/cuDNN version for your ORT version: + +- See [**the ONNX Runtime documentation**](https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#requirements) + +## Usage + +### Image data conventions + +- Capybara images are represented as `numpy.ndarray`. By default, they follow OpenCV conventions: **BGR**, and shape is typically `(H, W, 3)`. +- If you prefer working in RGB, use `imread(..., color_base="RGB")` or convert with `imcvtcolor(img, "BGR2RGB")`. + +### Image I/O + +```python +from capybara import imread, imwrite + +img = imread("your_image.jpg") +if img is None: + raise RuntimeError("Failed to read image.") + +imwrite(img, "out.jpg") +``` + +Notes: + +- `imread` returns `None` when it fails to decode an image (if the path doesn't exist, it raises `FileExistsError`). +- `imread` also supports `.heic` (requires `pillow-heif` + OS-level libheif). + +### Resize / pad + +With `imresize`, you can pass `None` in `size` to keep the aspect ratio and have the other dimension inferred automatically. + +```python +import numpy as np +from capybara import BORDER, imresize, pad + +img = np.zeros((480, 640, 3), dtype=np.uint8) +img = imresize(img, (320, None)) # (height, width) +img = pad(img, pad_size=(8, 8), pad_mode=BORDER.REPLICATE) +``` + +### Color conversion + +```python +import numpy as np +from capybara import imcvtcolor + +img = np.zeros((240, 320, 3), dtype=np.uint8) # BGR +gray = imcvtcolor(img, "BGR2GRAY") # grayscale +rgb = imcvtcolor(img, "BGR2RGB") # RGB +``` + +### Rotation / perspective correction + +```python +import numpy as np +from capybara import Polygon, imrotate, imwarp_quadrangle + +img = np.zeros((240, 320, 3), dtype=np.uint8) +rot = imrotate(img, angle=15, expand=True) # Angle definition matches OpenCV: positive values rotate counterclockwise + +poly = Polygon([[10, 10], [200, 20], [190, 120], [20, 110]]) +patch = imwarp_quadrangle(img, poly) # 4-point perspective warp +``` + +### Cropping (Box / Boxes) + +```python +import numpy as np +from capybara import Box, Boxes, imcropbox, imcropboxes + +img = np.zeros((240, 320, 3), dtype=np.uint8) +crop1 = imcropbox(img, Box([10, 20, 110, 120]), use_pad=True) +crop_list = imcropboxes( + img, + Boxes([[0, 0, 10, 10], [100, 100, 400, 300]]), + use_pad=True, +) +``` + +### Binarization + morphology + +Morphology operators live in `capybara.vision.morphology` (not in the top-level `capybara` namespace). + +```python +import numpy as np +from capybara import imbinarize +from capybara.vision.morphology import imopen + +img = np.zeros((240, 320, 3), dtype=np.uint8) +mask = imbinarize(img) # OTSU + binary +mask = imopen(mask, ksize=3) # Opening to remove small noise +``` + +### Boxes / IoU + +```python +import numpy as np +from capybara import Box, Boxes, pairwise_iou + +boxes_a = Boxes([[10, 10, 20, 20], [30, 30, 60, 60]]) +boxes_b = Boxes(np.array([[12, 12, 18, 18]], dtype=np.float32)) +print(pairwise_iou(boxes_a, boxes_b)) + +box = Box([0.1, 0.2, 0.9, 0.8], is_normalized=True).convert("XYWH") +print(box.numpy()) +``` + +### Polygons / IoU + +```python +from capybara import Polygon, polygon_iou + +p1 = Polygon([[0, 0], [10, 0], [10, 10], [0, 10]]) +p2 = Polygon([[5, 5], [15, 5], [15, 15], [5, 15]]) +print(polygon_iou(p1, p2)) +``` + +### Base64 (image / ndarray) + +```python +import numpy as np +from capybara import img_to_b64str, npy_to_b64str +from capybara.vision.improc import b64str_to_img, b64str_to_npy + +img = np.zeros((32, 32, 3), dtype=np.uint8) +b64_img = img_to_b64str(img) # JPEG bytes -> base64 string +if b64_img is None: + raise RuntimeError("Failed to encode image into base64.") +img2 = b64str_to_img(b64_img) # base64 string -> numpy image + +vec = np.arange(8, dtype=np.float32) +b64_vec = npy_to_b64str(vec) +vec2 = b64str_to_npy(b64_vec, dtype="float32") +``` + +### PDF to images + +```python +from capybara.vision.improc import pdf2imgs + +pages = pdf2imgs("file.pdf") # list[np.ndarray], each page is BGR image +if pages is None: + raise RuntimeError("Failed to decode PDF.") +print(len(pages)) +``` + +### Visualization (optional) + +Install first: `pip install "capybara-docsaid[visualization]"`. + +```python +import numpy as np +from capybara import Box +from capybara.vision.visualization.draw import draw_box + +img = np.zeros((240, 320, 3), dtype=np.uint8) +img = draw_box(img, Box([10, 20, 100, 120])) +``` + +### IPCam (optional) + +`IpcamCapture` itself does not depend on Flask; you only need the `ipcam` extra to use `WebDemo`. + +```python +from capybara.vision.ipcam.camera import IpcamCapture + +cap = IpcamCapture(url=0, color_base="BGR") # or provide an RTSP/HTTP URL +frame = next(cap) +``` + +Web demo (install first: `pip install "capybara-docsaid[ipcam]"`): + +```python +from capybara.vision.ipcam.app import WebDemo + +WebDemo("rtsp://").run(port=5001) +``` + +### System info (optional) + +Install first: `pip install "capybara-docsaid[system]"`. + +```python +from capybara.utils.system_info import get_system_info + +print(get_system_info()) ``` -## Docker for Deployment +### Video frame extraction -We provide a Docker script for convenient deployment, ensuring a consistent environment. Below are the steps to build the image with Capybara installed. +```python +from capybara import video2frames_v2 -1. Clone this repository: +frames = video2frames_v2("demo.mp4", frame_per_sec=2, max_size=1280) +print(len(frames)) +``` - ```bash - git clone https://github.com/DocsaidLab/Capybara.git - ``` +## Inference Backends -2. Enter the project directory and run the build script: +Inference backends are optional; install the corresponding extras before importing the relevant engine modules. - ```bash - cd Capybara - bash docker/build.bash - ``` +### Runtime / backend matrix - This will build an image using the [**Dockerfile**](docker/Dockerfile) in the project. The image is based on `nvidia/cuda:12.8.1-cudnn-runtime-ubuntu24.04` by default, providing the CUDA environment required for ONNXRuntime inference. +Note: TorchScript runtime is named `Runtime.pt` in code (corresponding extra: `torchscript`). -3. After the build is complete, mount the working directory and run the program: +| Runtime (`capybara.runtime.Runtime`) | Backend name | Provider / device | +| ------------------------------------ | --------------- | ----------------------------------------------------------------------------------------------------------- | +| `onnx` | `cpu` | `["CPUExecutionProvider"]` | +| `onnx` | `cuda` | `["CUDAExecutionProvider"(device_id), "CPUExecutionProvider"]` | +| `onnx` | `tensorrt` | `["TensorrtExecutionProvider"(device_id), "CUDAExecutionProvider"(device_id), "CPUExecutionProvider"]` | +| `onnx` | `tensorrt_rtx` | `["NvTensorRTRTXExecutionProvider"(device_id), "CUDAExecutionProvider"(device_id), "CPUExecutionProvider"]` | +| `openvino` | `cpu` | `device="CPU"` | +| `openvino` | `gpu` | `device="GPU"` | +| `openvino` | `npu` | `device="NPU"` | +| `pt` | `cpu` | `torch.device("cpu")` | +| `pt` | `cuda` | `torch.device("cuda")` | - ```bash - docker run --gpus all -it --rm capybara_docsaid:latest bash - ``` +### Runtime registry (auto backend selection) -**PS: If you want to compile cuda or cudnn for developing, please change the base image to `nvidia/cuda:12.8.1-cudnn-devel-ubuntu24.04`.** +```python +from capybara.runtime import Runtime -### gosu Permissions Issues +print(Runtime.onnx.auto_backend_name()) # Priority: cuda -> tensorrt_rtx -> tensorrt -> cpu +print(Runtime.openvino.auto_backend_name()) # Priority: gpu -> npu -> cpu +print(Runtime.pt.auto_backend_name()) # Priority: cuda -> cpu +``` -If you encounter issues with file ownership as root when running scripts inside the container, causing permission problems, you can use `gosu` to switch users in the Dockerfile. Specify `USER_ID` and `GROUP_ID` when starting the container to avoid frequent permission adjustments in collaborative development. +### ONNX Runtime (`capybara.onnxengine`) -For details, refer to the technical documentation: [**Integrating gosu Configuration**](https://docsaid.org/en/docs/capybara/advance/#integrating-gosu-configuration) +```python +import numpy as np +from capybara.onnxengine import EngineConfig, ONNXEngine -1. Install `gosu`: +engine = ONNXEngine( + "model.onnx", + backend="cpu", + config=EngineConfig(enable_io_binding=False), +) +outputs = engine.run({"input": np.ones((1, 3, 224, 224), dtype=np.float32)}) +print(outputs.keys()) +print(engine.summary()) +``` + +### OpenVINO (`capybara.openvinoengine`) - ```dockerfile - RUN apt-get update && apt-get install -y gosu - ``` +```python +import numpy as np +from capybara.openvinoengine import OpenVINOConfig, OpenVINODevice, OpenVINOEngine + +engine = OpenVINOEngine( + "model.xml", + device=OpenVINODevice.cpu, + config=OpenVINOConfig(num_requests=2), +) +outputs = engine.run({"input": np.ones((1, 3), dtype=np.float32)}) +print(outputs.keys()) +``` -2. Use `gosu` in the container start command to switch to a non-root user for file read/write operations. +### TorchScript (`capybara.torchengine`) + +```python +import numpy as np +from capybara.torchengine import TorchEngine + +engine = TorchEngine("model.pt", device="cpu") +outputs = engine.run({"image": np.zeros((1, 3, 224, 224), dtype=np.float32)}) +print(outputs.keys()) +``` + +### Benchmark (depends on hardware) + +All engines provide `benchmark(...)` for quick throughput/latency measurements. + +```python +import numpy as np +from capybara.onnxengine import ONNXEngine + +engine = ONNXEngine("model.onnx", backend="cpu") +dummy = np.zeros((1, 3, 224, 224), dtype=np.float32) +print(engine.benchmark({"input": dummy}, repeat=50, warmup=5)) +``` - ```dockerfile - # Create the entrypoint script - RUN printf '#!/bin/bash\n\ - if [ ! -z "$USER_ID" ] && [ ! -z "$GROUP_ID" ]; then\n\ - groupadd -g "$GROUP_ID" -o usergroup\n\ - useradd --shell /bin/bash -u "$USER_ID" -g "$GROUP_ID" -o -c "" -m user\n\ - export HOME=/home/user\n\ - chown -R "$USER_ID":"$GROUP_ID" /home/user\n\ - chown -R "$USER_ID":"$GROUP_ID" /code\n\ - fi\n\ - \n\ - # Check for parameters\n\ - if [ $# -gt 0 ]; then\n\ - exec gosu ${USER_ID:-0}:${GROUP_ID:-0} python "$@"\n\ - else\n\ - exec gosu ${USER_ID:-0}:${GROUP_ID:-0} bash\n\ - fi' > "$ENTRYPOINT_SCRIPT" +### Advanced: Custom options (optional) - RUN chmod +x "$ENTRYPOINT_SCRIPT" +`EngineConfig` / `OpenVINOConfig` / `TorchEngineConfig` are passed through to the underlying runtime as-is. - ENTRYPOINT ["/bin/bash", "/entrypoint.sh"] - ``` +```python +from capybara.onnxengine import EngineConfig, ONNXEngine -For more advanced configuration, refer to [**NVIDIA Container Toolkit**](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) and the official [**docker**](https://docs.docker.com/) documentation. +engine = ONNXEngine( + "model.onnx", + backend="cuda", + config=EngineConfig( + provider_options={ + "CUDAExecutionProvider": { + "enable_cuda_graph": True, + }, + }, + ), +) +``` + +## Quality Gates (Contributors) + +Before merging, this project requires: + +```bash +ruff check . +ruff format --check . +pyright +python -m pytest --cov=capybara --cov-config=.coveragerc --cov-report=term +``` + +Notes: + +- Coverage gate is **90% line coverage** (rules defined in `.coveragerc`). +- Heavy / environment-dependent modules are excluded from the default coverage gate to keep CI reproducible and maintainable. + +## Docker (optional) + +```bash +git clone https://github.com/DocsaidLab/Capybara.git +cd Capybara +bash docker/build.bash +``` + +Run: + +```bash +docker run --rm -it capybara_docsaid bash +``` -## Testing +If you need GPU access inside the container, use the NVIDIA container runtime (e.g. `--gpus all`). -This project uses `pytest` for unit testing, and users can run the tests themselves to verify the correctness of the functionalities. To install and run the tests, use the following commands: +## Testing (local) ```bash -pip install pytest -python -m pytest -vv tests +python -m pytest -vv ``` -Once completed, you can check if all modules are functioning properly. If any issues arise, first check the environment settings and package versions. +## License -If the problem persists, please report it in the Issue section. +Apache-2.0, see `LICENSE`. ## Citation @@ -174,7 +476,7 @@ If the problem persists, please report it in the Issue section. title = {Capybara: An Integrated Python Package for Image Processing and Deep Learning.}, year = {2025}, publisher = {GitHub}, - howpublished = {\url{https://github.com/DocsaidLab/Capybara}}, + howpublished = {\\url{https://github.com/DocsaidLab/Capybara}}, note = {* equal contribution} } ``` diff --git a/README_tw.md b/README_tw.md index c5cf4cd..9e56622 100644 --- a/README_tw.md +++ b/README_tw.md @@ -6,21 +6,38 @@ - - + +

+![title](https://raw.githubusercontent.com/DocsaidLab/Capybara/refs/heads/main/docs/title.webp) + +--- + ## 介紹 -![title](https://raw.githubusercontent.com/DocsaidLab/Capybara/refs/heads/main/docs/title.webp) +Capybara 的設計目標聚焦三個方向: -本專案是一個影像處理與深度學習的工具箱,主要包括以下幾個部分: +1. **預設安裝輕量化**:`pip install capybara-docsaid` 僅安裝核心 utils/structures/vision,不強迫安裝重型推論依賴。 +2. **推論後端改為 opt-in extras**:需要 ONNX Runtime / OpenVINO / TorchScript 時再用 extras 安裝。 +3. **降低風險**:導入 ruff/pyright/pytest 品質門檻,並以核心程式碼 **90%** 行覆蓋率為維護目標。 -- **Vision**:提供與電腦視覺相關的功能,例如圖像和影片處理。 -- **Structures**:用於處理結構化數據的模組,例如 BoundingBox 和 Polygon。 -- **ONNXEngine**:提供 ONNX 推理功能,支援 ONNX 格式的模型。 -- **Utils**:放置無法歸類到其他模組的工具函式。 -- **Tests**:包含各類功能的測試程式碼,用於驗證函式的正確性。 +你會得到: + +- **影像工具**(`capybara.vision`):讀寫、轉色、縮放/旋轉/補邊/裁切,以及影片抽幀工具。 +- **幾何結構**(`capybara.structures`):`Box/Boxes`、`Polygon/Polygons`、`Keypoints`,以及 IoU 等輔助函數。 +- **推論封裝(可選)**:`capybara.onnxengine` / `capybara.openvinoengine` / `capybara.torchengine`。 +- **功能 extras(可選)**:`visualization`(繪圖工具)、`ipcam`(簡易 Web demo)、`system`(系統資訊工具)。 +- **小工具**(`capybara.utils`):`PowerDict`、`Timer`、`make_batch`、`download_from_google` 等常用 helper。 + +## 快速開始 + +### 安裝與驗證 + +```bash +pip install capybara-docsaid +python -c "import capybara; print(capybara.__version__)" +``` ## 技術文件 @@ -30,185 +47,436 @@ ## 安裝 -在開始安裝 Capybara 之前,請先確保系統符合以下需求: +### 核心安裝(輕量) -### Python 版本 +```bash +pip install capybara-docsaid +``` -- 需要 Python 3.10 或以上版本。 +### 啟用推論後端(可選) -### 依賴套件 +```bash +# ONNXRuntime(CPU) +pip install "capybara-docsaid[onnxruntime]" -請依照作業系統,安裝下列必要的系統套件: +# ONNXRuntime(GPU) +pip install "capybara-docsaid[onnxruntime-gpu]" -- **Ubuntu** +# OpenVINO runtime +pip install "capybara-docsaid[openvino]" - ```bash - sudo apt install libturbojpeg exiftool ffmpeg libheif-dev - ``` +# TorchScript runtime +pip install "capybara-docsaid[torchscript]" -- **MacOS** +# 全部一起裝 +pip install "capybara-docsaid[all]" +``` - ```bash - brew install jpeg-turbo exiftool ffmpeg - ``` +### 選用功能 extras(可選) - - **特別注意**:經過測試,在 macOS 上使用 libheif 時,存在一些已知問題,主要包括: +```bash +# 視覺化(matplotlib/pillow) +pip install "capybara-docsaid[visualization]" - 1. **生成的 HEIC 檔案無法打開**:在 macOS 上,libheif 生成的 HEIC 檔案可能無法被某些程式打開。這可能與圖像尺寸有關,特別是當圖像的寬度或高度為奇數時,可能會導致相容性問題。 +# IPCam app(flask) +pip install "capybara-docsaid[ipcam]" - 2. **編譯錯誤**:在 macOS 上編譯 libheif 時,可能會遇到與 ffmpeg 解碼器相關的未定義符號錯誤。這可能是由於編譯選項或相依性設定不正確所致。 +# 系統資訊(psutil) +pip install "capybara-docsaid[system]" +``` - 3. **範例程式無法執行**:在 macOS Sonoma 上,libheif 的範例程式可能無法正常運行,出現動態鏈接錯誤,提示找不到 `libheif.1.dylib`,這可能與動態庫的路徑設定有關。 +### 挑選多個功能 - 由於問題不少,因此我們目前只在 Ubuntu 才會運行 libheif,至於 macOS 的部分則留給未來的版本。 +假設你想使用 openvino 推論,並搭配 ipcam 相關的功能,可以這樣安裝: -### pdf2image 依賴套件 +```bash +# 選用 OpenVINO 和 IPCam +pip install "capybara-docsaid[openvino,ipcam]" +``` -pdf2image 是用於將 PDF 文件轉換成影像的 Python 模組,請確保系統已安裝下列工具: +### 從 Git 安裝 -- MacOS:需要安裝 poppler +```bash +pip install git+https://github.com/DocsaidLab/Capybara.git +``` + +## 系統相依套件(依功能需求安裝) - ```bash - brew install poppler - ``` +有些功能需要 OS 層級的 codec / image IO / PDF 工具(依功能需求安裝): -- Linux:大多數發行版已內建 `pdftoppm` 與 `pdftocairo`。如未安裝,請執行: +- `PyTurboJPEG`(JPEG 讀寫加速):需要 TurboJPEG library。 +- `pillow-heif`(HEIC/HEIF 支援):需要 libheif。 +- `pdf2image`(PDF 轉圖):需要 Poppler。 +- 影片抽幀:建議安裝 `ffmpeg`(讓 OpenCV 影片讀取更穩定)。 - ```bash - sudo apt install poppler-utils - ``` +### Ubuntu -### ONNXRuntime GPU 依賴 +```bash +sudo apt install ffmpeg libturbojpeg libheif-dev poppler-utils +``` -若需使用 ONNXRuntime 進行 GPU 加速推理,請確保已安裝相容版本的 CUDA,如下示範: +### macOS ```bash -sudo apt install cuda-12-4 -# 假設要加入至 .bashrc -echo 'export PATH=/usr/local/cuda-12.4/bin${PATH:+:${PATH}}' >> ~/.bashrc -echo 'export LD_LIBRARY_PATH=/usr/local/cuda-12.4/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}' >> ~/.bashrc +brew install jpeg-turbo ffmpeg libheif poppler +``` + +### GPU 注意事項(ONNXRuntime CUDA) + +若使用 `onnxruntime-gpu`,請依 ORT 的版本安裝相容的 CUDA/cuDNN: + +- 請參考 [**onnxruntime 官方網站**](https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#requirements) + +## 使用方式 + +### 影像資料格式約定 + +- Capybara 的影像以 `numpy.ndarray` 表示,預設遵循 OpenCV 慣例:**BGR**、shape 通常為 `(H, W, 3)`。 +- 若你希望以 RGB 工作,可用 `imread(..., color_base="RGB")` 或 `imcvtcolor(img, "BGR2RGB")` 轉換。 + +### 影像 I/O + +```python +from capybara import imread, imwrite + +img = imread("your_image.jpg") +if img is None: + raise RuntimeError("Failed to read image.") + +imwrite(img, "out.jpg") +``` + +補充: + +- `imread` 讀不到圖時會回傳 `None`(路徑不存在則會直接丟 `FileExistsError`)。 +- `imread` 也支援 `.heic`(需 `pillow-heif` + OS 層級 libheif)。 + +### Resize / pad + +`imresize` 支援在 `size` 中用 `None` 表示「維持長寬比自動推算另一邊」。 + +```python +import numpy as np +from capybara import BORDER, imresize, pad + +img = np.zeros((480, 640, 3), dtype=np.uint8) +img = imresize(img, (320, None)) # (height, width) +img = pad(img, pad_size=(8, 8), pad_mode=BORDER.REPLICATE) +``` + +### 轉色(Color Conversion) + +```python +import numpy as np +from capybara import imcvtcolor + +img = np.zeros((240, 320, 3), dtype=np.uint8) # BGR +gray = imcvtcolor(img, "BGR2GRAY") # grayscale +rgb = imcvtcolor(img, "BGR2RGB") # RGB +``` + +### 旋轉 / 透視校正 + +```python +import numpy as np +from capybara import Polygon, imrotate, imwarp_quadrangle + +img = np.zeros((240, 320, 3), dtype=np.uint8) +rot = imrotate(img, angle=15, expand=True) # 角度定義與 OpenCV 相同:正值為逆時針 + +poly = Polygon([[10, 10], [200, 20], [190, 120], [20, 110]]) +patch = imwarp_quadrangle(img, poly) # 4 點透視校正 +``` + +### 裁切(Box / Boxes) + +```python +import numpy as np +from capybara import Box, Boxes, imcropbox, imcropboxes + +img = np.zeros((240, 320, 3), dtype=np.uint8) +crop1 = imcropbox(img, Box([10, 20, 110, 120]), use_pad=True) +crop_list = imcropboxes( + img, + Boxes([[0, 0, 10, 10], [100, 100, 400, 300]]), + use_pad=True, +) +``` + +### 二值化 + 形態學(Morphology) + +形態學操作位於 `capybara.vision.morphology`(不在頂層 `capybara` namespace)。 + +```python +import numpy as np +from capybara import imbinarize +from capybara.vision.morphology import imopen + +img = np.zeros((240, 320, 3), dtype=np.uint8) +mask = imbinarize(img) # OTSU + binary +mask = imopen(mask, ksize=3) # 開運算去除雜點 ``` -### 透過 PyPI 安裝 +### Boxes / IoU -1. 透過 PyPI 安裝套件: +```python +import numpy as np +from capybara import Box, Boxes, pairwise_iou - ```bash - pip install capybara-docsaid - ``` +boxes_a = Boxes([[10, 10, 20, 20], [30, 30, 60, 60]]) +boxes_b = Boxes(np.array([[12, 12, 18, 18]], dtype=np.float32)) +print(pairwise_iou(boxes_a, boxes_b)) -2. 驗證安裝: +box = Box([0.1, 0.2, 0.9, 0.8], is_normalized=True).convert("XYWH") +print(box.numpy()) +``` - ```bash - python -c "import capybara; print(capybara.__version__)" - ``` +### Polygons(多邊形)/ IoU -3. 若顯示版本號,則安裝成功。 +```python +from capybara import Polygon, polygon_iou -### 透過 git clone 安裝 +p1 = Polygon([[0, 0], [10, 0], [10, 10], [0, 10]]) +p2 = Polygon([[5, 5], [15, 5], [15, 15], [5, 15]]) +print(polygon_iou(p1, p2)) +``` -1. 下載本專案: +### Base64(影像 / ndarray) - ```bash - git clone https://github.com/DocsaidLab/Capybara.git - ``` +```python +import numpy as np +from capybara import img_to_b64str, npy_to_b64str +from capybara.vision.improc import b64str_to_img, b64str_to_npy -2. 安裝 wheel 套件: +img = np.zeros((32, 32, 3), dtype=np.uint8) +b64_img = img_to_b64str(img) # JPEG bytes -> base64 string +if b64_img is None: + raise RuntimeError("Failed to encode image into base64.") +img2 = b64str_to_img(b64_img) # base64 string -> numpy image - ```bash - pip install wheel - ``` +vec = np.arange(8, dtype=np.float32) +b64_vec = npy_to_b64str(vec) +vec2 = b64str_to_npy(b64_vec, dtype="float32") +``` -3. 建構 wheel 檔案: +### PDF 轉影像 - ```bash - cd Capybara - python setup.py bdist_wheel - ``` +```python +from capybara.vision.improc import pdf2imgs -4. 安裝建置完成的 wheel 檔: +pages = pdf2imgs("file.pdf") # list[np.ndarray], each page is BGR image +if pages is None: + raise RuntimeError("Failed to decode PDF.") +print(len(pages)) +``` - ```bash - pip install dist/capybara_docsaid-*-py3-none-any.whl - ``` +### 視覺化(可選) -### 透過 docker 安裝(建議) +需要先安裝:`pip install "capybara-docsaid[visualization]"`。 -若想在部署或協同開發時避免環境衝突,建議使用 Docker,以下為簡要示範流程: +```python +import numpy as np +from capybara import Box +from capybara.vision.visualization.draw import draw_box -1. 下載本專案: +img = np.zeros((240, 320, 3), dtype=np.uint8) +img = draw_box(img, Box([10, 20, 100, 120])) +``` - ```bash - git clone https://github.com/DocsaidLab/Capybara.git - ``` +### IPCam(可選) -2. 進入專案資料夾,執行建置腳本: +`IpcamCapture` 本身不依賴 Flask;若要使用 `WebDemo` 才需要安裝 `ipcam` extra。 - ```bash - cd Capybara - bash docker/build.bash - ``` +```python +from capybara.vision.ipcam.camera import IpcamCapture - 這會使用專案中的 [**Dockerfile**](https://github.com/DocsaidLab/Capybara/blob/main/docker/Dockerfile) 來建立映像檔;映像檔預設以 `nvcr.io/nvidia/cuda:12.4.1-cudnn-runtime-ubuntu22.04` 為基底,提供 ONNXRuntime 推理所需的 CUDA 環境。 +cap = IpcamCapture(url=0, color_base="BGR") # 或填入 RTSP/HTTP URL +frame = next(cap) +``` -3. 建置完成後,使用指令掛載工作目錄並執行程式: +Web demo(需要先安裝:`pip install "capybara-docsaid[ipcam]"`): - ```bash - docker run -v ${PWD}:/code -it capybara_infer_image your_scripts.py - ``` +```python +from capybara.vision.ipcam.app import WebDemo - 若需 GPU 加速,可於執行時加入 `--gpus all`。 +WebDemo("rtsp://").run(port=5001) +``` + +### 系統資訊(可選) -#### gosu 權限問題 +需要先安裝:`pip install "capybara-docsaid[system]"`。 -若在容器內執行腳本時,遇到輸出檔案歸屬為 root,導致檔案權限不便的情況,可在 Dockerfile 中加入 `gosu` 進行使用者切換,並在容器啟動時指定 `USER_ID` 與 `GROUP_ID`。 -這樣可避免在多位開發者協作時,需要頻繁調整檔案權限的問題。 +```python +from capybara.utils.system_info import get_system_info + +print(get_system_info()) +``` -具體作法可參考技術文件:[**Integrating gosu Configuration**](https://docsaid.org/docs/capybara/advance/#integrating-gosu-configuration) +### 影片抽幀 + +```python +from capybara import video2frames_v2 + +frames = video2frames_v2("demo.mp4", frame_per_sec=2, max_size=1280) +print(len(frames)) +``` -1. 安裝 `gosu`: +## 推論後端(Inference Backends) - ```dockerfile - RUN apt-get update && apt-get install -y gosu - ``` +推論後端為可選功能;請先用 extras 安裝後再 import 對應 engine 模組。 -2. 在容器啟動指令中使用 `gosu` 切換至容器內的非 root 帳號,以利檔案的讀寫。 +### Runtime / Backend 搭配表 - ```dockerfile - # Create the entrypoint script - RUN printf '#!/bin/bash\n\ - if [ ! -z "$USER_ID" ] && [ ! -z "$GROUP_ID" ]; then\n\ - groupadd -g "$GROUP_ID" -o usergroup\n\ - useradd --shell /bin/bash -u "$USER_ID" -g "$GROUP_ID" -o -c "" -m user\n\ - export HOME=/home/user\n\ - chown -R "$USER_ID":"$GROUP_ID" /home/user\n\ - chown -R "$USER_ID":"$GROUP_ID" /code\n\ - fi\n\ - \n\ - # Check for parameters\n\ - if [ $# -gt 0 ]; then\n\ - exec gosu ${USER_ID:-0}:${GROUP_ID:-0} python "$@"\n\ - else\n\ - exec gosu ${USER_ID:-0}:${GROUP_ID:-0} bash\n\ - fi' > "$ENTRYPOINT_SCRIPT" +注意:TorchScript runtime 在程式內以 `Runtime.pt` 命名(對應安裝 extra:`torchscript`)。 - RUN chmod +x "$ENTRYPOINT_SCRIPT" +| Runtime (`capybara.runtime.Runtime`) | Backend 名稱 | Provider / device | +| ------------------------------------ | -------------- | ----------------------------------------------------------------------------------------------------------- | +| `onnx` | `cpu` | `["CPUExecutionProvider"]` | +| `onnx` | `cuda` | `["CUDAExecutionProvider"(device_id), "CPUExecutionProvider"]` | +| `onnx` | `tensorrt` | `["TensorrtExecutionProvider"(device_id), "CUDAExecutionProvider"(device_id), "CPUExecutionProvider"]` | +| `onnx` | `tensorrt_rtx` | `["NvTensorRTRTXExecutionProvider"(device_id), "CUDAExecutionProvider"(device_id), "CPUExecutionProvider"]` | +| `openvino` | `cpu` | `device="CPU"` | +| `openvino` | `gpu` | `device="GPU"` | +| `openvino` | `npu` | `device="NPU"` | +| `pt` | `cpu` | `torch.device("cpu")` | +| `pt` | `cuda` | `torch.device("cuda")` | - ENTRYPOINT ["/bin/bash", "/entrypoint.sh"] - ``` +### Runtime registry(auto 後端選擇) -更多進階配置請參考 [**NVIDIA Container Toolkit**](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) 及 [**docker**](https://docs.docker.com/) 官方文件。 +```python +from capybara.runtime import Runtime + +print(Runtime.onnx.auto_backend_name()) # 優先順序:cuda -> tensorrt_rtx -> tensorrt -> cpu +print(Runtime.openvino.auto_backend_name()) # 優先順序:gpu -> npu -> cpu +print(Runtime.pt.auto_backend_name()) # 優先順序:cuda -> cpu +``` + +### ONNX Runtime(`capybara.onnxengine`) + +```python +import numpy as np +from capybara.onnxengine import EngineConfig, ONNXEngine + +engine = ONNXEngine( + "model.onnx", + backend="cpu", + config=EngineConfig(enable_io_binding=False), +) +outputs = engine.run({"input": np.ones((1, 3, 224, 224), dtype=np.float32)}) +print(outputs.keys()) +print(engine.summary()) +``` + +### OpenVINO(`capybara.openvinoengine`) + +```python +import numpy as np +from capybara.openvinoengine import OpenVINOConfig, OpenVINODevice, OpenVINOEngine + +engine = OpenVINOEngine( + "model.xml", + device=OpenVINODevice.cpu, + config=OpenVINOConfig(num_requests=2), +) +outputs = engine.run({"input": np.ones((1, 3), dtype=np.float32)}) +print(outputs.keys()) +``` + +### TorchScript(`capybara.torchengine`) + +```python +import numpy as np +from capybara.torchengine import TorchEngine + +engine = TorchEngine("model.pt", device="cpu") +outputs = engine.run({"image": np.zeros((1, 3, 224, 224), dtype=np.float32)}) +print(outputs.keys()) +``` -## 測試 +### Benchmark(依硬體而異) -本專案使用 `pytest` 進行單元測試,用戶可自行運行測試以驗證功能的正確性。 -安裝並執行測試的方式如下: +所有 engines 都提供 `benchmark(...)`,用於快速量測吞吐/延遲。 + +```python +import numpy as np +from capybara.onnxengine import ONNXEngine + +engine = ONNXEngine("model.onnx", backend="cpu") +dummy = np.zeros((1, 3, 224, 224), dtype=np.float32) +print(engine.benchmark({"input": dummy}, repeat=50, warmup=5)) +``` + +### 進階:自訂參數(可選) + +`EngineConfig` / `OpenVINOConfig` / `TorchEngineConfig` 會原樣傳遞到底層 runtime。 + +```python +from capybara.onnxengine import EngineConfig, ONNXEngine + +engine = ONNXEngine( + "model.onnx", + backend="cuda", + config=EngineConfig( + provider_options={ + "CUDAExecutionProvider": { + "enable_cuda_graph": True, + }, + }, + ), +) +``` + +## 品質門檻(Quality Gates / 開發者) + +本專案在合併前會強制通過: + +```bash +ruff check . +ruff format --check . +pyright +python -m pytest --cov=capybara --cov-config=.coveragerc --cov-report=term +``` + +備註: + +- 覆蓋率門檻為 **90% 覆蓋率**(規則定義於 `.coveragerc`)。 +- 重型/環境相依模組不納入預設 coverage gate,以維持 CI 可重現與可維護。 + +## Docker(可選) + +```bash +git clone https://github.com/DocsaidLab/Capybara.git +cd Capybara +bash docker/build.bash +``` + +執行: + +```bash +docker run --rm -it capybara_docsaid bash +``` + +若你需要在容器內使用 GPU,請使用 NVIDIA container runtime(例如 `--gpus all`)。 + +## 測試(本地) ```bash -pip install pytest -python -m pytest -vv tests +python -m pytest -vv ``` -完成後即可確認各模組運作是否正常。若遇到功能異常,請先檢查環境設定與套件版本。 +## 授權 -若仍無法解決,可至 Issue 區回報。 +Apache-2.0,見 `LICENSE`。 + +## 引用 + +```bibtex +@misc{lin2025capybara, + author = {Kun-Hsiang Lin*, Ze Yuan*}, + title = {Capybara: An Integrated Python Package for Image Processing and Deep Learning.}, + year = {2025}, + publisher = {GitHub}, + howpublished = {\\url{https://github.com/DocsaidLab/Capybara}}, + note = {* equal contribution} +} +``` diff --git a/capybara/__init__.py b/capybara/__init__.py index a304b45..f28e399 100644 --- a/capybara/__init__.py +++ b/capybara/__init__.py @@ -1,8 +1,100 @@ -from .enums import * -from .mixins import * -from .onnxengine import * -from .structures import * -from .utils import * -from .vision import * +from __future__ import annotations + +from .enums import BORDER, COLORSTR, FORMATSTR, IMGTYP, INTER, MORPH, ROTATE +from .mixins import ( + DataclassCopyMixin, + DataclassToJsonMixin, + EnumCheckMixin, + dict_to_jsonable, +) +from .structures.boxes import Box, Boxes, BoxMode +from .structures.functionals import ( + jaccard_index, + pairwise_ioa, + pairwise_iou, + polygon_iou, +) +from .structures.keypoints import Keypoints, KeypointsList +from .structures.polygons import ( + JOIN_STYLE, + Polygon, + Polygons, + order_points_clockwise, +) +from .utils.custom_path import get_curdir +from .utils.powerdict import PowerDict +from .utils.utils import colorstr, make_batch +from .vision.functionals import ( + gaussianblur, + imbinarize, + imcropbox, + imcropboxes, + imcvtcolor, + imresize_and_pad_if_need, + meanblur, + medianblur, + pad, +) +from .vision.geometric import ( + imresize, + imrotate, + imrotate90, + imwarp_quadrangle, + imwarp_quadrangles, +) +from .vision.improc import img_to_b64str, imread, imwrite, npy_to_b64str +from .vision.videotools.video2frames import video2frames +from .vision.videotools.video2frames_v2 import video2frames_v2 + +__all__ = [ + "BORDER", + "COLORSTR", + "FORMATSTR", + "IMGTYP", + "INTER", + "JOIN_STYLE", + "MORPH", + "ROTATE", + "Box", + "BoxMode", + "Boxes", + "DataclassCopyMixin", + "DataclassToJsonMixin", + "EnumCheckMixin", + "Keypoints", + "KeypointsList", + "Polygon", + "Polygons", + "PowerDict", + "colorstr", + "dict_to_jsonable", + "gaussianblur", + "get_curdir", + "imbinarize", + "imcropbox", + "imcropboxes", + "imcvtcolor", + "img_to_b64str", + "imread", + "imresize", + "imresize_and_pad_if_need", + "imrotate", + "imrotate90", + "imwarp_quadrangle", + "imwarp_quadrangles", + "imwrite", + "jaccard_index", + "make_batch", + "meanblur", + "medianblur", + "npy_to_b64str", + "order_points_clockwise", + "pad", + "pairwise_ioa", + "pairwise_iou", + "polygon_iou", + "video2frames", + "video2frames_v2", +] __version__ = "0.12.0" diff --git a/capybara/cpuinfo.py b/capybara/cpuinfo.py index 3c08fad..4d56b8f 100644 --- a/capybara/cpuinfo.py +++ b/capybara/cpuinfo.py @@ -21,7 +21,7 @@ Usage: >>> from cpuinfo import cpuinfo - >>> info = cpuinfo() # len(info) equals to num of cpus. + >>> info = cpuinfo() # len(info) equals to num of cpus. >>> print(list(info[0].keys())) >>> { 'processor', @@ -54,7 +54,12 @@ } """ -__all__ = ['cpuinfo'] +# ruff: noqa: N802, N815 +# +# This module is vendored from numexpr (see link above). It keeps upstream +# helper names that intentionally use non-PEP8 casing. + +__all__ = ["cpuinfo"] import inspect import os @@ -64,7 +69,9 @@ import sys import warnings -is_cpu_amd_intel = False # DEPRECATION WARNING: WILL BE REMOVED IN FUTURE RELEASE +is_cpu_amd_intel = ( + False # DEPRECATION WARNING: WILL BE REMOVED IN FUTURE RELEASE +) def getoutput(cmd, successful_status=(0,), stacklevel=1): @@ -72,9 +79,9 @@ def getoutput(cmd, successful_status=(0,), stacklevel=1): p = subprocess.Popen(cmd, stdout=subprocess.PIPE) output, _ = p.communicate() status = p.returncode - except EnvironmentError as e: + except OSError as e: warnings.warn(str(e), UserWarning, stacklevel=stacklevel) - return False, '' + return False, "" if os.WIFEXITED(status) and os.WEXITSTATUS(status) in successful_status: return True, output return False, output @@ -83,38 +90,42 @@ def getoutput(cmd, successful_status=(0,), stacklevel=1): def command_info(successful_status=(0,), stacklevel=1, **kw): info = {} for key in kw: - ok, output = getoutput(kw[key], successful_status=successful_status, - stacklevel=stacklevel + 1) + ok, output = getoutput( + kw[key], + successful_status=successful_status, + stacklevel=stacklevel + 1, + ) if ok: info[key] = output.strip() return info def command_by_line(cmd, successful_status=(0,), stacklevel=1): - ok, output = getoutput(cmd, successful_status=successful_status, - stacklevel=stacklevel + 1) + ok, output = getoutput( + cmd, successful_status=successful_status, stacklevel=stacklevel + 1 + ) if not ok: return # XXX: check - output = output.decode('ascii') + output = output.decode("ascii") for line in output.splitlines(): yield line.strip() -def key_value_from_command(cmd, sep, successful_status=(0,), - stacklevel=1): +def key_value_from_command(cmd, sep, successful_status=(0,), stacklevel=1): d = {} - for line in command_by_line(cmd, successful_status=successful_status, - stacklevel=stacklevel + 1): - l = [s.strip() for s in line.split(sep, 1)] - if len(l) == 2: - d[l[0]] = l[1] + for line in command_by_line( + cmd, successful_status=successful_status, stacklevel=stacklevel + 1 + ): + parts = [s.strip() for s in line.split(sep, 1)] + if len(parts) == 2: + d[parts[0]] = parts[1] return d -class CPUInfoBase(object): +class CPUInfoBase: """Holds CPU information and provides methods for requiring the availability of various CPU features. """ @@ -122,13 +133,13 @@ class CPUInfoBase(object): def _try_call(self, func): try: return func() - except: - pass + except Exception: + return None def __getattr__(self, name): - if not name.startswith('_'): - if hasattr(self, '_' + name): - attr = getattr(self, '_' + name) + if not name.startswith("_"): + if hasattr(self, "_" + name): + attr = getattr(self, "_" + name) if inspect.ismethod(attr): return lambda func=self._try_call, attr=attr: func(attr) else: @@ -140,14 +151,14 @@ def _getNCPUs(self): def __get_nbits(self): abits = platform.architecture()[0] - nbits = re.compile(r'(\d+)bit').search(abits).group(1) + nbits = re.compile(r"(\d+)bit").search(abits).group(1) return nbits def _is_32bit(self): - return self.__get_nbits() == '32' + return self.__get_nbits() == "32" def _is_64bit(self): - return self.__get_nbits() == '64' + return self.__get_nbits() == "64" class LinuxCPUInfo(CPUInfoBase): @@ -157,23 +168,21 @@ def __init__(self): if self.info is not None: return info = [{}] - ok, output = getoutput(['uname', '-m']) + ok, output = getoutput(["uname", "-m"]) if ok: - info[0]['uname_m'] = output.strip() + info[0]["uname_m"] = output.strip() try: - fo = open('/proc/cpuinfo') - except EnvironmentError as e: - warnings.warn(str(e), UserWarning) - else: - for line in fo: - name_value = [s.strip() for s in line.split(':', 1)] - if len(name_value) != 2: - continue - name, value = name_value - if not info or name in info[-1]: # next processor - info.append({}) - info[-1][name] = value - fo.close() + with open("/proc/cpuinfo") as fo: + for line in fo: + name_value = [s.strip() for s in line.split(":", 1)] + if len(name_value) != 2: + continue + name, value = name_value + if not info or name in info[-1]: # next processor + info.append({}) + info[-1][name] = value + except OSError as e: + warnings.warn(str(e), UserWarning, stacklevel=2) self.__class__.info = info def _not_impl(self): @@ -182,59 +191,62 @@ def _not_impl(self): # Athlon def _is_AMD(self): - return self.info[0]['vendor_id'] == 'AuthenticAMD' + return self.info[0]["vendor_id"] == "AuthenticAMD" def _is_AthlonK6_2(self): - return self._is_AMD() and self.info[0]['model'] == '2' + return self._is_AMD() and self.info[0]["model"] == "2" def _is_AthlonK6_3(self): - return self._is_AMD() and self.info[0]['model'] == '3' + return self._is_AMD() and self.info[0]["model"] == "3" def _is_AthlonK6(self): - return re.match(r'.*?AMD-K6', self.info[0]['model name']) is not None + return re.match(r".*?AMD-K6", self.info[0]["model name"]) is not None def _is_AthlonK7(self): - return re.match(r'.*?AMD-K7', self.info[0]['model name']) is not None + return re.match(r".*?AMD-K7", self.info[0]["model name"]) is not None def _is_AthlonMP(self): - return re.match(r'.*?Athlon\(tm\) MP\b', - self.info[0]['model name']) is not None + return ( + re.match(r".*?Athlon\(tm\) MP\b", self.info[0]["model name"]) + is not None + ) def _is_AMD64(self): - return self.is_AMD() and self.info[0]['family'] == '15' + return self.is_AMD() and self.info[0]["family"] == "15" def _is_Athlon64(self): - return re.match(r'.*?Athlon\(tm\) 64\b', - self.info[0]['model name']) is not None + return ( + re.match(r".*?Athlon\(tm\) 64\b", self.info[0]["model name"]) + is not None + ) def _is_AthlonHX(self): - return re.match(r'.*?Athlon HX\b', - self.info[0]['model name']) is not None + return ( + re.match(r".*?Athlon HX\b", self.info[0]["model name"]) is not None + ) def _is_Opteron(self): - return re.match(r'.*?Opteron\b', - self.info[0]['model name']) is not None + return re.match(r".*?Opteron\b", self.info[0]["model name"]) is not None def _is_Hammer(self): - return re.match(r'.*?Hammer\b', - self.info[0]['model name']) is not None + return re.match(r".*?Hammer\b", self.info[0]["model name"]) is not None # Alpha def _is_Alpha(self): - return self.info[0]['cpu'] == 'Alpha' + return self.info[0]["cpu"] == "Alpha" def _is_EV4(self): - return self.is_Alpha() and self.info[0]['cpu model'] == 'EV4' + return self.is_Alpha() and self.info[0]["cpu model"] == "EV4" def _is_EV5(self): - return self.is_Alpha() and self.info[0]['cpu model'] == 'EV5' + return self.is_Alpha() and self.info[0]["cpu model"] == "EV5" def _is_EV56(self): - return self.is_Alpha() and self.info[0]['cpu model'] == 'EV56' + return self.is_Alpha() and self.info[0]["cpu model"] == "EV56" def _is_PCA56(self): - return self.is_Alpha() and self.info[0]['cpu model'] == 'PCA56' + return self.is_Alpha() and self.info[0]["cpu model"] == "PCA56" # Intel @@ -242,94 +254,107 @@ def _is_PCA56(self): _is_i386 = _not_impl def _is_Intel(self): - return self.info[0]['vendor_id'] == 'GenuineIntel' + return self.info[0]["vendor_id"] == "GenuineIntel" def _is_i486(self): - return self.info[0]['cpu'] == 'i486' + return self.info[0]["cpu"] == "i486" def _is_i586(self): - return self.is_Intel() and self.info[0]['cpu family'] == '5' + return self.is_Intel() and self.info[0]["cpu family"] == "5" def _is_i686(self): - return self.is_Intel() and self.info[0]['cpu family'] == '6' + return self.is_Intel() and self.info[0]["cpu family"] == "6" def _is_Celeron(self): - return re.match(r'.*?Celeron', - self.info[0]['model name']) is not None + return re.match(r".*?Celeron", self.info[0]["model name"]) is not None def _is_Pentium(self): - return re.match(r'.*?Pentium', - self.info[0]['model name']) is not None + return re.match(r".*?Pentium", self.info[0]["model name"]) is not None def _is_PentiumII(self): - return re.match(r'.*?Pentium.*?II\b', - self.info[0]['model name']) is not None + return ( + re.match(r".*?Pentium.*?II\b", self.info[0]["model name"]) + is not None + ) def _is_PentiumPro(self): - return re.match(r'.*?PentiumPro\b', - self.info[0]['model name']) is not None + return ( + re.match(r".*?PentiumPro\b", self.info[0]["model name"]) is not None + ) def _is_PentiumMMX(self): - return re.match(r'.*?Pentium.*?MMX\b', - self.info[0]['model name']) is not None + return ( + re.match(r".*?Pentium.*?MMX\b", self.info[0]["model name"]) + is not None + ) def _is_PentiumIII(self): - return re.match(r'.*?Pentium.*?III\b', - self.info[0]['model name']) is not None + return ( + re.match(r".*?Pentium.*?III\b", self.info[0]["model name"]) + is not None + ) def _is_PentiumIV(self): - return re.match(r'.*?Pentium.*?(IV|4)\b', - self.info[0]['model name']) is not None + return ( + re.match(r".*?Pentium.*?(IV|4)\b", self.info[0]["model name"]) + is not None + ) def _is_PentiumM(self): - return re.match(r'.*?Pentium.*?M\b', - self.info[0]['model name']) is not None + return ( + re.match(r".*?Pentium.*?M\b", self.info[0]["model name"]) + is not None + ) def _is_Prescott(self): return self.is_PentiumIV() and self.has_sse3() def _is_Nocona(self): - return (self.is_Intel() and - self.info[0]['cpu family'] in ('6', '15') and - # two s sse3; three s ssse3 not the same thing, this is fine - (self.has_sse3() and not self.has_ssse3()) and - re.match(r'.*?\blm\b', self.info[0]['flags']) is not None) + return ( + self.is_Intel() + and self.info[0]["cpu family"] in ("6", "15") + and + # two s sse3; three s ssse3 not the same thing, this is fine + (self.has_sse3() and not self.has_ssse3()) + and re.match(r".*?\blm\b", self.info[0]["flags"]) is not None + ) def _is_Core2(self): - return (self.is_64bit() and self.is_Intel() and - re.match(r'.*?Core\(TM\)2\b', - self.info[0]['model name']) is not None) + return ( + self.is_64bit() + and self.is_Intel() + and re.match(r".*?Core\(TM\)2\b", self.info[0]["model name"]) + is not None + ) def _is_Itanium(self): - return re.match(r'.*?Itanium\b', - self.info[0]['family']) is not None + return re.match(r".*?Itanium\b", self.info[0]["family"]) is not None def _is_XEON(self): - return re.match(r'.*?XEON\b', - self.info[0]['model name'], re.IGNORECASE) is not None + return ( + re.match(r".*?XEON\b", self.info[0]["model name"], re.IGNORECASE) + is not None + ) _is_Xeon = _is_XEON # Power def _is_Power(self): - return re.match(r'.*POWER.*', - self.info[0]['cpu']) is not None + return re.match(r".*POWER.*", self.info[0]["cpu"]) is not None def _is_Power7(self): - return re.match(r'.*POWER7.*', - self.info[0]['cpu']) is not None + return re.match(r".*POWER7.*", self.info[0]["cpu"]) is not None def _is_Power8(self): - return re.match(r'.*POWER8.*', - self.info[0]['cpu']) is not None + return re.match(r".*POWER8.*", self.info[0]["cpu"]) is not None def _is_Power9(self): - return re.match(r'.*POWER9.*', - self.info[0]['cpu']) is not None + return re.match(r".*POWER9.*", self.info[0]["cpu"]) is not None def _has_Altivec(self): - return re.match(r'.*altivec\ supported.*', - self.info[0]['cpu']) is not None + return ( + re.match(r".*altivec\ supported.*", self.info[0]["cpu"]) is not None + ) # Varia @@ -340,31 +365,31 @@ def _getNCPUs(self): return len(self.info) def _has_fdiv_bug(self): - return self.info[0]['fdiv_bug'] == 'yes' + return self.info[0]["fdiv_bug"] == "yes" def _has_f00f_bug(self): - return self.info[0]['f00f_bug'] == 'yes' + return self.info[0]["f00f_bug"] == "yes" def _has_mmx(self): - return re.match(r'.*?\bmmx\b', self.info[0]['flags']) is not None + return re.match(r".*?\bmmx\b", self.info[0]["flags"]) is not None def _has_sse(self): - return re.match(r'.*?\bsse\b', self.info[0]['flags']) is not None + return re.match(r".*?\bsse\b", self.info[0]["flags"]) is not None def _has_sse2(self): - return re.match(r'.*?\bsse2\b', self.info[0]['flags']) is not None + return re.match(r".*?\bsse2\b", self.info[0]["flags"]) is not None def _has_sse3(self): - return re.match(r'.*?\bpni\b', self.info[0]['flags']) is not None + return re.match(r".*?\bpni\b", self.info[0]["flags"]) is not None def _has_ssse3(self): - return re.match(r'.*?\bssse3\b', self.info[0]['flags']) is not None + return re.match(r".*?\bssse3\b", self.info[0]["flags"]) is not None def _has_3dnow(self): - return re.match(r'.*?\b3dnow\b', self.info[0]['flags']) is not None + return re.match(r".*?\b3dnow\b", self.info[0]["flags"]) is not None def _has_3dnowext(self): - return re.match(r'.*?\b3dnowext\b', self.info[0]['flags']) is not None + return re.match(r".*?\b3dnowext\b", self.info[0]["flags"]) is not None class IRIXCPUInfo(CPUInfoBase): @@ -373,21 +398,22 @@ class IRIXCPUInfo(CPUInfoBase): def __init__(self): if self.info is not None: return - info = key_value_from_command('sysconf', sep=' ', - successful_status=(0, 1)) + info = key_value_from_command( + "sysconf", sep=" ", successful_status=(0, 1) + ) self.__class__.info = info def _not_impl(self): pass def _is_singleCPU(self): - return self.info.get('NUM_PROCESSORS') == '1' + return self.info.get("NUM_PROCESSORS") == "1" def _getNCPUs(self): - return int(self.info.get('NUM_PROCESSORS', 1)) + return int(self.info.get("NUM_PROCESSORS", 1)) def __cputype(self, n): - return self.info.get('PROCESSORS').split()[0].lower() == 'r%s' % (n) + return self.info.get("PROCESSORS").split()[0].lower() == f"r{n}" def _is_r2000(self): return self.__cputype(2000) @@ -432,16 +458,16 @@ def _is_r12000(self): return self.__cputype(12000) def _is_rorion(self): - return self.__cputype('orion') + return self.__cputype("orion") def get_ip(self): try: - return self.info.get('MACHINE') - except: - pass + return self.info.get("MACHINE") + except Exception: + return None def __machine(self, n): - return self.info.get('MACHINE').lower() == 'ip%s' % (n) + return self.info.get("MACHINE").lower() == f"ip{n}" def _is_IP19(self): return self.__machine(19) @@ -495,63 +521,81 @@ class DarwinCPUInfo(CPUInfoBase): def __init__(self): if self.info is not None: return - info = command_info(arch='arch', - machine='machine') - info['sysctl_hw'] = key_value_from_command(['sysctl', 'hw'], sep='=') + info = command_info(arch="arch", machine="machine") + info["sysctl_hw"] = key_value_from_command(["sysctl", "hw"], sep="=") self.__class__.info = info - def _not_impl(self): pass + def _not_impl(self): + pass def _getNCPUs(self): - return int(self.info['sysctl_hw'].get('hw.ncpu', 1)) + return int(self.info["sysctl_hw"].get("hw.ncpu", 1)) def _is_Power_Macintosh(self): - return self.info['sysctl_hw']['hw.machine'] == 'Power Macintosh' + return self.info["sysctl_hw"]["hw.machine"] == "Power Macintosh" def _is_i386(self): - return self.info['arch'] == 'i386' + return self.info["arch"] == "i386" def _is_ppc(self): - return self.info['arch'] == 'ppc' + return self.info["arch"] == "ppc" def __machine(self, n): - return self.info['machine'] == 'ppc%s' % n + return self.info["machine"] == f"ppc{n}" - def _is_ppc601(self): return self.__machine(601) + def _is_ppc601(self): + return self.__machine(601) - def _is_ppc602(self): return self.__machine(602) + def _is_ppc602(self): + return self.__machine(602) - def _is_ppc603(self): return self.__machine(603) + def _is_ppc603(self): + return self.__machine(603) - def _is_ppc603e(self): return self.__machine('603e') + def _is_ppc603e(self): + return self.__machine("603e") - def _is_ppc604(self): return self.__machine(604) + def _is_ppc604(self): + return self.__machine(604) - def _is_ppc604e(self): return self.__machine('604e') + def _is_ppc604e(self): + return self.__machine("604e") - def _is_ppc620(self): return self.__machine(620) + def _is_ppc620(self): + return self.__machine(620) - def _is_ppc630(self): return self.__machine(630) + def _is_ppc630(self): + return self.__machine(630) - def _is_ppc740(self): return self.__machine(740) + def _is_ppc740(self): + return self.__machine(740) - def _is_ppc7400(self): return self.__machine(7400) + def _is_ppc7400(self): + return self.__machine(7400) - def _is_ppc7450(self): return self.__machine(7450) + def _is_ppc7450(self): + return self.__machine(7450) - def _is_ppc750(self): return self.__machine(750) + def _is_ppc750(self): + return self.__machine(750) - def _is_ppc403(self): return self.__machine(403) + def _is_ppc403(self): + return self.__machine(403) - def _is_ppc505(self): return self.__machine(505) + def _is_ppc505(self): + return self.__machine(505) - def _is_ppc801(self): return self.__machine(801) + def _is_ppc801(self): + return self.__machine(801) - def _is_ppc821(self): return self.__machine(821) + def _is_ppc821(self): + return self.__machine(821) - def _is_ppc823(self): return self.__machine(823) + def _is_ppc823(self): + return self.__machine(823) - def _is_ppc860(self): return self.__machine(860) + def _is_ppc860(self): + return self.__machine(860) class NetBSDCPUInfo(CPUInfoBase): @@ -561,25 +605,22 @@ def __init__(self): if self.info is not None: return info = {} - info['sysctl_hw'] = key_value_from_command(['sysctl', 'hw'], sep='=') - info['arch'] = info['sysctl_hw'].get('hw.machine_arch', 1) - info['machine'] = info['sysctl_hw'].get('hw.machine', 1) + info["sysctl_hw"] = key_value_from_command(["sysctl", "hw"], sep="=") + info["arch"] = info["sysctl_hw"].get("hw.machine_arch", 1) + info["machine"] = info["sysctl_hw"].get("hw.machine", 1) self.__class__.info = info - def _not_impl(self): pass + def _not_impl(self): + pass def _getNCPUs(self): - return int(self.info['sysctl_hw'].get('hw.ncpu', 1)) + return int(self.info["sysctl_hw"].get("hw.ncpu", 1)) def _is_Intel(self): - if self.info['sysctl_hw'].get('hw.model', "")[0:5] == 'Intel': - return True - return False + return self.info["sysctl_hw"].get("hw.model", "")[0:5] == "Intel" def _is_AMD(self): - if self.info['sysctl_hw'].get('hw.model', "")[0:3] == 'AMD': - return True - return False + return self.info["sysctl_hw"].get("hw.model", "")[0:3] == "AMD" class SunOSCPUInfo(CPUInfoBase): @@ -588,17 +629,18 @@ class SunOSCPUInfo(CPUInfoBase): def __init__(self): if self.info is not None: return - info = command_info(arch='arch', - mach='mach', - uname_i=['uname', '-i'], - isainfo_b=['isainfo', '-b'], - isainfo_n=['isainfo', '-n'], - ) - info['uname_X'] = key_value_from_command(['uname', '-X'], sep='=') - for line in command_by_line(['psrinfo', '-v', '0']): - m = re.match(r'\s*The (?P

[\w\d]+) processor operates at', line) + info = command_info( + arch="arch", + mach="mach", + uname_i=["uname", "-i"], + isainfo_b=["isainfo", "-b"], + isainfo_n=["isainfo", "-n"], + ) + info["uname_X"] = key_value_from_command(["uname", "-X"], sep="=") + for line in command_by_line(["psrinfo", "-v", "0"]): + m = re.match(r"\s*The (?P

[\w\d]+) processor operates at", line) if m: - info['processor'] = m.group('p') + info["processor"] = m.group("p") break self.__class__.info = info @@ -606,73 +648,76 @@ def _not_impl(self): pass def _is_i386(self): - return self.info['isainfo_n'] == 'i386' + return self.info["isainfo_n"] == "i386" def _is_sparc(self): - return self.info['isainfo_n'] == 'sparc' + return self.info["isainfo_n"] == "sparc" def _is_sparcv9(self): - return self.info['isainfo_n'] == 'sparcv9' + return self.info["isainfo_n"] == "sparcv9" def _getNCPUs(self): - return int(self.info['uname_X'].get('NumCPU', 1)) + return int(self.info["uname_X"].get("NumCPU", 1)) def _is_sun4(self): - return self.info['arch'] == 'sun4' + return self.info["arch"] == "sun4" def _is_SUNW(self): - return re.match(r'SUNW', self.info['uname_i']) is not None + return re.match(r"SUNW", self.info["uname_i"]) is not None def _is_sparcstation5(self): - return re.match(r'.*SPARCstation-5', self.info['uname_i']) is not None + return re.match(r".*SPARCstation-5", self.info["uname_i"]) is not None def _is_ultra1(self): - return re.match(r'.*Ultra-1', self.info['uname_i']) is not None + return re.match(r".*Ultra-1", self.info["uname_i"]) is not None def _is_ultra250(self): - return re.match(r'.*Ultra-250', self.info['uname_i']) is not None + return re.match(r".*Ultra-250", self.info["uname_i"]) is not None def _is_ultra2(self): - return re.match(r'.*Ultra-2', self.info['uname_i']) is not None + return re.match(r".*Ultra-2", self.info["uname_i"]) is not None def _is_ultra30(self): - return re.match(r'.*Ultra-30', self.info['uname_i']) is not None + return re.match(r".*Ultra-30", self.info["uname_i"]) is not None def _is_ultra4(self): - return re.match(r'.*Ultra-4', self.info['uname_i']) is not None + return re.match(r".*Ultra-4", self.info["uname_i"]) is not None def _is_ultra5_10(self): - return re.match(r'.*Ultra-5_10', self.info['uname_i']) is not None + return re.match(r".*Ultra-5_10", self.info["uname_i"]) is not None def _is_ultra5(self): - return re.match(r'.*Ultra-5', self.info['uname_i']) is not None + return re.match(r".*Ultra-5", self.info["uname_i"]) is not None def _is_ultra60(self): - return re.match(r'.*Ultra-60', self.info['uname_i']) is not None + return re.match(r".*Ultra-60", self.info["uname_i"]) is not None def _is_ultra80(self): - return re.match(r'.*Ultra-80', self.info['uname_i']) is not None + return re.match(r".*Ultra-80", self.info["uname_i"]) is not None def _is_ultraenterprice(self): - return re.match(r'.*Ultra-Enterprise', self.info['uname_i']) is not None + return re.match(r".*Ultra-Enterprise", self.info["uname_i"]) is not None def _is_ultraenterprice10k(self): - return re.match(r'.*Ultra-Enterprise-10000', self.info['uname_i']) is not None + return ( + re.match(r".*Ultra-Enterprise-10000", self.info["uname_i"]) + is not None + ) def _is_sunfire(self): - return re.match(r'.*Sun-Fire', self.info['uname_i']) is not None + return re.match(r".*Sun-Fire", self.info["uname_i"]) is not None def _is_ultra(self): - return re.match(r'.*Ultra', self.info['uname_i']) is not None + return re.match(r".*Ultra", self.info["uname_i"]) is not None def _is_cpusparcv7(self): - return self.info['processor'] == 'sparcv7' + return self.info["processor"] == "sparcv7" def _is_cpusparcv8(self): - return self.info['processor'] == 'sparcv8' + return self.info["processor"] == "sparcv8" def _is_cpusparcv9(self): - return self.info['processor'] == 'sparcv9' + return self.info["processor"] == "sparcv9" class Win32CPUInfo(CPUInfoBase): @@ -694,8 +739,11 @@ def __init__(self): try: # XXX: Bad style to use so long `try:...except:...`. Fix it! - prgx = re.compile(r"family\s+(?P\d+)\s+model\s+(?P\d+)" - r"\s+stepping\s+(?P\d+)", re.IGNORECASE) + prgx = re.compile( + r"family\s+(?P\d+)\s+model\s+(?P\d+)" + r"\s+stepping\s+(?P\d+)", + re.IGNORECASE, + ) chnd = _winreg.OpenKey(_winreg.HKEY_LOCAL_MACHINE, self.pkey) pnum = 0 while 1: @@ -721,9 +769,11 @@ def __init__(self): if srch: info[-1]["Family"] = int(srch.group("FML")) info[-1]["Model"] = int(srch.group("MDL")) - info[-1]["Stepping"] = int(srch.group("STP")) - except: - print(sys.exc_value, '(ignoring)') + info[-1]["Stepping"] = int( + srch.group("STP") + ) + except Exception as exc: + warnings.warn(f"{exc} (ignoring)", RuntimeWarning, stacklevel=2) self.__class__.info = info def _not_impl(self): @@ -732,86 +782,116 @@ def _not_impl(self): # Athlon def _is_AMD(self): - return self.info[0]['VendorIdentifier'] == 'AuthenticAMD' + return self.info[0]["VendorIdentifier"] == "AuthenticAMD" def _is_Am486(self): - return self.is_AMD() and self.info[0]['Family'] == 4 + return self.is_AMD() and self.info[0]["Family"] == 4 def _is_Am5x86(self): - return self.is_AMD() and self.info[0]['Family'] == 4 + return self.is_AMD() and self.info[0]["Family"] == 4 def _is_AMDK5(self): - return (self.is_AMD() and self.info[0]['Family'] == 5 and - self.info[0]['Model'] in [0, 1, 2, 3]) + return ( + self.is_AMD() + and self.info[0]["Family"] == 5 + and self.info[0]["Model"] in [0, 1, 2, 3] + ) def _is_AMDK6(self): - return (self.is_AMD() and self.info[0]['Family'] == 5 and - self.info[0]['Model'] in [6, 7]) + return ( + self.is_AMD() + and self.info[0]["Family"] == 5 + and self.info[0]["Model"] in [6, 7] + ) def _is_AMDK6_2(self): - return (self.is_AMD() and self.info[0]['Family'] == 5 and - self.info[0]['Model'] == 8) + return ( + self.is_AMD() + and self.info[0]["Family"] == 5 + and self.info[0]["Model"] == 8 + ) def _is_AMDK6_3(self): - return (self.is_AMD() and self.info[0]['Family'] == 5 and - self.info[0]['Model'] == 9) + return ( + self.is_AMD() + and self.info[0]["Family"] == 5 + and self.info[0]["Model"] == 9 + ) def _is_AMDK7(self): - return self.is_AMD() and self.info[0]['Family'] == 6 + return self.is_AMD() and self.info[0]["Family"] == 6 # To reliably distinguish between the different types of AMD64 chips # (Athlon64, Operton, Athlon64 X2, Semperon, Turion 64, etc.) would # require looking at the 'brand' from cpuid def _is_AMD64(self): - return self.is_AMD() and self.info[0]['Family'] == 15 + return self.is_AMD() and self.info[0]["Family"] == 15 # Intel def _is_Intel(self): - return self.info[0]['VendorIdentifier'] == 'GenuineIntel' + return self.info[0]["VendorIdentifier"] == "GenuineIntel" def _is_i386(self): - return self.info[0]['Family'] == 3 + return self.info[0]["Family"] == 3 def _is_i486(self): - return self.info[0]['Family'] == 4 + return self.info[0]["Family"] == 4 def _is_i586(self): - return self.is_Intel() and self.info[0]['Family'] == 5 + return self.is_Intel() and self.info[0]["Family"] == 5 def _is_i686(self): - return self.is_Intel() and self.info[0]['Family'] == 6 + return self.is_Intel() and self.info[0]["Family"] == 6 def _is_Pentium(self): - return self.is_Intel() and self.info[0]['Family'] == 5 + return self.is_Intel() and self.info[0]["Family"] == 5 def _is_PentiumMMX(self): - return (self.is_Intel() and self.info[0]['Family'] == 5 and - self.info[0]['Model'] == 4) + return ( + self.is_Intel() + and self.info[0]["Family"] == 5 + and self.info[0]["Model"] == 4 + ) def _is_PentiumPro(self): - return (self.is_Intel() and self.info[0]['Family'] == 6 and - self.info[0]['Model'] == 1) + return ( + self.is_Intel() + and self.info[0]["Family"] == 6 + and self.info[0]["Model"] == 1 + ) def _is_PentiumII(self): - return (self.is_Intel() and self.info[0]['Family'] == 6 and - self.info[0]['Model'] in [3, 5, 6]) + return ( + self.is_Intel() + and self.info[0]["Family"] == 6 + and self.info[0]["Model"] in [3, 5, 6] + ) def _is_PentiumIII(self): - return (self.is_Intel() and self.info[0]['Family'] == 6 and - self.info[0]['Model'] in [7, 8, 9, 10, 11]) + return ( + self.is_Intel() + and self.info[0]["Family"] == 6 + and self.info[0]["Model"] in [7, 8, 9, 10, 11] + ) def _is_PentiumIV(self): - return self.is_Intel() and self.info[0]['Family'] == 15 + return self.is_Intel() and self.info[0]["Family"] == 15 def _is_PentiumM(self): - return (self.is_Intel() and self.info[0]['Family'] == 6 and - self.info[0]['Model'] in [9, 13, 14]) + return ( + self.is_Intel() + and self.info[0]["Family"] == 6 + and self.info[0]["Model"] in [9, 13, 14] + ) def _is_Core2(self): - return (self.is_Intel() and self.info[0]['Family'] == 6 and - self.info[0]['Model'] in [15, 16, 17]) + return ( + self.is_Intel() + and self.info[0]["Family"] == 6 + and self.info[0]["Model"] in [15, 16, 17] + ) # Varia @@ -823,23 +903,25 @@ def _getNCPUs(self): def _has_mmx(self): if self.is_Intel(): - return ((self.info[0]['Family'] == 5 and - self.info[0]['Model'] == 4) or - (self.info[0]['Family'] in [6, 15])) + return ( + self.info[0]["Family"] == 5 and self.info[0]["Model"] == 4 + ) or (self.info[0]["Family"] in [6, 15]) elif self.is_AMD(): - return self.info[0]['Family'] in [5, 6, 15] + return self.info[0]["Family"] in [5, 6, 15] else: return False def _has_sse(self): if self.is_Intel(): - return ((self.info[0]['Family'] == 6 and - self.info[0]['Model'] in [7, 8, 9, 10, 11]) or - self.info[0]['Family'] == 15) + return ( + self.info[0]["Family"] == 6 + and self.info[0]["Model"] in [7, 8, 9, 10, 11] + ) or self.info[0]["Family"] == 15 elif self.is_AMD(): - return ((self.info[0]['Family'] == 6 and - self.info[0]['Model'] in [6, 7, 8, 10]) or - self.info[0]['Family'] == 15) + return ( + self.info[0]["Family"] == 6 + and self.info[0]["Model"] in [6, 7, 8, 10] + ) or self.info[0]["Family"] == 15 else: return False @@ -852,25 +934,27 @@ def _has_sse2(self): return False def _has_3dnow(self): - return self.is_AMD() and self.info[0]['Family'] in [5, 6, 15] + return self.is_AMD() and self.info[0]["Family"] in [5, 6, 15] def _has_3dnowext(self): - return self.is_AMD() and self.info[0]['Family'] in [6, 15] + return self.is_AMD() and self.info[0]["Family"] in [6, 15] -if sys.platform.startswith('linux'): # variations: linux2,linux-i386 (any others?) +if sys.platform.startswith( + "linux" +): # variations: linux2,linux-i386 (any others?) cpuinfo = LinuxCPUInfo -elif sys.platform.startswith('irix'): +elif sys.platform.startswith("irix"): cpuinfo = IRIXCPUInfo -elif sys.platform == 'darwin': +elif sys.platform == "darwin": cpuinfo = DarwinCPUInfo -elif sys.platform[0:6] == 'netbsd': +elif sys.platform[0:6] == "netbsd": cpuinfo = NetBSDCPUInfo -elif sys.platform.startswith('sunos'): +elif sys.platform.startswith("sunos"): cpuinfo = SunOSCPUInfo -elif sys.platform.startswith('win32'): +elif sys.platform.startswith("win32"): cpuinfo = Win32CPUInfo -elif sys.platform.startswith('cygwin'): +elif sys.platform.startswith("cygwin"): cpuinfo = LinuxCPUInfo # XXX: other OS's. Eg. use _winreg on Win32. Or os.uname on unices. else: diff --git a/capybara/enums.py b/capybara/enums.py index d958ffa..15bc85e 100644 --- a/capybara/enums.py +++ b/capybara/enums.py @@ -5,7 +5,13 @@ from .mixins import EnumCheckMixin __all__ = [ - 'INTER', 'ROTATE', 'BORDER', 'MORPH', 'COLORSTR', 'FORMATSTR', 'IMGTYP' + "BORDER", + "COLORSTR", + "FORMATSTR", + "IMGTYP", + "INTER", + "MORPH", + "ROTATE", ] diff --git a/capybara/mixins.py b/capybara/mixins.py index 6116d61..a8a99d1 100644 --- a/capybara/mixins.py +++ b/capybara/mixins.py @@ -1,8 +1,9 @@ import json from collections import OrderedDict +from collections.abc import Callable, Mapping, MutableMapping from dataclasses import asdict from enum import Enum -from typing import Any, Callable, Dict, Mapping, Optional +from typing import Any, TypeVar, cast from warnings import warn import numpy as np @@ -11,32 +12,39 @@ from .structures import Box, Boxes, Keypoints, KeypointsList, Polygon, Polygons __all__ = [ - "EnumCheckMixin", "DataclassCopyMixin", "DataclassToJsonMixin", + "EnumCheckMixin", "dict_to_jsonable", ] def dict_to_jsonable( - d: Mapping, - jsonable_func: Optional[Dict[str, Callable]] = None, - dict_factory: Mapping = OrderedDict, -) -> Any: + d: Mapping[str, Any], + jsonable_func: Mapping[str, Callable[[Any], Any]] | None = None, + dict_factory: Callable[[], MutableMapping[str, Any]] = OrderedDict, +) -> MutableMapping[str, Any]: out = dict_factory() for k, v in d.items(): if jsonable_func is not None and k in jsonable_func: out[k] = jsonable_func[k](v) else: if isinstance(v, (Box, Boxes)): - out[k] = v.convert("XYXY").numpy().astype(float).round().tolist() + out[k] = ( + v.convert("XYXY").numpy().astype(float).round().tolist() + ) elif isinstance(v, (Keypoints, KeypointsList, Polygon, Polygons)): out[k] = v.numpy().astype(float).round().tolist() elif isinstance(v, (np.ndarray, np.generic)): # include array and scalar, if you want jsonable image please use jsonable_func out[k] = v.tolist() elif isinstance(v, (list, tuple)): - out[k] = [dict_to_jsonable(x, jsonable_func) if isinstance(x, dict) else x for x in v] + out[k] = [ + dict_to_jsonable(x, jsonable_func) + if isinstance(x, dict) + else x + for x in v + ] elif isinstance(v, Enum): out[k] = v.name elif isinstance(v, Mapping): @@ -47,24 +55,27 @@ def dict_to_jsonable( try: json.dumps(out) except Exception as e: - warn(e) + warn(str(e), stacklevel=2) return out +_EnumT = TypeVar("_EnumT", bound="EnumCheckMixin") + + class EnumCheckMixin: @classmethod - def obj_to_enum(cls: Enum, obj: Any): + def obj_to_enum(cls: type[_EnumT], obj: Any) -> _EnumT: if isinstance(obj, str): try: - return getattr(cls, obj) + return cast(_EnumT, getattr(cls, obj)) except AttributeError: pass elif isinstance(obj, cls): return obj elif isinstance(obj, int): try: - return cls(obj) + return cast(_EnumT, cast(Any, cls)(obj)) except ValueError: pass @@ -73,10 +84,18 @@ def obj_to_enum(cls: Enum, obj: Any): class DataclassCopyMixin: def __copy__(self): - return self.__class__(**{field: getattr(self, field) for field in self.__dataclass_fields__}) + dataclass_fields = getattr(self, "__dataclass_fields__", None) + if dataclass_fields is None: + raise TypeError( + f"{self.__class__.__name__} is not a dataclass instance." + ) + field_names = cast(dict[str, Any], dataclass_fields).keys() + return self.__class__( + **{field: getattr(self, field) for field in field_names} + ) def __deepcopy__(self, memo): - out = asdict(self, dict_factory=OrderedDict) + out = asdict(cast(Any, self), dict_factory=OrderedDict) return from_dict(data_class=self.__class__, data=out) @@ -84,5 +103,7 @@ class DataclassToJsonMixin: jsonable_func = None def be_jsonable(self, dict_factory=OrderedDict): - d = asdict(self, dict_factory=dict_factory) - return dict_to_jsonable(d, jsonable_func=self.jsonable_func, dict_factory=dict_factory) + d = asdict(cast(Any, self), dict_factory=dict_factory) + return dict_to_jsonable( + d, jsonable_func=self.jsonable_func, dict_factory=dict_factory + ) diff --git a/capybara/onnxengine/__init__.py b/capybara/onnxengine/__init__.py index 18c1434..0b07016 100644 --- a/capybara/onnxengine/__init__.py +++ b/capybara/onnxengine/__init__.py @@ -1,8 +1,24 @@ -from .engine import ONNXEngine -from .engine_io_binding import ONNXEngineIOBinding -from .enum import Backend -from .metadata import get_onnx_metadata, parse_metadata_from_onnx, write_metadata_into_onnx -from .tools import get_onnx_input_infos, get_onnx_output_infos, get_recommended_backend, make_onnx_dynamic_axes +from ..runtime import Backend +from .engine import EngineConfig, ONNXEngine +from .metadata import ( + get_onnx_metadata, + parse_metadata_from_onnx, + write_metadata_into_onnx, +) +from .utils import ( + get_onnx_input_infos, + get_onnx_output_infos, + make_onnx_dynamic_axes, +) -# 暫時無法使用 -# from .quantize import quantize, quantize_static +__all__ = [ + "Backend", + "EngineConfig", + "ONNXEngine", + "get_onnx_input_infos", + "get_onnx_metadata", + "get_onnx_output_infos", + "make_onnx_dynamic_axes", + "parse_metadata_from_onnx", + "write_metadata_into_onnx", +] diff --git a/capybara/onnxengine/engine.py b/capybara/onnxengine/engine.py index bbed318..4be5202 100644 --- a/capybara/onnxengine/engine.py +++ b/capybara/onnxengine/engine.py @@ -1,178 +1,432 @@ +from __future__ import annotations + +import json +import time +from collections.abc import Mapping +from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, Union +from typing import Any -import colored import numpy as np -import onnxruntime as ort -from .enum import Backend -from .metadata import parse_metadata_from_onnx -from .tools import get_onnx_input_infos, get_onnx_output_infos +try: + _BFLOAT16 = np.dtype("bfloat16") +except TypeError: # pragma: no cover - older numpy + _BFLOAT16 = np.float16 + +try: # pragma: no cover - optional dependency handled at runtime + import onnxruntime as ort # type: ignore +except Exception as exc: # pragma: no cover + raise ImportError( + "onnxruntime is required. Install 'onnxruntime' or 'onnxruntime-gpu'." + ) from exc + +from ..runtime import Backend, ProviderSpec + + +@dataclass(slots=True) +class EngineConfig: + """High level configuration knobs for the ONNX engine.""" + + graph_optimization: str | ort.GraphOptimizationLevel = "all" + execution_mode: str | ort.ExecutionMode | None = None + intra_op_num_threads: int | None = None + inter_op_num_threads: int | None = None + log_severity_level: int | None = None + session_config_entries: dict[str, str] | None = None + provider_options: dict[str, dict[str, Any]] | None = None + fallback_to_cpu: bool = True + enable_io_binding: bool = False + run_config_entries: dict[str, str] | None = None + enable_profiling: bool = False + + +@dataclass(slots=True) +class _InputSpec: + name: str + dtype: np.dtype[Any] | None + dtype_str: str + shape: tuple[int | None, ...] + + +_GRAPH_OPT_MAP = { + "disable": ort.GraphOptimizationLevel.ORT_DISABLE_ALL, + "basic": ort.GraphOptimizationLevel.ORT_ENABLE_BASIC, + "extended": ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED, + "all": ort.GraphOptimizationLevel.ORT_ENABLE_ALL, +} + +_EXECUTION_MODE_MAP = { + "sequential": ort.ExecutionMode.ORT_SEQUENTIAL, + "parallel": ort.ExecutionMode.ORT_PARALLEL, +} + +_TYPE_MAP = { + "tensor(float)": np.float32, + "tensor(float16)": np.float16, + "tensor(double)": np.float64, + "tensor(bfloat16)": _BFLOAT16, + "tensor(int64)": np.int64, + "tensor(int32)": np.int32, + "tensor(int16)": np.int16, + "tensor(int8)": np.int8, + "tensor(uint64)": np.uint64, + "tensor(uint32)": np.uint32, + "tensor(uint16)": np.uint16, + "tensor(uint8)": np.uint8, + "tensor(bool)": np.bool_, +} + + +# NOTE: Benchmarks are intentionally strict about negative/zero loops to avoid +# reporting misleading throughput/latency statistics. +def _validate_benchmark_params(*, repeat: Any, warmup: Any) -> tuple[int, int]: + repeat_i = int(repeat) + warmup_i = int(warmup) + if repeat_i < 1: + raise ValueError("repeat must be >= 1.") + if warmup_i < 0: + raise ValueError("warmup must be >= 0.") + return repeat_i, warmup_i + + +# Provider-specific defaults that enable CUDA graph and TensorRT caching tweaks. +_DEFAULT_PROVIDER_OPTIONS: dict[str, dict[str, Any]] = { + "CUDAExecutionProvider": { + "cudnn_conv_algo_search": "HEURISTIC", + "cudnn_conv_use_max_workspace": "1", + "do_copy_in_default_stream": True, + "arena_extend_strategy": "kSameAsRequested", + "tunable_op_enable": False, + "enable_cuda_graph": False, # 試驗性功能, 推論效果不穩定 + }, + "TensorrtExecutionProvider": { + "trt_max_workspace_size": 4 * 1024 * 1024 * 1024, + "trt_int8_enable": False, + "trt_fp16_enable": False, + "trt_engine_cache_enable": True, + "trt_engine_cache_path": "trt_engine_cache", + "trt_timing_cache_enable": True, + }, +} class ONNXEngine: + """A thin wrapper around onnxruntime.InferenceSession with ergonomic defaults.""" + def __init__( self, - model_path: Union[str, Path], + model_path: str | Path, gpu_id: int = 0, - backend: Union[str, int, Backend] = Backend.cpu, - session_option: Dict[str, Any] = {}, - provider_option: Dict[str, Any] = {}, - ): - """ - Initialize an ONNX model inference engine. - - Args: - model_path (Union[str, Path]): - Filename or serialized ONNX or ORT format model in a byte string. - gpu_id (int, optional): - GPU ID. Defaults to 0. - backend (Union[str, int, Backend], optional): - Backend. Defaults to Backend.cuda. - session_option (Dict[str, Any], optional): - Session options. Defaults to {}. - provider_option (Dict[str, Any], optional): - Provider options. Defaults to {}. - """ - # setting device info - backend = Backend.obj_to_enum(backend) - self.device_id = 0 if backend.name == "cpu" else gpu_id - - # setting provider options - providers = self._get_providers(backend, provider_option) - - # setting session options - sess_options = self._get_session_info(session_option) - - # setting onnxruntime session - model_path = str(model_path) if isinstance(model_path, Path) else model_path - self.sess = ort.InferenceSession(model_path, sess_options=sess_options, providers=providers) - - # setting onnxruntime session info - self.model_path = model_path - self.metadata = parse_metadata_from_onnx(model_path) - self.providers = self.sess.get_providers() - self.provider_options = self.sess.get_provider_options() - - self.input_infos = get_onnx_input_infos(model_path) - self.output_infos = get_onnx_output_infos(model_path) - - def __call__(self, **xs) -> Dict[str, np.ndarray]: - output_names = list(self.output_infos.keys()) - outs = self.sess.run(output_names, {k: v for k, v in xs.items()}) - outs = {k: v for k, v in zip(output_names, outs)} - return outs - - def _get_session_info(self, session_option: Dict[str, Any] = {}) -> ort.SessionOptions: - """ - Ref: https://onnxruntime.ai/docs/api/python/api_summary.html#sessionoptions - """ - sess_opt = ort.SessionOptions() - session_option_default = { - "graph_optimization_level": ort.GraphOptimizationLevel.ORT_ENABLE_ALL, - "log_severity_level": 2, + backend: str | Backend = Backend.cpu, + session_option: Mapping[str, Any] | None = None, + provider_option: Mapping[str, Any] | None = None, + config: EngineConfig | None = None, + ) -> None: + self.model_path = str(model_path) + self.backend = Backend.from_any(backend, runtime="onnx") + self.device_id = int(gpu_id) + self._session_overrides = dict(session_option or {}) + self._provider_override = dict(provider_option or {}) + self._cfg = config or EngineConfig() + + self._session = self._create_session() + self.providers = self._session.get_providers() + self.provider_options = self._session.get_provider_options() + self.metadata = self._extract_metadata() + + self._output_names = [node.name for node in self._session.get_outputs()] + self._input_specs = self._inspect_inputs() + self._binding = ( + self._session.io_binding() if self._cfg.enable_io_binding else None + ) + self._run_options = self._build_run_options() + + def __call__(self, **inputs: np.ndarray) -> dict[str, np.ndarray]: + if len(inputs) == 1 and isinstance( + next(iter(inputs.values())), Mapping + ): + feed_dict = dict(next(iter(inputs.values()))) + else: + feed_dict = inputs + feed = self._prepare_feed(feed_dict) + return self._run(feed) + + def run( + self, + feed: Mapping[str, np.ndarray], + ) -> dict[str, np.ndarray]: + return self._run(self._prepare_feed(feed)) + + def summary(self) -> dict[str, Any]: + inputs = [ + { + "name": spec.name, + "dtype": spec.dtype_str, + "shape": list(spec.shape), + } + for spec in self._input_specs + ] + outputs = [ + { + "name": out.name, + "dtype": getattr(out, "type", ""), + "shape": list(getattr(out, "shape", [])), + } + for out in self._session.get_outputs() + ] + return { + "model": self.model_path, + "providers": self.providers, + "inputs": inputs, + "outputs": outputs, } - session_option_default.update(session_option) - for k, v in session_option_default.items(): - setattr(sess_opt, k, v) - return sess_opt - - def _get_providers(self, backend: Union[str, int, Backend], provider_option: Dict[str, Any] = {}) -> Backend: - """ - Ref: https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#configuration-options - """ - if backend == Backend.cuda: - providers = [ - ( - "CUDAExecutionProvider", - { - "device_id": self.device_id, - "cudnn_conv_use_max_workspace": "1", - **provider_option, - }, - ) - ] - elif backend == Backend.coreml: - providers = [ - ( - "CoreMLExecutionProvider", - { - "ModelFormat": "MLProgram", - "MLComputeUnits": "ALL", - "RequireStaticInputShapes": "1", - **provider_option, - }, + + def benchmark( + self, + inputs: Mapping[str, np.ndarray], + *, + repeat: int = 100, + warmup: int = 10, + ) -> dict[str, Any]: + repeat, warmup = _validate_benchmark_params( + repeat=repeat, warmup=warmup + ) + feed = self._prepare_feed(inputs) + for _ in range(warmup): + self._session.run(self._output_names, feed) + + latencies: list[float] = [] + t0 = time.perf_counter() + for _ in range(repeat): + start = time.perf_counter() + self._session.run(self._output_names, feed) + latencies.append((time.perf_counter() - start) * 1e3) + total = time.perf_counter() - t0 + arr = np.asarray(latencies, dtype=np.float64) + + return { + "repeat": repeat, + "warmup": warmup, + "throughput_fps": repeat / total if total else None, + "latency_ms": { + "mean": float(arr.mean()) if arr.size else None, + "median": float(np.median(arr)) if arr.size else None, + "p90": float(np.percentile(arr, 90)) if arr.size else None, + "p95": float(np.percentile(arr, 95)) if arr.size else None, + "min": float(arr.min()) if arr.size else None, + "max": float(arr.max()) if arr.size else None, + }, + } + + def _run(self, feed: Mapping[str, np.ndarray]) -> dict[str, np.ndarray]: + if self._binding is not None: + binding = self._binding + binding.clear_binding_inputs() + binding.clear_binding_outputs() + for name, array in feed.items(): + binding.bind_cpu_input(name, array) + for name in self._output_names: + binding.bind_output(name) + if self._run_options is not None: + self._session.run_with_iobinding(binding, self._run_options) + else: + self._session.run_with_iobinding(binding) + outputs = binding.copy_outputs_to_cpu() + else: + if self._run_options is not None: + outputs = self._session.run( + self._output_names, feed, self._run_options ) + else: + outputs = self._session.run(self._output_names, feed) + converted: list[np.ndarray] = [] + for out in outputs: + toarray = getattr(out, "toarray", None) + if callable(toarray): + out = toarray() + converted.append(np.asarray(out)) + return dict(zip(self._output_names, converted, strict=False)) + + def _prepare_feed(self, feed: Mapping[str, Any]) -> dict[str, np.ndarray]: + prepared: dict[str, np.ndarray] = {} + for spec in self._input_specs: + if spec.name not in feed: + raise KeyError(f"Missing required input '{spec.name}'.") + array = np.asarray(feed[spec.name]) + if spec.dtype is not None and array.dtype != spec.dtype: + array = array.astype(spec.dtype, copy=False) + prepared[spec.name] = array + return prepared + + def _create_session(self) -> ort.InferenceSession: + session_options = self._build_session_options() + provider_tuples = self._resolve_providers() + provider_names = [name for name, _ in provider_tuples] + provider_options = [opts for _, opts in provider_tuples] + + available = set(ort.get_available_providers()) + missing = [ + name + for name in provider_names + if name != "CPUExecutionProvider" and name not in available + ] + if missing and self._cfg.fallback_to_cpu: + provider_names = ["CPUExecutionProvider"] + provider_options = [ + self._build_provider_options("CPUExecutionProvider") ] - elif backend == Backend.cpu: - providers = [("CPUExecutionProvider", {})] - # "CPUExecutionProvider" is different from everything else. - # provider_option = None - else: - raise ValueError(f"backend={backend} is not supported.") - return providers - - def __repr__(self) -> str: - import re - - def strip_ansi_codes(text): - """Remove ANSI escape codes from a string.""" - ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") - return ansi_escape.sub("", text) - - def format_nested_dict(dict_data, indent=0): - """Recursively format nested dictionaries with indentation.""" - info = [] - prefix = " " * indent - for key, value in dict_data.items(): - if isinstance(value, dict): - info.append(f"{prefix}{key}:") - info.append(format_nested_dict(value, indent + 1)) - elif isinstance(value, str) and value.startswith("{") and value.endswith("}"): - try: - nested_dict = eval(value) - if isinstance(nested_dict, dict): - info.append(f"{prefix}{key}:") - info.append(format_nested_dict(nested_dict, indent + 1)) - else: - info.append(f"{prefix}{key}: {value}") - except Exception: - info.append(f"{prefix}{key}: {value}") + + return ort.InferenceSession( + self.model_path, + sess_options=session_options, + providers=provider_names, + provider_options=provider_options, + ) + + def _build_session_options(self) -> ort.SessionOptions: + opts = ort.SessionOptions() + cfg = self._cfg + + opts.enable_profiling = bool(cfg.enable_profiling) + opts.graph_optimization_level = self._resolve_graph_optimization( + cfg.graph_optimization + ) + if cfg.execution_mode is not None: + opts.execution_mode = self._resolve_execution_mode( + cfg.execution_mode + ) + if cfg.intra_op_num_threads is not None: + opts.intra_op_num_threads = int(cfg.intra_op_num_threads) + if cfg.inter_op_num_threads is not None: + opts.inter_op_num_threads = int(cfg.inter_op_num_threads) + if cfg.log_severity_level is not None: + opts.log_severity_level = int(cfg.log_severity_level) + + for key, value in (cfg.session_config_entries or {}).items(): + opts.add_session_config_entry(str(key), str(value)) + for key, value in self._session_overrides.items(): + if hasattr(opts, key): + setattr(opts, key, value) + else: + opts.add_session_config_entry(str(key), str(value)) + + return opts + + def _resolve_graph_optimization( + self, + option: str | ort.GraphOptimizationLevel, + ) -> ort.GraphOptimizationLevel: + if isinstance(option, ort.GraphOptimizationLevel): + return option + normalized = str(option).lower() + return _GRAPH_OPT_MAP.get( + normalized, ort.GraphOptimizationLevel.ORT_ENABLE_ALL + ) + + def _resolve_execution_mode( + self, + option: str | ort.ExecutionMode, + ) -> ort.ExecutionMode: + if isinstance(option, ort.ExecutionMode): + return option + normalized = str(option).lower() + return _EXECUTION_MODE_MAP.get( + normalized, ort.ExecutionMode.ORT_SEQUENTIAL + ) + + def _resolve_providers(self) -> list[tuple[str, dict[str, str]]]: + cfg_providers = self._cfg.provider_options or {} + provider_specs = self.backend.providers or ( + ProviderSpec("CPUExecutionProvider"), + ) + merged: list[tuple[str, dict[str, str]]] = [] + for spec in provider_specs: + opts = self._build_provider_options( + spec.name, include_device=spec.include_device + ) + merged_opts = dict(opts) + for key, value in cfg_providers.get(spec.name, {}).items(): + merged_opts[str(key)] = self._provider_value(value) + merged.append((spec.name, merged_opts)) + return merged + + def _build_provider_options( + self, + name: str, + *, + include_device: bool = False, + ) -> dict[str, str]: + opts: dict[str, Any] = dict(_DEFAULT_PROVIDER_OPTIONS.get(name, {})) + if include_device: + opts["device_id"] = self.device_id + opts.update( + self._provider_override + if (include_device or name == "CPUExecutionProvider") + else {} + ) + return {str(k): self._provider_value(v) for k, v in opts.items()} + + def _provider_value(self, value: Any) -> str: + if isinstance(value, bool): + return "True" if value else "False" + return str(value) + + def _build_run_options(self) -> ort.RunOptions | None: + entries = { + **(self._cfg.run_config_entries or {}), + } + if not entries: + return None + opts = ort.RunOptions() + for key, value in entries.items(): + opts.add_run_config_entry(str(key), str(value)) + return opts + + def _extract_metadata(self) -> dict[str, Any] | None: + try: + model_meta = self._session.get_modelmeta() + except Exception: + return None + + custom = getattr(model_meta, "custom_metadata_map", None) + if not isinstance(custom, Mapping): + return None + + parsed: dict[str, Any] = {} + for key, value in custom.items(): + if isinstance(value, str): + try: + parsed[key] = json.loads(value) + except Exception: + parsed[key] = value + else: + parsed[key] = value + return parsed or None + + def _inspect_inputs(self) -> list[_InputSpec]: + specs: list[_InputSpec] = [] + for node in self._session.get_inputs(): + dtype = _TYPE_MAP.get(getattr(node, "type", "")) + raw_shape = list(getattr(node, "shape", [])) + shape: list[int | None] = [] + for dim in raw_shape: + if isinstance(dim, (int, np.integer)): + shape.append(int(dim)) else: - info.append(f"{prefix}{key}: {value}") - return "\n".join(info) - - title = "DOCSAID X ONNXRUNTIME" - divider_length = 50 - divider = f"+{'-' * divider_length}+" - styled_title = colored.stylize(title, [colored.fg("blue"), colored.attr("bold")]) - - def center_text(text, width): - """Center text within a fixed width, handling ANSI escape codes.""" - plain_text = strip_ansi_codes(text) - text_length = len(plain_text) - left_padding = (width - text_length) // 2 - right_padding = width - text_length - left_padding - return f"|{' ' * left_padding}{text}{' ' * right_padding}|" - - path = f"Model Path: {self.model_path}" - input_info = format_nested_dict(self.input_infos) - output_info = format_nested_dict(self.output_infos) - metadata = format_nested_dict({"metadata": self.metadata}) - providers = f"Provider: {', '.join(self.providers)}" - provider_options = format_nested_dict(self.provider_options) - - sections = [ - divider, - center_text(styled_title, divider_length), - divider, - path, - input_info, - output_info, - metadata, - providers, - provider_options, - divider, - ] + shape.append(None) + specs.append( + _InputSpec( + name=node.name, + dtype=dtype, + dtype_str=getattr(node, "type", ""), + shape=tuple(shape), + ) + ) + return specs - return "\n\n".join(sections) + def __repr__(self) -> str: # pragma: no cover - human readable summary + return ( + f"ONNXEngine(model='{self.model_path}', " + f"backend='{self.backend.name}', providers={self.providers})" + ) diff --git a/capybara/onnxengine/engine_io_binding.py b/capybara/onnxengine/engine_io_binding.py deleted file mode 100644 index cc8f90c..0000000 --- a/capybara/onnxengine/engine_io_binding.py +++ /dev/null @@ -1,203 +0,0 @@ -from pathlib import Path -from typing import Any, Dict, Union - -import colored -import numpy as np -import onnxruntime as ort - -from .metadata import get_onnx_metadata -from .tools import get_onnx_input_infos, get_onnx_output_infos - - -class ONNXEngineIOBinding: - def __init__( - self, - model_path: Union[str, Path], - input_initializer: Dict[str, np.ndarray], - gpu_id: int = 0, - session_option: Dict[str, Any] = {}, - provider_option: Dict[str, Any] = {}, - ): - """ - Initialize an ONNX model inference engine. - - Args: - model_path (Union[str, Path]): - Filename or serialized ONNX or ORT format model in a byte string. - gpu_id (int, optional): - GPU ID. Defaults to 0. - session_option (Dict[str, Any], optional): - Session options. Defaults to {}. - provider_option (Dict[str, Any], optional): - Provider options. Defaults to {}. - """ - self.device_id = gpu_id - providers = ["CUDAExecutionProvider"] - provider_options = [ - { - "device_id": self.device_id, - "cudnn_conv_use_max_workspace": "1", - "enable_cuda_graph": "1", - **provider_option, - } - ] - - # setting session options - sess_options = self._get_session_info(session_option) - - # setting onnxruntime session - model_path = str(model_path) if isinstance(model_path, Path) else model_path - self.sess = ort.InferenceSession( - model_path, - sess_options=sess_options, - providers=providers, - provider_options=provider_options, - ) - self.device = "cuda" if "CUDAExecutionProvider" in self.sess.get_providers() else "cpu" - - # setting onnxruntime session info - self.model_path = model_path - self.metadata = get_onnx_metadata(model_path) - self.providers = self.sess.get_providers() - self.provider_options = self.sess.get_provider_options() - - input_infos, output_infos = self._init_io_infos(model_path, input_initializer) - - io_binding, x_ortvalues, y_ortvalues = self._setup_io_binding(input_infos, output_infos) - self.io_binding = io_binding - self.x_ortvalues = x_ortvalues - self.y_ortvalues = y_ortvalues - self.input_infos = input_infos - self.output_infos = output_infos - # # Pass gpu_graph_id to RunOptions through RunConfigs - # ro = ort.RunOptions() - # # gpu_graph_id is optional if the session uses only one cuda graph - # ro.add_run_config_entry("gpu_graph_id", "1") - # self.run_option = ro - - def __call__(self, **xs) -> Dict[str, np.ndarray]: - self._update_x_ortvalues(xs) - # self.sess.run_with_iobinding(self.io_binding, self.run_option) - self.sess.run_with_iobinding(self.io_binding) - return {k: v.numpy() for k, v in self.y_ortvalues.items()} - - def _get_session_info( - self, - session_option: Dict[str, Any] = {}, - ) -> ort.SessionOptions: - """ - Ref: https://onnxruntime.ai/docs/api/python/api_summary.html#sessionoptions - """ - sess_opt = ort.SessionOptions() - session_option_default = { - "graph_optimization_level": ort.GraphOptimizationLevel.ORT_ENABLE_ALL, - "log_severity_level": 2, - } - session_option_default.update(session_option) - for k, v in session_option_default.items(): - setattr(sess_opt, k, v) - return sess_opt - - def _init_io_infos(self, model_path, input_initializer: dict): - sess = ort.InferenceSession( - model_path, - providers=["CPUExecutionProvider"], - ) - outs = sess.run(None, input_initializer) - input_shapes = {k: v.shape for k, v in input_initializer.items()} - output_shapes = {x.name: o.shape for x, o in zip(sess.get_outputs(), outs)} - input_infos = get_onnx_input_infos(model_path) - output_infos = get_onnx_output_infos(model_path) - for k, v in input_infos.items(): - v["shape"] = input_shapes[k] - for k, v in output_infos.items(): - v["shape"] = output_shapes[k] - del sess - return input_infos, output_infos - - def _setup_io_binding(self, input_infos, output_infos): - x_ortvalues = {} - y_ortvalues = {} - for k, v in input_infos.items(): - m = np.zeros(**v) - x_ortvalues[k] = ort.OrtValue.ortvalue_from_numpy(m, device_type=self.device, device_id=self.device_id) - for k, v in output_infos.items(): - m = np.zeros(**v) - y_ortvalues[k] = ort.OrtValue.ortvalue_from_numpy(m, device_type=self.device, device_id=self.device_id) - - io_binding = self.sess.io_binding() - for k, v in x_ortvalues.items(): - io_binding.bind_ortvalue_input(k, v) - for k, v in y_ortvalues.items(): - io_binding.bind_ortvalue_output(k, v) - - return io_binding, x_ortvalues, y_ortvalues - - def _update_x_ortvalues(self, xs: dict): - for k, v in self.x_ortvalues.items(): - v.update_inplace(xs[k]) - - def __repr__(self) -> str: - import re - - def strip_ansi_codes(text): - """Remove ANSI escape codes from a string.""" - ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") - return ansi_escape.sub("", text) - - def format_nested_dict(dict_data, indent=0): - """Recursively format nested dictionaries with indentation.""" - info = [] - prefix = " " * indent - for key, value in dict_data.items(): - if isinstance(value, dict): - info.append(f"{prefix}{key}:") - info.append(format_nested_dict(value, indent + 1)) - elif isinstance(value, str) and value.startswith("{") and value.endswith("}"): - try: - nested_dict = eval(value) - if isinstance(nested_dict, dict): - info.append(f"{prefix}{key}:") - info.append(format_nested_dict(nested_dict, indent + 1)) - else: - info.append(f"{prefix}{key}: {value}") - except Exception: - info.append(f"{prefix}{key}: {value}") - else: - info.append(f"{prefix}{key}: {value}") - return "\n".join(info) - - title = "DOCSAID X ONNXRUNTIME" - divider_length = 50 - divider = f"+{'-' * divider_length}+" - styled_title = colored.stylize(title, [colored.fg("blue"), colored.attr("bold")]) - - def center_text(text, width): - """Center text within a fixed width, handling ANSI escape codes.""" - plain_text = strip_ansi_codes(text) - text_length = len(plain_text) - left_padding = (width - text_length) // 2 - right_padding = width - text_length - left_padding - return f"|{' ' * left_padding}{text}{' ' * right_padding}|" - - path = f"Model Path: {self.model_path}" - input_info = format_nested_dict(self.input_infos) - output_info = format_nested_dict(self.output_infos) - metadata = format_nested_dict({"metadata": self.metadata}) - providers = f"Provider: {', '.join(self.providers)}" - provider_options = format_nested_dict(self.provider_options) - - sections = [ - divider, - center_text(styled_title, divider_length), - divider, - path, - input_info, - output_info, - metadata, - providers, - provider_options, - divider, - ] - - return "\n\n".join(sections) diff --git a/capybara/onnxengine/enum.py b/capybara/onnxengine/enum.py deleted file mode 100644 index a697e12..0000000 --- a/capybara/onnxengine/enum.py +++ /dev/null @@ -1,9 +0,0 @@ -from enum import Enum - -from ..enums import EnumCheckMixin - - -class Backend(EnumCheckMixin, Enum): - cpu = 0 - cuda = 1 - coreml = 2 diff --git a/capybara/onnxengine/metadata.py b/capybara/onnxengine/metadata.py index a7d31dc..9e4f52f 100644 --- a/capybara/onnxengine/metadata.py +++ b/capybara/onnxengine/metadata.py @@ -1,48 +1,73 @@ +from __future__ import annotations + import json -from typing import Union +from pathlib import Path +from typing import Any import onnx -import onnxruntime as ort -from ..utils import Path, now +from ..utils.time import now + +try: # pragma: no cover - optional dependency + import onnxruntime as ort # type: ignore +except Exception: # pragma: no cover - handled lazily + ort = None # type: ignore[assignment] + +def _require_ort(): + if ort is None: # pragma: no cover - depends on optional dep + raise ImportError( + "onnxruntime is required to read/write ONNX metadata. " + "Install 'onnxruntime' or 'onnxruntime-gpu'." + ) + return ort -def get_onnx_metadata( - onnx_path: Union[str, Path], -) -> dict: - onnx_path = str(onnx_path) - sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) - metadata = sess.get_modelmeta().custom_metadata_map - del sess - return metadata + +def get_onnx_metadata(onnx_path: str | Path) -> dict[str, Any]: + ort_mod = _require_ort() + sess = ort_mod.InferenceSession( + str(onnx_path), providers=["CPUExecutionProvider"] + ) + try: + custom = sess.get_modelmeta().custom_metadata_map + return dict(custom) + finally: + del sess def write_metadata_into_onnx( - onnx_path: Union[str, Path], - out_path: Union[str, Path], + onnx_path: str | Path, + out_path: str | Path, drop_old_meta: bool = False, - **kwargs, -): - onnx_path = str(onnx_path) - onnx_model = onnx.load(onnx_path) - meta_data = parse_metadata_from_onnx(onnx_path) if not drop_old_meta else {} + **kwargs: Any, +) -> None: + onnx_model = onnx.load(str(onnx_path)) + meta_data: dict[str, Any] = ( + {} if drop_old_meta else parse_metadata_from_onnx(onnx_path) + ) meta_data.update({"Date": now(fmt="%Y-%m-%d %H:%M:%S"), **kwargs}) onnx.helper.set_model_props( onnx_model, - {k: json.dumps(v) for k, v in meta_data.items()}, + {str(k): json.dumps(v) for k, v in meta_data.items()}, + ) + onnx.save(onnx_model, str(out_path)) + + +def parse_metadata_from_onnx(onnx_path: str | Path) -> dict[str, Any]: + ort_mod = _require_ort() + sess = ort_mod.InferenceSession( + str(onnx_path), providers=["CPUExecutionProvider"] ) - onnx.save(onnx_model, out_path) - - -def parse_metadata_from_onnx( - onnx_path: Union[str, Path], -) -> dict: - onnx_path = str(onnx_path) - sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) - metadata = { - k: json.loads(v) for k, v in sess.get_modelmeta().custom_metadata_map.items() - } - del sess - return metadata + try: + metadata_map = sess.get_modelmeta().custom_metadata_map + parsed: dict[str, Any] = {} + for key, raw in metadata_map.items(): + if isinstance(raw, str): + parsed[key] = json.loads(raw) + else: + parsed[key] = raw + return parsed + finally: + del sess diff --git a/capybara/onnxengine/quantize.py b/capybara/onnxengine/quantize.py deleted file mode 100644 index 15e1aa0..0000000 --- a/capybara/onnxengine/quantize.py +++ /dev/null @@ -1,62 +0,0 @@ -# from enum import Enum -# from typing import List, Union - -# import onnx -# from onnxruntime.quantization import CalibrationDataReader, quantize_static -# from onnxruntime.quantization.calibrate import CalibrationMethod -# from onnxruntime.quantization.quant_utils import QuantFormat - -# from ..enums import Enum, EnumCheckMixin -# from ..utils import Path - - -# class DstDevice(EnumCheckMixin, Enum): -# mobile = 0 -# x86 = 1 - - -# def _get_exclude_names_from_op_type(onnx_fpath, op_type): -# onnx_model = onnx.load(str(onnx_fpath)) -# return [node.name for node in onnx_model.graph.node if node.op_type == op_type] - - -# def quantize( -# onnx_fpath: Union[str, Path], -# calibration_data_reader: CalibrationDataReader, -# dst_device: Union[str, DstDevice] = DstDevice.mobile, -# op_type_to_exclude: List[str] = [], -# **quant_cfg, -# ) -> str: -# dst_device = DstDevice.obj_to_enum(dst_device) -# onnx_fpath = Path(onnx_fpath) - -# print(f'\nStart to quantize onnx model to {dst_device.name}') - -# quant_fpath = str(onnx_fpath).replace('fp32', '').replace( -# '.onnx', f'_int8_{dst_device.name}.onnx') - -# quant_format = QuantFormat.QOperator \ -# if dst_device == DstDevice.mobile else QuantFormat.QDQ -# quant_format = quant_cfg.get('quant_format', quant_format) -# calibrate_method = quant_cfg.get( -# 'calibrate_method', CalibrationMethod.MinMax) -# per_channel = quant_cfg.get('per_channel', True) -# reduce_range = quant_cfg.get('reduce_range', True) -# nodes_to_exclude = quant_cfg.get('nodes_to_exclude', []) - -# for op_type in op_type_to_exclude: -# nodes_to_exclude.extend( -# _get_exclude_names_from_op_type(onnx_fpath, op_type)) - -# quantize_static( -# onnx_fpath, -# quant_fpath, -# calibration_data_reader, -# quant_format=quant_format, -# calibrate_method=calibrate_method, -# per_channel=per_channel, -# reduce_range=reduce_range, -# nodes_to_exclude=nodes_to_exclude -# ) - -# return quant_fpath diff --git a/capybara/onnxengine/tools.py b/capybara/onnxengine/tools.py deleted file mode 100644 index 73bffaa..0000000 --- a/capybara/onnxengine/tools.py +++ /dev/null @@ -1,94 +0,0 @@ -from pathlib import Path -from typing import Dict, List, Optional, Union - -import onnx -import onnxruntime as ort -import onnxslim -from onnx.helper import make_graph, make_model, make_opsetid, tensor_dtype_to_np_dtype - -from .enum import Backend - -__all__ = [ - "get_onnx_input_infos", - "get_onnx_output_infos", - "make_onnx_dynamic_axes", - "get_recommended_backend", -] - - -def get_onnx_input_infos(model: Union[str, Path, onnx.ModelProto]) -> Dict[str, List[int]]: - if not isinstance(model, onnx.ModelProto): - model = onnx.load(model) - return { - x.name: { - "shape": [d.dim_value if d.dim_value != 0 else -1 for d in x.type.tensor_type.shape.dim], - "dtype": tensor_dtype_to_np_dtype(x.type.tensor_type.elem_type), - } - for x in model.graph.input - } - - -def get_onnx_output_infos(model: Union[str, Path, onnx.ModelProto]) -> Dict[str, List[int]]: - if not isinstance(model, onnx.ModelProto): - model = onnx.load(model) - return { - x.name: { - "shape": [d.dim_value if d.dim_value != 0 else -1 for d in x.type.tensor_type.shape.dim], - "dtype": tensor_dtype_to_np_dtype(x.type.tensor_type.elem_type), - } - for x in model.graph.output - } - - -def make_onnx_dynamic_axes( - model_fpath: Union[str, Path], - output_fpath: Union[str, Path], - input_dims: Dict[str, Dict[int, str]], - output_dims: Dict[str, Dict[int, str]], - opset_version: Optional[int] = None, -) -> None: - onnx_model = onnx.load(model_fpath) - - new_graph = make_graph( - nodes=onnx_model.graph.node, - name=onnx_model.graph.name, - inputs=onnx_model.graph.input, - outputs=onnx_model.graph.output, - initializer=onnx_model.graph.initializer, - value_info=None, - ) - - if not any(opset.domain == "" for opset in onnx_model.opset_import): - onnx_model.opset_import.append(make_opsetid(domain="", version=opset_version)) - - new_model = make_model(new_graph, opset_imports=onnx_model.opset_import, ir_version=onnx_model.ir_version) - - for x in new_model.graph.input: - for name, v in input_dims.items(): - if x.name == name: - for k, d in v.items(): - x.type.tensor_type.shape.dim[k].dim_param = d - - for x in new_model.graph.output: - for name, v in output_dims.items(): - if x.name == name: - for k, d in v.items(): - x.type.tensor_type.shape.dim[k].dim_param = d - - for x in new_model.graph.node: - if x.op_type == "Reshape": - raise ValueError("Reshape cannot be trasformed to dynamic axes") - - new_model = onnxslim.slim(new_model) - onnx.save(new_model, output_fpath) - - -def get_recommended_backend() -> Backend: - providers = ort.get_available_providers() - device = ort.get_device() - if "CUDAExecutionProvider" in providers and device == "GPU": - return Backend.cuda - elif "CoreMLExecutionProvider" in providers: - return Backend.coreml - else: - return Backend.cpu diff --git a/capybara/onnxengine/utils.py b/capybara/onnxengine/utils.py new file mode 100644 index 0000000..717b7e3 --- /dev/null +++ b/capybara/onnxengine/utils.py @@ -0,0 +1,113 @@ +from pathlib import Path +from typing import Any, cast + +import onnx +import onnxslim +from onnx.helper import ( + make_graph, + make_model, + make_opsetid, + tensor_dtype_to_np_dtype, +) + +__all__ = [ + "get_onnx_input_infos", + "get_onnx_output_infos", +] + + +def get_onnx_input_infos( + model: str | Path | onnx.ModelProto, +) -> dict[str, dict[str, Any]]: + if not isinstance(model, onnx.ModelProto): + model = onnx.load(model) + + def _dim_to_value(dim: Any) -> int | str: + dim_param = getattr(dim, "dim_param", "") or "" + if dim_param: + return str(dim_param) + dim_value = int(getattr(dim, "dim_value", 0) or 0) + return dim_value if dim_value != 0 else -1 + + return { + x.name: { + "shape": [_dim_to_value(d) for d in x.type.tensor_type.shape.dim], + "dtype": tensor_dtype_to_np_dtype(x.type.tensor_type.elem_type), + } + for x in model.graph.input + } + + +def get_onnx_output_infos( + model: str | Path | onnx.ModelProto, +) -> dict[str, dict[str, Any]]: + if not isinstance(model, onnx.ModelProto): + model = onnx.load(model) + + def _dim_to_value(dim: Any) -> int | str: + dim_param = getattr(dim, "dim_param", "") or "" + if dim_param: + return str(dim_param) + dim_value = int(getattr(dim, "dim_value", 0) or 0) + return dim_value if dim_value != 0 else -1 + + return { + x.name: { + "shape": [_dim_to_value(d) for d in x.type.tensor_type.shape.dim], + "dtype": tensor_dtype_to_np_dtype(x.type.tensor_type.elem_type), + } + for x in model.graph.output + } + + +def make_onnx_dynamic_axes( + model_fpath: str | Path, + output_fpath: str | Path, + input_dims: dict[str, dict[int, str]], + output_dims: dict[str, dict[int, str]], + opset_version: int | None = None, +) -> None: + onnx_model = onnx.load(model_fpath) + + new_graph = make_graph( + nodes=onnx_model.graph.node, + name=onnx_model.graph.name, + inputs=onnx_model.graph.input, + outputs=onnx_model.graph.output, + initializer=onnx_model.graph.initializer, + value_info=None, + ) + + if not any(opset.domain == "" for opset in onnx_model.opset_import): + if opset_version is None: + opset_version = int(onnx.defs.onnx_opset_version()) + onnx_model.opset_import.append( + make_opsetid(domain="", version=opset_version) + ) + + new_model = make_model(new_graph, opset_imports=onnx_model.opset_import) + + for x in new_model.graph.input: + for name, v in input_dims.items(): + if x.name == name: + for k, d in v.items(): + x.type.tensor_type.shape.dim[k].dim_param = d + + for x in new_model.graph.output: + for name, v in output_dims.items(): + if x.name == name: + for k, d in v.items(): + x.type.tensor_type.shape.dim[k].dim_param = d + + for x in new_model.graph.node: + if x.op_type == "Reshape": + raise ValueError("Reshape cannot be trasformed to dynamic axes") + + simplify = getattr(onnxslim, "simplify", None) + if callable(simplify): + simplified = simplify(new_model) + if isinstance(simplified, tuple): + new_model = cast(onnx.ModelProto, simplified[0]) + else: + new_model = cast(onnx.ModelProto, simplified) + onnx.save(new_model, output_fpath) diff --git a/capybara/openvinoengine/__init__.py b/capybara/openvinoengine/__init__.py new file mode 100644 index 0000000..d2ca439 --- /dev/null +++ b/capybara/openvinoengine/__init__.py @@ -0,0 +1,3 @@ +from .engine import OpenVINOConfig, OpenVINODevice, OpenVINOEngine + +__all__ = ["OpenVINOConfig", "OpenVINODevice", "OpenVINOEngine"] diff --git a/capybara/openvinoengine/engine.py b/capybara/openvinoengine/engine.py new file mode 100644 index 0000000..febd359 --- /dev/null +++ b/capybara/openvinoengine/engine.py @@ -0,0 +1,626 @@ +from __future__ import annotations + +import contextlib +import queue +import threading +import time +import warnings +from collections import deque +from collections.abc import Mapping +from concurrent.futures import Future +from dataclasses import dataclass +from enum import Enum +from pathlib import Path +from typing import Any + +import numpy as np + +__all__ = ["OpenVINOConfig", "OpenVINODevice", "OpenVINOEngine"] + +try: + _BFLOAT16 = np.dtype("bfloat16") +except TypeError: # pragma: no cover + _BFLOAT16 = np.float16 + + +class OpenVINODevice(str, Enum): + auto = "AUTO" + cpu = "CPU" + gpu = "GPU" + npu = "NPU" + hetero = "HETERO" + auto_batch = "AUTO_BATCH" + + @classmethod + def from_any(cls, value: Any) -> OpenVINODevice: + if isinstance(value, cls): + return value + if isinstance(value, str): + normalized = value.upper() + for member in cls: + if member.value == normalized: + return member + raise ValueError( + f"Unsupported OpenVINO device '{value}'. " + f"Available: {[member.value for member in cls]}" + ) + + +@dataclass(slots=True) +class OpenVINOConfig: + compile_properties: dict[str, Any] | None = None + core_properties: dict[str, Any] | None = None + cache_dir: str | Path | None = None + num_streams: int | None = None + num_threads: int | None = None + # None => use engine defaults + # 0 => OpenVINO auto (async queue only; sync pool falls back to 1) + # >=1 => fixed number of requests + num_requests: int | None = None + # When disabled, outputs may share OpenVINO-owned buffers and can be + # overwritten by subsequent inference calls. + copy_outputs: bool = True + + +@dataclass(slots=True) +class _InputSpec: + name: str + dtype: np.dtype[Any] | None + shape: tuple[int | None, ...] + + +def _lazy_import_openvino(): + try: + import openvino.runtime as ov_runtime # type: ignore + except Exception as exc: # pragma: no cover - missing optional dep + raise ImportError( + "OpenVINO is required. Install it via 'pip install openvino-dev'." + ) from exc + return ov_runtime + + +def _normalize_device(device: str | OpenVINODevice) -> str: + if isinstance(device, OpenVINODevice): + return device.value + return str(device).upper() + + +def _validate_benchmark_params(*, repeat: Any, warmup: Any) -> tuple[int, int]: + repeat_i = int(repeat) + warmup_i = int(warmup) + if repeat_i < 1: + raise ValueError("repeat must be >= 1.") + if warmup_i < 0: + raise ValueError("warmup must be >= 0.") + return repeat_i, warmup_i + + +class OpenVINOEngine: + def __init__( + self, + model_path: str | Path, + *, + device: str | OpenVINODevice = OpenVINODevice.auto, + config: OpenVINOConfig | None = None, + core: Any | None = None, + input_shapes: dict[str, Any] | None = None, + ) -> None: + ov = _lazy_import_openvino() + + self.model_path = str(model_path) + self.device = _normalize_device(device) + self._cfg = config or OpenVINOConfig() + self._core = core or ov.Core() + self._input_shapes = input_shapes + self._ov = ov + + if self._cfg.core_properties: + self._core.set_property(self._cfg.core_properties) + + self._type_map = self._build_type_map(ov) + self._compiled_model = self._compile_model() + self._input_specs = self._inspect_inputs() + self._output_ports = list(self._compiled_model.outputs) + self._output_names = [ + port.get_any_name() for port in self._output_ports + ] + self._copy_outputs = bool(self._cfg.copy_outputs) + self._request_pool = self._create_request_pool() + + def __call__(self, **inputs: np.ndarray) -> dict[str, np.ndarray]: + if len(inputs) == 1 and isinstance( + next(iter(inputs.values())), Mapping + ): + feed_dict = dict(next(iter(inputs.values()))) + else: + feed_dict = inputs + return self.run(feed_dict) + + def run(self, feed: Mapping[str, np.ndarray]) -> dict[str, np.ndarray]: + prepared = self._prepare_feed(feed) + infer_request = self._request_pool.get() + try: + infer_request.infer(prepared) + except Exception: + try: + replacement = self._compiled_model.create_infer_request() + except Exception: # pragma: no cover - unexpected runtime failure + replacement = infer_request + self._request_pool.put(replacement) + raise + else: + try: + return self._collect_outputs( + infer_request, copy_outputs=self._copy_outputs + ) + finally: + self._request_pool.put(infer_request) + + def create_async_queue( + self, + *, + num_requests: int | None = None, + copy_outputs: bool | None = None, + ) -> OpenVINOAsyncQueue: + """Create an async inference queue backed by AsyncInferQueue. + + This enables pipelining: while the device is running inference for one + request, your Python code can prepare the next input(s). + + `num_requests` semantics: + - None: use `OpenVINOConfig.num_requests` (or a default) + - 0: let OpenVINO decide the number of requests (AUTO) + - >=1: fixed number of requests + """ + return OpenVINOAsyncQueue( + self, + num_requests=num_requests, + copy_outputs=copy_outputs, + ) + + def summary(self) -> dict[str, Any]: + inputs = [ + { + "name": spec.name, + "dtype": str(spec.dtype), + "shape": list(spec.shape), + } + for spec in self._input_specs + ] + outputs = [ + { + "name": port.get_any_name(), + "dtype": str(port.get_element_type()), + "shape": list( + self._partial_shape_to_tuple(port.get_partial_shape()) + ), + } + for port in self._output_ports + ] + return { + "model": self.model_path, + "device": self.device, + "inputs": inputs, + "outputs": outputs, + } + + def benchmark( + self, + feed: Mapping[str, np.ndarray], + *, + repeat: int = 100, + warmup: int = 10, + ) -> dict[str, Any]: + repeat, warmup = _validate_benchmark_params( + repeat=repeat, warmup=warmup + ) + prepared = self._prepare_feed(feed) + infer_request = self._compiled_model.create_infer_request() + + for _ in range(warmup): + infer_request.infer(prepared) + + latencies: list[float] = [] + t0 = time.perf_counter() + for _ in range(repeat): + start = time.perf_counter() + infer_request.infer(prepared) + latencies.append((time.perf_counter() - start) * 1e3) + total = time.perf_counter() - t0 + arr = np.asarray(latencies, dtype=np.float64) + return { + "repeat": repeat, + "warmup": warmup, + "throughput_fps": repeat / total if total else None, + "latency_ms": { + "mean": float(arr.mean()) if arr.size else None, + "median": float(np.median(arr)) if arr.size else None, + "p90": float(np.percentile(arr, 90)) if arr.size else None, + "p95": float(np.percentile(arr, 95)) if arr.size else None, + "min": float(arr.min()) if arr.size else None, + "max": float(arr.max()) if arr.size else None, + }, + } + + def benchmark_async( + self, + feed: Mapping[str, np.ndarray], + *, + repeat: int = 100, + warmup: int = 10, + num_requests: int | None = None, + ) -> dict[str, Any]: + """Benchmark throughput using AsyncInferQueue.""" + repeat, warmup = _validate_benchmark_params( + repeat=repeat, warmup=warmup + ) + prepared = self._prepare_feed(feed) + default_jobs = self._resolve_async_jobs( + self._cfg.num_requests, default=2 + ) + jobs = self._resolve_async_jobs(num_requests, default=default_jobs) + + # Warmup with synchronous inference to keep behaviour deterministic. + warm_req = self._compiled_model.create_infer_request() + for _ in range(warmup): + warm_req.infer(prepared) + + latencies: list[float] = [] + with self.create_async_queue(num_requests=jobs) as async_queue: + in_flight_limit = jobs if jobs > 0 else 2 + in_flight_limit = max(1, in_flight_limit) + in_flight: deque[tuple[Future[dict[str, np.ndarray]], float]] = ( + deque() + ) + + t0 = time.perf_counter() + for _ in range(repeat): + if len(in_flight) >= in_flight_limit: + fut, start = in_flight.popleft() + fut.result() + latencies.append((time.perf_counter() - start) * 1e3) + + start = time.perf_counter() + fut = async_queue.submit(prepared) + in_flight.append((fut, start)) + + while in_flight: + fut, start = in_flight.popleft() + fut.result() + latencies.append((time.perf_counter() - start) * 1e3) + total = time.perf_counter() - t0 + + arr = np.asarray(latencies, dtype=np.float64) + return { + "repeat": repeat, + "warmup": warmup, + "num_requests": jobs, + "throughput_fps": repeat / total if total else None, + "latency_ms": { + "mean": float(arr.mean()) if arr.size else None, + "median": float(np.median(arr)) if arr.size else None, + "p90": float(np.percentile(arr, 90)) if arr.size else None, + "p95": float(np.percentile(arr, 95)) if arr.size else None, + "min": float(arr.min()) if arr.size else None, + "max": float(arr.max()) if arr.size else None, + }, + } + + def _compile_model(self): + model = self._core.read_model(self.model_path) + self._maybe_reshape_model(model) + properties = dict(self._cfg.compile_properties or {}) + if self._cfg.cache_dir is not None: + properties["CACHE_DIR"] = str(self._cfg.cache_dir) + if self._cfg.num_streams is not None: + properties["NUM_STREAMS"] = str(self._cfg.num_streams) + if self._cfg.num_threads is not None: + properties["INFERENCE_NUM_THREADS"] = str(self._cfg.num_threads) + return self._core.compile_model(model, self.device, properties) + + def _create_request_pool(self) -> queue.Queue[Any]: + pool_size = self._resolve_pool_size(self._cfg.num_requests) + pool: queue.Queue[Any] = queue.Queue(maxsize=pool_size) + for _ in range(pool_size): + pool.put(self._compiled_model.create_infer_request()) + return pool + + def _resolve_pool_size(self, value: Any | None) -> int: + if value is None: + return 1 + requests = int(value) + if requests < 0: + raise ValueError("num_requests must be >= 0.") + if requests == 0: + return 1 + return requests + + def _resolve_async_jobs(self, value: Any | None, *, default: int) -> int: + if value is None: + return int(default) + requests = int(value) + if requests < 0: + raise ValueError("num_requests must be >= 0.") + return requests + + def _collect_outputs( + self, + infer_request: Any, + *, + copy_outputs: bool, + ) -> dict[str, np.ndarray]: + """Return output tensors as numpy arrays. + + When `copy_outputs=False`, the returned arrays may share OpenVINO-owned + buffers and can be overwritten by subsequent inference calls. + """ + outputs: dict[str, np.ndarray] = {} + for port, name in zip( + self._output_ports, self._output_names, strict=True + ): + tensor = infer_request.get_tensor(port) + outputs[name] = np.array(tensor.data, copy=copy_outputs) + return outputs + + def _maybe_reshape_model(self, model: Any) -> None: + if not self._input_shapes: + return + if not hasattr(model, "reshape"): + raise RuntimeError( + "OpenVINO model does not support reshape(), " + "but input_shapes were provided." + ) + + reshape_map: dict[str, Any] = {} + for name, shape in self._input_shapes.items(): + reshape_map[name] = self._normalize_shape_dims(shape) + + model.reshape(reshape_map) + + def _normalize_shape_dims(self, shape: Any) -> Any: + if isinstance(shape, (list, tuple)): + dims: list[int] = [] + for dim in shape: + if dim is None: + raise ValueError( + "input_shapes must use concrete dimensions; got None." + ) + dims.append(int(dim)) + return tuple(dims) + return shape + + def _inspect_inputs(self) -> list[_InputSpec]: + specs: list[_InputSpec] = [] + for port in self._compiled_model.inputs: + specs.append( + _InputSpec( + name=port.get_any_name(), + dtype=self._type_map.get(port.get_element_type()), + shape=self._partial_shape_to_tuple( + port.get_partial_shape() + ), + ) + ) + return specs + + def _prepare_feed(self, feed: Mapping[str, Any]) -> dict[str, np.ndarray]: + prepared: dict[str, np.ndarray] = {} + for spec in self._input_specs: + if spec.name not in feed: + raise KeyError(f"Missing required input '{spec.name}'.") + array = np.asarray(feed[spec.name]) + if spec.dtype is not None and array.dtype != spec.dtype: + array = array.astype(spec.dtype, copy=False) + prepared[spec.name] = array + return prepared + + def _partial_shape_to_tuple(self, shape) -> tuple[int | None, ...]: + dims: list[int | None] = [] + for dim in shape: + if hasattr(dim, "is_static") and dim.is_static: + dims.append(int(dim.get_length())) + elif isinstance(dim, int): + dims.append(int(dim)) + else: + dims.append(None) + return tuple(dims) + + def _build_type_map(self, ov_module) -> dict[Any, np.dtype[Any]]: + type_cls = getattr(ov_module, "Type", None) + if type_cls is None: + return {} + mapping: dict[Any, np.dtype[Any]] = {} + pairs = { + "f32": np.float32, + "f16": np.float16, + "bf16": _BFLOAT16, + "i64": np.int64, + "i32": np.int32, + "i16": np.int16, + "i8": np.int8, + "u64": np.uint64, + "u32": np.uint32, + "u16": np.uint16, + "u8": np.uint8, + "boolean": np.bool_, + } + for attr, dtype in pairs.items(): + if hasattr(type_cls, attr): + mapping[getattr(type_cls, attr)] = dtype + return mapping + + def __repr__(self) -> str: # pragma: no cover - helper for debugging + return ( + f"OpenVINOEngine(model='{self.model_path}', device='{self.device}')" + ) + + +class OpenVINOAsyncQueue: + """A small wrapper around OpenVINO AsyncInferQueue. + + Use this to overlap preprocessing/postprocessing with device execution. + """ + + def __init__( + self, + engine: OpenVINOEngine, + *, + num_requests: int | None = None, + copy_outputs: bool | None = None, + ) -> None: + self._engine = engine + self._ov = engine._ov + self._copy_outputs = ( + engine._copy_outputs if copy_outputs is None else bool(copy_outputs) + ) + self._completion_queue_warned = False + default_jobs = engine._resolve_async_jobs( + engine._cfg.num_requests, default=2 + ) + jobs = engine._resolve_async_jobs(num_requests, default=default_jobs) + + infer_queue_cls = getattr(self._ov, "AsyncInferQueue", None) + if infer_queue_cls is None: + raise RuntimeError( + "Your OpenVINO installation does not expose AsyncInferQueue. " + "Please upgrade 'openvino' / 'openvino-dev' to a newer version." + ) + + self._queue = infer_queue_cls(engine._compiled_model, jobs) + self._queue.set_callback(self._callback) + self._closed = False + self._request_id_lock = threading.Lock() + self._request_id_seq = 0 + + def _allocate_request_id(self) -> int: + with self._request_id_lock: + request_id = self._request_id_seq + self._request_id_seq += 1 + return request_id + + def submit( + self, + feed: Mapping[str, np.ndarray], + *, + request_id: Any | None = None, + completion_queue: Any | None = None, + ) -> Future[dict[str, np.ndarray]]: + """Start an async request and return a Future of raw model outputs. + + If ``completion_queue`` is provided, this enqueues ``(request_id, outputs)`` + in the OpenVINO callback (non-blocking; events are dropped if the queue is + full or incompatible). + + Note: ``outputs`` are the raw model outputs (before any decoding). When + ``copy_outputs=False``, the returned arrays may share OpenVINO-owned + buffers and can be overwritten by subsequent inference calls; consume them + immediately. + + The returned Future includes a ``request_id`` attribute that matches the + ID used for submission. If ``request_id`` is omitted, an integer is + auto-generated. You can also pass any correlation ID (e.g. a string). + """ + if self._closed: + raise RuntimeError("Async queue is closed.") + resolved_request_id = request_id + if resolved_request_id is None: + resolved_request_id = self._allocate_request_id() + future: Future[dict[str, np.ndarray]] = Future() + with contextlib.suppress(Exception): # pragma: no cover - defensive + future.request_id = resolved_request_id # type: ignore[attr-defined] + if completion_queue is not None: + with contextlib.suppress(Exception): # pragma: no cover - defensive + future.completion_queue = completion_queue # type: ignore[attr-defined] + prepared = self._engine._prepare_feed(feed) + self._queue.start_async( + prepared, + _AsyncUserdata( + future=future, + request_id=resolved_request_id, + completion_queue=completion_queue, + ), + ) + return future + + def wait_all(self) -> None: + self._queue.wait_all() + + def close(self) -> None: + if self._closed: + return + self.wait_all() + self._closed = True + + def __enter__(self) -> OpenVINOAsyncQueue: + return self + + def __exit__(self, exc_type, exc, tb) -> None: + self.close() + + def _callback(self, infer_request: Any, userdata: Any) -> None: + completion_queue = None + request_id = None + if isinstance(userdata, _AsyncUserdata): + future = userdata.future + completion_queue = userdata.completion_queue + request_id = userdata.request_id + elif isinstance( + userdata, Future + ): # pragma: no cover - legacy defensive + future = userdata + completion_queue = getattr(future, "completion_queue", None) + request_id = getattr(future, "request_id", None) + else: # pragma: no cover - defensive + return + try: + outputs = self._engine._collect_outputs( + infer_request, copy_outputs=self._copy_outputs + ) + except Exception as exc: # pragma: no cover - surfaced via future + future.set_exception(exc) + else: + future.set_result(outputs) + if completion_queue is not None and request_id is not None: + item = (request_id, dict(outputs)) + try: + put_nowait = getattr(completion_queue, "put_nowait", None) + if callable(put_nowait): + put_nowait(item) + else: + completion_queue.put(item, block=False) + except queue.Full: + if not self._completion_queue_warned: + warnings.warn( + "completion_queue is full; dropping async completion " + "events to avoid blocking the OpenVINO callback thread.", + RuntimeWarning, + stacklevel=2, + ) + self._completion_queue_warned = True + except (TypeError, AttributeError): + if not self._completion_queue_warned: + warnings.warn( + "completion_queue does not support non-blocking put; " + "dropping async completion events to avoid blocking " + "the OpenVINO callback thread.", + RuntimeWarning, + stacklevel=2, + ) + self._completion_queue_warned = True + except Exception: # pragma: no cover - defensive + if not self._completion_queue_warned: + warnings.warn( + "completion_queue enqueue failed; dropping async " + "completion events to avoid blocking the OpenVINO " + "callback thread.", + RuntimeWarning, + stacklevel=2, + ) + self._completion_queue_warned = True + + +@dataclass(slots=True) +class _AsyncUserdata: + future: Future[dict[str, np.ndarray]] + request_id: Any | None + completion_queue: Any | None diff --git a/capybara/runtime.py b/capybara/runtime.py new file mode 100644 index 0000000..3ed4ab0 --- /dev/null +++ b/capybara/runtime.py @@ -0,0 +1,340 @@ +from __future__ import annotations + +from dataclasses import dataclass +from importlib import import_module +from typing import Any, ClassVar, cast + +__all__ = ["Backend", "ProviderSpec", "Runtime"] + + +def _normalize_key(value: str | Runtime) -> str: + if isinstance(value, Runtime): + return value.name + return str(value).strip().lower() + + +@dataclass(frozen=True) +class ProviderSpec: + name: str + include_device: bool = False + + +@dataclass(frozen=True) +class Backend: + name: str + runtime_key: str + providers: tuple[ProviderSpec, ...] = () + device: str | None = None + description: str | None = None + + _REGISTRY: ClassVar[dict[str, dict[str, Backend]]] = {} + cpu: ClassVar[Backend] + cuda: ClassVar[Backend] + tensorrt: ClassVar[Backend] + tensorrt_rtx: ClassVar[Backend] + ov_cpu: ClassVar[Backend] + ov_gpu: ClassVar[Backend] + ov_npu: ClassVar[Backend] + + def __post_init__(self) -> None: + normalized_runtime = self.runtime_key.strip().lower() + normalized_name = self.name.strip().lower() + object.__setattr__(self, "runtime_key", normalized_runtime) + object.__setattr__(self, "name", normalized_name) + namespace = self._REGISTRY.setdefault(normalized_runtime, {}) + if normalized_name in namespace: + raise ValueError( + f"Backend '{self.name}' already registered for runtime '{self.runtime_key}'." + ) + namespace[normalized_name] = self + + @property + def runtime(self) -> str: + return self.runtime_key + + @classmethod + def available(cls, runtime: str | Runtime) -> tuple[Backend, ...]: + runtime_key = _normalize_key(runtime) + namespace = cls._REGISTRY.get(runtime_key, {}) + return tuple(namespace.values()) + + @classmethod + def from_any( + cls, + value: Any, + *, + runtime: str | Runtime | None = None, + ) -> Backend: + runtime_key = cls._resolve_runtime_key(runtime) + namespace = cls._REGISTRY.get(runtime_key, {}) + if isinstance(value, Backend): + if value.runtime_key != runtime_key: + raise ValueError( + f"Backend '{value.name}' is not registered for runtime '{runtime_key}'." + ) + return value + normalized = str(value).strip().lower() + try: + return namespace[normalized] + except KeyError: # pragma: no cover - defensive guard + options = ", ".join(namespace.keys()) or "" + raise ValueError( + f"Unsupported backend '{value}' for runtime '{runtime_key}'. " + f"Pick from [{options}]" + ) from None + + @classmethod + def _resolve_runtime_key(cls, runtime: str | Runtime | None) -> str: + if runtime is None: + if len(cls._REGISTRY) == 1: + return next(iter(cls._REGISTRY)) + raise ValueError( + "runtime must be specified when resolving backends." + ) + if isinstance(runtime, Runtime): + return runtime.name + return str(runtime).strip().lower() + + +@dataclass(frozen=True) +class Runtime: + name: str + backend_names: tuple[str, ...] + default_backend_name: str + description: str | None = None + + _REGISTRY: ClassVar[dict[str, Runtime]] = {} + onnx: ClassVar[Runtime] + openvino: ClassVar[Runtime] + pt: ClassVar[Runtime] + + def __post_init__(self) -> None: + normalized = self.name.strip().lower() + object.__setattr__(self, "name", normalized) + normalized_backends = tuple( + name.strip().lower() for name in self.backend_names + ) + object.__setattr__(self, "backend_names", normalized_backends) + default_backend = self.default_backend_name.strip().lower() + object.__setattr__(self, "default_backend_name", default_backend) + + if normalized in self._REGISTRY: + raise ValueError(f"Runtime '{self.name}' is already registered.") + self._REGISTRY[normalized] = self + + available = {backend.name for backend in Backend.available(normalized)} + requested = set(normalized_backends) + missing = requested - available + if missing: + raise ValueError( + f"Runtime '{self.name}' references unknown backend(s): {sorted(missing)}." + ) + if default_backend not in requested: + raise ValueError( + f"Default backend '{self.default_backend_name}' is not tracked " + f"for runtime '{self.name}'." + ) + + @classmethod + def from_any(cls, value: Any) -> Runtime: + if isinstance(value, Runtime): + return value + normalized = str(value).strip().lower() + try: + return cls._REGISTRY[normalized] + except KeyError: # pragma: no cover - defensive + options = ", ".join(cls._REGISTRY) or "" + raise ValueError( + f"Unsupported runtime '{value}'. Pick from [{options}]" + ) from None + + def available_backends(self) -> tuple[Backend, ...]: + namespace = { + backend.name: backend for backend in Backend.available(self) + } + order = [] + for name in self.backend_names: + order.append(namespace[name]) + return tuple(order) + + def normalize_backend(self, backend: Backend | str | None) -> Backend: + if backend is None: + return Backend.from_any(self.default_backend_name, runtime=self) + resolved = Backend.from_any(backend, runtime=self) + return resolved + + def auto_backend_name(self) -> str: + if self.name == "onnx": + providers = _get_available_onnx_providers() + for backend_name, provider_name in _ONNX_AUTO_PRIORITY: + if provider_name in providers: + return backend_name + return self.default_backend_name + if self.name == "pt": + has_torch, has_cuda = _get_torch_capabilities() + if has_torch and has_cuda: + return "cuda" + return self.default_backend_name + if self.name == "openvino": + devices = _get_openvino_devices() + for backend_name, device_prefix in _OPENVINO_AUTO_PRIORITY: + if any(str(dev).startswith(device_prefix) for dev in devices): + return backend_name + return self.default_backend_name + return self.default_backend_name + + def __str__(self) -> str: # pragma: no cover + return self.name + + +_ONNX_BACKENDS: tuple[Backend, ...] = ( + Backend( + name="cpu", + runtime_key="onnx", + providers=(ProviderSpec("CPUExecutionProvider"),), + description="Pure CPU execution provider.", + ), + Backend( + name="cuda", + runtime_key="onnx", + providers=( + ProviderSpec("CUDAExecutionProvider", include_device=True), + ProviderSpec("CPUExecutionProvider"), + ), + description="CUDA with CPU fallback.", + ), + Backend( + name="tensorrt", + runtime_key="onnx", + providers=( + ProviderSpec("TensorrtExecutionProvider", include_device=True), + ProviderSpec("CUDAExecutionProvider", include_device=True), + ProviderSpec("CPUExecutionProvider"), + ), + description="TensorRT backed by CUDA and CPU providers.", + ), + Backend( + name="tensorrt_rtx", + runtime_key="onnx", + providers=( + ProviderSpec("NvTensorRTRTXExecutionProvider", include_device=True), + ProviderSpec("CUDAExecutionProvider", include_device=True), + ProviderSpec("CPUExecutionProvider"), + ), + description="TensorRT RTX provider chain.", + ), +) + +_OPENVINO_BACKENDS: tuple[Backend, ...] = ( + Backend( + name="cpu", + runtime_key="openvino", + device="CPU", + description="Intel CPU device for OpenVINO.", + ), + Backend( + name="gpu", + runtime_key="openvino", + device="GPU", + description="Intel GPU device for OpenVINO.", + ), + Backend( + name="npu", + runtime_key="openvino", + device="NPU", + description="Intel NPU device for OpenVINO.", + ), +) + +_PT_BACKENDS: tuple[Backend, ...] = ( + Backend( + name="cpu", + runtime_key="pt", + device="cpu", + description="PyTorch CPU execution.", + ), + Backend( + name="cuda", + runtime_key="pt", + device="cuda", + description="PyTorch CUDA execution.", + ), +) + +Runtime.onnx = Runtime( + name="onnx", + backend_names=tuple(backend.name for backend in _ONNX_BACKENDS), + default_backend_name="cpu", + description="ONNXRuntime execution backend.", +) + +Runtime.openvino = Runtime( + name="openvino", + backend_names=tuple(backend.name for backend in _OPENVINO_BACKENDS), + default_backend_name="cpu", + description="Intel OpenVINO runtime.", +) + +Runtime.pt = Runtime( + name="pt", + backend_names=tuple(backend.name for backend in _PT_BACKENDS), + default_backend_name="cpu", + description="TorchScript/PT runtime.", +) + +# Convenience handles for legacy-style access. +Backend.cpu = Backend.from_any("cpu", runtime="onnx") +Backend.cuda = Backend.from_any("cuda", runtime="onnx") +Backend.tensorrt = Backend.from_any("tensorrt", runtime="onnx") +Backend.tensorrt_rtx = Backend.from_any("tensorrt_rtx", runtime="onnx") +Backend.ov_cpu = Backend.from_any("cpu", runtime="openvino") +Backend.ov_gpu = Backend.from_any("gpu", runtime="openvino") +Backend.ov_npu = Backend.from_any("npu", runtime="openvino") + + +_ONNX_AUTO_PRIORITY: tuple[tuple[str, str], ...] = ( + # Prefer pure CUDA when available to avoid TensorRT dependency issues. + ("cuda", "CUDAExecutionProvider"), + ("tensorrt_rtx", "NvTensorRTRTXExecutionProvider"), + ("tensorrt", "TensorrtExecutionProvider"), +) + +_OPENVINO_AUTO_PRIORITY: tuple[tuple[str, str], ...] = ( + # Prefer GPU when available, then NPU; fall back to CPU/default otherwise. + ("gpu", "GPU"), + ("npu", "NPU"), +) + + +def _get_available_onnx_providers() -> set[str]: + try: # pragma: no cover - optional dependency + import onnxruntime as ort # type: ignore + except Exception: + return set() + try: + return set(ort.get_available_providers()) + except Exception: # pragma: no cover - runtime query failure + return set() + + +def _get_torch_capabilities() -> tuple[bool, bool]: + try: # pragma: no cover - optional dependency + import torch # type: ignore + except Exception: + return False, False + try: + return True, bool(torch.cuda.is_available()) + except Exception: # pragma: no cover - runtime query failure + return True, False + + +def _get_openvino_devices() -> set[str]: + try: # pragma: no cover - optional dependency + ov = cast(Any, import_module("openvino.runtime")) + except Exception: + return set() + try: + core = ov.Core() + return {str(dev) for dev in getattr(core, "available_devices", [])} + except Exception: # pragma: no cover - runtime query failure + return set() diff --git a/capybara/structures/__init__.py b/capybara/structures/__init__.py index 513e8a4..499894f 100644 --- a/capybara/structures/__init__.py +++ b/capybara/structures/__init__.py @@ -1,4 +1,35 @@ -from .boxes import * -from .functionals import * -from .keypoints import * -from .polygons import * +from __future__ import annotations + +from .boxes import Box, Boxes, BoxMode +from .functionals import ( + calc_angle, + is_inside_box, + jaccard_index, + pairwise_intersection, + pairwise_ioa, + pairwise_iou, + poly_angle, + polygon_iou, +) +from .keypoints import Keypoints, KeypointsList +from .polygons import JOIN_STYLE, Polygon, Polygons, order_points_clockwise + +__all__ = [ + "JOIN_STYLE", + "Box", + "BoxMode", + "Boxes", + "Keypoints", + "KeypointsList", + "Polygon", + "Polygons", + "calc_angle", + "is_inside_box", + "jaccard_index", + "order_points_clockwise", + "pairwise_intersection", + "pairwise_ioa", + "pairwise_iou", + "poly_angle", + "polygon_iou", +] diff --git a/capybara/structures/boxes.py b/capybara/structures/boxes.py index 2735ade..30bb57e 100644 --- a/capybara/structures/boxes.py +++ b/capybara/structures/boxes.py @@ -1,25 +1,17 @@ +from collections.abc import Sequence from enum import Enum, unique -from typing import Any, List, Tuple, Union +from typing import Any, Union, overload from warnings import warn import numpy as np from ..typing import _Number -__all__ = ['BoxMode', 'Box', 'Boxes'] +__all__ = ["Box", "BoxMode", "Boxes"] _BoxMode = Union["BoxMode", int, str] -_Box = Union[ - np.ndarray, - Tuple[_Number, _Number, _Number, _Number], - "Box" -] - -_Boxes = Union[ - np.ndarray, - List[_Box], - "Boxes" -] +_Box = Union[np.ndarray, Sequence[_Number], "Box"] +_Boxes = Union[np.ndarray, Sequence[_Box], "Boxes"] @unique @@ -49,7 +41,9 @@ class BoxMode(Enum): """ @staticmethod - def convert(box: np.ndarray, from_mode: _BoxMode, to_mode: _BoxMode) -> np.ndarray: + def convert( + box: np.ndarray, from_mode: _BoxMode, to_mode: _BoxMode + ) -> np.ndarray: """ Convert function for box format converting @@ -82,8 +76,9 @@ def convert(box: np.ndarray, from_mode: _BoxMode, to_mode: _BoxMode) -> np.ndarr elif from_mode == BoxMode.CXCYWH and to_mode == BoxMode.XYWH: arr[..., :2] -= arr[..., 2:] / 2 else: - raise NotImplementedError( - f"Conversion from BoxMode {str(from_mode)} to {str(to_mode)} is not supported yet") + raise NotImplementedError( # pragma: no cover + f"Conversion from BoxMode {from_mode!s} to {to_mode!s} is not supported yet" + ) return arr @staticmethod @@ -95,16 +90,15 @@ def align_code(box_mode: _BoxMode): elif isinstance(box_mode, BoxMode): return box_mode else: - raise TypeError(f'Given `box_mode` is not int, str, or BoxMode.') + raise TypeError("Given `box_mode` is not int, str, or BoxMode.") class Box: - def __init__( self, array: _Box, box_mode: _BoxMode = BoxMode.XYXY, - is_normalized: bool = False + is_normalized: bool = False, ): """ Args: @@ -121,12 +115,20 @@ def __init__( self._xywh = BoxMode.convert(self._array, self.box_mode, BoxMode.XYWH) def __repr__(self) -> str: - return f"{self.__class__.__name__}({str(self._array)}), {str(BoxMode(self.box_mode))}" + return f"{self.__class__.__name__}({self._array!s}), {BoxMode(self.box_mode)!s}" def __len__(self): return self._array.shape[0] - def __getitem__(self, item) -> float: + @overload + def __getitem__(self, item: int) -> float: ... + + @overload + def __getitem__(self, item: slice) -> np.ndarray: ... + + def __getitem__(self, item: int | slice) -> float | np.ndarray: + if isinstance(item, int): + return float(self._array[item]) return self._array[item] def __eq__(self, value: object) -> bool: @@ -135,20 +137,20 @@ def __eq__(self, value: object) -> bool: return np.allclose(self._array, value._array) def _check_valid_array(self, array: Any) -> np.ndarray: - cond1 = isinstance(array, tuple) and len(array) == 4 - cond2 = isinstance(array, list) and len(array) == 4 - cond3 = isinstance( - array, np.ndarray) and array.ndim == 1 and len(array) == 4 - cond4 = isinstance(array, self.__class__) - if not (cond1 or cond2 or cond3 or cond4): - raise TypeError(f'Input array must be {_Box}, but got {type(array)}.') - if cond3: - array = array.astype('float32') - elif cond4: + if isinstance(array, Box): array = array.numpy() - else: - array = np.array(array, dtype='float32') - return array + + if isinstance(array, np.ndarray): + if array.ndim != 1 or len(array) != 4: + raise TypeError( + f"Input array must be {_Box}, but got shape {array.shape}." + ) + return array.astype("float32") + + if isinstance(array, (tuple, list)) and len(array) == 4: + return np.array(array, dtype="float32") + + raise TypeError(f"Input array must be {_Box}, but got {type(array)}.") def convert(self, to_mode: _BoxMode) -> "Box": """ @@ -161,21 +163,33 @@ def convert(self, to_mode: _BoxMode) -> "Box": Converted Box object. """ transed = BoxMode.convert(self._array, self.box_mode, to_mode) - return self.__class__(transed, to_mode) + return self.__class__( + transed, + to_mode, + is_normalized=self.is_normalized, + ) def copy(self) -> Any: - """ Create a copy of the Box object. """ - return self.__class__(self._array, self.box_mode) + """Create a copy of the Box object.""" + return self.__class__( + self._array, + self.box_mode, + is_normalized=self.is_normalized, + ) def numpy(self) -> np.ndarray: - """ Convert the Box object to a numpy array. """ + """Convert the Box object to a numpy array.""" return self._array.copy() def square(self) -> "Box": - """ Convert the box to a square box. """ - arr = self.convert('CXCYWH').numpy() + """Convert the box to a square box.""" + arr = self.convert("CXCYWH").numpy() arr[2:] = arr[2:].min() - return self.__class__(arr, 'CXCYWH').convert(self.box_mode) + return self.__class__( + arr, + "CXCYWH", + is_normalized=self.is_normalized, + ).convert(self.box_mode) def normalize(self, w: int, h: int) -> "Box": """ @@ -189,7 +203,7 @@ def normalize(self, w: int, h: int) -> "Box": Normalized Box object. """ if self.is_normalized: - warn(f'Normalized box is forced to do normalization.') + warn("Normalized box is forced to do normalization.", stacklevel=2) arr = self._array.copy() arr[::2] = arr[::2] / w arr[1::2] = arr[1::2] / h @@ -207,7 +221,10 @@ def denormalize(self, w: int, h: int) -> "Box": Denormalized Box object. """ if not self.is_normalized: - warn(f'Non-normalized box is forced to do denormalization.') + warn( + "Non-normalized box is forced to do denormalization.", + stacklevel=2, + ) arr = self._array.copy() arr[::2] = arr[::2] * w arr[1::2] = arr[1::2] * h @@ -233,7 +250,12 @@ def clip(self, xmin: int, ymin: int, xmax: int, ymax: int) -> "Box": arr = BoxMode.convert(self._array, self.box_mode, BoxMode.XYXY) arr[0::2] = np.clip(arr[0::2], max(xmin, 0), xmax) arr[1::2] = np.clip(arr[1::2], max(ymin, 0), ymax) - return self.__class__(arr, self.box_mode) + clipped = self.__class__( + arr, + BoxMode.XYXY, + is_normalized=self.is_normalized, + ) + return clipped.convert(self.box_mode) def shift(self, shift_x: float, shift_y: float) -> "Box": """ @@ -248,9 +270,18 @@ def shift(self, shift_x: float, shift_y: float) -> "Box": """ arr = self._xywh.copy() arr[:2] += (shift_x, shift_y) - return self.__class__(arr, "XYWH").convert(self.box_mode) + return self.__class__( + arr, + "XYWH", + is_normalized=self.is_normalized, + ).convert(self.box_mode) - def scale(self, dsize: Tuple[int, int] = None, fx: float = None, fy: float = None) -> "Box": + def scale( + self, + dsize: tuple[int, int] | None = None, + fx: float | None = None, + fy: float | None = None, + ) -> "Box": """ Method to scale Box with a given scale. @@ -282,13 +313,13 @@ def scale(self, dsize: Tuple[int, int] = None, fx: float = None, fy: float = Non arr[3] += dy else: if fx is not None: - fx = arr[2] * (fx - 1) - arr[0] -= fx / 2 - arr[2] += fx + delta_x = arr[2] * (fx - 1) + arr[0] -= delta_x / 2 + arr[2] += delta_x if fy is not None: - fy = arr[3] * (fy - 1) - arr[1] -= fy / 2 - arr[3] += fy + delta_y = arr[3] * (fy - 1) + arr[1] -= delta_y / 2 + arr[3] += delta_y return self.__class__(arr, "XYWH").convert(self.box_mode) @@ -296,16 +327,17 @@ def to_list(self) -> list: return self._array.tolist() def tolist(self) -> list: - """ Alias of `to_list` (numpy style) """ + """Alias of `to_list` (numpy style)""" return self.to_list() def to_polygon(self): from .polygons import Polygon + arr = self._xywh.copy() if (arr[2:] <= 0).any(): raise ValueError( - 'Some element in Box has invaild value, which width or ' - 'height is smaller than zero or other unexpected reasons.' + "Some element in Box has invaild value, which width or " + "height is smaller than zero or other unexpected reasons." ) p1 = arr[:2] p2 = np.stack([arr[0::2].sum(), arr[1]]) @@ -313,59 +345,60 @@ def to_polygon(self): p4 = np.stack([arr[0], arr[1::2].sum()]) return Polygon(np.stack([p1, p2, p3, p4]), self.is_normalized) - @ property + @property def width(self) -> np.ndarray: - """ Get width of the box. """ + """Get width of the box.""" return self._xywh[2] - @ property + @property def height(self) -> np.ndarray: - """ Get height of the box. """ + """Get height of the box.""" return self._xywh[3] - @ property + @property def left_top(self) -> np.ndarray: - """ Get the left-top point of the box. """ + """Get the left-top point of the box.""" return self._xywh[0:2] - @ property + @property def right_bottom(self) -> np.ndarray: - """ Get the right-bottom point of the box. """ + """Get the right-bottom point of the box.""" return self._xywh[0:2] + self._xywh[2:4] - @ property + @property def left_bottom(self) -> np.ndarray: - """ Get the left_bottom point of the box. """ - return self._xywh[0:2] + [0, self._xywh[3]] + """Get the left_bottom point of the box.""" + xywh = np.asarray(self._xywh) + return xywh[0:2] + np.array([0, xywh[3]], dtype=xywh.dtype) - @ property + @property def right_top(self) -> np.ndarray: - """ Get the right_top point of the box. """ - return self._xywh[0:2] + [self._xywh[2], 0] + """Get the right_top point of the box.""" + xywh = np.asarray(self._xywh) + return xywh[0:2] + np.array([xywh[2], 0], dtype=xywh.dtype) - @ property + @property def area(self) -> np.ndarray: - """ Get the area of the boxes. """ + """Get the area of the boxes.""" return self._xywh[2] * self._xywh[3] - @ property + @property def aspect_ratio(self) -> np.ndarray: - """ Compute the aspect ratios (widths / heights) of the boxes. """ + """Compute the aspect ratios (widths / heights) of the boxes.""" return self._xywh[2] / self._xywh[3] - @ property + @property def center(self) -> np.ndarray: - """ Compute the center of the box. """ + """Compute the center of the box.""" return self._xywh[:2] + self._xywh[2:] / 2 class Boxes: - def __init__( self, array: _Boxes, box_mode: _BoxMode = BoxMode.XYXY, - is_normalized: bool = False + is_normalized: bool = False, ): self.box_mode = BoxMode.align_code(box_mode) self.is_normalized = is_normalized @@ -373,16 +406,33 @@ def __init__( self._xywh = BoxMode.convert(self._array, box_mode, BoxMode.XYWH) def __repr__(self) -> str: - return f"{self.__class__.__name__}({str(self._array)}), {str(BoxMode(self.box_mode))}" + return f"{self.__class__.__name__}({self._array!s}), {BoxMode(self.box_mode)!s}" def __len__(self): return self._array.shape[0] + @overload + def __getitem__(self, item: int) -> "Box": ... + + @overload + def __getitem__(self, item: list[int] | slice | np.ndarray) -> "Boxes": ... + def __getitem__(self, item) -> Union["Box", "Boxes"]: if isinstance(item, int): - return Box(self._array[item], self.box_mode, is_normalized=self.is_normalized) + return Box( + self._array[item], + self.box_mode, + is_normalized=self.is_normalized, + ) if isinstance(item, (list, slice, np.ndarray)): - return self.__class__(self._array[item], self.box_mode, is_normalized=self.is_normalized) + return self.__class__( + self._array[item], + self.box_mode, + is_normalized=self.is_normalized, + ) + raise TypeError( + "Boxes indices must be int, slice, list[int], or numpy array." + ) def __iter__(self) -> Any: for i in range(len(self)): @@ -395,22 +445,34 @@ def __eq__(self, value: object) -> bool: def _check_valid_array(self, array: Any) -> np.ndarray: cond1 = isinstance(array, list) - cond2 = isinstance( - array, np.ndarray) and array.ndim == 2 and array.shape[-1] == 4 - cond3 = isinstance( - array, np.ndarray) and array.ndim == 1 and len(array) == 0 + cond2 = ( + isinstance(array, np.ndarray) + and array.ndim == 2 + and array.shape[-1] == 4 + ) + cond3 = ( + isinstance(array, np.ndarray) + and array.ndim == 1 + and len(array) == 0 + ) cond4 = isinstance(array, self.__class__) if not (cond1 or cond2 or cond3 or cond4): - raise TypeError(f'Input array must be {_Boxes}.') + raise TypeError(f"Input array must be {_Boxes}.") if cond1: for i, x in enumerate(array): try: - array[i] = Box(x, box_mode=self.box_mode, is_normalized=self.is_normalized).numpy() - except TypeError: - raise TypeError(f'Input array[{i}] must be {_Box}.') + array[i] = Box( + x, + box_mode=self.box_mode, + is_normalized=self.is_normalized, + ).numpy() + except TypeError as exc: + raise TypeError( + f"Input array[{i}] must be {_Box}." + ) from exc if cond4: array = [box.convert(self.box_mode).numpy() for box in array] - array = np.array(array, dtype='float32').copy() + array = np.array(array, dtype="float32").copy() return array def convert(self, to_mode: _BoxMode) -> "Boxes": @@ -424,20 +486,34 @@ def convert(self, to_mode: _BoxMode) -> "Boxes": Converted Box object. """ transed = BoxMode.convert(self._array, self.box_mode, to_mode) - return self.__class__(transed, to_mode) + return self.__class__( + transed, + to_mode, + is_normalized=self.is_normalized, + ) def copy(self) -> Any: - """ Create a copy of the Box object. """ - return self.__class__(self._array, self.box_mode) + """Create a copy of the Box object.""" + return self.__class__( + self._array, + self.box_mode, + is_normalized=self.is_normalized, + ) def numpy(self) -> np.ndarray: - """ Convert the Box object to a numpy array. """ + """Convert the Box object to a numpy array.""" return self._array.copy() def square(self) -> "Boxes": - arr = self.convert('CXCYWH').numpy() - arr[..., 2:] = arr[..., 2:].max(1) - return self.__class__(arr, 'CXCYWH').convert(self.box_mode) + arr = self.convert("CXCYWH").numpy() + # Use per-box maximum side length, keeping each box centered. + side = arr[..., 2:].max(axis=1, keepdims=True) + arr[..., 2:] = side + return self.__class__( + arr, + "CXCYWH", + is_normalized=self.is_normalized, + ).convert(self.box_mode) def normalize(self, w: int, h: int) -> "Boxes": """ @@ -451,7 +527,7 @@ def normalize(self, w: int, h: int) -> "Boxes": Normalized Box object. """ if self.is_normalized: - warn(f'Normalized box is forced to do normalization.') + warn("Normalized box is forced to do normalization.", stacklevel=2) arr = self._array.copy() arr[:, ::2] = arr[:, ::2] / w arr[:, 1::2] = arr[:, 1::2] / h @@ -469,13 +545,16 @@ def denormalize(self, w: int, h: int) -> "Boxes": Denormalized Boxes object. """ if not self.is_normalized: - warn(f'Non-normalized box is forced to do denormalization.') + warn( + "Non-normalized box is forced to do denormalization.", + stacklevel=2, + ) arr = self._array.copy() arr[:, ::2] = arr[:, ::2] * w arr[:, 1::2] = arr[:, 1::2] * h return self.__class__(arr, self.box_mode, is_normalized=False) - def clip(self, xmin: int, ymin: int, xmax: int, ymax: int) -> "Box": + def clip(self, xmin: int, ymin: int, xmax: int, ymax: int) -> "Boxes": """ Method to clip the box by limiting x coordinates to the range [xmin, xmax] and y coordinates to the range [ymin, ymax]. @@ -495,7 +574,12 @@ def clip(self, xmin: int, ymin: int, xmax: int, ymax: int) -> "Box": arr = BoxMode.convert(self._array, self.box_mode, BoxMode.XYXY) arr[:, 0::2] = np.clip(arr[:, 0::2], max(xmin, 0), xmax) arr[:, 1::2] = np.clip(arr[:, 1::2], max(ymin, 0), ymax) - return self.__class__(arr, self.box_mode) + clipped = self.__class__( + arr, + BoxMode.XYXY, + is_normalized=self.is_normalized, + ) + return clipped.convert(self.box_mode) def shift(self, shift_x: float, shift_y: float) -> "Boxes": """ @@ -510,9 +594,18 @@ def shift(self, shift_x: float, shift_y: float) -> "Boxes": """ arr = self._xywh.copy() arr[:, :2] += (shift_x, shift_y) - return self.__class__(arr, "XYWH").convert(self.box_mode) + return self.__class__( + arr, + "XYWH", + is_normalized=self.is_normalized, + ).convert(self.box_mode) - def scale(self, dsize: Tuple[int, int] = None, fx: float = None, fy: float = None) -> "Boxes": + def scale( + self, + dsize: tuple[int, int] | None = None, + fx: float | None = None, + fy: float | None = None, + ) -> "Boxes": """ Method to scale Box with a given scale. @@ -544,38 +637,42 @@ def scale(self, dsize: Tuple[int, int] = None, fx: float = None, fy: float = Non arr[:, 3] += dy else: if fx is not None: - fx = arr[:, 2] * (fx - 1) - arr[:, 0] -= fx / 2 - arr[:, 2] += fx + delta_x = arr[:, 2] * (fx - 1) + arr[:, 0] -= delta_x / 2 + arr[:, 2] += delta_x if fy is not None: - fy = arr[3] * (fy - 1) - arr[:, 1] -= fy / 2 - arr[:, 3] += fy + delta_y = arr[:, 3] * (fy - 1) + arr[:, 1] -= delta_y / 2 + arr[:, 3] += delta_y return self.__class__(arr, "XYWH").convert(self.box_mode) def get_empty_index(self) -> np.ndarray: - """ Get the index of empty boxes. """ + """Get the index of empty boxes.""" return np.where((self._xywh[:, 2] <= 0) | (self._xywh[:, 3] <= 0))[0] def drop_empty(self) -> "Boxes": - """ Drop the empty boxes. """ - return self.__class__(self._array[(self._xywh[:, 2] > 0) & (self._xywh[:, 3] > 0)], self.box_mode) + """Drop the empty boxes.""" + return self.__class__( + self._array[(self._xywh[:, 2] > 0) & (self._xywh[:, 3] > 0)], + self.box_mode, + ) def to_list(self) -> list: return self._array.tolist() def tolist(self) -> list: - """ Alias of `to_list` (numpy style) """ + """Alias of `to_list` (numpy style)""" return self.to_list() def to_polygons(self): from .polygons import Polygons + arr = self._xywh.copy() if (arr[:, 2:] <= 0).any(): raise ValueError( - 'Some element in Boxes has invaild value, which width or ' - 'height is smaller than zero or other unexpected reasons.' + "Some element in Boxes has invaild value, which width or " + "height is smaller than zero or other unexpected reasons." ) p1 = arr[:, :2] @@ -584,37 +681,37 @@ def to_polygons(self): p4 = np.stack([arr[:, 0], arr[:, 1::2].sum(1)], axis=1) return Polygons(np.stack([p1, p2, p3, p4], axis=1), self.is_normalized) - @ property + @property def width(self) -> np.ndarray: - """ Get width of the box. """ + """Get width of the box.""" return self._xywh[:, 2] - @ property + @property def height(self) -> np.ndarray: - """ Get height of the box. """ + """Get height of the box.""" return self._xywh[:, 3] - @ property + @property def left_top(self) -> np.ndarray: - """ Get the left-top point of the box. """ + """Get the left-top point of the box.""" return self._xywh[:, :2] - @ property + @property def right_bottom(self) -> np.ndarray: - """ Get the right-bottom point of the box. """ + """Get the right-bottom point of the box.""" return self._xywh[:, :2] + self._xywh[:, 2:4] - @ property + @property def area(self) -> np.ndarray: - """ Get the area of the boxes. """ + """Get the area of the boxes.""" return self._xywh[:, 2] * self._xywh[:, 3] - @ property + @property def aspect_ratio(self) -> np.ndarray: - """ Compute the aspect ratios (widths / heights) of the boxes. """ + """Compute the aspect ratios (widths / heights) of the boxes.""" return self._xywh[:, 2] / self._xywh[:, 3] - @ property + @property def center(self) -> np.ndarray: - """ Compute the center of the box. """ + """Compute the center of the box.""" return self._xywh[:, :2] + self._xywh[:, 2:] / 2 diff --git a/capybara/structures/functionals.py b/capybara/structures/functionals.py index acc880b..091495c 100644 --- a/capybara/structures/functionals.py +++ b/capybara/structures/functionals.py @@ -1,5 +1,3 @@ -from typing import Optional, Tuple, Union - import cv2 import numpy as np from shapely.geometry import Polygon as ShapelyPolygon @@ -9,9 +7,14 @@ from .polygons import Polygon __all__ = [ - 'pairwise_intersection', 'pairwise_iou', 'pairwise_ioa', - 'jaccard_index', 'polygon_iou', 'is_inside_box', 'calc_angle', - 'poly_angle', + "calc_angle", + "is_inside_box", + "jaccard_index", + "pairwise_intersection", + "pairwise_ioa", + "pairwise_iou", + "poly_angle", + "polygon_iou", ] @@ -27,10 +30,10 @@ def pairwise_intersection(boxes1: Boxes, boxes2: Boxes) -> np.ndarray: ndarray: intersection, sized [N, M]. """ if not isinstance(boxes1, Boxes) or not isinstance(boxes2, Boxes): - raise TypeError(f'Input type of boxes1 and boxes2 must be Boxes.') + raise TypeError("Input type of boxes1 and boxes2 must be Boxes.") - boxes1_ = boxes1.convert('XYXY').numpy() - boxes2_ = boxes2.convert('XYXY').numpy() + boxes1_ = boxes1.convert("XYXY").numpy() + boxes2_ = boxes2.convert("XYXY").numpy() lt = np.maximum(boxes1_[:, None, :2], boxes2_[:, :2]) rb = np.minimum(boxes1_[:, None, 2:], boxes2_[:, 2:]) width_height = (rb - lt).clip(min=0) @@ -51,12 +54,12 @@ def pairwise_iou(boxes1: Boxes, boxes2: Boxes) -> np.ndarray: ndarray: IoU, sized [N,M]. """ if not isinstance(boxes1, Boxes) or not isinstance(boxes2, Boxes): - raise TypeError(f'Input type of boxes1 and boxes2 must be Boxes.') + raise TypeError("Input type of boxes1 and boxes2 must be Boxes.") if np.any(boxes1._xywh[:, 2:] <= 0) or np.any(boxes2._xywh[:, 2:] <= 0): raise ValueError( - 'Some boxes in Boxes has invaild value, which width or ' - 'height is smaller than zero or other unexpected reasons, ' + "Some boxes in Boxes has invaild value, which width or " + "height is smaller than zero or other unexpected reasons, " 'try to run "drop_empty()" at first.' ) @@ -77,12 +80,12 @@ def pairwise_ioa(boxes1: Boxes, boxes2: Boxes) -> np.ndarray: ndarray: IoA, sized [N,M]. """ if not isinstance(boxes1, Boxes) or not isinstance(boxes2, Boxes): - raise TypeError(f'Input type of boxes1 and boxes2 must be Boxes.') + raise TypeError("Input type of boxes1 and boxes2 must be Boxes.") if np.any(boxes1._xywh[:, 2:] <= 0) or np.any(boxes2._xywh[:, 2:] <= 0): raise ValueError( - 'Some boxes in Boxes has invaild value, which width or ' - 'height is smaller than zero or other unexpected reasons, ' + "Some boxes in Boxes has invaild value, which width or " + "height is smaller than zero or other unexpected reasons, " 'try to run "drop_empty()" at first.' ) @@ -95,7 +98,7 @@ def pairwise_ioa(boxes1: Boxes, boxes2: Boxes) -> np.ndarray: def jaccard_index( pred_poly: np.ndarray, gt_poly: np.ndarray, - image_size: Tuple[int, int], + image_size: tuple[int, int], ) -> float: """ Reference : https://github.com/jchazalon/smartdoc15-ch1-eval @@ -114,29 +117,26 @@ def jaccard_index( """ if pred_poly.shape != (4, 2) or gt_poly.shape != (4, 2): - raise ValueError(f'Input polygon must be 4-point polygon.') + raise ValueError("Input polygon must be 4-point polygon.") if image_size is None: - raise ValueError(f'Input image size must be provided.') + raise ValueError("Input image size must be provided.") pred_poly = pred_poly.astype(np.float32) gt_poly = gt_poly.astype(np.float32) height, width = image_size - object_coord_target = np.array([ - [0, 0], - [width, 0], - [width, height], - [0, height]] + object_coord_target = np.array( + [[0, 0], [width, 0], [width, height], [0, height]] ).astype(np.float32) - M = cv2.getPerspectiveTransform( + matrix = cv2.getPerspectiveTransform( gt_poly.reshape(-1, 1, 2), object_coord_target[None, ...], ) transformed_pred_coords = cv2.perspectiveTransform( - pred_poly.reshape(-1, 1, 2), M + pred_poly.reshape(-1, 1, 2), matrix ) try: @@ -157,10 +157,10 @@ def jaccard_index( area_inter = area_min jaccard_index = area_inter / area_union - except: - # 通常錯誤來自於: + except Exception: + # 通常錯誤來自於: # TopologyException: Input geom 1 is invalid: Ring Self-intersection - # 表示多邊形自己交叉了,這時候就直接給 0 + # 表示多邊形自己交叉了, 這時候就直接給 0 jaccard_index = 0 return jaccard_index @@ -177,18 +177,18 @@ def polygon_iou(poly1: Polygon, poly2: Polygon): float: IoU. """ if not isinstance(poly1, Polygon) or not isinstance(poly2, Polygon): - raise TypeError(f'Input type of poly1 and poly2 must be Polygon.') + raise TypeError("Input type of poly1 and poly2 must be Polygon.") - poly1 = poly1.numpy().astype(np.float32) - poly2 = poly2.numpy().astype(np.float32) + poly1_arr = poly1.numpy().astype(np.float32) + poly2_arr = poly2.numpy().astype(np.float32) try: - poly1 = ShapelyPolygon(poly1) - poly2 = ShapelyPolygon(poly2) - poly_inter = poly1 & poly2 + poly1_shape = ShapelyPolygon(poly1_arr) + poly2_shape = ShapelyPolygon(poly2_arr) + poly_inter = poly1_shape.intersection(poly2_shape) - area_target = poly1.area - area_test = poly2.area + area_target = poly1_shape.area + area_test = poly2_shape.area area_inter = poly_inter.area area_union = area_test + area_target - area_inter @@ -200,16 +200,16 @@ def polygon_iou(poly1: Polygon, poly2: Polygon): area_inter = area_min iou = area_inter / area_union - except: - # 通常錯誤來自於: + except Exception: + # 通常錯誤來自於: # TopologyException: Input geom 1 is invalid: Ring Self-intersection - # 表示多邊形自己交叉了,這時候就直接給 0 + # 表示多邊形自己交叉了, 這時候就直接給 0 iou = 0 return iou -def is_inside_box(x: Union[Box, Keypoints, Polygon], box: Box) -> np.bool_: +def is_inside_box(x: Box | Keypoints | Polygon, box: Box) -> np.bool_: cond1 = x._array >= box.left_top cond2 = x._array <= box.right_bottom return np.concatenate((cond1, cond2), axis=-1).all() @@ -238,8 +238,8 @@ def calc_angle(v1, v2): def poly_angle( poly1: Polygon, - poly2: Optional[Polygon] = None, - base_vector: Tuple[int, int] = (0, 1) + poly2: Polygon | None = None, + base_vector: tuple[int, int] = (0, 1), ) -> float: """ Calculate the angle between two polygons or a polygon and a base vector. @@ -252,7 +252,10 @@ def _get_angle(poly): return vector1 + vector2 v1 = _get_angle(poly1) - v2 = _get_angle(poly2) if poly2 is not None else np.array( - base_vector, dtype='float32') + v2 = ( + _get_angle(poly2) + if poly2 is not None + else np.array(base_vector, dtype="float32") + ) return calc_angle(v1, v2) diff --git a/capybara/structures/keypoints.py b/capybara/structures/keypoints.py index ae8e5d9..f621e45 100644 --- a/capybara/structures/keypoints.py +++ b/capybara/structures/keypoints.py @@ -1,7 +1,8 @@ -from typing import Any, List, Tuple, Union +import colorsys +from collections.abc import Sequence +from typing import Any, TypeAlias, Union from warnings import warn -import matplotlib import numpy as np from ..typing import _Number @@ -9,17 +10,40 @@ __all__ = ["Keypoints", "KeypointsList"] -_Keypoints = Union[ +def _colormap_bytes(cmap: str, steps: np.ndarray) -> np.ndarray: + try: # pragma: no cover - optional dependency + import matplotlib # type: ignore + + try: + color_map = matplotlib.colormaps[cmap] + except Exception: + color_map = matplotlib.colormaps["rainbow"] + return np.asarray(color_map(steps, bytes=True), dtype=np.uint8) + except Exception: + out = np.empty((len(steps), 4), dtype=np.uint8) + for i, t in enumerate(steps): + hue = float(t) + if len(steps) > 1 and hue >= 1.0: + hue = 1.0 - (1.0 / len(steps)) + r, g, b = colorsys.hsv_to_rgb(hue, 1.0, 1.0) + out[i, 0] = int(r * 255) + out[i, 1] = int(g * 255) + out[i, 2] = int(b * 255) + out[i, 3] = 255 + return out + + +_Keypoints: TypeAlias = Union[ np.ndarray, - List[np.ndarray], - List[Tuple[_Number, _Number]], - List[Tuple[_Number, _Number, _Number]], + Sequence[np.ndarray], + Sequence[tuple[_Number, _Number]], + Sequence[tuple[_Number, _Number, _Number]], "Keypoints", ] - -_KeypointsList = Union[ +_KeypointsList: TypeAlias = Union[ np.ndarray, - List[_Keypoints], + Sequence[_Keypoints], + "KeypointsList", ] @@ -32,18 +56,19 @@ class Keypoints: * v=2: labeled and visible """ - def __init__(self, array: _Keypoints, cmap="rainbow", is_normalized: bool = False): + def __init__( + self, array: _Keypoints, cmap="rainbow", is_normalized: bool = False + ): self._array = self._check_valid_array(array) steps = np.linspace(0.0, 1.0, self._array.shape[-2]) - color_map = matplotlib.colormaps[cmap] - self._point_colors = np.array(color_map(steps, bytes=True))[..., :3].tolist() + self._point_colors = _colormap_bytes(str(cmap), steps)[..., :3].tolist() self._is_normalized = is_normalized def __len__(self) -> int: return self._array.shape[0] def __repr__(self) -> str: - return f"{self.__class__.__name__}({str(self._array)})" + return f"{self.__class__.__name__}({self._array!s})" def __eq__(self, value: object) -> bool: if not isinstance(value, self.__class__): @@ -52,25 +77,34 @@ def __eq__(self, value: object) -> bool: def _check_valid_array(self, array: Any) -> np.ndarray: cond1 = isinstance(array, np.ndarray) - cond2 = isinstance(array, list) and all(isinstance(x, (tuple, np.ndarray)) for x in array) + cond2 = isinstance(array, list) and all( + isinstance(x, (tuple, np.ndarray)) for x in array + ) cond3 = isinstance(array, self.__class__) if not (cond1 or cond2 or cond3): - raise TypeError(f"Input array is not {_Keypoints}, but got {type(array)}.") + raise TypeError( + f"Input array is not {_Keypoints}, but got {type(array)}." + ) - if cond3: - array = array.numpy() - else: - array = np.array(array, dtype="float32") + array = array.numpy() if cond3 else np.array(array, dtype="float32") if not array.ndim == 2: - raise ValueError(f"Input array ndim = {array.ndim} is not 2, which is invalid.") + raise ValueError( + f"Input array ndim = {array.ndim} is not 2, which is invalid." + ) if array.shape[-1] not in [2, 3]: - raise ValueError(f"Input array's shape[-1] = {array.shape[-1]} is not in [2, 3], which is invalid.") + raise ValueError( + f"Input array's shape[-1] = {array.shape[-1]} is not in [2, 3], which is invalid." + ) - if array.shape[-1] == 3 and not ((array[..., 2] <= 2).all() and (array[..., 2] >= 0).all()): - raise ValueError("Given array is invalid because of its labels. (array[..., 2])") + if array.shape[-1] == 3 and not ( + (array[..., 2] <= 2).all() and (array[..., 2] >= 0).all() + ): + raise ValueError( + "Given array is invalid because of its labels. (array[..., 2])" + ) return array.copy() def numpy(self) -> np.ndarray: @@ -91,7 +125,10 @@ def scale(self, fx: float, fy: float) -> "Keypoints": def normalize(self, w: float, h: float) -> "Keypoints": if self.is_normalized: - warn("Normalized keypoints are forced to do normalization.") + warn( + "Normalized keypoints are forced to do normalization.", + stacklevel=2, + ) arr = self._array.copy() arr[..., :2] = arr[..., :2] / (w, h) kpts = self.__class__(arr) @@ -100,7 +137,10 @@ def normalize(self, w: float, h: float) -> "Keypoints": def denormalize(self, w: float, h: float) -> "Keypoints": if not self.is_normalized: - warn("Non-normalized keypoints is forced to do denormalization.") + warn( + "Non-normalized keypoints is forced to do denormalization.", + stacklevel=2, + ) arr = self._array.copy() arr[..., :2] = arr[..., :2] * (w, h) kpts = self.__class__(arr) @@ -112,22 +152,29 @@ def is_normalized(self) -> bool: return self._is_normalized @property - def point_colors(self) -> List[Tuple[int, int, int]]: - return [tuple([int(x) for x in cs]) for cs in self._point_colors] + def point_colors(self) -> list[tuple[int, int, int]]: + return [ + (int(cs[0]), int(cs[1]), int(cs[2])) for cs in self._point_colors + ] - @point_colors.setter - def set_point_colors(self, cmap: str): + def set_point_colors(self, cmap: str) -> None: steps = np.linspace(0.0, 1.0, self._array.shape[-2]) - self._point_colors = matplotlib.colormaps[cmap](steps, bytes=True) + self._point_colors = _colormap_bytes(str(cmap), steps)[..., :3].tolist() + + @point_colors.setter + def point_colors(self, cmap: str) -> None: + self.set_point_colors(cmap) class KeypointsList: - def __init__(self, array: _KeypointsList, cmap="rainbow", is_normalized: bool = False) -> None: + def __init__( + self, array: _KeypointsList, cmap="rainbow", is_normalized: bool = False + ) -> None: self._array = self._check_valid_array(array).copy() self._is_normalized = is_normalized if len(self._array): steps = np.linspace(0.0, 1.0, self._array.shape[-2]) - self._point_colors = matplotlib.colormaps[cmap](steps, bytes=True) + self._point_colors = _colormap_bytes(str(cmap), steps) else: self._point_colors = None @@ -136,8 +183,12 @@ def __len__(self) -> int: def __getitem__(self, item) -> Any: if isinstance(item, int): - return Keypoints(self._array[item], is_normalized=self.is_normalized) - return KeypointsList(self._array[item], is_normalized=self.is_normalized) + return Keypoints( + self._array[item], is_normalized=self.is_normalized + ) + return KeypointsList( + self._array[item], is_normalized=self.is_normalized + ) def __setitem__(self, item, value): if not isinstance(value, (Keypoints, KeypointsList)): @@ -151,7 +202,7 @@ def __iter__(self) -> Any: yield self[i] def __repr__(self) -> str: - return f"{self.__class__.__name__}({str(self._array)})" + return f"{self.__class__.__name__}({self._array!s})" def __eq__(self, value: object) -> bool: if not isinstance(value, self.__class__): @@ -161,17 +212,16 @@ def __eq__(self, value: object) -> bool: def _check_valid_array(self, array: Any) -> np.ndarray: cond1 = isinstance(array, np.ndarray) cond2 = isinstance(array, list) and len(array) == 0 - cond3 = ( - isinstance(array, list) - and ( - all(isinstance(x, (np.ndarray, Keypoints)) for x in array) - or all(isinstance(y, tuple) for x in array for y in x) - ) + cond3 = isinstance(array, list) and ( + all(isinstance(x, (np.ndarray, Keypoints)) for x in array) + or all(isinstance(y, tuple) for x in array for y in x) ) cond4 = isinstance(array, self.__class__) if not (cond1 or cond2 or cond3 or cond4): - raise TypeError(f"Input array is not {_KeypointsList}, but got {type(array)}.") + raise TypeError( + f"Input array is not {_KeypointsList}, but got {type(array)}." + ) if cond4: array = array.numpy() @@ -184,13 +234,21 @@ def _check_valid_array(self, array: Any) -> np.ndarray: return array if array.ndim != 3: - raise ValueError(f"Input array's ndim = {array.ndim} is not 3, which is invalid.") + raise ValueError( + f"Input array's ndim = {array.ndim} is not 3, which is invalid." + ) if array.shape[-1] not in [2, 3]: - raise ValueError(f"Input array's shape[-1] = {array.shape[-1]} is not 2 or 3, which is invalid.") + raise ValueError( + f"Input array's shape[-1] = {array.shape[-1]} is not 2 or 3, which is invalid." + ) - if array.shape[-1] == 3 and not ((array[..., 2] <= 2).all() and (array[..., 2] >= 0).all()): - raise ValueError("Given array is invalid because of its labels. (array[..., 2])") + if array.shape[-1] == 3 and not ( + (array[..., 2] <= 2).all() and (array[..., 2] >= 0).all() + ): + raise ValueError( + "Given array is invalid because of its labels. (array[..., 2])" + ) return array @@ -212,7 +270,10 @@ def scale(self, fx: float, fy: float) -> Any: def normalize(self, w: float, h: float) -> "KeypointsList": if self.is_normalized: - warn("Normalized keypoints_list is forced to do normalization.") + warn( + "Normalized keypoints_list is forced to do normalization.", + stacklevel=2, + ) arr = self._array.copy() arr[..., :2] = arr[..., :2] / (w, h) kpts_list = self.__class__(arr) @@ -221,7 +282,10 @@ def normalize(self, w: float, h: float) -> "KeypointsList": def denormalize(self, w: float, h: float) -> "KeypointsList": if not self.is_normalized: - warn("Non-normalized box is forced to do denormalization.") + warn( + "Non-normalized box is forced to do denormalization.", + stacklevel=2, + ) arr = self._array.copy() arr[..., :2] = arr[..., :2] * (w, h) kpts_list = self.__class__(arr) @@ -233,16 +297,24 @@ def is_normalized(self) -> bool: return self._is_normalized @property - def point_colors(self): - return [tuple(c) for c in self._point_colors[..., :3].tolist()] + def point_colors(self) -> list[tuple[int, int, int]]: + if self._point_colors is None: + return [] + colors = self._point_colors[..., :3].tolist() + return [(int(c[0]), int(c[1]), int(c[2])) for c in colors] + + def set_point_colors(self, cmap: str) -> None: + if self._point_colors is None: + return + steps = np.linspace(0.0, 1.0, self._array.shape[-2]) + self._point_colors = _colormap_bytes(str(cmap), steps) @point_colors.setter - def set_point_colors(self, cmap: str): - steps = np.linspace(0.0, 1.0, self._array.shape[-2]) - self._point_colors = matplotlib.colormaps[cmap](steps, bytes=True) + def point_colors(self, cmap: str) -> None: + self.set_point_colors(cmap) @classmethod - def cat(cls, keypoints_lists: List["KeypointsList"]) -> "KeypointsList": + def cat(cls, keypoints_lists: list["KeypointsList"]) -> "KeypointsList": """ Concatenates a list of KeypointsList into a single KeypointsList @@ -260,7 +332,17 @@ def cat(cls, keypoints_lists: List["KeypointsList"]) -> "KeypointsList": if len(keypoints_lists) == 0: raise ValueError("Given keypoints_list is empty.") - if not all(isinstance(keypoints_list, KeypointsList) for keypoints_list in keypoints_lists): - raise TypeError("All type of elements in keypoints_lists must be KeypointsList.") + if not all( + isinstance(keypoints_list, KeypointsList) + for keypoints_list in keypoints_lists + ): + raise TypeError( + "All type of elements in keypoints_lists must be KeypointsList." + ) - return cls(np.concatenate([keypoints_list.numpy() for keypoints_list in keypoints_lists], axis=0)) + return cls( + np.concatenate( + [keypoints_list.numpy() for keypoints_list in keypoints_lists], + axis=0, + ) + ) diff --git a/capybara/structures/polygons.py b/capybara/structures/polygons.py index 5105912..3365e2c 100644 --- a/capybara/structures/polygons.py +++ b/capybara/structures/polygons.py @@ -1,5 +1,6 @@ +from collections.abc import Sequence from copy import deepcopy -from typing import Any, List, Tuple, Union +from typing import Any, Union, cast, overload from warnings import warn import cv2 @@ -10,26 +11,20 @@ from ..typing import _Number __all__ = [ - 'Polygon', 'Polygons', 'order_points_clockwise', 'JOIN_STYLE', + "JOIN_STYLE", + "Polygon", + "Polygons", + "order_points_clockwise", ] -_Polygon = Union[ - np.ndarray, - List[Tuple[_Number, _Number]], - "Polygon" -] - -_Polygons = Union[ - np.ndarray, - List["Polygon"], - List[np.ndarray], - List[List[Tuple[_Number, _Number]]], - "Polygons" -] +_Polygon = Union[np.ndarray, Sequence[Sequence[_Number]], "Polygon"] +_Polygons = Union[np.ndarray, Sequence[_Polygon], "Polygons"] -def order_points_clockwise(pts: np.ndarray, inverse: bool = False) -> np.ndarray: - """ Order the 4 points clockwise. +def order_points_clockwise( + pts: np.ndarray, inverse: bool = False +) -> np.ndarray: + """Order the 4 points clockwise. Args: pts (np.ndarray): @@ -45,7 +40,7 @@ def order_points_clockwise(pts: np.ndarray, inverse: bool = False) -> np.ndarray """ if pts.shape != (4, 2): - raise ValueError('Input array `pts` must be of shape (4, 2).') + raise ValueError("Input array `pts` must be of shape (4, 2).") x_sorted = pts[np.argsort(pts[:, 0]), :] left_most = x_sorted[:2, :] @@ -69,11 +64,7 @@ class Polygon: has shape (K, 2) where K is the number of points per instance. """ - def __init__( - self, - array: _Polygon, - is_normalized: bool = False - ): + def __init__(self, array: _Polygon, is_normalized: bool = False): """ Args: array (Union[np.ndarray, list]): A Nx2 or Nx1x2 matrix. @@ -82,12 +73,18 @@ def __init__( self.is_normalized = is_normalized def __repr__(self) -> str: - return f"{self.__class__.__name__}({str(self._array)})" + return f"{self.__class__.__name__}({self._array!s})" def __len__(self) -> int: return self._array.shape[0] - def __getitem__(self, item) -> float: + @overload + def __getitem__(self, item: int) -> np.ndarray: ... + + @overload + def __getitem__(self, item: slice) -> np.ndarray: ... + + def __getitem__(self, item): return self._array[item] def __eq__(self, value: object) -> bool: @@ -97,32 +94,27 @@ def __eq__(self, value: object) -> bool: def _check_valid_array(self, array: _Polygon) -> np.ndarray: if isinstance(array, (list, tuple)): - array = np.array(array, dtype='float32') + array = np.array(array, dtype="float32") if isinstance(array, Polygon): array = array.numpy() - cond1 = isinstance( - array, np.ndarray) and array.ndim == 3 and array.shape[1] == 1 - cond2 = isinstance( - array, np.ndarray) and array.ndim == 2 and array.shape[1] == 2 - cond3 = isinstance( - array, np.ndarray) and array.ndim == 1 and len(array) == 0 - cond4 = isinstance( - array, np.ndarray) and array.ndim == 1 and len(array) == 2 - cond5 = isinstance(array, self.__class__) - if not (cond1 or cond2 or cond3 or cond4 or cond5): - raise TypeError(f'Input array must be {_Polygon}.') - if cond3 or cond4: + if not isinstance(array, np.ndarray): + raise TypeError(f"Input array must be {_Polygon}.") + if array.ndim == 1 and len(array) in {0, 2}: array = array[None] - if cond1: + if array.ndim == 3 and array.shape[1] == 1: array = np.squeeze(array, axis=1) - return array.astype('float32') + if array.ndim == 2 and array.shape[1] == 2: + return array.astype("float32") + if array.ndim == 2 and array.shape == (1, 0): + return array.astype("float32") + raise TypeError(f"Input array must be {_Polygon}.") def copy(self) -> "Polygon": - """ Create a copy of the Polygon object. """ + """Create a copy of the Polygon object.""" return self.__class__(self._array) def numpy(self) -> np.ndarray: - """ Convert the Polygon object to a numpy array. """ + """Convert the Polygon object to a numpy array.""" return self._array.copy() def normalize(self, w: float, h: float) -> "Polygon": @@ -137,7 +129,10 @@ def normalize(self, w: float, h: float) -> "Polygon": Normalized Polygon object. """ if self.is_normalized: - warn(f'Normalized polygon is forced to do normalization.') + warn( + "Normalized polygon is forced to do normalization.", + stacklevel=2, + ) arr = self._array.copy() arr = arr / (w, h) return self.__class__(arr, is_normalized=True) @@ -154,7 +149,10 @@ def denormalize(self, w: float, h: float) -> "Polygon": Denormalized Polygon object. """ if not self.is_normalized: - warn(f'Non-normalized polygon is forced to do denormalization.') + warn( + "Non-normalized polygon is forced to do denormalization.", + stacklevel=2, + ) arr = self._array.copy() arr = arr * (w, h) return self.__class__(arr, is_normalized=False) @@ -198,7 +196,7 @@ def shift(self, shift_x: float, shift_y: float) -> "Polygon": def scale( self, distance: int, - join_style: JOIN_STYLE = JOIN_STYLE.mitre + join_style: int = JOIN_STYLE.mitre, ) -> "Polygon": """ Returns an approximate representation of all points within a given distance @@ -211,14 +209,15 @@ def scale( These values are also enumerated by the object shapely.geometry.JOIN_STYLE """ poly = _Polygon_shapely(self._array).buffer( - distance, join_style=join_style) + distance, join_style=cast(Any, join_style) + ) if isinstance(poly, MultiPolygon): poly = max(poly.geoms, key=lambda p: p.area) if isinstance(poly, _Polygon_shapely) and not poly.exterior.is_empty: pts = np.zeros_like(self._array) - for x, y in zip(*poly.exterior.xy): + for x, y in zip(*poly.exterior.xy, strict=True): pt = np.array([x, y]) dist = np.linalg.norm(pt - self._array, axis=1) pts[dist.argmin()] = pt @@ -237,15 +236,18 @@ def to_convexhull(self) -> "Polygon": return self.__class__(hull) def to_min_boxpoints(self) -> "Polygon": - """ Converts polygon to the min area bounding box. """ + """Converts polygon to the min area bounding box.""" min_box = cv2.boxPoints(self.min_box).round(4) min_box = order_points_clockwise(np.array(min_box)) return self.__class__(min_box) - def to_box(self, box_mode: str = 'xyxy'): - """ Converts polygon to the bounding box. """ + def to_box(self, box_mode: str = "xyxy"): + """Converts polygon to the bounding box.""" from .boxes import Box - return Box(self.boundingbox, "xywh", self.is_normalized).convert(box_mode) + + return Box(self.boundingbox, "xywh", self.is_normalized).convert( + box_mode + ) def to_list(self, flatten: bool = False) -> list: if flatten: @@ -254,7 +256,7 @@ def to_list(self, flatten: bool = False) -> list: return self._array.tolist() def tolist(self, flatten: bool = False) -> list: - """ Alias of `to_list` (numpy style) """ + """Alias of `to_list` (numpy style)""" return self.to_list(flatten=flatten) def is_empty(self, threshold: int = 3) -> bool: @@ -264,86 +266,95 @@ def is_empty(self, threshold: int = 3) -> bool: """ if not isinstance(threshold, int): raise TypeError( - f'Input threshold type error, expected "int", got "{type(threshold)}".') + f'Input threshold type error, expected "int", got "{type(threshold)}".' + ) return len(self) < threshold @property def moments(self) -> dict: - """ Get the moment of area. """ + """Get the moment of area.""" return cv2.moments(self._array) @property def area(self) -> float: - """ Get the region area. """ - return self.moments['m00'] + """Get the region area.""" + return self.moments["m00"] @property def arclength(self) -> float: - """ Get the region arc length. """ + """Get the region arc length.""" return cv2.arcLength(self._array, closed=True) @property def centroid(self) -> np.ndarray: - """ Get the mass centers. """ - return np.array([ - self.moments['m10'] / (self.moments['m00'] + 1e-5), - self.moments['m01'] / (self.moments['m00'] + 1e-5) - ]) + """Get the mass centers.""" + return np.array( + [ + self.moments["m10"] / (self.moments["m00"] + 1e-5), + self.moments["m01"] / (self.moments["m00"] + 1e-5), + ] + ) @property - def boundingbox(self) -> np.ndarray: - """ Get the bounding box. """ + def boundingbox(self): + """Get the bounding box.""" from .boxes import Box - bbox = cv2.boundingRect(self._array) + + bbox = np.array(cv2.boundingRect(self._array), dtype="float32") if not self.is_normalized: - bbox = bbox - np.array([0, 0, 1, 1]) - return Box(bbox, 'xywh') + bbox = bbox - np.array([0, 0, 1, 1], dtype="float32") + return Box(bbox, "xywh", is_normalized=self.is_normalized) @property - def min_circle(self) -> Tuple[Tuple[int, int], int]: - """ Get the min closed circle. """ + def min_circle(self) -> tuple[Sequence[float], float]: + """Get the min closed circle.""" return cv2.minEnclosingCircle(self._array) @property - def min_box(self) -> Tuple[Tuple[int, int], Tuple[int, int], int]: - """ Get the min area rectangle. """ + def min_box( + self, + ) -> tuple[Sequence[float], Sequence[float], float]: + """Get the min area rectangle.""" return cv2.minAreaRect(self._array) @property def orientation(self) -> float: - """ Get the min area rectangle. """ + """Get the min area rectangle.""" _, _, angle = self.min_box return angle @property - def min_box_wh(self) -> Tuple[float, float]: - """ Get the min area rectangle. """ + def min_box_wh(self) -> tuple[float, float]: + """Get the min area rectangle.""" _, (w, h), _ = self.min_box return w, h @property def extent(self) -> float: - """ Ratio of pixels in the region to pixels in the total bounding box. """ + """Ratio of pixels in the region to pixels in the total bounding box.""" _, _, w, h = self.boundingbox return self.area / (w * h) @property def solidity(self) -> float: - """ Ratio of pixels in the region to pixels of the convex hull image. """ + """Ratio of pixels in the region to pixels of the convex hull image.""" return self.area / (self.to_convexhull().area + 1e-5) class Polygons: - def __init__(self, polygons: _Polygons, is_normalized: bool = False): - if not isinstance(polygons, (list, np.ndarray)): + if isinstance(polygons, Polygons): + polygons = polygons._polygons + elif not isinstance(polygons, (list, tuple, np.ndarray)): raise TypeError( - f'Input type error: "{polygons}", must be list or np.ndarray type.') + f'Input type error: "{polygons}", must be list or np.ndarray type.' + ) self.is_normalized = is_normalized + self._polygons: list[Polygon] self._polygons = [Polygon(p, is_normalized) for p in polygons] def __repr__(self) -> str: - return f"{self.__class__.__name__}({str(self._polygons)})" + return f"{self.__class__.__name__}({self._polygons!s})" def __len__(self) -> int: return len(self._polygons) @@ -355,8 +366,17 @@ def __iter__(self) -> Any: def __eq__(self, value: object) -> bool: if not isinstance(value, self.__class__): return False - is_same = sum([t == x for t, x in zip(value, self)]) == len(self) - return is_same + if len(value) != len(self): + return False + return all(t == x for t, x in zip(value, self, strict=True)) + + @overload + def __getitem__(self, item: int) -> Polygon: ... + + @overload + def __getitem__( + self, item: list[int] | slice | np.ndarray + ) -> "Polygons": ... def __getitem__(self, item) -> Union["Polygons", "Polygon"]: """ @@ -383,12 +403,13 @@ def __getitem__(self, item) -> Union["Polygons", "Polygon"]: elif isinstance(item, slice): output = Polygons(self._polygons[item]) elif isinstance(item, np.ndarray): - if item.dtype == 'bool': + if item.dtype == "bool": item = np.argwhere(item).flatten() output = Polygons([self._polygons[i] for i in item]) else: raise TypeError( - 'Input item type error, expected to be int, list, ndarray or slice.') + "Input item type error, expected to be int, list, ndarray or slice." + ) return output def is_empty(self, threshold: int = 3) -> np.ndarray: @@ -400,102 +421,117 @@ def to_min_boxpoints(self) -> "Polygons": def to_convexhull(self) -> "Polygons": return Polygons([poly.to_convexhull() for poly in self._polygons]) - def to_boxes(self, box_mode: str = 'xyxy'): + def to_boxes(self, box_mode: str = "xyxy"): from .boxes import Boxes - return Boxes(self.boundingbox, 'xywh', self.is_normalized).convert(box_mode) + + return Boxes(self.boundingbox, "xywh", self.is_normalized).convert( + box_mode + ) def drop_empty(self, threshold: int = 3) -> "Polygons": - return Polygons([p for p in self._polygons if not p.is_empty(threshold)]) + return Polygons( + [p for p in self._polygons if not p.is_empty(threshold)] + ) def copy(self): return self.__class__(deepcopy(self._polygons)) def normalize(self, w: float, h: float) -> "Polygons": if self.is_normalized: - warn(f'Normalized polygons are forced to do normalization.') + warn( + "Normalized polygons are forced to do normalization.", + stacklevel=2, + ) _polygons = [x.normalize(w, h) for x in self._polygons] polygons = self.__class__(_polygons, is_normalized=True) return polygons def denormalize(self, w: float, h: float) -> "Polygons": if not self.is_normalized: - warn(f'Non-normalized polygons are forced to do denormalization.') + warn( + "Non-normalized polygons are forced to do denormalization.", + stacklevel=2, + ) _polygons = [x.denormalize(w, h) for x in self._polygons] polygons = self.__class__(_polygons, is_normalized=False) return polygons def clip(self, xmin: int, ymin: int, xmax: int, ymax: int) -> "Polygons": - return Polygons([p.clip(xmin, ymin, xmax, ymax) for p in self._polygons]) + return Polygons( + [p.clip(xmin, ymin, xmax, ymax) for p in self._polygons] + ) def shift(self, shift_x: float, shift_y: float) -> "Polygons": return Polygons([p.shift(shift_x, shift_y) for p in self._polygons]) def scale(self, distance: int) -> "Polygons": - return Polygons([p.scale(distance) for p in self._polygons]).drop_empty() + return Polygons( + [p.scale(distance) for p in self._polygons] + ).drop_empty() def numpy(self, flatten: bool = False): len_polys = np.array([len(p) for p in self._polygons], dtype=np.int32) if (len_polys == len_polys.mean()).all(): - return np.array(self.to_list(flatten=flatten)).astype('float32') + return np.array(self.to_list(flatten=flatten)).astype("float32") else: return np.array(self._polygons, dtype=object) def to_list(self, flatten: bool = False) -> list: - """ Convert boxes to list. + """Convert boxes to list. - Args: + Args: - is_flatten (bool): - True -> Output format (Nx(Mx2)): + is_flatten (bool): + True -> Output format (Nx(Mx2)): + [ + [p11, p12, p13, p14], + [p21, p22, p23, p24], + ..., + ]. + False -> Output format (NxMx2): + [ [ - [p11, p12, p13, p14], - [p21, p22, p23, p24], - ..., - ]. - False -> Output format (NxMx2): + [p11], + [p12], + [p13], + [p14] + ], [ - [ - [p11], - [p12], - [p13], - [p14] - ], - [ - [p21], - [p22], - [p23], - [p24] - ], - ..., - ]. + [p21], + [p22], + [p23], + [p24] + ], + ..., + ]. """ return [p.to_list(flatten) for p in self._polygons] def tolist(self, flatten: bool = False) -> list: - """ Alias of `to_list` (numpy style) """ + """Alias of `to_list` (numpy style)""" return self.to_list(flatten=flatten) - @ property + @property def moments(self) -> list: return [poly.moments for poly in self._polygons] - @ property + @property def min_circle(self) -> list: return [poly.min_circle for poly in self._polygons] - @ property + @property def min_box(self) -> list: return [poly.min_box for poly in self._polygons] - @ property + @property def area(self) -> np.ndarray: return np.array([poly.area for poly in self._polygons]) - @ property + @property def arclength(self) -> np.ndarray: return np.array([poly.arclength for poly in self._polygons]) - @ property + @property def centroid(self) -> np.ndarray: return np.array([poly.centroid for poly in self._polygons]) @@ -503,15 +539,15 @@ def centroid(self) -> np.ndarray: def boundingbox(self) -> np.ndarray: return np.array([poly.boundingbox for poly in self._polygons]) - @ property + @property def extent(self) -> np.ndarray: return np.array([poly.extent for poly in self._polygons]) - @ property + @property def solidity(self) -> np.ndarray: return np.array([poly.solidity for poly in self._polygons]) - @ property + @property def orientation(self) -> np.ndarray: return np.array([poly.orientation for poly in self._polygons]) @@ -524,31 +560,34 @@ def from_image( cls, image: np.ndarray, mode: int = cv2.RETR_EXTERNAL, - method: int = cv2.CHAIN_APPROX_SIMPLE + method: int = cv2.CHAIN_APPROX_SIMPLE, ) -> "Polygons": if not isinstance(image, np.ndarray): - raise TypeError('Input image must be a np.ndarray.') + raise TypeError("Input image must be a np.ndarray.") contours, _ = cv2.findContours(image, mode=mode, method=method) if len(contours) > 0: contours = [c for c in contours if c.shape[0] > 1] return cls(list(contours)) @classmethod - def cat(cls, polygons_list: List["Polygons"]) -> "Polygons": + def cat(cls, polygons_list: list["Polygons"]) -> "Polygons": """ Concatenates a list of Polygon into a single Polygons. Returns: Polygon: the concatenated Polygon """ if not isinstance(polygons_list, list): - raise TypeError('Given polygon_list should be a list.') + raise TypeError("Given polygon_list should be a list.") if len(polygons_list) == 0: - raise ValueError('Given polygon_list is empty.') + raise ValueError("Given polygon_list is empty.") - if not all(isinstance(polygons, Polygons) for polygons in polygons_list): + if not all( + isinstance(polygons, Polygons) for polygons in polygons_list + ): raise TypeError( - 'All type of elements in polygon_list must be Polygon.') + "All type of elements in polygon_list must be Polygon." + ) _polygons = [] for polys in polygons_list: diff --git a/capybara/torchengine/__init__.py b/capybara/torchengine/__init__.py new file mode 100644 index 0000000..49b4b95 --- /dev/null +++ b/capybara/torchengine/__init__.py @@ -0,0 +1,3 @@ +from .engine import TorchEngine, TorchEngineConfig + +__all__ = ["TorchEngine", "TorchEngineConfig"] diff --git a/capybara/torchengine/engine.py b/capybara/torchengine/engine.py new file mode 100644 index 0000000..44e65eb --- /dev/null +++ b/capybara/torchengine/engine.py @@ -0,0 +1,242 @@ +from __future__ import annotations + +import time +from collections.abc import Mapping, Sequence +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import numpy as np + +__all__ = ["TorchEngine", "TorchEngineConfig"] + + +def _lazy_import_torch(): + try: # pragma: no cover - optional dependency + import torch # type: ignore + except Exception as exc: # pragma: no cover - surfaced in init + raise ImportError( + "PyTorch is required to use Runtime.pt. Please install torch>=2.0." + ) from exc + return torch + + +def _validate_benchmark_params(*, repeat: Any, warmup: Any) -> tuple[int, int]: + repeat_i = int(repeat) + warmup_i = int(warmup) + if repeat_i < 1: + raise ValueError("repeat must be >= 1.") + if warmup_i < 0: + raise ValueError("warmup must be >= 0.") + return repeat_i, warmup_i + + +@dataclass(slots=True) +class TorchEngineConfig: + dtype: str | Any | None = None + cuda_sync: bool = True + + +class TorchEngine: + """Thin wrapper around torch.jit.ScriptModule for AngiDetection.""" + + def __init__( + self, + model_path: str | Path, + *, + device: str | Any = "cuda", + output_names: Sequence[str] | None = None, + config: TorchEngineConfig | None = None, + ) -> None: + self.model_path = str(model_path) + self._torch = _lazy_import_torch() + self._cfg = config or TorchEngineConfig() + self.device = self._normalize_device(device) + self.output_names = tuple(output_names or ()) + self.dtype = self._normalize_dtype(self._cfg.dtype) + self._model = self._load_model() + + def __call__(self, **inputs: Any) -> dict[str, np.ndarray]: + if len(inputs) == 1 and isinstance( + next(iter(inputs.values())), Mapping + ): + feed_dict = dict(next(iter(inputs.values()))) + else: + feed_dict = inputs + return self.run(feed_dict) + + def run(self, feed: Mapping[str, Any]) -> dict[str, np.ndarray]: + prepared = self._prepare_feed(feed) + outputs = self._forward(prepared) + return self._format_outputs(outputs) + + def benchmark( + self, + inputs: Mapping[str, Any], + *, + repeat: int = 100, + warmup: int = 10, + cuda_sync: bool | None = None, + ) -> dict[str, Any]: + repeat, warmup = _validate_benchmark_params( + repeat=repeat, warmup=warmup + ) + prepared = self._prepare_feed(inputs) + sync = self._should_sync(cuda_sync) + with self._torch.inference_mode(): + for _ in range(warmup): + self._forward(prepared) + if sync: + self._sync() + + latencies: list[float] = [] + t0 = time.perf_counter() + with self._torch.inference_mode(): + for _ in range(repeat): + if sync: + self._sync() + start = time.perf_counter() + self._forward(prepared) + if sync: + self._sync() + latencies.append((time.perf_counter() - start) * 1e3) + total = time.perf_counter() - t0 + arr = np.asarray(latencies, dtype=np.float64) + return { + "repeat": repeat, + "warmup": warmup, + "throughput_fps": repeat / total if total else None, + "latency_ms": { + "mean": float(arr.mean()) if arr.size else None, + "median": float(np.median(arr)) if arr.size else None, + "p90": float(np.percentile(arr, 90)) if arr.size else None, + "p95": float(np.percentile(arr, 95)) if arr.size else None, + "min": float(arr.min()) if arr.size else None, + "max": float(arr.max()) if arr.size else None, + }, + } + + def summary(self) -> dict[str, Any]: + return { + "model": self.model_path, + "device": str(self.device), + "dtype": str(self.dtype), + "outputs": list(self.output_names), + } + + # Internal helpers ----------------------------------------------------- + def _load_model(self): + torch = self._torch + model = torch.jit.load(self.model_path, map_location=self.device) + model.eval() + with torch.no_grad(): + model.to(self.device) + if self.dtype == torch.float16: + model.half() + elif self.dtype == torch.float32: + model.float() + else: + model.to(dtype=self.dtype) + return model + + def _prepare_feed(self, feed: Mapping[str, Any]) -> list[Any]: + if not isinstance(feed, Mapping): + raise TypeError("TorchEngine feed must be a mapping.") + prepared: list[Any] = [] + for value in feed.values(): + tensor = self._as_tensor(value) + tensor = tensor.to(self.device) + tensor = tensor.to(self.dtype) + prepared.append(tensor) + return prepared + + def _forward(self, inputs: Sequence[Any]) -> Any: + if len(inputs) == 1: + return self._model(inputs[0]) + return self._model(*inputs) + + def _format_outputs(self, outputs: Any) -> dict[str, np.ndarray]: + torch = self._torch + if isinstance(outputs, Mapping): + return { + str(key): self._tensor_to_numpy(value) + for key, value in outputs.items() + } + if isinstance(outputs, (list, tuple)): + names = self._normalize_output_names(len(outputs)) + return { + names[idx]: self._tensor_to_numpy(value) + for idx, value in enumerate(outputs) + } + if torch.is_tensor(outputs): + name = self.output_names[0] if self.output_names else "output" + return {name: self._tensor_to_numpy(outputs)} + raise TypeError( + "Unsupported TorchScript output. Expected tensor/dict/sequence." + ) + + def _normalize_output_names(self, count: int) -> tuple[str, ...]: + if self.output_names: + if len(self.output_names) != count: + raise ValueError( + f"output_names has {len(self.output_names)} entries but " + f"model produced {count} outputs." + ) + return self.output_names + return tuple(f"output_{idx}" for idx in range(count)) + + def _tensor_to_numpy(self, tensor: Any) -> np.ndarray: + torch = self._torch + if not torch.is_tensor(tensor): + raise TypeError("Model outputs must be torch.Tensor instances.") + array = tensor.detach().to("cpu") + if array.dtype != torch.float32: + array = array.to(torch.float32) + return array.contiguous().numpy() + + def _as_tensor(self, value: Any): + torch = self._torch + if torch.is_tensor(value): + return value + arr = np.asarray(value, dtype=np.float32) + return torch.from_numpy(arr) + + def _normalize_device(self, device: Any): + torch = self._torch + torch_device_type = getattr(torch, "device", None) + if isinstance(torch_device_type, type) and isinstance( + device, torch_device_type + ): + return device + return torch.device(device) + + def _normalize_dtype(self, dtype: Any | None): + torch = self._torch + if dtype is None or ( + isinstance(dtype, str) and dtype.strip().lower() == "auto" + ): + token = Path(self.model_path).name.lower() + if "fp16" in token and self._device_is_cuda(): + return torch.float16 + return torch.float32 + if isinstance(dtype, torch.dtype): + return dtype + if isinstance(dtype, str): + normalized = dtype.strip().lower() + if normalized in {"fp16", "float16", "half"}: + return torch.float16 + if normalized in {"fp32", "float32"}: + return torch.float32 + raise ValueError(f"Unsupported dtype specification '{dtype}'.") + + def _device_is_cuda(self) -> bool: + return getattr(self.device, "type", "") == "cuda" + + def _should_sync(self, override: bool | None) -> bool: + if override is not None: + return bool(override and self._device_is_cuda()) + return bool(self._cfg.cuda_sync and self._device_is_cuda()) + + def _sync(self) -> None: + if self._device_is_cuda(): + self._torch.cuda.synchronize(self.device) diff --git a/capybara/typing.py b/capybara/typing.py index 542d665..3bfdaaa 100644 --- a/capybara/typing.py +++ b/capybara/typing.py @@ -1,5 +1,3 @@ -from typing import Union - import numpy as np -_Number = Union[np.number, int, float] +_Number = np.number | int | float diff --git a/capybara/utils/__init__.py b/capybara/utils/__init__.py index 1a31110..1f9cdc0 100644 --- a/capybara/utils/__init__.py +++ b/capybara/utils/__init__.py @@ -1,7 +1,21 @@ -from .custom_path import * -from .custom_tqdm import * -from .files_utils import * -from .powerdict import * -from .system_info import * -from .time import * -from .utils import * +from __future__ import annotations + +from pathlib import Path + +from .custom_path import copy_path, get_curdir, rm_path +from .powerdict import PowerDict +from .time import Timer, now +from .utils import colorstr, download_from_google, make_batch + +__all__ = [ + "Path", + "PowerDict", + "Timer", + "colorstr", + "copy_path", + "download_from_google", + "get_curdir", + "make_batch", + "now", + "rm_path", +] diff --git a/capybara/utils/custom_path.py b/capybara/utils/custom_path.py index 0a32200..b4a2857 100644 --- a/capybara/utils/custom_path.py +++ b/capybara/utils/custom_path.py @@ -1,14 +1,10 @@ import shutil from pathlib import Path -from typing import Union -__all__ = ['Path', 'get_curdir', 'rm_path'] +__all__ = ["Path", "get_curdir", "rm_path"] -def get_curdir( - path: Union[str, Path], - absolute: bool = True -) -> Path: +def get_curdir(path: str | Path, absolute: bool = True) -> Path: """ Function to get the path of current workspace. @@ -23,15 +19,15 @@ def get_curdir( return path.parent.resolve() if absolute else path.parent -def rm_path(path: Union[str, Path]): +def rm_path(path: str | Path): pth = Path(path) - if pth.is_dir(): - pth.rmdir() - else: - pth.unlink() + if pth.is_dir() and not pth.is_symlink(): + shutil.rmtree(pth) + return + pth.unlink() -def copy_path(path_src: Union[str, Path], path_dst: Union[str, Path]): +def copy_path(path_src: str | Path, path_dst: str | Path): if not Path(path_src).is_file(): raise ValueError(f'Input path: "{path_src}" is invaild.') shutil.copy(path_src, path_dst) diff --git a/capybara/utils/custom_tqdm.py b/capybara/utils/custom_tqdm.py index 747bedd..67e0e56 100644 --- a/capybara/utils/custom_tqdm.py +++ b/capybara/utils/custom_tqdm.py @@ -1,24 +1,23 @@ from collections.abc import Sized +from typing import Any, cast from tqdm import tqdm as _tqdm -__all__ = ['Tqdm'] +__all__ = ["Tqdm"] class Tqdm(_tqdm): - def __init__(self, iterable=None, desc=None, smoothing=0, **kwargs): - - if 'total' in kwargs: - total = kwargs.pop('total', None) + if "total" in kwargs: + total = kwargs.pop("total", None) else: total = len(iterable) if isinstance(iterable, Sized) else None super().__init__( - iterable=iterable, + iterable=cast(Any, iterable), desc=desc, total=total, smoothing=smoothing, dynamic_ncols=True, - **kwargs + **kwargs, ) diff --git a/capybara/utils/files_utils.py b/capybara/utils/files_utils.py index 47821ea..31a7e35 100644 --- a/capybara/utils/files_utils.py +++ b/capybara/utils/files_utils.py @@ -1,7 +1,7 @@ import errno import hashlib import os -from typing import Any, List, Tuple, Union +from typing import Any import dill import numpy as np @@ -13,12 +13,19 @@ from .custom_tqdm import Tqdm __all__ = [ - 'gen_md5', 'get_files', 'load_json', 'dump_json', 'load_pickle', - 'dump_pickle', 'load_yaml', 'dump_yaml', 'img_to_md5', + "dump_json", + "dump_pickle", + "dump_yaml", + "gen_md5", + "get_files", + "img_to_md5", + "load_json", + "load_pickle", + "load_yaml", ] -def gen_md5(file: Union[str, Path], block_size: int = 256 * 128) -> str: +def gen_md5(file: str | Path, block_size: int = 256 * 128) -> str: """ This function is to gen md5 based on given file. @@ -32,9 +39,9 @@ def gen_md5(file: Union[str, Path], block_size: int = 256 * 128) -> str: Returns: md5 (str) """ - with open(str(file), 'rb') as f: + with open(str(file), "rb") as f: md5 = hashlib.md5() - for chunk in iter(lambda: f.read(block_size), b''): + for chunk in iter(lambda: f.read(block_size), b""): md5.update(chunk) return str(md5.hexdigest()) @@ -47,7 +54,7 @@ def img_to_md5(img: np.ndarray) -> str: return str(md5_hash.hexdigest()) -def load_json(path: Union[Path, str], **kwargs) -> dict: +def load_json(path: Path | str, **kwargs) -> dict: """ Function to read ujson. @@ -57,12 +64,12 @@ def load_json(path: Union[Path, str], **kwargs) -> dict: Returns: dict: ujson load to dictionary """ - with open(str(path), 'r') as f: + with open(str(path)) as f: data = ujson.load(f, **kwargs) return data -def dump_json(obj: Any, path: Union[str, Path] = None, **kwargs) -> None: +def dump_json(obj: Any, path: str | Path | None = None, **kwargs) -> None: """ Function to write obj to ujson @@ -71,23 +78,23 @@ def dump_json(obj: Any, path: Union[str, Path] = None, **kwargs) -> None: path (Union[str, Path]): ujson file's path """ dump_options = { - 'sort_keys': False, - 'indent': 2, - 'ensure_ascii': False, - 'escape_forward_slashes': False, + "sort_keys": False, + "indent": 2, + "ensure_ascii": False, + "escape_forward_slashes": False, } dump_options.update(kwargs) if path is None: - path = Path.cwd() / 'tmp.json' + path = Path.cwd() / "tmp.json" - with open(str(path), 'w') as f: + with open(str(path), "w") as f: ujson.dump(obj, f, **dump_options) def get_files( - folder: Union[str, Path], - suffix: Union[str, List[str], Tuple[str]] = None, + folder: str | Path, + suffix: str | list[str] | tuple[str, ...] | None = None, recursive: bool = True, return_pathlib: bool = True, sort_path: bool = True, @@ -126,25 +133,28 @@ def get_files( folder = Path(folder) if not folder.is_dir(): raise FileNotFoundError( - errno.ENOENT, os.strerror(errno.ENOENT), str(folder)) + errno.ENOENT, os.strerror(errno.ENOENT), str(folder) + ) if not isinstance(suffix, (str, list, tuple)) and suffix is not None: - raise TypeError('suffix must be a string, list or tuple.') + raise TypeError("suffix must be a string, list or tuple.") # checking suffix suffix = [suffix] if isinstance(suffix, str) else suffix if suffix is not None and ignore_letter_case: suffix = [s.lower() for s in suffix] - if recursive: - files_gen = folder.rglob('*') - else: - files_gen = folder.glob('*') + files_gen = folder.rglob("*") if recursive else folder.glob("*") files = [] for f in Tqdm(files_gen, leave=False): - if suffix is None or (ignore_letter_case and f.suffix.lower() in suffix) \ - or (not ignore_letter_case and f.suffix in suffix): + if not f.is_file(): + continue + if ( + suffix is None + or (ignore_letter_case and f.suffix.lower() in suffix) + or (not ignore_letter_case and f.suffix in suffix) + ): files.append(f.absolute()) if not return_pathlib: @@ -156,7 +166,7 @@ def get_files( return files -def load_pickle(path: Union[str, Path]): +def load_pickle(path: str | Path): """ Function to load a pickle. @@ -166,11 +176,11 @@ def load_pickle(path: Union[str, Path]): Returns: loaded_pickle (dict): loaded pickle. """ - with open(str(path), 'rb') as f: + with open(str(path), "rb") as f: return dill.load(f) -def dump_pickle(obj, path: Union[str, Path]): +def dump_pickle(obj, path: str | Path): """ Function to dump an obj to a pickle file. @@ -178,11 +188,11 @@ def dump_pickle(obj, path: Union[str, Path]): obj: object to be dump. path (Union[str, Path]): file path. """ - with open(str(path), 'wb') as f: + with open(str(path), "wb") as f: dill.dump(obj, f) -def load_yaml(path: Union[Path, str]) -> dict: +def load_yaml(path: Path | str) -> dict: """ Function to read yaml. @@ -192,12 +202,12 @@ def load_yaml(path: Union[Path, str]) -> dict: Returns: dict: yaml load to dictionary """ - with open(str(path), 'r') as f: + with open(str(path)) as f: data = yaml.load(f, Loader=yaml.FullLoader) return data -def dump_yaml(obj, path: Union[str, Path] = None, **kwargs): +def dump_yaml(obj, path: str | Path | None = None, **kwargs) -> None: """ Function to dump an obj to a yaml file. @@ -205,14 +215,11 @@ def dump_yaml(obj, path: Union[str, Path] = None, **kwargs): obj: object to be dump. path (Union[str, Path]): file path. """ - dump_options = { - 'indent': 2, - 'sort_keys': True - } + dump_options = {"indent": 2, "sort_keys": True} dump_options.update(kwargs) if path is None: - path = Path.cwd() / 'tmp.yaml' + path = Path.cwd() / "tmp.yaml" - with open(str(path), 'w') as f: + with open(str(path), "w") as f: yaml.dump(obj, f, **dump_options) diff --git a/capybara/utils/powerdict.py b/capybara/utils/powerdict.py index c85a08c..a71b15f 100644 --- a/capybara/utils/powerdict.py +++ b/capybara/utils/powerdict.py @@ -1,14 +1,22 @@ from collections.abc import Mapping from pprint import pprint +from typing import Any -from .files_utils import (dump_json, dump_pickle, dump_yaml, load_json, - load_pickle, load_yaml) +from .files_utils import ( + dump_json, + dump_pickle, + dump_yaml, + load_json, + load_pickle, + load_yaml, +) -__all__ = ['PowerDict'] +__all__ = ["PowerDict"] +_MISSING = object() -class PowerDict(dict): +class PowerDict(dict): def __init__(self, d=None, **kwargs): """ This class is used to create a namespace dictionary with freeze and melt functions. @@ -25,12 +33,23 @@ def __init__(self, d=None, **kwargs): self._frozen = False + def __getattr__(self, key: str) -> Any: + try: + return self[key] + except KeyError as exc: + raise AttributeError(key) from exc + def __set(self, key, value): if not self.is_frozen: - if isinstance(value, Mapping) and not isinstance(value, self.__class__): + if isinstance(value, Mapping) and not isinstance( + value, self.__class__ + ): value = self.__class__(value) if isinstance(value, (list, tuple)): - value = [self.__class__(v) if isinstance(v, dict) else v for v in value] + value = [ + self.__class__(v) if isinstance(v, dict) else v + for v in value + ] super().__setattr__(key, value) super().__setitem__(key, value) else: @@ -44,49 +63,57 @@ def __del(self, key): raise ValueError(f"PowerDict is frozen. '{key}' cannot be del.") def __setattr__(self, key, value): - if key == '_frozen': + if key == "_frozen": super().__setattr__(key, value) else: self.__set(key, value) def __delattr__(self, key): - if key == '_frozen': + if key == "_frozen": raise KeyError("Can not del '_frozen'.") else: self.__del(key) def __setitem__(self, key, value): - if key == '_frozen': + if key == "_frozen": raise KeyError("Can not set '_frozen' as an item.") else: self.__set(key, value) def __delitem__(self, key): - if key == '_frozen': + if key == "_frozen": raise KeyError("There is not _frozen in items.") else: self.__del(key) def update(self, e=None, **f): if self.is_frozen: - raise Warning(f'PowerDict is frozen and cannot be update.') + raise ValueError("PowerDict is frozen. Update is not allowed.") + + if e is None: + d: dict = {} else: - d = e or dict() - d.update(f) - for k in d: - setattr(self, k, d[k]) + d = dict(e) + d.update(f) + for key, value in d.items(): + setattr(self, key, value) - def pop(self, key, d=None): + def pop(self, key, default=_MISSING): if self.is_frozen: - raise Warning(f"PowerDict is frozen and cannot be pop.") - else: - d = getattr(self, key, d) - delattr(self, key) - return d + raise ValueError("PowerDict is frozen. Pop is not allowed.") + + if key in self: + value = self[key] + del self[key] + return value + + if default is _MISSING: + raise KeyError(key) + return default def freeze(self): self._frozen = True - for v in self.values(): + for v in dict.values(self): if isinstance(v, PowerDict): v.freeze() if isinstance(v, (list, tuple)): @@ -96,7 +123,7 @@ def freeze(self): def melt(self): self._frozen = False - for v in self.values(): + for v in dict.values(self): if isinstance(v, PowerDict): v.melt() if isinstance(v, (list, tuple)): @@ -106,20 +133,22 @@ def melt(self): @property def is_frozen(self): - return getattr(self, '_frozen', False) + return getattr(self, "_frozen", False) def __deepcopy__(self, memo): if self._frozen: - raise Warning('PowerDict is frozen and cannot be copy.') + raise Warning("PowerDict is frozen and cannot be copy.") return self.__class__(self) def to_dict(self): out = {} - for k, v in self.items(): + for k, v in dict.items(self): if isinstance(v, PowerDict): out[k] = v.to_dict() elif isinstance(v, list): - out[k] = [x.to_dict() if isinstance(x, PowerDict) else x for x in v] + out[k] = [ + x.to_dict() if isinstance(x, PowerDict) else x for x in v + ] else: out[k] = v @@ -135,7 +164,7 @@ def to_yaml(self, path): def to_txt(self, path): d = self.to_dict() - with open(path, 'w') as f: + with open(path, "w") as f: pprint(d, f) def to_pickle(self, path): diff --git a/capybara/utils/system_info.py b/capybara/utils/system_info.py index 3840f8e..9e88002 100644 --- a/capybara/utils/system_info.py +++ b/capybara/utils/system_info.py @@ -1,13 +1,18 @@ import platform import socket import subprocess +from importlib import import_module +from typing import Any, cast -import psutil +import psutil # type: ignore import requests __all__ = [ - "get_package_versions", "get_gpu_cuda_versions", "get_system_info", - "get_cpu_info", "get_external_ip" + "get_cpu_info", + "get_external_ip", + "get_gpu_cuda_versions", + "get_package_versions", + "get_system_info", ] @@ -22,28 +27,32 @@ def get_package_versions(): # PyTorch try: - import torch + import torch # type: ignore + versions_info["PyTorch Version"] = torch.__version__ except Exception as e: versions_info["PyTorch Error"] = str(e) # PyTorch Lightning try: - import pytorch_lightning as pl - versions_info["PyTorch Lightning Version"] = pl.__version__ + import pytorch_lightning as pl # type: ignore + + versions_info["PyTorch Lightning Version"] = str( + getattr(pl, "__version__", "unknown") + ) except Exception as e: versions_info["PyTorch Lightning Error"] = str(e) # TensorFlow try: - import tensorflow as tf + tf = cast(Any, import_module("tensorflow")) versions_info["TensorFlow Version"] = tf.__version__ except Exception as e: versions_info["TensorFlow Error"] = str(e) # Keras try: - import keras + keras = cast(Any, import_module("keras")) versions_info["Keras Version"] = keras.__version__ except Exception as e: versions_info["Keras Error"] = str(e) @@ -51,27 +60,31 @@ def get_package_versions(): # NumPy try: import numpy as np + versions_info["NumPy Version"] = np.__version__ except Exception as e: versions_info["NumPy Error"] = str(e) # Pandas try: - import pandas as pd + import pandas as pd # type: ignore + versions_info["Pandas Version"] = pd.__version__ except Exception as e: versions_info["Pandas Error"] = str(e) # Scikit-learn try: - import sklearn + import sklearn # type: ignore + versions_info["Scikit-learn Version"] = sklearn.__version__ except Exception as e: versions_info["Scikit-learn Error"] = str(e) # OpenCV try: - import cv2 + import cv2 # type: ignore + versions_info["OpenCV Version"] = cv2.__version__ except Exception as e: versions_info["OpenCV Error"] = str(e) @@ -93,15 +106,16 @@ def get_gpu_cuda_versions(): # Attempt to retrieve CUDA version using PyTorch try: - import torch - cuda_version = torch.version.cuda + import torch # type: ignore + + cuda_version = getattr(getattr(torch, "version", None), "cuda", None) except ImportError: pass # If not retrieved via PyTorch, try using TensorFlow if not cuda_version: try: - import tensorflow as tf + tf = cast(Any, import_module("tensorflow")) cuda_version = tf.version.COMPILER_VERSION except ImportError: pass @@ -109,25 +123,33 @@ def get_gpu_cuda_versions(): # If still not retrieved, try using CuPy if not cuda_version: try: - import cupy + cupy = cast(Any, import_module("cupy")) cuda_version = cupy.cuda.runtime.runtimeGetVersion() except ImportError: - cuda_version = "Error: None of PyTorch, TensorFlow, or CuPy are installed." + cuda_version = ( + "Error: None of PyTorch, TensorFlow, or CuPy are installed." + ) # Try to get Nvidia driver version using nvidia-smi command try: - smi_output = subprocess.check_output([ - 'nvidia-smi', - '--query-gpu=driver_version', - '--format=csv,noheader,nounits' - ]).decode('utf-8').strip() - nvidia_driver_version = smi_output.split('\n')[0] + smi_output = ( + subprocess.check_output( + [ + "nvidia-smi", + "--query-gpu=driver_version", + "--format=csv,noheader,nounits", + ] + ) + .decode("utf-8") + .strip() + ) + nvidia_driver_version = smi_output.split("\n")[0] except Exception as e: nvidia_driver_version = f"Error getting NVIDIA driver version: {e}" return { "CUDA Version": cuda_version, - "NVIDIA Driver Version": nvidia_driver_version + "NVIDIA Driver Version": nvidia_driver_version, } @@ -147,15 +169,21 @@ def get_cpu_info(): elif platform.system() == "Linux": # For Linux command = "cat /proc/cpuinfo | grep 'model name' | uniq" - return subprocess.check_output(command, shell=True).strip().decode().split(":")[1].strip() + return ( + subprocess.check_output(command, shell=True) + .strip() + .decode() + .split(":")[1] + .strip() + ) else: return "N/A" def get_external_ip(): try: - response = requests.get('https://httpbin.org/ip') - return response.json()['origin'] + response = requests.get("https://httpbin.org/ip") + return response.json()["origin"] except Exception as e: return f"Error obtaining IP: {e}" @@ -171,19 +199,31 @@ def get_system_info(): "OS Version": platform.platform(), "CPU Model": get_cpu_info(), "Physical CPU Cores": psutil.cpu_count(logical=False), - "Logical CPU Cores (incl. hyper-threading)": psutil.cpu_count(logical=True), - "Total RAM (GB)": round(psutil.virtual_memory().total / (1024 ** 3), 2), - "Available RAM (GB)": round(psutil.virtual_memory().available / (1024 ** 3), 2), - "Disk Total (GB)": round(psutil.disk_usage('/').total / (1024 ** 3), 2), - "Disk Used (GB)": round(psutil.disk_usage('/').used / (1024 ** 3), 2), - "Disk Free (GB)": round(psutil.disk_usage('/').free / (1024 ** 3), 2) + "Logical CPU Cores (incl. hyper-threading)": psutil.cpu_count( + logical=True + ), + "Total RAM (GB)": round(psutil.virtual_memory().total / (1024**3), 2), + "Available RAM (GB)": round( + psutil.virtual_memory().available / (1024**3), 2 + ), + "Disk Total (GB)": round(psutil.disk_usage("/").total / (1024**3), 2), + "Disk Used (GB)": round(psutil.disk_usage("/").used / (1024**3), 2), + "Disk Free (GB)": round(psutil.disk_usage("/").free / (1024**3), 2), } # Try to fetch GPU information using nvidia-smi command try: - gpu_info = subprocess.check_output( - ['nvidia-smi', '--query-gpu=name', '--format=csv,noheader,nounits'] - ).decode('utf-8').strip() + gpu_info = ( + subprocess.check_output( + [ + "nvidia-smi", + "--query-gpu=name", + "--format=csv,noheader,nounits", + ] + ) + .decode("utf-8") + .strip() + ) info["GPU Info"] = gpu_info except Exception: info["GPU Info"] = "N/A or Error" @@ -191,22 +231,26 @@ def get_system_info(): # Get network information addrs = psutil.net_if_addrs() info["IPV4 Address"] = [ - addr.address for addr in addrs.get('enp5s0', []) if addr.family == socket.AF_INET + addr.address + for addr in addrs.get("enp5s0", []) + if addr.family == socket.AF_INET ] info["IPV4 Address (External)"] = get_external_ip() # Determine platform and choose correct address family for MAC - if hasattr(socket, 'AF_LINK'): - AF_LINK = socket.AF_LINK - elif hasattr(psutil, 'AF_LINK'): - AF_LINK = psutil.AF_LINK - else: - raise Exception( - "Cannot determine the correct AF_LINK value for this platform.") + af_link = getattr(socket, "AF_LINK", None) + if af_link is None: + af_link = getattr(psutil, "AF_LINK", None) + if af_link is None: + raise RuntimeError( + "Cannot determine the correct AF_LINK value for this platform." + ) info["MAC Address"] = [ - addr.address for addr in addrs.get('enp5s0', []) if addr.family == AF_LINK + addr.address + for addr in addrs.get("enp5s0", []) + if addr.family == af_link ] return info diff --git a/capybara/utils/time.py b/capybara/utils/time.py index 31209cf..2937264 100644 --- a/capybara/utils/time.py +++ b/capybara/utils/time.py @@ -1,27 +1,26 @@ import time from datetime import datetime from time import struct_time -from typing import Union import numpy as np from .utils import colorstr __all__ = [ - 'Timer', - 'now', - 'timestamp2datetime', - 'timestamp2time', - 'timestamp2str', - 'time2datetime', - 'time2timestamp', - 'time2str', - 'datetime2time', - 'datetime2timestamp', - 'datetime2str', - 'str2time', - 'str2datetime', - 'str2timestamp', + "Timer", + "datetime2str", + "datetime2time", + "datetime2timestamp", + "now", + "str2datetime", + "str2time", + "str2timestamp", + "time2datetime", + "time2str", + "time2timestamp", + "timestamp2datetime", + "timestamp2str", + "timestamp2time", ] __doc__ = """ @@ -32,18 +31,18 @@ ==========|========================================================== Directive | Meaning ==========|========================================================== - %a | Locale’s abbreviated weekday name. - %A | Locale’s full weekday name. - %b | Locale’s abbreviated month name. - %B | Locale’s full month name. - %c | Locale’s appropriate date and time representation. + %a | Locale's abbreviated weekday name. + %A | Locale's full weekday name. + %b | Locale's abbreviated month name. + %B | Locale's full month name. + %c | Locale's appropriate date and time representation. %d | Day of the month as a decimal number [01,31]. %H | Hour (24-hour clock) as a decimal number [00,23]. %I | Hour (12-hour clock) as a decimal number [01,12]. %j | Day of the year as a decimal number [001,366]. %m | Month as a decimal number [01,12]. %M | Minute as a decimal number [00,59]. - %p | Locale’s equivalent of either AM or PM. + %p | Locale's equivalent of either AM or PM. %S | Second as a decimal number [00,61]. %U | Week number of the year (Sunday as the first day of the week) | as a decimal number [00,53]. All days in a new year preceding @@ -52,8 +51,8 @@ %W | Week number of the year (Monday as the first day of the week) | as a decimal number [00,53]. All days in a new year preceding | the first Monday are considered to be in week 0. - %x | Locale’s appropriate date representation. - %X | Locale’s appropriate time representation. + %x | Locale's appropriate date representation. + %X | Locale's appropriate time representation. %y | Year without century as a decimal number [00,99]. %Y | Year with century as a decimal number. %z | Time zone offset indicating a positive or negative time difference @@ -95,27 +94,33 @@ def testing_function(*args, **kwargs): do something... """ - def __init__(self, precision: int = 5, desc: str = None, verbose: bool = False): + def __init__( + self, + precision: int = 5, + desc: str | None = None, + verbose: bool = False, + ): self.precision = precision self.desc = desc self.verbose = verbose self.__record = [] def tic(self): - """ start timer """ + """start timer""" if self.desc is not None and self.verbose: - print(colorstr(self.desc, 'yellow')) + print(colorstr(self.desc, "yellow")) self.time = time.perf_counter() def toc(self, verbose=False): - """ get time lag from start """ - if getattr(self, 'time', None) is None: + """get time lag from start""" + if getattr(self, "time", None) is None: raise ValueError( - f'The timer has not been started. Tic the timer first.') + "The timer has not been started. Tic the timer first." + ) total = round(time.perf_counter() - self.time, self.precision) if verbose or self.verbose: - print(colorstr(f'Cost: {total} sec', 'white')) + print(colorstr(f"Cost: {total} sec", "white")) self.__record.append(total) return total @@ -126,10 +131,12 @@ def warp(*args, **kwargs): result = fcn(*args, **kwargs) self.toc() return result + return warp def __enter__(self): self.tic() + return self def __exit__(self, type, value, traceback): self.dt = self.toc(True) @@ -137,28 +144,28 @@ def __exit__(self, type, value, traceback): def clear_record(self): self.__record = [] - @ property + @property def mean(self): if len(self.__record): return np.array(self.__record).mean().round(self.precision) - @ property + @property def max(self): if len(self.__record): return np.array(self.__record).max().round(self.precision) - @ property + @property def min(self): if len(self.__record): return np.array(self.__record).min().round(self.precision) - @ property + @property def std(self): if len(self.__record): return np.array(self.__record).std().round(self.precision) -def now(typ: str = 'timestamp', fmt: str = None): +def now(typ: str = "timestamp", fmt: str | None = None): """ Get now time. Specify the output type of time, or give the formatted rule to get the time string, eg: now(fmt='%Y-%m-%d'). @@ -171,82 +178,82 @@ def now(typ: str = 'timestamp', fmt: str = None): Raises: ValueError: Unsupported type error. """ - if typ == 'timestamp': + if typ == "timestamp": t = time.time() - elif typ == 'datetime': + elif typ == "datetime": t = datetime.now() - elif typ == 'time': + elif typ == "time": t = time.gmtime(time.time()) else: - raise ValueError(f'Unsupported input {typ} type of time.') + raise ValueError(f"Unsupported input {typ} type of time.") - if fmt != None: + if fmt is not None: t = timestamp2str(time.time(), fmt=fmt) return t -def timestamp2datetime(ts: Union[int, float]): +def timestamp2datetime(ts: int | float): return datetime.fromtimestamp(ts) -def timestamp2time(ts: Union[int, float]): +def timestamp2time(ts: int | float): return time.localtime(ts) -def timestamp2str(ts: Union[int, float], fmt: str): +def timestamp2str(ts: int | float, fmt: str): return time2str(timestamp2time(ts), fmt) def time2datetime(t: struct_time): if not isinstance(t, struct_time): - raise TypeError(f'Input type: {type(t)} error.') + raise TypeError(f"Input type: {type(t)} error.") return datetime(*t[0:6]) def time2timestamp(t: struct_time): if not isinstance(t, struct_time): - raise TypeError(f'Input type: {type(t)} error.') + raise TypeError(f"Input type: {type(t)} error.") return time.mktime(t) def time2str(t: struct_time, fmt: str): if not isinstance(t, struct_time): - raise TypeError(f'Input type: {type(t)} error.') + raise TypeError(f"Input type: {type(t)} error.") return time.strftime(fmt, t) def datetime2time(dt: datetime): if not isinstance(dt, datetime): - raise TypeError(f'Input type: {type(dt)} error.') + raise TypeError(f"Input type: {type(dt)} error.") return dt.timetuple() def datetime2timestamp(dt: datetime): if not isinstance(dt, datetime): - raise TypeError(f'Input type: {type(dt)} error.') + raise TypeError(f"Input type: {type(dt)} error.") return dt.timestamp() def datetime2str(dt: datetime, fmt: str): if not isinstance(dt, datetime): - raise TypeError(f'Input type: {type(dt)} error.') + raise TypeError(f"Input type: {type(dt)} error.") return dt.strftime(fmt) def str2time(s: str, fmt: str): if not isinstance(s, str): - raise TypeError(f'Input type: {type(s)} error.') + raise TypeError(f"Input type: {type(s)} error.") return time.strptime(s, fmt) def str2datetime(s: str, fmt: str): if not isinstance(s, str): - raise TypeError(f'Input type: {type(s)} error.') + raise TypeError(f"Input type: {type(s)} error.") return datetime.strptime(s, fmt) def str2timestamp(s: str, fmt: str): if not isinstance(s, str): - raise TypeError(f'Input type: {type(s)} error.') + raise TypeError(f"Input type: {type(s)} error.") return time2timestamp(str2time(s, fmt)) diff --git a/capybara/utils/utils.py b/capybara/utils/utils.py index e2b3984..b2782a0 100644 --- a/capybara/utils/utils.py +++ b/capybara/utils/utils.py @@ -1,7 +1,8 @@ -import os import re +from collections.abc import Generator, Iterable +from pathlib import Path from pprint import pprint -from typing import Any, Generator, Iterable, List, Union +from typing import Any, cast import requests from bs4 import BeautifulSoup @@ -10,15 +11,16 @@ from ..enums import COLORSTR, FORMATSTR __all__ = [ - 'make_batch', 'colorstr', 'pprint', - 'download_from_google', + "colorstr", + "download_from_google", + "make_batch", + "pprint", ] def make_batch( - data: Union[Iterable, Generator], - batch_size: int -) -> Generator[List, None, None]: + data: Iterable | Generator, batch_size: int +) -> Generator[list, None, None]: """ This function is used to make data to batched data. @@ -41,8 +43,8 @@ def make_batch( def colorstr( obj: Any, - color: Union[COLORSTR, int, str] = COLORSTR.BLUE, - fmt: Union[FORMATSTR, int, str] = FORMATSTR.BOLD + color: COLORSTR | int | str = COLORSTR.BLUE, + fmt: FORMATSTR | int | str = FORMATSTR.BOLD, ) -> str: """ This function is make colorful string for python. @@ -66,11 +68,13 @@ def colorstr( fmt = fmt.upper() color_code = COLORSTR.obj_to_enum(color).value format_code = FORMATSTR.obj_to_enum(fmt).value - color_string = f'\033[{format_code};{color_code}m{obj}\033[0m' + color_string = f"\033[{format_code};{color_code}m{obj}\033[0m" return color_string -def download_from_google(file_id: str, file_name: str, target: str = "."): +def download_from_google( + file_id: str, file_name: str, target: str | Path = "." +) -> Path: """ Downloads a file from Google Drive, handling potential confirmation tokens for large files. @@ -103,16 +107,13 @@ def download_from_google(file_id: str, file_name: str, target: str = "."): target="./downloads" ) """ - # 第一次嘗試:docs.google.com/uc?export=download&id=檔案ID + # 第一次嘗試: docs.google.com/uc?export=download&id=檔案ID base_url = "https://docs.google.com/uc" session = requests.Session() - params = { - "export": "download", - "id": file_id - } + params = {"export": "download", "id": file_id} response = session.get(base_url, params=params, stream=True) - # 如果已經出現 Content-Disposition,代表直接拿到檔案 + # 如果已經出現 Content-Disposition, 代表直接拿到檔案 if "content-disposition" not in response.headers: # 先嘗試從 cookies 拿 token token = None @@ -121,37 +122,44 @@ def download_from_google(file_id: str, file_name: str, target: str = "."): token = v break - # 如果 cookies 沒有,就從 HTML 解析 + # 如果 cookies 沒有, 就從 HTML 解析 if not token: soup = BeautifulSoup(response.text, "html.parser") - # 常見情況:HTML 裡面有一個 form#download-form + # 常見情況: HTML 裡面有一個 form#download-form download_form = soup.find("form", {"id": "download-form"}) - if download_form and download_form.get("action"): - # 將 action 裡的網址抓出來,可能是 drive.usercontent.google.com/download - download_url = download_form["action"] + download_form_tag = cast(Any, download_form) + if download_form_tag and download_form_tag.get("action"): + # 將 action 裡的網址抓出來, 可能是 drive.usercontent.google.com/download + download_url = str(download_form_tag["action"]) # 收集所有 hidden 欄位 - hidden_inputs = download_form.find_all( - "input", {"type": "hidden"}) + hidden_inputs = download_form_tag.find_all( + "input", {"type": "hidden"} + ) form_params = {} for inp in hidden_inputs: - if inp.get("name") and inp.get("value") is not None: - form_params[inp["name"]] = inp["value"] + inp_tag = cast(Any, inp) + name = inp_tag.get("name") + value = inp_tag.get("value") + if name and value is not None: + form_params[str(name)] = str(value) # 用這些參數去重新 GET - # 注意:原本 action 可能只是相對路徑,這裡直接用完整網址 + # 注意: 原本 action 可能只是相對路徑, 這裡直接用完整網址 response = session.get( - download_url, params=form_params, stream=True) + download_url, params=form_params, stream=True + ) else: # 或者有些情況是直接在 HTML 裡 search confirm=xxx - match = re.search(r'confirm=([0-9A-Za-z-_]+)', response.text) + match = re.search(r"confirm=([0-9A-Za-z-_]+)", response.text) if match: token = match.group(1) # 帶上 confirm token 再重新請求 docs.google.com params["confirm"] = token - response = session.get( - base_url, params=params, stream=True) + response = session.get(base_url, params=params, stream=True) else: - raise Exception("無法在回應中找到下載連結或確認參數,下載失敗。") + raise Exception( + "無法在回應中找到下載連結或確認參數, 下載失敗。" + ) else: # 直接帶上 cookies 抓到的 token 再打一次 @@ -159,25 +167,30 @@ def download_from_google(file_id: str, file_name: str, target: str = "."): response = session.get(base_url, params=params, stream=True) # 確保下載目錄存在 - os.makedirs(target, exist_ok=True) - file_path = os.path.join(target, file_name) + target_path = Path(target) + target_path.mkdir(parents=True, exist_ok=True) + file_path = target_path / file_name - # 開始把檔案 chunk 寫到本地,附帶進度條 + # 開始把檔案 chunk 寫到本地, 附帶進度條 try: - total_size = int(response.headers.get('content-length', 0)) - with open(file_path, "wb") as f, tqdm( - desc=file_name, - total=total_size, - unit="B", - unit_scale=True, - unit_divisor=1024, - ) as bar: + total_size = int(response.headers.get("content-length", 0)) + with ( + open(file_path, "wb") as f, + tqdm( + desc=file_name, + total=total_size, + unit="B", + unit_scale=True, + unit_divisor=1024, + ) as bar, + ): for chunk in response.iter_content(chunk_size=32768): if chunk: f.write(chunk) bar.update(len(chunk)) print(f"File successfully downloaded to: {file_path}") + return file_path except Exception as e: - raise Exception(f"File download failed: {e}") + raise RuntimeError(f"File download failed: {e}") from e diff --git a/capybara/vision/__init__.py b/capybara/vision/__init__.py index 79ee22d..be3a5ef 100644 --- a/capybara/vision/__init__.py +++ b/capybara/vision/__init__.py @@ -1,7 +1,17 @@ -from .functionals import * -from .geometric import * -from .improc import * -from .ipcam import * -from .morphology import * -from .videotools import * -from .visualization import * +from __future__ import annotations + +from . import ( + functionals, + geometric, + improc, + morphology, + videotools, +) + +__all__ = [ + "functionals", + "geometric", + "improc", + "morphology", + "videotools", +] diff --git a/capybara/vision/functionals.py b/capybara/vision/functionals.py index cc40437..fbdb81e 100644 --- a/capybara/vision/functionals.py +++ b/capybara/vision/functionals.py @@ -1,5 +1,5 @@ from bisect import bisect_left -from typing import List, Optional, Tuple, Union +from typing import Literal, overload import cv2 import numpy as np @@ -9,26 +9,31 @@ from .geometric import imresize __all__ = [ - 'meanblur', 'gaussianblur', 'medianblur', 'imcvtcolor', 'imadjust', 'pad', - 'imcropbox', 'imcropboxes', 'imbinarize', 'centercrop', 'imresize_and_pad_if_need' + "centercrop", + "gaussianblur", + "imadjust", + "imbinarize", + "imcropbox", + "imcropboxes", + "imcvtcolor", + "imresize_and_pad_if_need", + "meanblur", + "medianblur", + "pad", ] -_Ksize = Union[int, Tuple[int, int], np.ndarray] +_Ksize = int | tuple[int, int] | np.ndarray -def _check_ksize(ksize: _Ksize) -> Tuple[int, int]: +def _check_ksize(ksize: _Ksize) -> tuple[int, int]: if isinstance(ksize, int): - ksize = (ksize, ksize) - elif isinstance(ksize, tuple) and len(ksize) == 2 \ - and all(isinstance(val, int) for val in ksize): - ksize = tuple(ksize) - elif isinstance(ksize, np.ndarray) and ksize.ndim == 0: - ksize = (int(ksize), int(ksize)) - else: - raise TypeError(f'The input ksize = {ksize} is invalid.') - - ksize = tuple(int(val) for val in ksize) - return ksize + return ksize, ksize + if isinstance(ksize, tuple) and len(ksize) == 2: + return int(ksize[0]), int(ksize[1]) + if isinstance(ksize, np.ndarray) and ksize.ndim == 0: + value = int(ksize) + return value, value + raise TypeError(f"The input ksize = {ksize} is invalid.") def meanblur(img: np.ndarray, ksize: _Ksize = 3, **kwargs) -> np.ndarray: @@ -52,7 +57,9 @@ def meanblur(img: np.ndarray, ksize: _Ksize = 3, **kwargs) -> np.ndarray: return cv2.blur(img, ksize=ksize, **kwargs) -def gaussianblur(img: np.ndarray, ksize: _Ksize = 3, sigmaX: int = 0, **kwargs) -> np.ndarray: +def gaussianblur( + img: np.ndarray, ksize: _Ksize = 3, sigma_x: int = 0, **kwargs +) -> np.ndarray: """ Apply Gaussian blur to the input image. @@ -65,7 +72,7 @@ def gaussianblur(img: np.ndarray, ksize: _Ksize = 3, sigmaX: int = 0, **kwargs) size will be used. If a tuple (k_height, k_width) is provided, a rectangular kernel of the specified size will be used. Defaults to 3. - sigmaX (int, optional): + sigma_x (int, optional): The standard deviation in the X direction for Gaussian kernel. Defaults to 0. @@ -73,7 +80,8 @@ def gaussianblur(img: np.ndarray, ksize: _Ksize = 3, sigmaX: int = 0, **kwargs) np.ndarray: The blurred image. """ ksize = _check_ksize(ksize) - return cv2.GaussianBlur(img, ksize=ksize, sigmaX=sigmaX, **kwargs) + sigma_x = int(kwargs.pop("sigmaX", sigma_x)) + return cv2.GaussianBlur(img, ksize=ksize, sigmaX=sigma_x, **kwargs) def medianblur(img: np.ndarray, ksize: int = 3, **kwargs) -> np.ndarray: @@ -94,7 +102,7 @@ def medianblur(img: np.ndarray, ksize: int = 3, **kwargs) -> np.ndarray: return cv2.medianBlur(img, ksize=ksize, **kwargs) -def imcvtcolor(img: np.ndarray, cvt_mode: Union[int, str]) -> np.ndarray: +def imcvtcolor(img: np.ndarray, cvt_mode: int | str) -> np.ndarray: """ Convert the color space of the input image. @@ -113,18 +121,24 @@ def imcvtcolor(img: np.ndarray, cvt_mode: Union[int, str]) -> np.ndarray: Raises: ValueError: If the input cvt_mode is invalid or not supported. """ - code = getattr(cv2, f'COLOR_{cvt_mode}', None) - if code is None: - raise ValueError(f'Input cvt_mode: "{cvt_mode}" is invaild.') + if isinstance(cvt_mode, (int, np.integer)): + code = int(cvt_mode) + else: + mode = str(cvt_mode).upper() + if mode.startswith("COLOR_"): + mode = mode[len("COLOR_") :] + code = getattr(cv2, f"COLOR_{mode}", None) + if code is None: + raise ValueError(f'Input cvt_mode: "{cvt_mode}" is invaild.') img = cv2.cvtColor(img.copy(), code) return img def imadjust( img: np.ndarray, - rng_out: Tuple[int, int] = (0, 255), + rng_out: tuple[int, int] = (0, 255), gamma: float = 1.0, - color_base: str = 'BGR' + color_base: str = "BGR", ) -> np.ndarray: """ Adjust the intensity of an image. @@ -161,7 +175,7 @@ def imadjust( """ is_trans_hsv = False if img.ndim == 3: - img_hsv = imcvtcolor(img, f'{color_base}2HSV') + img_hsv = imcvtcolor(img, f"{color_base}2HSV") v = img_hsv[..., 2] is_trans_hsv = True else: @@ -171,7 +185,7 @@ def imadjust( total = v.size low_bound, upp_bound = total * 0.01, total * 0.99 - hist, _ = np.histogram(v.ravel(), 256, [0, 256]) + hist, _ = np.histogram(v.ravel(), 256, (0, 256)) cdf = hist.cumsum() rng_in = [bisect_left(cdf, low_bound), bisect_left(cdf, upp_bound)] if (rng_in[0] == rng_in[1]) or (rng_in[1] == 0): @@ -182,21 +196,21 @@ def imadjust( dist_out = rng_out[1] - rng_out[0] dst = np.clip((np.clip(v, rng_in[0], None) - rng_in[0]) / dist_in, 0, 1) - dst = (dst ** gamma) * dist_out + rng_out[0] - dst = np.clip(dst, rng_out[0], rng_out[1]).astype('uint8') + dst = (dst**gamma) * dist_out + rng_out[0] + dst = np.clip(dst, rng_out[0], rng_out[1]).astype("uint8") if is_trans_hsv: img_hsv[..., 2] = dst - dst = imcvtcolor(img_hsv, f'HSV2{color_base}') + dst = imcvtcolor(img_hsv, f"HSV2{color_base}") return dst def pad( img: np.ndarray, - pad_size: Union[int, Tuple[int, int], Tuple[int, int, int, int]], - pad_value: Optional[Union[int, Tuple[int, int, int]]] = 0, - pad_mode: Union[str, int, BORDER] = BORDER.CONSTANT + pad_size: int | tuple[int, int] | tuple[int, int, int, int], + pad_value: int | tuple[int, ...] | None = 0, + pad_mode: str | int | BORDER = BORDER.CONSTANT, ) -> np.ndarray: """ Pad the input image with specified padding size and mode. @@ -227,29 +241,62 @@ def pad( np.ndarray: The padded image. """ if isinstance(pad_size, int): - left = right = top = bottom = pad_size - elif len(pad_size) == 2: - top = bottom = pad_size[0] - left = right = pad_size[1] - elif len(pad_size) == 4: - top, bottom, left, right = pad_size + left = right = top = bottom = int(pad_size) + elif isinstance(pad_size, tuple) and len(pad_size) == 2: + top = bottom = int(pad_size[0]) + left = right = int(pad_size[1]) + elif isinstance(pad_size, tuple) and len(pad_size) == 4: + top, bottom, left, right = (int(v) for v in pad_size) else: raise ValueError( - f'pad_size is not an int, a tuple with 2 ints, or a tuple with 4 ints.') + "pad_size is not an int, a tuple with 2 ints, or a tuple with 4 ints." + ) pad_mode = BORDER.obj_to_enum(pad_mode) + + if pad_value is None: + pad_value = 0 + if pad_mode == BORDER.CONSTANT: - if img.ndim == 3 and isinstance(pad_value, int): - pad_value = (pad_value, ) * img.shape[-1] - cond1 = img.ndim == 3 and len(pad_value) == img.shape[-1] - cond2 = img.ndim == 2 and isinstance(pad_value, int) - if not (cond1 or cond2): + if img.ndim == 2: + channels = 1 + elif img.ndim == 3: + channels = int(img.shape[-1]) + else: + raise ValueError("img must be a 2D or 3D numpy image.") + + if isinstance(pad_value, int): + if channels == 1: + pad_value = int(pad_value) + else: + pad_value = tuple(int(pad_value) for _ in range(channels)) + elif isinstance(pad_value, tuple): + if channels == 1: + if len(pad_value) == 1: + pad_value = int(pad_value[0]) + else: + raise ValueError( + "pad_value must be an int when padding a grayscale image." + ) + elif len(pad_value) == channels: + pad_value = tuple(int(v) for v in pad_value) + else: + raise ValueError( + f"channel of image is {channels} but pad_value is {pad_value}." + ) + else: raise ValueError( - f'channel of image is {img.shape[-1]} but length of fill is {len(pad_value)}.') + "pad_value must be an int or a tuple matching channel count." + ) img = cv2.copyMakeBorder( - src=img, top=top, bottom=bottom, left=left, right=right, - borderType=pad_mode, value=pad_value + src=img, + top=top, + bottom=bottom, + left=left, + right=right, + borderType=pad_mode.value, + value=pad_value, ) return img @@ -257,7 +304,7 @@ def pad( def imcropbox( img: np.ndarray, - box: Union[Box, np.ndarray], + box: Box | np.ndarray, use_pad: bool = False, ) -> np.ndarray: """ @@ -302,7 +349,8 @@ def imcropbox( x1, y1, x2, y2 = box.astype(int) else: raise TypeError( - f'Input box is not of type Box or NumPy array with 4 elements.') + "Input box is not of type Box or NumPy array with 4 elements." + ) im_h, im_w = img.shape[:2] crop_x1 = max(0, x1) @@ -325,9 +373,9 @@ def imcropbox( def imcropboxes( img: np.ndarray, - boxes: Union[Boxes, np.ndarray], + boxes: Boxes | np.ndarray, use_pad: bool = False, -) -> List[np.ndarray]: +) -> list[np.ndarray]: """ Crop the input image using multiple boxes. """ @@ -335,9 +383,7 @@ def imcropboxes( def imbinarize( - img: np.ndarray, - threth: int = cv2.THRESH_BINARY, - color_base: str = 'BGR' + img: np.ndarray, threth: int = cv2.THRESH_BINARY, color_base: str = "BGR" ) -> np.ndarray: """ Function for image binarize. @@ -367,7 +413,7 @@ def imbinarize( Binary image. """ if img.ndim == 3: - img = imcvtcolor(img, f'{color_base}2GRAY') + img = imcvtcolor(img, f"{color_base}2GRAY") _, dst = cv2.threshold(img, 0, 255, type=threth + cv2.THRESH_OTSU) return dst @@ -383,22 +429,49 @@ def centercrop(img: np.ndarray) -> np.ndarray: Returns: np.ndarray: The cropped image. """ - box = Box([0, 0, img.shape[1], img.shape[0]]).square() + box = Box((0, 0, int(img.shape[1]), int(img.shape[0]))).square() return imcropbox(img, box) +@overload +def imresize_and_pad_if_need( + img: np.ndarray, + max_h: int, + max_w: int, + interpolation: str | int | INTER = INTER.BILINEAR, + pad_value: int | tuple[int, int, int] | None = 0, + pad_mode: str | int | BORDER = BORDER.CONSTANT, + return_scale: Literal[False] = False, +) -> np.ndarray: ... + + +@overload def imresize_and_pad_if_need( img: np.ndarray, max_h: int, max_w: int, - interpolation: Union[str, int, INTER] = INTER.BILINEAR, - pad_value: Optional[Union[int, Tuple[int, int, int]]] = 0, - pad_mode: Union[str, int, BORDER] = BORDER.CONSTANT, + interpolation: str | int | INTER = INTER.BILINEAR, + pad_value: int | tuple[int, int, int] | None = 0, + pad_mode: str | int | BORDER = BORDER.CONSTANT, + return_scale: Literal[True] = True, +) -> tuple[np.ndarray, float]: ... + + +def imresize_and_pad_if_need( + img: np.ndarray, + max_h: int, + max_w: int, + interpolation: str | int | INTER = INTER.BILINEAR, + pad_value: int | tuple[int, int, int] | None = 0, + pad_mode: str | int | BORDER = BORDER.CONSTANT, return_scale: bool = False, -): +) -> np.ndarray | tuple[np.ndarray, float]: raw_h, raw_w = img.shape[:2] scale = min(max_h / raw_h, max_w / raw_w) - dst_h, dst_w = min(int(raw_h * scale), max_h), min(int(raw_w * scale), max_w) + dst_h, dst_w = ( + min(int(raw_h * scale), max_h), + min(int(raw_w * scale), max_w), + ) img = imresize( img, (dst_h, dst_w), diff --git a/capybara/vision/geometric.py b/capybara/vision/geometric.py index 0cb072b..88f9bd5 100644 --- a/capybara/vision/geometric.py +++ b/capybara/vision/geometric.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Union +from typing import Literal, overload import cv2 import numpy as np @@ -7,17 +7,38 @@ from ..structures import Polygon, Polygons, order_points_clockwise __all__ = [ - 'imresize', 'imrotate90', 'imrotate', 'imwarp_quadrangle', - 'imwarp_quadrangles' + "imresize", + "imrotate", + "imrotate90", + "imwarp_quadrangle", + "imwarp_quadrangles", ] +@overload def imresize( img: np.ndarray, - size: Tuple[int, int], - interpolation: Union[str, int, INTER] = INTER.BILINEAR, + size: tuple[int | None, int | None], + interpolation: str | int | INTER = INTER.BILINEAR, + return_scale: Literal[False] = False, +) -> np.ndarray: ... + + +@overload +def imresize( + img: np.ndarray, + size: tuple[int | None, int | None], + interpolation: str | int | INTER = INTER.BILINEAR, + return_scale: Literal[True] = True, +) -> tuple[np.ndarray, float, float]: ... + + +def imresize( + img: np.ndarray, + size: tuple[int | None, int | None], + interpolation: str | int | INTER = INTER.BILINEAR, return_scale: bool = False, -) -> Union[np.ndarray, Tuple[np.ndarray, float, float]]: +) -> np.ndarray | tuple[np.ndarray, float, float]: """ This function is used to resize image. @@ -50,10 +71,13 @@ def imresize( scale = h / raw_h w = int(raw_w * scale + 0.5) # round to nearest integer + if h is None or w is None: + raise ValueError("`size` must provide at least one dimension.") + resized_img = cv2.resize(img, (w, h), interpolation=interpolation.value) if return_scale: - if 'scale' not in locals(): # calculate scale if not already done + if "scale" not in locals(): # calculate scale if not already done w_scale = w / raw_w h_scale = h / raw_h else: @@ -81,13 +105,13 @@ def imrotate( img: np.ndarray, angle: float, scale: float = 1, - interpolation: Union[str, int, INTER] = INTER.BILINEAR, - bordertype: Union[str, int, BORDER] = BORDER.CONSTANT, - bordervalue: Union[int, Tuple[int, int, int]] = None, + interpolation: str | int | INTER = INTER.BILINEAR, + bordertype: str | int | BORDER = BORDER.CONSTANT, + bordervalue: int | tuple[int, ...] | None = None, expand: bool = True, - center: Tuple[int, int] = None, + center: tuple[int, int] | None = None, ) -> np.ndarray: - ''' + """ Rotate the image by angle. Args: @@ -113,49 +137,76 @@ def imrotate( Returns: rotated img: rotated img. - ''' + """ bordertype = BORDER.obj_to_enum(bordertype) - bordervalue = (bordervalue,) * \ - 3 if isinstance(bordervalue, int) else bordervalue interpolation = INTER.obj_to_enum(interpolation) + if img.ndim == 2: + channels = 1 + elif img.ndim == 3: + channels = int(img.shape[-1]) + else: + raise ValueError("img must be a 2D or 3D numpy image.") + + if bordervalue is None: + bordervalue = 0 + elif isinstance(bordervalue, int): + bordervalue = ( + int(bordervalue) + if channels == 1 + else tuple(int(bordervalue) for _ in range(channels)) + ) + elif isinstance(bordervalue, tuple): + if channels == 1 and len(bordervalue) == 1: + bordervalue = int(bordervalue[0]) + elif len(bordervalue) == channels: + bordervalue = tuple(int(v) for v in bordervalue) + else: + raise ValueError( + f"channel of image is {channels} but bordervalue is {bordervalue}." + ) + h, w = img.shape[:2] center = center or (w / 2, h / 2) - M = cv2.getRotationMatrix2D(center, angle=angle, scale=scale) + matrix = cv2.getRotationMatrix2D(center, angle=angle, scale=scale) if expand: - cos = np.abs(M[0, 0]) - sin = np.abs(M[0, 1]) + cos = np.abs(matrix[0, 0]) + sin = np.abs(matrix[0, 1]) # compute the new bounding dimensions of the image - nW = int((h * sin) + (w * cos)) + 1 - nH = int((h * cos) + (w * sin)) + 1 + new_w = int((h * sin) + (w * cos)) + 1 + new_h = int((h * cos) + (w * sin)) + 1 # adjust the rotation matrix to take into account translation - M[0, 2] += (nW / 2) - center[0] - M[1, 2] += (nH / 2) - center[1] + matrix[0, 2] += (new_w / 2) - center[0] + matrix[1, 2] += (new_h / 2) - center[1] # perform the actual rotation and return the image dst = cv2.warpAffine( - img, M, (nW, nH), - flags=interpolation, - borderMode=bordertype, - borderValue=bordervalue + img, + matrix, + (new_w, new_h), + flags=interpolation.value, + borderMode=bordertype.value, + borderValue=bordervalue, ) else: dst = cv2.warpAffine( - img, M, (w, h), - flags=interpolation, - borderMode=bordertype, - borderValue=bordervalue + img, + matrix, + (w, h), + flags=interpolation.value, + borderMode=bordertype.value, + borderValue=bordervalue, ) - return dst.astype('uint8') + return dst def imwarp_quadrangle( img: np.ndarray, - polygon: Union[Polygon, np.ndarray], - dst_size: Tuple[int, int] = None, + polygon: Polygon | np.ndarray, + dst_size: tuple[int, int] | None = None, do_order_points: bool = True, ) -> np.ndarray: """ @@ -188,12 +239,12 @@ def imwarp_quadrangle( polygon = Polygon(polygon) if not isinstance(polygon, Polygon): - raise TypeError( - f'Input type of polygon {type(polygon)} not supported.') + raise TypeError(f"Input type of polygon {type(polygon)} not supported.") if len(polygon) != 4: raise ValueError( - f'Input polygon, which is not contain 4 points is invalid.') + "Input polygon, which is not contain 4 points is invalid." + ) if dst_size is None: width, height = polygon.min_box_wh @@ -209,10 +260,9 @@ def imwarp_quadrangle( if do_order_points: src_pts = order_points_clockwise(src_pts) - dst_pts = np.array([[0, 0], - [width, 0], - [width, height], - [0, height]], dtype="float32") + dst_pts = np.array( + [[0, 0], [width, 0], [width, height], [0, height]], dtype="float32" + ) matrix = cv2.getPerspectiveTransform(src_pts, dst_pts) return cv2.warpPerspective(img, matrix, (width, height)) @@ -221,9 +271,9 @@ def imwarp_quadrangle( def imwarp_quadrangles( img: np.ndarray, polygons: Polygons, - dst_size: Tuple[int, int] = None, + dst_size: tuple[int, int] | None = None, do_order_points: bool = True, -) -> List[np.ndarray]: +) -> list[np.ndarray]: """ Apply a 4-point perspective transform to an image using a given polygons. @@ -246,11 +296,11 @@ def imwarp_quadrangles( """ if not isinstance(polygons, Polygons): raise TypeError( - f'Input type of polygons {type(polygons)} not supported.') + f"Input type of polygons {type(polygons)} not supported." + ) return [ imwarp_quadrangle( - img, poly, - dst_size=dst_size, - do_order_points=do_order_points - ) for poly in polygons + img, poly, dst_size=dst_size, do_order_points=do_order_points + ) + for poly in polygons ] diff --git a/capybara/vision/improc.py b/capybara/vision/improc.py index 39e59db..51e66c1 100644 --- a/capybara/vision/improc.py +++ b/capybara/vision/improc.py @@ -1,6 +1,9 @@ +import os +import tempfile import warnings +from contextlib import suppress from pathlib import Path -from typing import Any, List, Union +from typing import Any, cast import cv2 import numpy as np @@ -15,27 +18,27 @@ from .geometric import imrotate90 __all__ = [ - "imread", - "imwrite", - "imencode", - "imdecode", - "img_to_b64", - "img_to_b64str", "b64_to_img", - "b64str_to_img", "b64_to_npy", + "b64str_to_img", "b64str_to_npy", + "get_orientation_code", + "imdecode", + "imencode", + "img_to_b64", + "img_to_b64str", + "imread", + "imwrite", + "is_numpy_img", + "jpgdecode", + "jpgencode", + "jpgread", "npy_to_b64", "npy_to_b64str", "npyread", "pdf2imgs", - "jpgencode", - "jpgdecode", - "jpgread", - "pngencode", "pngdecode", - "is_numpy_img", - "get_orientation_code", + "pngencode", ] jpeg = TurboJPEG() @@ -45,47 +48,52 @@ def is_numpy_img(x: Any) -> bool: """ x == ndarray (H x W x C) """ - return isinstance(x, np.ndarray) and (x.ndim == 2 or (x.ndim == 3 and x.shape[-1] in [1, 3])) + return isinstance(x, np.ndarray) and ( + x.ndim == 2 or (x.ndim == 3 and x.shape[-1] in [1, 3]) + ) -def get_orientation_code(stream: Union[str, Path, bytes]): - code = None +def get_orientation_code(stream: str | Path | bytes): try: exif_dict = piexif.load(stream) - if piexif.ImageIFD.Orientation in exif_dict["0th"]: - orientation = exif_dict["0th"][piexif.ImageIFD.Orientation] - if orientation == 3: - code = ROTATE.ROTATE_180 - elif orientation == 6: - code = ROTATE.ROTATE_90 - elif orientation == 8: - code = ROTATE.ROTATE_270 - finally: - return code - - -def jpgencode(img: np.ndarray, quality: int = 90) -> Union[bytes, None]: + except Exception: + return None + + orientation = exif_dict.get("0th", {}).get(piexif.ImageIFD.Orientation) + if orientation == 3: + return ROTATE.ROTATE_180 + if orientation == 6: + return ROTATE.ROTATE_90 + if orientation == 8: + return ROTATE.ROTATE_270 + return None + + +def jpgencode(img: np.ndarray, quality: int = 90) -> bytes | None: byte_ = None if is_numpy_img(img): - try: - byte_ = jpeg.encode(img, quality=quality) - except Exception as _: - pass + with suppress(Exception): + encoded = jpeg.encode(img, quality=quality) + if isinstance(encoded, tuple): + encoded = encoded[0] + byte_ = cast(bytes, encoded) return byte_ -def jpgdecode(byte_: bytes) -> Union[np.ndarray, None]: +def jpgdecode(byte_: bytes) -> np.ndarray | None: try: bgr_array = jpeg.decode(byte_) code = get_orientation_code(byte_) - bgr_array = imrotate90(bgr_array, code) if code is not None else bgr_array + bgr_array = ( + imrotate90(bgr_array, code) if code is not None else bgr_array + ) except Exception as _: bgr_array = None return bgr_array -def jpgread(img_file: Union[str, Path]) -> Union[np.ndarray, None]: +def jpgread(img_file: str | Path) -> np.ndarray | None: with open(str(img_file), "rb") as f: binary_img = f.read() bgr_array = jpgdecode(binary_img) @@ -93,17 +101,19 @@ def jpgread(img_file: Union[str, Path]) -> Union[np.ndarray, None]: return bgr_array -def pngencode(img: np.ndarray, compression: int = 1) -> Union[bytes, None]: +def pngencode(img: np.ndarray, compression: int = 1) -> bytes | None: byte_ = None if is_numpy_img(img): - try: - byte_ = cv2.imencode(".png", img, params=[int(cv2.IMWRITE_PNG_COMPRESSION), compression])[1].tobytes() - except Exception as _: - pass + with suppress(Exception): + byte_ = cv2.imencode( + ".png", + img, + params=[int(cv2.IMWRITE_PNG_COMPRESSION), compression], + )[1].tobytes() return byte_ -def pngdecode(byte_: bytes) -> Union[np.ndarray, None]: +def pngdecode(byte_: bytes) -> np.ndarray | None: try: enc = np.frombuffer(byte_, "uint8") img = cv2.imdecode(enc, cv2.IMREAD_COLOR) @@ -112,14 +122,25 @@ def pngdecode(byte_: bytes) -> Union[np.ndarray, None]: return img -def imencode(img: np.ndarray, IMGTYP: Union[str, int, IMGTYP] = IMGTYP.JPEG) -> Union[bytes, None]: - IMGTYP = IMGTYP.obj_to_enum(IMGTYP) - encode_fn = jpgencode if IMGTYP == IMGTYP.JPEG else pngencode - byte_ = encode_fn(img) - return byte_ - - -def imdecode(byte_: bytes) -> Union[np.ndarray, None]: +def imencode( + img: np.ndarray, + imgtyp: str | int | IMGTYP = IMGTYP.JPEG, + **kwargs: object, +) -> bytes | None: + if "IMGTYP" in kwargs: + if imgtyp != IMGTYP.JPEG: + raise TypeError("imgtyp and IMGTYP were both provided.") + imgtyp = cast(str | int | IMGTYP, kwargs.pop("IMGTYP")) + if kwargs: + unexpected = ", ".join(sorted(kwargs)) + raise TypeError(f"Unexpected keyword arguments: {unexpected}") + + imgtyp_enum = IMGTYP.obj_to_enum(imgtyp) + encode_fn = jpgencode if imgtyp_enum == IMGTYP.JPEG else pngencode + return encode_fn(img) + + +def imdecode(byte_: bytes) -> np.ndarray | None: try: img = jpgdecode(byte_) img = pngdecode(byte_) if img is None else img @@ -128,11 +149,26 @@ def imdecode(byte_: bytes) -> Union[np.ndarray, None]: return img -def img_to_b64(img: np.ndarray, IMGTYP: Union[str, int, IMGTYP] = IMGTYP.JPEG) -> Union[bytes, None]: - IMGTYP = IMGTYP.obj_to_enum(IMGTYP) - encode_fn = jpgencode if IMGTYP == IMGTYP.JPEG else pngencode +def img_to_b64( + img: np.ndarray, + imgtyp: str | int | IMGTYP = IMGTYP.JPEG, + **kwargs: object, +) -> bytes | None: + if "IMGTYP" in kwargs: + if imgtyp != IMGTYP.JPEG: + raise TypeError("imgtyp and IMGTYP were both provided.") + imgtyp = cast(str | int | IMGTYP, kwargs.pop("IMGTYP")) + if kwargs: + unexpected = ", ".join(sorted(kwargs)) + raise TypeError(f"Unexpected keyword arguments: {unexpected}") + + imgtyp_enum = IMGTYP.obj_to_enum(imgtyp) + encode_fn = jpgencode if imgtyp_enum == IMGTYP.JPEG else pngencode try: - b64 = pybase64.b64encode(encode_fn(img)) + encoded = encode_fn(img) + if encoded is None: + return None + b64 = pybase64.b64encode(encoded) except Exception as _: b64 = None return b64 @@ -142,18 +178,23 @@ def npy_to_b64(x: np.ndarray, dtype="float32") -> bytes: return pybase64.b64encode(x.astype(dtype).tobytes()) -def npy_to_b64str(x: np.ndarray, dtype="float32", string_encode: str = "utf-8") -> str: +def npy_to_b64str( + x: np.ndarray, dtype="float32", string_encode: str = "utf-8" +) -> str: return pybase64.b64encode(x.astype(dtype).tobytes()).decode(string_encode) def img_to_b64str( - img: np.ndarray, IMGTYP: Union[str, int, IMGTYP] = IMGTYP.JPEG, string_encode: str = "utf-8" -) -> Union[str, None]: - b64 = img_to_b64(img, IMGTYP) + img: np.ndarray, + imgtyp: str | int | IMGTYP = IMGTYP.JPEG, + string_encode: str = "utf-8", + **kwargs: object, +) -> str | None: + b64 = img_to_b64(img, imgtyp, **kwargs) return b64.decode(string_encode) if isinstance(b64, bytes) else None -def b64_to_img(b64: bytes) -> Union[np.ndarray, None]: +def b64_to_img(b64: bytes) -> np.ndarray | None: try: img = imdecode(pybase64.b64decode(b64)) except Exception as _: @@ -161,9 +202,11 @@ def b64_to_img(b64: bytes) -> Union[np.ndarray, None]: return img -def b64str_to_img(b64str: Union[str, None], string_encode: str = "utf-8") -> Union[np.ndarray, None]: +def b64str_to_img( + b64str: str | None, string_encode: str = "utf-8" +) -> np.ndarray | None: if b64str is None: - warnings.warn("b64str is None.") + warnings.warn("b64str is None.", stacklevel=2) return None if not isinstance(b64str, str): @@ -176,11 +219,15 @@ def b64_to_npy(x: bytes, dtype="float32") -> np.ndarray: return np.frombuffer(pybase64.b64decode(x), dtype=dtype) -def b64str_to_npy(x: bytes, dtype="float32", string_encode: str = "utf-8") -> np.ndarray: - return np.frombuffer(pybase64.b64decode(x.encode(string_encode)), dtype=dtype) +def b64str_to_npy( + x: str, dtype="float32", string_encode: str = "utf-8" +) -> np.ndarray: + return np.frombuffer( + pybase64.b64decode(x.encode(string_encode)), dtype=dtype + ) -def npyread(path: Union[str, Path]) -> Union[np.ndarray, None]: +def npyread(path: str | Path) -> np.ndarray | None: try: with open(str(path), "rb") as f: img = np.load(f) @@ -189,7 +236,9 @@ def npyread(path: Union[str, Path]) -> Union[np.ndarray, None]: return img -def imread(path: Union[str, Path], color_base: str = "BGR", verbose: bool = False) -> Union[np.ndarray, None]: +def imread( + path: str | Path, color_base: str = "BGR", verbose: bool = False +) -> np.ndarray | None: """ This function reads an image from a given file path and converts its color base if necessary. @@ -215,8 +264,12 @@ def imread(path: Union[str, Path], color_base: str = "BGR", verbose: bool = Fals if not Path(path).exists(): raise FileExistsError(f"{path} can not found.") + color_base = color_base.upper() + if Path(path).suffix.lower() == ".heic": - heif_file = pillow_heif.open_heif(str(path), convert_hdr_to_8bit=True, bgr_mode=True) + heif_file = pillow_heif.open_heif( + str(path), convert_hdr_to_8bit=True, bgr_mode=True + ) img = np.asarray(heif_file) else: img = jpgread(path) @@ -224,7 +277,7 @@ def imread(path: Union[str, Path], color_base: str = "BGR", verbose: bool = Fals if img is None: if verbose: - warnings.warn("Got a None type image.") + warnings.warn("Got a None type image.", stacklevel=2) return if color_base != "BGR": @@ -235,7 +288,7 @@ def imread(path: Union[str, Path], color_base: str = "BGR", verbose: bool = Fals def imwrite( img: np.ndarray, - path: Union[str, Path] = None, + path: str | Path | None = None, color_base: str = "BGR", suffix: str = ".jpg", ) -> bool: @@ -260,10 +313,15 @@ def imwrite( color_base = color_base.upper() if color_base != "BGR": img = imcvtcolor(img, cvt_mode=f"{color_base}2BGR") - return cv2.imwrite(str(path) if path else f"tmp{suffix}", img) + if path is None: + fd, target = tempfile.mkstemp(prefix="capybara_", suffix=suffix) + os.close(fd) + else: + target = str(path) + return bool(cv2.imwrite(target, img)) -def pdf2imgs(stream: Union[str, Path, bytes]) -> Union[List[np.ndarray], None]: +def pdf2imgs(stream: str | Path | bytes) -> list[np.ndarray] | None: """ Function for converting a PDF document to numpy images. @@ -278,6 +336,8 @@ def pdf2imgs(stream: Union[str, Path, bytes]) -> Union[List[np.ndarray], None]: pil_imgs = convert_from_bytes(stream) else: pil_imgs = convert_from_path(stream) - return [imcvtcolor(np.array(img), cvt_mode="RGB2BGR") for img in pil_imgs] + return [ + imcvtcolor(np.array(img), cvt_mode="RGB2BGR") for img in pil_imgs + ] except Exception as _: return diff --git a/capybara/vision/ipcam/__init__.py b/capybara/vision/ipcam/__init__.py index 5bda3fc..e27b3bc 100644 --- a/capybara/vision/ipcam/__init__.py +++ b/capybara/vision/ipcam/__init__.py @@ -1,2 +1,5 @@ -from .app import * -from .camera import * +from __future__ import annotations + +from . import app, camera + +__all__ = ["app", "camera"] diff --git a/capybara/vision/ipcam/app.py b/capybara/vision/ipcam/app.py index b8f631a..ffbaff5 100644 --- a/capybara/vision/ipcam/app.py +++ b/capybara/vision/ipcam/app.py @@ -1,50 +1,60 @@ -from flask import Flask, Response, render_template_string +from flask import Flask, Response, render_template_string # type: ignore from ...utils import get_curdir from ..improc import jpgencode from .camera import IpcamCapture -__all__ = ['WebDemo'] +__all__ = ["WebDemo"] -def gen(cap, pipelines=[]): +def gen(cap, pipelines=None): + if pipelines is None: + pipelines = [] while True: frame = cap.get_frame() for f in pipelines: frame = f(frame) frame_bytes = jpgencode(frame) + if frame_bytes is None: + continue yield ( - b'--frame\r\n' - b'Content-Type: image/jpeg\r\n\r\n' + frame_bytes + b'\r\n\r\n' + b"--frame\r\n" + b"Content-Type: image/jpeg\r\n\r\n" + frame_bytes + b"\r\n\r\n" ) class WebDemo: - def __init__( self, camera_ip: str, - color_base: str = 'BGR', - route: str = '/', - pipelines: list = [], + color_base: str = "BGR", + route: str = "/", + pipelines: list | None = None, ): + if pipelines is None: + pipelines = [] app = Flask(__name__) - @ app.route(route) + @app.route(route) def _index(): - with open(str(get_curdir(__file__) / 'video_streaming.html'), 'r', encoding='utf-8') as file: + with open( + str(get_curdir(__file__) / "video_streaming.html"), + encoding="utf-8", + ) as file: html_content = file.read() return render_template_string(html_content) - @ app.route('/video_feed') + @app.route("/video_feed") def _video_feed(): return Response( - gen(cap=IpcamCapture(camera_ip, color_base), pipelines=pipelines), - mimetype='multipart/x-mixed-replace; boundary=frame' + gen( + cap=IpcamCapture(camera_ip, color_base), pipelines=pipelines + ), + mimetype="multipart/x-mixed-replace; boundary=frame", ) self.app = app - def run(self, host='0.0.0.0', port=5001, debug=False, threaded=True): + def run(self, host="0.0.0.0", port=5001, debug=False, threaded=True): self.app.run(host=host, port=port, debug=debug, threaded=threaded) diff --git a/capybara/vision/ipcam/camera.py b/capybara/vision/ipcam/camera.py index 4868325..1d1ff60 100644 --- a/capybara/vision/ipcam/camera.py +++ b/capybara/vision/ipcam/camera.py @@ -5,12 +5,11 @@ from ..functionals import imcvtcolor -__all__ = ['IpcamCapture'] +__all__ = ["IpcamCapture"] class IpcamCapture: - - def __init__(self, url=0, color_base='BGR'): + def __init__(self, url: int | str = 0, color_base: str = "BGR") -> None: """ Initializes the IpcamCapture class. @@ -43,7 +42,7 @@ def __init__(self, url=0, color_base='BGR'): self._lock = Lock() if self._h == 0 or self._w == 0: - raise ValueError(f'The image size is not supported.') + raise ValueError("The image size is not supported.") Thread(target=self._queryframe, daemon=True).start() @@ -52,8 +51,8 @@ def _queryframe(self): ret, frame = self._capture.read() if not ret: break # Stop the loop if the video stream has ended or is unreadable - if self.color_base != 'BGR': - frame = imcvtcolor(frame, cvt_mode=f'BGR2{self.color_base}') + if self.color_base != "BGR": + frame = imcvtcolor(frame, cvt_mode=f"BGR2{self.color_base}") with self._lock: self._frame = frame @@ -66,4 +65,7 @@ def get_frame(self): return frame def __iter__(self): - yield self.get_frame() + return self + + def __next__(self): + return self.get_frame() diff --git a/capybara/vision/morphology.py b/capybara/vision/morphology.py index f704235..3bd7b21 100644 --- a/capybara/vision/morphology.py +++ b/capybara/vision/morphology.py @@ -1,20 +1,23 @@ -from typing import Tuple, Union - import cv2 import numpy as np from ..enums import MORPH __all__ = [ - 'imerode', 'imdilate', 'imopen', 'imclose', - 'imgradient', 'imtophat', 'imblackhat', + "imblackhat", + "imclose", + "imdilate", + "imerode", + "imgradient", + "imopen", + "imtophat", ] def imerode( img: np.ndarray, - ksize: Union[int, Tuple[int, int]] = (3, 3), - kstruct: Union[str, int, MORPH] = MORPH.RECT + ksize: int | tuple[int, int] = (3, 3), + kstruct: str | int | MORPH = MORPH.RECT, ) -> np.ndarray: """ Erosion: @@ -33,9 +36,9 @@ def imerode( MORPH.ELLIPSE}, Defaults to MORPH.RECT. """ if isinstance(ksize, int): - ksize = (ksize, ) * 2 + ksize = (ksize, ksize) elif not isinstance(ksize, tuple) or len(ksize) != 2: - raise TypeError(f'Got inappropriate type or shape of size. {ksize}.') + raise TypeError(f"Got inappropriate type or shape of size. {ksize}.") kstruct = MORPH.obj_to_enum(kstruct) @@ -44,8 +47,8 @@ def imerode( def imdilate( img: np.ndarray, - ksize: Union[int, Tuple[int, int]] = (3, 3), - kstruct: Union[str, int, MORPH] = MORPH.RECT + ksize: int | tuple[int, int] = (3, 3), + kstruct: str | int | MORPH = MORPH.RECT, ) -> np.ndarray: """ Dilation: @@ -64,9 +67,9 @@ def imdilate( MORPH.ELLIPSE}, Defaults to MORPH.RECT. """ if isinstance(ksize, int): - ksize = (ksize, ) * 2 + ksize = (ksize, ksize) elif not isinstance(ksize, tuple) or len(ksize) != 2: - raise TypeError(f'Got inappropriate type or shape of size. {ksize}.') + raise TypeError(f"Got inappropriate type or shape of size. {ksize}.") kstruct = MORPH.obj_to_enum(kstruct) @@ -75,8 +78,8 @@ def imdilate( def imopen( img: np.ndarray, - ksize: Union[int, Tuple[int, int]] = (3, 3), - kstruct: Union[str, int, MORPH] = MORPH.RECT + ksize: int | tuple[int, int] = (3, 3), + kstruct: str | int | MORPH = MORPH.RECT, ) -> np.ndarray: """ Opening: @@ -93,19 +96,21 @@ def imopen( MORPH.ELLIPSE}, Defaults to MORPH.RECT. """ if isinstance(ksize, int): - ksize = (ksize, ) * 2 + ksize = (ksize, ksize) elif not isinstance(ksize, tuple) or len(ksize) != 2: - raise TypeError(f'Got inappropriate type or shape of size. {ksize}.') + raise TypeError(f"Got inappropriate type or shape of size. {ksize}.") kstruct = MORPH.obj_to_enum(kstruct) - return cv2.morphologyEx(img, cv2.MORPH_OPEN, cv2.getStructuringElement(kstruct, ksize)) + return cv2.morphologyEx( + img, cv2.MORPH_OPEN, cv2.getStructuringElement(kstruct, ksize) + ) def imclose( img: np.ndarray, - ksize: Union[int, Tuple[int, int]] = (3, 3), - kstruct: Union[str, int, MORPH] = MORPH.RECT + ksize: int | tuple[int, int] = (3, 3), + kstruct: str | int | MORPH = MORPH.RECT, ) -> np.ndarray: """ Closing: @@ -123,19 +128,21 @@ def imclose( MORPH.ELLIPSE}, Defaults to MORPH.RECT. """ if isinstance(ksize, int): - ksize = (ksize, ) * 2 + ksize = (ksize, ksize) elif not isinstance(ksize, tuple) or len(ksize) != 2: - raise TypeError(f'Got inappropriate type or shape of size. {ksize}.') + raise TypeError(f"Got inappropriate type or shape of size. {ksize}.") kstruct = MORPH.obj_to_enum(kstruct) - return cv2.morphologyEx(img, cv2.MORPH_CLOSE, cv2.getStructuringElement(kstruct, ksize)) + return cv2.morphologyEx( + img, cv2.MORPH_CLOSE, cv2.getStructuringElement(kstruct, ksize) + ) def imgradient( img: np.ndarray, - ksize: Union[int, Tuple[int, int]] = (3, 3), - kstruct: Union[str, int, MORPH] = MORPH.RECT + ksize: int | tuple[int, int] = (3, 3), + kstruct: str | int | MORPH = MORPH.RECT, ) -> np.ndarray: """ Morphological Gradient: @@ -151,19 +158,21 @@ def imgradient( MORPH.ELLIPSE}, Defaults to MORPH.RECT. """ if isinstance(ksize, int): - ksize = (ksize, ) * 2 + ksize = (ksize, ksize) elif not isinstance(ksize, tuple) or len(ksize) != 2: - raise TypeError(f'Got inappropriate type or shape of size. {ksize}.') + raise TypeError(f"Got inappropriate type or shape of size. {ksize}.") kstruct = MORPH.obj_to_enum(kstruct) - return cv2.morphologyEx(img, cv2.MORPH_GRADIENT, cv2.getStructuringElement(kstruct, ksize)) + return cv2.morphologyEx( + img, cv2.MORPH_GRADIENT, cv2.getStructuringElement(kstruct, ksize) + ) def imtophat( img: np.ndarray, - ksize: Union[int, Tuple[int, int]] = (3, 3), - kstruct: Union[str, int, MORPH] = MORPH.RECT + ksize: int | tuple[int, int] = (3, 3), + kstruct: str | int | MORPH = MORPH.RECT, ) -> np.ndarray: """ Top Hat: @@ -179,19 +188,21 @@ def imtophat( MORPH.ELLIPSE}, Defaults to MORPH.RECT. """ if isinstance(ksize, int): - ksize = (ksize, ) * 2 + ksize = (ksize, ksize) elif not isinstance(ksize, tuple) or len(ksize) != 2: - raise TypeError(f'Got inappropriate type or shape of size. {ksize}.') + raise TypeError(f"Got inappropriate type or shape of size. {ksize}.") kstruct = MORPH.obj_to_enum(kstruct) - return cv2.morphologyEx(img, cv2.MORPH_TOPHAT, cv2.getStructuringElement(kstruct, ksize)) + return cv2.morphologyEx( + img, cv2.MORPH_TOPHAT, cv2.getStructuringElement(kstruct, ksize) + ) def imblackhat( img: np.ndarray, - ksize: Union[int, Tuple[int, int]] = (3, 3), - kstruct: Union[str, int, MORPH] = MORPH.RECT + ksize: int | tuple[int, int] = (3, 3), + kstruct: str | int | MORPH = MORPH.RECT, ): """ Black Hat: @@ -207,10 +218,12 @@ def imblackhat( MORPH.ELLIPSE}, Defaults to MORPH.RECT. """ if isinstance(ksize, int): - ksize = (ksize, ) * 2 + ksize = (ksize, ksize) elif not isinstance(ksize, tuple) or len(ksize) != 2: - raise TypeError(f'Got inappropriate type or shape of size. {ksize}.') + raise TypeError(f"Got inappropriate type or shape of size. {ksize}.") kstruct = MORPH.obj_to_enum(kstruct) - return cv2.morphologyEx(img, cv2.MORPH_BLACKHAT, cv2.getStructuringElement(kstruct, ksize)) + return cv2.morphologyEx( + img, cv2.MORPH_BLACKHAT, cv2.getStructuringElement(kstruct, ksize) + ) diff --git a/capybara/vision/videotools/__init__.py b/capybara/vision/videotools/__init__.py index 0872b94..42e9184 100644 --- a/capybara/vision/videotools/__init__.py +++ b/capybara/vision/videotools/__init__.py @@ -1,2 +1,6 @@ -from .video2frames import * -from .video2frames_v2 import * +from __future__ import annotations + +from .video2frames import video2frames +from .video2frames_v2 import video2frames_v2 + +__all__ = ["video2frames", "video2frames_v2"] diff --git a/capybara/vision/videotools/video2frames.py b/capybara/vision/videotools/video2frames.py index 9251464..cec8429 100644 --- a/capybara/vision/videotools/video2frames.py +++ b/capybara/vision/videotools/video2frames.py @@ -1,12 +1,13 @@ +import math from pathlib import Path -from typing import Any, List +from typing import Any import cv2 import numpy as np -__all__ = ['video2frames', 'is_video_file'] +__all__ = ["is_video_file", "video2frames"] -VIDEO_SUFFIX = ['.MOV', '.MP4', '.AVI', '.WEBM', '.3GP', '.MKV'] +VIDEO_SUFFIX = [".MOV", ".MP4", ".AVI", ".WEBM", ".3GP", ".MKV"] def is_video_file(x: Any) -> bool: @@ -17,9 +18,9 @@ def is_video_file(x: Any) -> bool: def video2frames( - video_path: str, - frame_per_sec: int = None, -) -> List[np.ndarray]: + video_path: str | Path, + frame_per_sec: int | None = None, +) -> list[np.ndarray]: """ Extracts the frames from a video using ray Inputs: @@ -31,7 +32,7 @@ def video2frames( frames (List[np.ndarray]): A list of frames. """ if not is_video_file(video_path): - raise TypeError(f'The video_path {video_path} is inappropriate.') + raise TypeError(f"The video_path {video_path} is inappropriate.") # get total_frames frames of video cap = cv2.VideoCapture(str(video_path)) @@ -40,11 +41,19 @@ def video2frames( return [] # Get the original FPS of the video - original_fps = cap.get(cv2.CAP_PROP_FPS) + original_fps = float(cap.get(cv2.CAP_PROP_FPS) or 0.0) # Calculate the interval for frame extraction - interval = 1 if frame_per_sec is None \ - else int(original_fps / frame_per_sec) + if frame_per_sec is None: + interval = 1 + else: + frame_per_sec_i = int(frame_per_sec) + if frame_per_sec_i <= 0: + raise ValueError("frame_per_sec must be > 0.") + if not math.isfinite(original_fps) or original_fps <= 0: + interval = 1 + else: + interval = max(1, int(original_fps / frame_per_sec_i)) frames = [] index = 0 diff --git a/capybara/vision/videotools/video2frames_v2.py b/capybara/vision/videotools/video2frames_v2.py index d50ae84..07a8c1f 100644 --- a/capybara/vision/videotools/video2frames_v2.py +++ b/capybara/vision/videotools/video2frames_v2.py @@ -1,6 +1,6 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from itertools import chain -from typing import Any, List +from typing import Any import cv2 import numpy as np @@ -15,7 +15,9 @@ def is_numpy_img(x: Any) -> bool: """ x == ndarray (H x W x C) """ - return isinstance(x, np.ndarray) and (x.ndim == 2 or (x.ndim == 3 and x.shape[-1] in [1, 3])) + return isinstance(x, np.ndarray) and ( + x.ndim == 2 or (x.ndim == 3 and x.shape[-1] in [1, 3]) + ) def flatten_list(xs: list) -> list: @@ -38,18 +40,26 @@ def flatten_list(xs: list) -> list: def get_step_inds(start: int, end: int, num: int): if num > (end - start): raise ValueError("num is larger than the number of total frames.") - return np.around(np.linspace(start=start, stop=end, num=num, endpoint=False)).astype(int).tolist() + return ( + np.around(np.linspace(start=start, stop=end, num=num, endpoint=False)) + .astype(int) + .tolist() + ) def _extract_frames( - inds: List[int], video_path: str, max_size: int = 1920, color_base: str = "BGR", global_ind: int = 0 + inds: list[int], + video_path: str | Any, + max_size: int = 1920, + color_base: str = "BGR", + global_ind: int = 0, ): # check video path if not is_video_file(video_path): raise TypeError(f"The video_path {video_path} is inappropriate.") # open cap - cap = cv2.VideoCapture(video_path) + cap = cv2.VideoCapture(str(video_path)) # if start or end isn't specified lets assume 0 start = inds[0] end = inds[-1] + 1 @@ -62,7 +72,9 @@ def _process_frame(frame): if scale < 1: frame = imresize(frame, (dst_h, dst_w)) elif scale > 1: - frame = imresize(frame, (dst_h, dst_w), interpolation=cv2.INTER_AREA) + frame = imresize( + frame, (dst_h, dst_w), interpolation=cv2.INTER_LINEAR + ) if color_base.upper() != "BGR": frame = imcvtcolor(frame, cvt_mode=f"BGR2{color_base}") @@ -87,14 +99,14 @@ def _pickup_frame(): def video2frames_v2( - video_path: str, - frame_per_sec: int = None, + video_path: str | Any, + frame_per_sec: int | None = None, start_sec: float = 0, - end_sec: float = None, + end_sec: float | None = None, n_threads: int = 8, max_size: int = 1920, color_base: str = "BGR", -) -> List[np.ndarray]: +) -> list[np.ndarray]: """ Extracts the frames from a video using ray Inputs: @@ -129,7 +141,13 @@ def video2frames_v2( if total_frames == 0 or fps == 0: return [] - frame_per_sec = fps if frame_per_sec is None else frame_per_sec + n_threads = int(n_threads) + if n_threads < 1: + raise ValueError("n_threads must be >= 1.") + + frame_per_sec = fps if frame_per_sec is None else int(frame_per_sec) + if frame_per_sec <= 0: + raise ValueError("frame_per_sec must be > 0.") total_sec = total_frames / fps # get frame inds end_sec = total_sec if end_sec is None or end_sec > total_sec else end_sec @@ -140,19 +158,29 @@ def video2frames_v2( start_frame = round(start_sec * fps) end_frame = round(end_sec * fps) num = round(total_sec * frame_per_sec) + if num <= 0: + return [] frame_inds = get_step_inds(start_frame, end_frame, num) + if not frame_inds: # pragma: no cover + return [] out_frames = [] - with ThreadPoolExecutor(max_workers=n_threads) as executor: + worker_count = min(n_threads, len(frame_inds)) + chunk_size = max(1, (len(frame_inds) + worker_count - 1) // worker_count) + frame_inds_list = [ + frame_inds[i : i + chunk_size] + for i in range(0, len(frame_inds), chunk_size) + ] + + with ThreadPoolExecutor(max_workers=worker_count) as executor: ## -----start process---- ## - # split the frames into chunk lists - divide_size = round(len(frame_inds) / n_threads) + 1 - frame_inds_list = [frame_inds[i * divide_size : (i + 1) * divide_size] for i in range(n_threads)] future_to_frames = { - executor.submit(_extract_frames, inds, video_path, max_size, color_base, i): inds + executor.submit( + _extract_frames, inds, video_path, max_size, color_base, i + ): inds for i, inds in enumerate(frame_inds_list) } - out_frames = [[] for _ in range(n_threads)] + out_frames = [[] for _ in range(len(frame_inds_list))] for future in as_completed(future_to_frames): frames = future_to_frames[future] diff --git a/capybara/vision/visualization/__init__.py b/capybara/vision/visualization/__init__.py index c6df1f7..72918ad 100644 --- a/capybara/vision/visualization/__init__.py +++ b/capybara/vision/visualization/__init__.py @@ -1 +1,8 @@ -from .draw import * +from __future__ import annotations + +from typing import TYPE_CHECKING + +__all__ = ["draw", "utils"] + +if TYPE_CHECKING: # pragma: no cover + from . import draw, utils diff --git a/capybara/vision/visualization/draw.py b/capybara/vision/visualization/draw.py index aa5b256..8840fe2 100644 --- a/capybara/vision/visualization/draw.py +++ b/capybara/vision/visualization/draw.py @@ -2,31 +2,58 @@ import functools import hashlib from pathlib import Path -from typing import List, Tuple, Union import cv2 import matplotlib +import matplotlib.colors as mpl_colors import numpy as np from PIL import Image, ImageDraw, ImageFont -from ...structures import Polygons from ...structures.boxes import _Box, _Boxes from ...structures.keypoints import _Keypoints, _KeypointsList from ...structures.polygons import _Polygon, _Polygons -from ...utils import download_from_google, get_curdir +from ...utils import get_curdir from ..geometric import imresize -from .utils import (_Color, _Colors, _Point, _Points, _Scale, _Scales, - _Thickness, _Thicknesses, prepare_box, prepare_boxes, - prepare_color, prepare_colors, prepare_img, - prepare_keypoints, prepare_keypoints_list, prepare_point, - prepare_points, prepare_polygon, prepare_polygons, - prepare_scale, prepare_scales, prepare_thickness, - prepare_thicknesses) +from .utils import ( + _Color, + _Colors, + _Point, + _Points, + _Scale, + _Scales, + _Thickness, + _Thicknesses, + prepare_box, + prepare_boxes, + prepare_color, + prepare_colors, + prepare_img, + prepare_keypoints, + prepare_keypoints_list, + prepare_point, + prepare_points, + prepare_polygon, + prepare_polygons, + prepare_scale, + prepare_scales, + prepare_thickness, + prepare_thicknesses, +) __all__ = [ - 'draw_box', 'draw_boxes', 'draw_polygon', 'draw_polygons', 'draw_text', - 'generate_colors', 'draw_mask', 'draw_point', 'draw_points', 'draw_keypoints', - 'draw_keypoints_list', 'draw_detection', 'draw_detections', + "draw_box", + "draw_boxes", + "draw_detection", + "draw_detections", + "draw_keypoints", + "draw_keypoints_list", + "draw_mask", + "draw_point", + "draw_points", + "draw_polygon", + "draw_polygons", + "draw_text", + "generate_colors", ] DIR = get_curdir(__file__) @@ -34,16 +61,29 @@ DEFAULT_FONT_PATH = DIR / "NotoSansMonoCJKtc-VF.ttf" -if not (font_path := DIR / DEFAULT_FONT_PATH).exists(): - file_id = "1m6jvsBGKgQsxzpIoe4iEp_EqFxYXe7T1" - download_from_google(file_id, font_path.name, DIR) +def _load_font( + font_path: str | Path | None, + *, + size: int, +) -> ImageFont.FreeTypeFont | ImageFont.ImageFont: + candidates: list[Path] = [] + if font_path is not None: + candidates.append(Path(font_path)) + candidates.append(DEFAULT_FONT_PATH) + + for candidate in candidates: + try: + return ImageFont.truetype(str(candidate), size=int(size)) + except Exception: + continue + return ImageFont.load_default() def draw_box( img: np.ndarray, box: _Box, color: _Color = (0, 255, 0), - thickness: _Thickness = 2 + thickness: _Thickness = 2, ) -> np.ndarray: """ Draws a bounding box on the image. @@ -70,14 +110,16 @@ def draw_box( h, w = img.shape[:2] box = box.denormalize(w, h) x1, y1, x2, y2 = box.numpy().astype(int).tolist() - return cv2.rectangle(img, (x1, y1), (x2, y2), color=color, thickness=thickness) + return cv2.rectangle( + img, (x1, y1), (x2, y2), color=color, thickness=thickness + ) def draw_boxes( img: np.ndarray, boxes: _Boxes, colors: _Colors = (0, 255, 0), - thicknesses: _Thicknesses = 2 + thicknesses: _Thicknesses = 2, ) -> np.ndarray: """ Draws multiple bounding boxes on the image. @@ -101,7 +143,7 @@ def draw_boxes( boxes = prepare_boxes(boxes) colors = prepare_colors(colors, len(boxes)) thicknesses = prepare_thicknesses(thicknesses, len(boxes)) - for box, c, t in zip(boxes, colors, thicknesses): + for box, c, t in zip(boxes, colors, thicknesses, strict=True): draw_box(img, box, color=c, thickness=t) return img @@ -112,7 +154,7 @@ def draw_polygon( color: _Color = (0, 255, 0), thickness: _Thickness = 2, fillup=False, - **kwargs + **kwargs, ): """ Draw a polygon on the input image. @@ -148,8 +190,14 @@ def draw_polygon( if fillup: img = cv2.fillPoly(img, [polygon], color=color, **kwargs) else: - img = cv2.polylines(img, [polygon], isClosed=True, - color=color, thickness=thickness, **kwargs) + img = cv2.polylines( + img, + [polygon], + isClosed=True, + color=color, + thickness=thickness, + **kwargs, + ) return img @@ -160,7 +208,7 @@ def draw_polygons( colors: _Colors = (0, 255, 0), thicknesses: _Thicknesses = 2, fillup=False, - **kwargs + **kwargs, ): """ Draw polygons on the input image. @@ -196,9 +244,10 @@ def draw_polygons( polygons = prepare_polygons(polygons) colors = prepare_colors(colors, len(polygons)) thicknesses = prepare_thicknesses(thicknesses, len(polygons)) - for polygon, c, t in zip(polygons, colors, thicknesses): - draw_polygon(img, polygon, color=c, thickness=t, - fillup=fillup, **kwargs) + for polygon, c, t in zip(polygons, colors, thicknesses, strict=True): + draw_polygon( + img, polygon, color=c, thickness=t, fillup=fillup, **kwargs + ) return img @@ -208,8 +257,8 @@ def draw_text( location: _Point, color: _Color = (0, 0, 0), text_size: int = 12, - font_path: Union[str, Path] = None, - **kwargs + font_path: str | Path | None = None, + **kwargs, ) -> np.ndarray: """ Draw specified text on the given image at the provided location. @@ -236,24 +285,24 @@ def draw_text( np.ndarray: Image with the text drawn on it. """ img = prepare_img(img) - img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) - - draw = ImageDraw.Draw(img) - font_path = DEFAULT_FONT_PATH if font_path is None else font_path - font = ImageFont.truetype(str(font_path), size=text_size) + color = prepare_color(color) + pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) - _, top, _, bottom = font.getbbox(text) - _, offset = font.getmask2(text) - text_height = bottom - top + draw = ImageDraw.Draw(pil_img) + font = _load_font(font_path, size=text_size) - offset_y = int(0.5 * (font.size - text_height) - offset[1]) + offset_y = 0 + try: + _left, top, _right, _bottom = font.getbbox(text) + offset_y = -int(top) + except Exception: + offset_y = 0 loc = prepare_point(location) loc = (loc[0], loc[1] + offset_y) - kwargs.update({'fill': (color[2], color[1], color[0])}) + kwargs.update({"fill": (color[2], color[1], color[0])}) draw.text(loc, text, font=font, **kwargs) - img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) - - return img + out = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR) + return out def draw_line( @@ -262,11 +311,11 @@ def draw_line( pt2: _Point, color: _Color = (0, 255, 0), thickness: _Thickness = 1, - style: str = 'dotted', + style: str = "dotted", gap: int = 20, inplace: bool = False, ): - ''' + """ Draw a line on the image. Args: @@ -294,24 +343,39 @@ def draw_line( Returns: np.ndarray: Image with the drawn line. - ''' + """ img = img.copy() if not inplace else img img = prepare_img(img) pt1 = prepare_point(pt1) pt2 = prepare_point(pt2) - dist = ((pt1[0] - pt2[0])**2 + (pt1[1] - pt2[1])**2)**.5 + color = prepare_color(color) + thickness = prepare_thickness(thickness) + gap = int(gap) + if gap <= 0: + raise ValueError("gap must be > 0.") + dist = ((pt1[0] - pt2[0]) ** 2 + (pt1[1] - pt2[1]) ** 2) ** 0.5 + if dist == 0: + cv2.circle( + img, + pt1, + radius=max(1, abs(thickness)), + color=color, + thickness=-1, + lineType=cv2.LINE_AA, + ) + return img pts = [] for i in np.arange(0, dist, gap): r = i / dist - x = int((pt1[0] * (1 - r) + pt2[0] * r) + .5) - y = int((pt1[1] * (1 - r) + pt2[1] * r) + .5) + x = int((pt1[0] * (1 - r) + pt2[0] * r) + 0.5) + y = int((pt1[1] * (1 - r) + pt2[1] * r) + 0.5) p = (x, y) pts.append(p) - if style == 'dotted': + if style == "dotted": for p in pts: - cv2.circle(img, p, thickness, color, -1) - elif style == 'line': + cv2.circle(img, p, radius=thickness, color=color, thickness=-1) + elif style == "line": s = pts[0] e = pts[0] i = 0 @@ -319,7 +383,7 @@ def draw_line( s = e e = p if i % 2 == 1: - cv2.line(img, s, e, color, thickness) + cv2.line(img, s, e, color=color, thickness=thickness) i += 1 else: raise ValueError(f"Unknown style: {style}") @@ -333,7 +397,7 @@ def draw_point( color: _Color = (0, 255, 0), thickness: _Thickness = -1, ) -> np.ndarray: - ''' + """ Draw a point on the image. Args: @@ -350,15 +414,22 @@ def draw_point( Returns: np.ndarray: Image with the drawn point. - ''' + """ is_gray_img = img.ndim == 2 img = prepare_img(img) point = prepare_point(point) color = prepare_color(color) + thickness = prepare_thickness(thickness) h, w = img.shape[:2] size = 1 + (np.sqrt(h * w) * 0.002 * scale).round().astype(int).item() - img = cv2.circle(img, point, radius=size, color=color, - lineType=cv2.LINE_AA, thickness=thickness) + img = cv2.circle( + img, + point, + radius=size, + color=color, + lineType=cv2.LINE_AA, + thickness=thickness, + ) img = img[..., 0] if is_gray_img else img return img @@ -366,11 +437,11 @@ def draw_point( def draw_points( img: np.ndarray, points: _Points, - scales: _Scales = 1., + scales: _Scales = 1.0, colors: _Colors = (0, 255, 0), thicknesses: _Thicknesses = -1, ) -> np.ndarray: - ''' + """ Draw multiple points on the image. Args: @@ -387,14 +458,14 @@ def draw_points( Returns: np.ndarray: Image with the drawn points. - ''' + """ img = prepare_img(img).copy() points = prepare_points(points) colors = prepare_colors(colors, len(points)) thicknesses = prepare_thicknesses(thicknesses, len(points)) scales = prepare_scales(scales, len(points)) - for p, s, c, t in zip(points, scales, colors, thicknesses): + for p, s, c, t in zip(points, scales, colors, thicknesses, strict=True): img = draw_point(img, p, s, c, t) return img @@ -403,10 +474,10 @@ def draw_points( def draw_keypoints( img: np.ndarray, keypoints: _Keypoints, - scale: _Scale = 1., - thickness: _Thickness = -1 + scale: _Scale = 1.0, + thickness: _Thickness = -1, ) -> np.ndarray: - ''' + """ Draw keypoints on the image. Args: @@ -421,7 +492,7 @@ def draw_keypoints( Returns: np.ndarray: Image with the drawn keypoints. - ''' + """ img = prepare_img(img) keypoints = prepare_keypoints(keypoints) @@ -433,7 +504,7 @@ def draw_keypoints( scale = prepare_scale(scale) thickness = prepare_thickness(thickness) points = keypoints.numpy()[..., :2] - for p, c in zip(points, colors): + for p, c in zip(points, colors, strict=True): img = draw_point(img, p, scale, c, thickness) return img @@ -441,10 +512,10 @@ def draw_keypoints( def draw_keypoints_list( img: np.ndarray, keypoints_list: _KeypointsList, - scales: _Scales = 1., - thicknesses: _Thicknesses = -1 + scales: _Scales = 1.0, + thicknesses: _Thicknesses = -1, ) -> np.ndarray: - ''' + """ Draw keypoints list on the image. Args: @@ -459,45 +530,60 @@ def draw_keypoints_list( Returns: np.ndarray: Image with the drawn keypoints list. - ''' + """ img = prepare_img(img) keypoints_list = prepare_keypoints_list(keypoints_list) scales = prepare_scales(scales, len(keypoints_list)) thicknesses = prepare_thicknesses(thicknesses, len(keypoints_list)) - for ps, s, t in zip(keypoints_list, scales, thicknesses): + for ps, s, t in zip(keypoints_list, scales, thicknesses, strict=True): img = draw_keypoints(img, ps, s, t) return img -def generate_colors_from_cmap(n: int, scheme: str) -> List[tuple]: - cm = matplotlib.cm.get_cmap(scheme) - return [cm(i/n)[:-1] for i in range(n)] +def generate_colors_from_cmap( + n: int, scheme: str +) -> list[tuple[float, float, float]]: + cm = matplotlib.colormaps.get_cmap(scheme) + rgb_colors = [] + for i in range(n): + rgba = cm(i / n) + rgb_colors.append((float(rgba[0]), float(rgba[1]), float(rgba[2]))) + return rgb_colors -def generate_triadic_colors(n: int) -> List[tuple]: +def generate_triadic_colors(n: int) -> list[tuple[float, float, float]]: base_hue = np.random.rand() - return [matplotlib.colors.hsv_to_rgb(((base_hue + i / 3.0) % 1, 1, 1)) for i in range(n)] + return [ + tuple(mpl_colors.hsv_to_rgb(((base_hue + i / 3.0) % 1, 1, 1))) + for i in range(n) + ] -def generate_analogous_colors(n: int) -> List[tuple]: +def generate_analogous_colors(n: int) -> list[tuple[float, float, float]]: base_hue = np.random.rand() step = 0.05 - return [matplotlib.colors.hsv_to_rgb(((base_hue + i * step) % 1, 1, 1)) for i in range(n)] + return [ + tuple(mpl_colors.hsv_to_rgb(((base_hue + i * step) % 1, 1, 1))) + for i in range(n) + ] -def generate_square_colors(n: int) -> List[tuple]: +def generate_square_colors(n: int) -> list[tuple[float, float, float]]: base_hue = np.random.rand() - return [matplotlib.colors.hsv_to_rgb(((base_hue + i / 4.0) % 1, 1, 1)) for i in range(n)] + return [ + tuple(mpl_colors.hsv_to_rgb(((base_hue + i / 4.0) % 1, 1, 1))) + for i in range(n) + ] -def generate_colors(n: int, scheme: str = 'hsv') -> List[tuple]: +def generate_colors(n: int, scheme: str = "hsv") -> list[tuple[int, int, int]]: """ Generates n different colors based on the chosen color scheme. """ color_generators = { - 'triadic': generate_triadic_colors, - 'analogous': generate_analogous_colors, - 'square': generate_square_colors + "triadic": generate_triadic_colors, + "analogous": generate_analogous_colors, + "square": generate_square_colors, } if scheme in color_generators: @@ -507,19 +593,27 @@ def generate_colors(n: int, scheme: str = 'hsv') -> List[tuple]: colors = generate_colors_from_cmap(n, scheme) except ValueError: print( - f"Color scheme '{scheme}' not recognized. Returning empty list.") + f"Color scheme '{scheme}' not recognized. Returning empty list." + ) colors = [] - return [tuple(int(c * 255) for c in color) for color in colors] + return [ + ( + int(color[0] * 255), + int(color[1] * 255), + int(color[2] * 255), + ) + for color in colors + ] def draw_mask( img: np.ndarray, mask: np.ndarray, colormap: int = cv2.COLORMAP_JET, - weight: Tuple[float, float] = (0.5, 0.5), + weight: tuple[float, float] = (0.5, 0.5), gamma: float = 0, - min_max_normalize: bool = False + min_max_normalize: bool = False, ) -> np.ndarray: """ Draw the mask on the image. @@ -543,15 +637,16 @@ def draw_mask( """ # Ensure the input image has 3 channels - if img.ndim == 2: - img = np.stack([img] * 3, axis=-1) - else: - img = img.copy() # Avoid modifying the original image + img = np.stack([img] * 3, axis=-1) if img.ndim == 2 else img.copy() # Normalize mask if required if min_max_normalize: mask = mask.astype(np.float32) - mask = (mask - mask.min()) / (mask.max() - mask.min()) + denom = float(mask.max() - mask.min()) + if denom > 0: + mask = (mask - mask.min()) / denom + else: + mask = np.zeros_like(mask, dtype=np.float32) mask = (mask * 255).astype(np.uint8) else: mask = mask.astype(np.uint8) # Ensure mask is uint8 for color mapping @@ -579,7 +674,7 @@ def _vdc(n: int, base: int = 2) -> float: return vdc -@functools.lru_cache(maxsize=None) +@functools.cache def distinct_color(idx: int) -> _Color: """Generate a perceptually distinct BGR color for class *idx*. @@ -588,9 +683,9 @@ def distinct_color(idx: int) -> _Color: between close indices. 2. Saturation / Value: cycle every 20 / 10 ids to avoid hue-only clashes. """ - hue = _vdc(idx + 1) # (0,1) ─ 長距離跳耀 + hue = _vdc(idx + 1) # (0,1) ─ 長距離跳耀 sat_cycle = (0.65, 0.80, 0.50) # ┐週期性變化 - val_cycle = (1.00, 0.90, 0.80) # ┘增亮/加深 + val_cycle = (1.00, 0.90, 0.80) # ┘增亮/加深 s = sat_cycle[(idx // 20) % len(sat_cycle)] v = val_cycle[(idx // 10) % len(val_cycle)] r, g, b = colorsys.hsv_to_rgb(hue, s, v) @@ -609,11 +704,11 @@ def draw_detection( img: np.ndarray, box: _Box, label: str, - score: Union[float, None] = None, + score: float | None = None, color: _Color | None = None, thickness: _Thickness | None = None, text_color: _Color = (255, 255, 255), - font_path: Union[str, Path] | None = None, + font_path: str | Path | None = None, text_size: int | None = None, box_alpha: float = 1.0, text_bg_alpha: float = 0.6, @@ -652,36 +747,49 @@ def draw_detection( draw_color = distinct_color(idx) else: draw_color = color + draw_color = prepare_color(draw_color) if thickness is None: # proportional to image diagonal - diag = (canvas.shape[0]**2 + canvas.shape[1]**2)**0.5 + diag = (canvas.shape[0] ** 2 + canvas.shape[1] ** 2) ** 0.5 line_thickness = max(1, int(diag * 0.002 + 0.5)) else: line_thickness = thickness + line_thickness = prepare_thickness(line_thickness) # 3. Draw box (with optional transparency) if box_alpha >= 1.0: - cv2.rectangle(canvas, (x1, y1), (x2, y2), - draw_color, line_thickness, cv2.LINE_AA) + cv2.rectangle( + canvas, + (x1, y1), + (x2, y2), + color=draw_color, + thickness=line_thickness, + lineType=cv2.LINE_AA, + ) else: overlay = canvas.copy() - cv2.rectangle(overlay, (x1, y1), (x2, y2), - draw_color, line_thickness, cv2.LINE_AA) + cv2.rectangle( + overlay, + (x1, y1), + (x2, y2), + color=draw_color, + thickness=line_thickness, + lineType=cv2.LINE_AA, + ) canvas = cv2.addWeighted(overlay, box_alpha, canvas, 1 - box_alpha, 0) # 4. Prepare label text - text = f"{label} {score*100:.1f}%" if score is not None else label + text = f"{label} {score * 100:.1f}%" if score is not None else label # auto font size (~10% of box height, min 12) if text_size is None: text_size = max(12, int((y2 - y1) * 0.10)) # 5. Measure text size with PIL - font_file = DEFAULT_FONT_PATH if font_path is None else Path(font_path) pil_img = Image.fromarray(cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB)) draw = ImageDraw.Draw(pil_img, "RGBA") - font = ImageFont.truetype(str(font_file), size=text_size) + font = _load_font(font_path, size=text_size) bbox = draw.textbbox((0, 0), text, font=font) text_w, text_h = bbox[2] - bbox[0], bbox[3] - bbox[1] @@ -701,14 +809,14 @@ def draw_detection( y0, y1_ = sorted([bg_y0, bg_y1]) # 7. Draw semi-transparent background + text_color = prepare_color(text_color) draw.rectangle( [(x0, y0), (x1_, y1_)], fill=(*draw_color[::-1], int(text_bg_alpha * 255)), ) # 8. Draw text (PIL expecting RGB) - draw.text((x0 + pad, y0 + pad), text, - font=font, fill=text_color[::-1]) + draw.text((x0 + pad, y0 + pad), text, font=font, fill=text_color[::-1]) # 9. Convert back to BGR OpenCV annotated = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR) @@ -718,13 +826,13 @@ def draw_detection( def draw_detections( img: np.ndarray, boxes: _Boxes, - labels: List[str], - scores: List[float] = None, - colors: _Colors = None, - thicknesses: _Thicknesses = None, + labels: list[str], + scores: list[float] | None = None, + colors: _Colors | None = None, + thicknesses: _Thicknesses | None = None, text_colors: _Colors = (255, 255, 255), - font_path: Union[str, Path] = None, - text_sizes: List[int] = None, + font_path: str | Path | None = None, + text_sizes: list[int] | None = None, box_alpha: float = 1.0, text_bg_alpha: float = 0.6, ) -> np.ndarray: @@ -756,18 +864,17 @@ def draw_detections( if scores is not None and len(scores) != len(labels): raise ValueError("Number of scores must match number of labels") - if colors is not None: - colors = prepare_colors(colors, len(boxes)) - else: - colors = [None] * len(boxes) - - if thicknesses is not None: - thicknesses = prepare_thicknesses(thicknesses, len(boxes)) - else: - thicknesses = [None] * len(boxes) - - if text_colors is not None: - text_colors = prepare_colors(text_colors, len(boxes)) + colors_list = ( + prepare_colors(colors, len(boxes)) + if colors is not None + else [None] * len(boxes) + ) + thicknesses_list = ( + prepare_thicknesses(thicknesses, len(boxes)) + if thicknesses is not None + else [None] * len(boxes) + ) + text_colors_list = prepare_colors(text_colors, len(boxes)) if text_sizes is not None: text_sizes = [int(size) for size in text_sizes] @@ -775,10 +882,9 @@ def draw_detections( for i, box in enumerate(boxes): label = labels[i] score = scores[i] if scores is not None else None - color = colors[i] if colors is not None else None - thickness = thicknesses[i] if thicknesses is not None else None - text_color = text_colors[i] if text_colors is not None else ( - 255, 255, 255) + color = colors_list[i] + thickness = thicknesses_list[i] + text_color = text_colors_list[i] text_size = text_sizes[i] if text_sizes is not None else None canvas = draw_detection( @@ -792,7 +898,7 @@ def draw_detections( font_path=font_path, text_size=text_size, box_alpha=box_alpha, - text_bg_alpha=text_bg_alpha + text_bg_alpha=text_bg_alpha, ) return canvas diff --git a/capybara/vision/visualization/utils.py b/capybara/vision/visualization/utils.py index 378acb4..81beda7 100644 --- a/capybara/vision/visualization/utils.py +++ b/capybara/vision/visualization/utils.py @@ -1,51 +1,61 @@ -from typing import Any, List, Optional, Tuple, Union +from collections.abc import Sequence +from typing import Any, TypeAlias, cast import numpy as np -from ...structures import (Box, Boxes, BoxMode, Keypoints, KeypointsList, - Polygon, Polygons) +from ...structures import ( + Box, + Boxes, + BoxMode, + Keypoints, + KeypointsList, + Polygon, + Polygons, +) from ...structures.boxes import _Box, _Boxes, _Number from ...structures.keypoints import _Keypoints, _KeypointsList from ...structures.polygons import _Polygon, _Polygons -_Color = Union[int, List[int], Tuple[int, int, int], np.ndarray] -_Colors = Union[_Color, List[_Color], np.ndarray] -_Point = Union[List[int], Tuple[int, int], Tuple[int, int], np.ndarray] -_Points = Union[List[_Point], np.ndarray] -_Thickness = Union[_Number, np.ndarray] -_Thicknesses = Union[List[_Thickness], _Thickness, np.ndarray] -_Scale = Union[_Number, np.ndarray] -_Scales = Union[List[_Scale], _Number, np.ndarray] +_Color: TypeAlias = int | Sequence[int] | tuple[int, int, int] | np.ndarray +_Colors: TypeAlias = _Color | Sequence[_Color] | np.ndarray +_Point: TypeAlias = Sequence[int] | tuple[int, int] | np.ndarray +_Points: TypeAlias = Sequence[_Point] | np.ndarray +_Thickness: TypeAlias = _Number +_Thicknesses: TypeAlias = _Thickness | Sequence[_Thickness] | np.ndarray +_Scale: TypeAlias = _Number +_Scales: TypeAlias = _Scale | Sequence[_Scale] | np.ndarray __all__ = [ - 'is_numpy_img', - 'prepare_color', - 'prepare_colors', - 'prepare_img', - 'prepare_box', - 'prepare_boxes', - 'prepare_keypoints', - 'prepare_keypoints_list', - 'prepare_polygon', - 'prepare_polygons', - 'prepare_point', - 'prepare_points', - 'prepare_thickness', - 'prepare_thicknesses', - 'prepare_scale', - 'prepare_scales', + "is_numpy_img", + "prepare_box", + "prepare_boxes", + "prepare_color", + "prepare_colors", + "prepare_img", + "prepare_keypoints", + "prepare_keypoints_list", + "prepare_point", + "prepare_points", + "prepare_polygon", + "prepare_polygons", + "prepare_scale", + "prepare_scales", + "prepare_thickness", + "prepare_thicknesses", ] def is_numpy_img(x: Any) -> bool: if not isinstance(x, np.ndarray): return False - return (x.ndim == 2 or (x.ndim == 3 and x.shape[-1] in [1, 3])) + return x.ndim == 2 or (x.ndim == 3 and x.shape[-1] in [1, 3]) -def prepare_color(color: _Color, ind: int = None) -> Tuple[int, int, int]: - ''' +def prepare_color( + color: _Color, ind: int | None = None +) -> tuple[int, int, int]: + """ This function prepares the color input for opencv. Args: @@ -57,23 +67,32 @@ def prepare_color(color: _Color, ind: int = None) -> Tuple[int, int, int]: Returns: Tuple[int, int, int]: a tuple of 3 integers. - ''' + """ cond1 = isinstance(color, int) - cond2 = isinstance(color, (tuple, list)) and len(color) == 3 and \ - isinstance(color[0], int) and \ - isinstance(color[1], int) and \ - isinstance(color[2], int) - cond3 = isinstance(color, np.ndarray) and color.ndim == 1 and len(color) == 3 + cond2 = ( + isinstance(color, (tuple, list)) + and len(color) == 3 + and isinstance(color[0], int) + and isinstance(color[1], int) + and isinstance(color[2], int) + ) + cond3 = ( + isinstance(color, np.ndarray) and color.ndim == 1 and len(color) == 3 + ) if not (cond1 or cond2 or cond3): - i = '' if ind is None else f's[{ind}]' - raise TypeError(f'The input color{i} = {color} is invalid. Should be {_Color}') - c = (color, ) * 3 if cond1 else color + i = "" if ind is None else f"s[{ind}]" + raise TypeError( + f"The input color{i} = {color} is invalid. Should be {_Color}" + ) + c = (color,) * 3 if cond1 else color c = tuple(np.array(c, dtype=int).tolist()) return c -def prepare_colors(colors: _Colors, length: Union[int] = None) -> List[Tuple[int, int, int]]: - ''' +def prepare_colors( + colors: _Colors, length: int | None = None +) -> list[tuple[int, int, int]]: + """ This function prepares the colors input for opencv. Args: @@ -82,21 +101,23 @@ def prepare_colors(colors: _Colors, length: Union[int] = None) -> List[Tuple[int Returns: List[Tuple[int, int, int]]: a list of tuples of 3 integers. - ''' - if isinstance(colors, (list, np.ndarray)) and not isinstance(colors[0], _Number): + """ + try: + c = prepare_color(cast(Any, colors), 0) + except TypeError: + if not isinstance(colors, (list, tuple, np.ndarray)): + raise if length is not None and len(colors) != length: - raise ValueError(f'The length of colors = {len(colors)} is not equal to the length = {length}.') - cs = [] - for i, color in enumerate(colors): - cs.append(prepare_color(color, i)) - else: - c = prepare_color(colors, 0) - cs = [c] * length - return cs + raise ValueError( + f"The length of colors = {len(colors)} is not equal to the length = {length}." + ) from None + return [prepare_color(color, i) for i, color in enumerate(colors)] + repeat = 1 if length is None else length + return [c] * repeat -def prepare_img(img: np.ndarray, ind: Optional[int] = None) -> np.ndarray: - ''' +def prepare_img(img: np.ndarray, ind: int | None = None) -> np.ndarray: + """ This function prepares the image input for opencv. Args: @@ -108,25 +129,25 @@ def prepare_img(img: np.ndarray, ind: Optional[int] = None) -> np.ndarray: Returns: np.ndarray: a valid numpy image for opencv. - ''' + """ if is_numpy_img(img): if img.ndim == 2: img = img[..., None].repeat(3, axis=-1) elif img.ndim == 3 and img.shape[-1] == 1: img = img.repeat(3, axis=-1) else: - i = '' if ind is None else f's[{ind}]' - raise ValueError(f'The input image{i} is not invalid numpy image.') + i = "" if ind is None else f"s[{ind}]" + raise ValueError(f"The input image{i} is not invalid numpy image.") return img def prepare_box( box: _Box, - ind: Optional[int] = None, - src_mode: Union[str, BoxMode] = BoxMode.XYXY, - dst_mode: Union[str, BoxMode] = BoxMode.XYXY, + ind: int | None = None, + src_mode: str | BoxMode = BoxMode.XYXY, + dst_mode: str | BoxMode = BoxMode.XYXY, ) -> Box: - ''' + """ This function prepares the box input to XYXY format. Args: @@ -137,23 +158,27 @@ def prepare_box( Returns: Box: a valid Box instance. - ''' + """ try: is_normalized = box.is_normalized if isinstance(box, Box) else False src_mode = box.box_mode if isinstance(box, Box) else src_mode - box = Box(box, box_mode=src_mode, is_normalized=is_normalized).convert(dst_mode) - except: - i = "" if ind is None else f'es[{ind}]' - raise ValueError(f"The input box{i} is invalid value = {box}. Should be {_Box}") + box = Box(box, box_mode=src_mode, is_normalized=is_normalized).convert( + dst_mode + ) + except Exception as exc: + i = "" if ind is None else f"es[{ind}]" + raise ValueError( + f"The input box{i} is invalid value = {box}. Should be {_Box}" + ) from exc return box def prepare_boxes( boxes: _Boxes, - src_mode: Union[str, BoxMode] = BoxMode.XYXY, - dst_mode: Union[str, BoxMode] = BoxMode.XYXY, + src_mode: str | BoxMode = BoxMode.XYXY, + dst_mode: str | BoxMode = BoxMode.XYXY, ) -> Boxes: - ''' + """ This function prepares the boxes input to XYXY format. Args: @@ -163,16 +188,23 @@ def prepare_boxes( Returns: Boxes: a valid Boxes instance. - ''' + """ if isinstance(boxes, Boxes): boxes = boxes.convert(dst_mode) else: - boxes = Boxes([prepare_box(box, i, src_mode, dst_mode) for i, box in enumerate(boxes)]) + boxes = Boxes( + [ + prepare_box(box, i, src_mode, dst_mode) + for i, box in enumerate(boxes) + ] + ) return boxes -def prepare_keypoints(keypoints: _Keypoints, ind: Optional[int] = None) -> Keypoints: - ''' +def prepare_keypoints( + keypoints: _Keypoints, ind: int | None = None +) -> Keypoints: + """ This function prepares the keypoints input. Args: @@ -181,17 +213,21 @@ def prepare_keypoints(keypoints: _Keypoints, ind: Optional[int] = None) -> Keypo Returns: Keypoints: a valid Keypoints instance. - ''' + """ + if isinstance(keypoints, Keypoints): + return keypoints try: keypoints = Keypoints(keypoints) - except: - i = "" if ind is None else f'_list[{ind}]' - raise TypeError(f"The input keypoints{i} is invalid value = {keypoints}. Should be {_Keypoints}") + except Exception as exc: + i = "" if ind is None else f"_list[{ind}]" + raise TypeError( + f"The input keypoints{i} is invalid value = {keypoints}. Should be {_Keypoints}" + ) from exc return keypoints def prepare_keypoints_list(keypoints_list: _KeypointsList) -> KeypointsList: - ''' + """ This function prepares the keypoints list input. Args: @@ -199,78 +235,100 @@ def prepare_keypoints_list(keypoints_list: _KeypointsList) -> KeypointsList: Returns: KeypointsList: a valid KeypointsList instance. - ''' - if not isinstance(keypoints_list, KeypointsList): - keypoints_list = KeypointsList([prepare_keypoints(keypoints, i) for i, keypoints in enumerate(keypoints_list)]) + """ + if isinstance(keypoints_list, KeypointsList): + return keypoints_list + keypoints_list = KeypointsList( + [ + prepare_keypoints(keypoints, i) + for i, keypoints in enumerate(keypoints_list) + ] + ) return keypoints_list -def prepare_polygon(polygon: _Polygon, ind: Union[int] = None) -> Polygon: +def prepare_polygon(polygon: _Polygon, ind: int | None = None) -> Polygon: + if isinstance(polygon, Polygon): + return polygon try: polygon = Polygon(polygon) - except: - i = "" if ind is None else f's[{ind}]' - raise TypeError(f"The input polygon{i} is invalid value = {polygon}. Should be {_Polygon}") + except Exception as exc: + i = "" if ind is None else f"s[{ind}]" + raise TypeError( + f"The input polygon{i} is invalid value = {polygon}. Should be {_Polygon}" + ) from exc return polygon def prepare_polygons(polygons: _Polygons) -> Polygons: - if not isinstance(polygons, Polygons): - polygons = Polygons([prepare_polygon(polygon, i) for i, polygon in enumerate(polygons)]) + if isinstance(polygons, Polygons): + return polygons + polygons = Polygons( + [prepare_polygon(polygon, i) for i, polygon in enumerate(polygons)] + ) return polygons -def prepare_point(point: _Point, ind: Optional[int] = None) -> tuple: - cond1 = isinstance(point, (tuple, list)) and len(point) == 2 and \ - isinstance(point[0], _Number) and \ - isinstance(point[1], _Number) - cond2 = isinstance(point, np.ndarray) and point.ndim == 1 and len(point) == 2 +def prepare_point(point: _Point, ind: int | None = None) -> tuple: + cond1 = ( + isinstance(point, (tuple, list)) + and len(point) == 2 + and isinstance(point[0], _Number) + and isinstance(point[1], _Number) + ) + cond2 = ( + isinstance(point, np.ndarray) and point.ndim == 1 and len(point) == 2 + ) if not (cond1 or cond2): - i = '' if ind is None else f's[{ind}]' - raise TypeError(f'The input point{i} is invalid.') + i = "" if ind is None else f"s[{ind}]" + raise TypeError(f"The input point{i} is invalid.") return tuple(np.array(point, dtype=int).tolist()) -def prepare_points(points: _Points) -> List[_Point]: +def prepare_points(points: _Points) -> list[_Point]: ps = [] for i, point in enumerate(points): ps.append(prepare_point(point, i)) return ps -def prepare_thickness(thickness: _Thickness, ind: int = None) -> int: +def prepare_thickness(thickness: _Thickness, ind: int | None = None) -> int: if not isinstance(thickness, _Number) or thickness < -1: - i = '' if ind is None else f's[{ind}]' - raise ValueError(f'The thickness[{i}] = {thickness} is not correct. \n') - thickness = np.array(thickness, dtype='int').tolist() - return thickness + i = "" if ind is None else f"s[{ind}]" + raise ValueError(f"The thickness[{i}] = {thickness} is not correct. \n") + value = np.array(thickness, dtype="int").tolist() + return int(value) -def prepare_thicknesses(thicknesses: _Thicknesses, length: Optional[int] = None) -> List[int]: - if isinstance(thicknesses, (list, np.ndarray)): +def prepare_thicknesses( + thicknesses: _Thicknesses, length: int | None = None +) -> list[int]: + if isinstance(thicknesses, _Number): + thickness = prepare_thickness(thicknesses, 0) + repeat = 1 if length is None else length + cs = [thickness] * repeat + else: cs = [] for i, thickness in enumerate(thicknesses): cs.append(prepare_thickness(thickness, i)) - else: - thickness = prepare_thickness(thicknesses, 0) - cs = [thickness] * length return cs -def prepare_scale(scale: _Scale, ind: int = None) -> float: +def prepare_scale(scale: _Scale, ind: int | None = None) -> float: if not isinstance(scale, _Number) or scale < -1: - i = '' if ind is None else f's[{ind}]' - raise ValueError(f'The scale[{i}] = {scale} is not correct. \n') - scale = np.array(scale, dtype=float).tolist() - return scale + i = "" if ind is None else f"s[{ind}]" + raise ValueError(f"The scale[{i}] = {scale} is not correct. \n") + value = np.array(scale, dtype=float).tolist() + return float(value) -def prepare_scales(scales: _Scales, length: Optional[int] = None) -> List[float]: - if isinstance(scales, (list, np.ndarray)): +def prepare_scales(scales: _Scales, length: int | None = None) -> list[float]: + if isinstance(scales, _Number): + scale = prepare_scale(scales, 0) + repeat = 1 if length is None else length + cs = [scale] * repeat + else: cs = [] for i, scale in enumerate(scales): cs.append(prepare_scale(scale, i)) - else: - scale = prepare_scale(scales, 0) - cs = [scale] * length return cs diff --git a/docker/Dockerfile b/docker/Dockerfile index 0ed3011..cea6e05 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -21,12 +21,12 @@ RUN ln -s /usr/bin/python3 /usr/bin/python && \ python -m pip install --no-cache-dir -U pip setuptools wheel COPY . /usr/local/Capybara -RUN cd /usr/local/Capybara && python setup.py bdist_wheel && \ - python -m pip install dist/*.whl && rm -rf /usr/local/Capybara +RUN python -m pip install --no-cache-dir /usr/local/Capybara && \ + rm -rf /usr/local/Capybara # Preload data RUN python -c "import capybara" WORKDIR /code -CMD ["bash"] \ No newline at end of file +CMD ["bash"] diff --git a/pyproject.toml b/pyproject.toml index 58ac44e..b196535 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools>=61", "wheel"] +requires = ["setuptools>=68", "wheel"] build-backend = "setuptools.build_meta" [project] @@ -16,16 +16,16 @@ classifiers = [ "Intended Audience :: Science/Research", "Operating System :: OS Independent", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", "Topic :: Software Development :: Libraries", "Topic :: Software Development :: Libraries :: Python Modules" ] dependencies = [ "dacite", - "psutil", "requests", - "onnx", - "colored", "numpy", "pdf2image", "ujson", @@ -34,20 +34,52 @@ dependencies = [ "pybase64", "PyTurboJPEG", "dill", - "networkx", "natsort", - "flask", "shapely", "piexif", - "matplotlib", "opencv-python>=4.12.0.88", - "onnxslim", "beautifulsoup4", - "onnxruntime==1.22.0; platform_system == 'Darwin'", - "onnxruntime_gpu==1.22.0; platform_system == 'Linux'", "pillow-heif" ] +[project.optional-dependencies] +# Inference backends (opt-in; keep `pip install .` minimal) +onnxruntime = [ + "onnxruntime>=1.22.0,<2", + "onnx>=1.18.0", + "onnxslim>=0.1.0", +] +onnxruntime-gpu = [ + "onnxruntime-gpu>=1.22.0,<2", + "onnx>=1.18.0", + "onnxslim>=0.1.0", +] +openvino = [ + "openvino>=2024.0.0", +] +torchscript = [ + "torch>=2.0", +] +all = [ + "onnxruntime>=1.22.0,<2", + "onnx>=1.18.0", + "onnxslim>=0.1.0", + "openvino>=2024.0.0", + "torch>=2.0", +] + +# Feature extras (non-core runtime) +ipcam = [ + "flask>=2.0", +] +system = [ + "psutil", +] +visualization = [ + "matplotlib", + "pillow", +] + [project.urls] Homepage = "https://docsaid.org/en/docs/capybara/" Repository = "https://github.com/DocsaidLab/Capybara" @@ -58,7 +90,68 @@ include-package-data = true [tool.setuptools.packages.find] include = ["capybara*"] -exclude = ["demo", "tests", "docs", ".github", "docker", "wheelhouse"] +exclude = [ + "demo", + "tests", + "docs", + ".github", + "docker", + "wheelhouse" +] [tool.setuptools.dynamic] -version = { attr = "capybara.__version__" } \ No newline at end of file +version = { attr = "capybara.__version__" } + +[tool.pytest.ini_options] +pythonpath = [ + ".", +] +testpaths = [ + "tests", +] +addopts = "--ignore=tmp" + +[tool.ruff] +target-version = "py310" +line-length = 80 +extend-exclude = [ + ".git", + ".mypy_cache", + ".pytest_cache", + "__pycache__", + ".venv", + "venv", + "build", + "dist", + ".tox", + ".eggs", + "*.egg-info", +] + +[tool.ruff.lint] +select = [ + "E", + "F", + "W", + "B", + "UP", + "N", + "I", + "C4", + "SIM", + "RUF", +] +ignore = [ + "E203", + "E501", +] + +[tool.ruff.lint.isort] +known-first-party = ["capybara", "tests"] + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +skip-magic-trailing-comma = false +line-ending = "auto" +docstring-code-format = true diff --git a/pyrightconfig.json b/pyrightconfig.json new file mode 100644 index 0000000..cf282ef --- /dev/null +++ b/pyrightconfig.json @@ -0,0 +1,7 @@ +{ + "include": ["capybara", "tests"], + "exclude": ["tmp", "capybara/cpuinfo.py"], + "pythonVersion": "3.10", + "typeCheckingMode": "basic", + "reportMissingTypeStubs": false +} diff --git a/setup.py b/setup.py deleted file mode 100644 index 6b40b52..0000000 --- a/setup.py +++ /dev/null @@ -1,4 +0,0 @@ -from setuptools import setup - -if __name__ == '__main__': - setup() diff --git a/tests/onnxengine/test_engine_stubbed.py b/tests/onnxengine/test_engine_stubbed.py new file mode 100644 index 0000000..fae9953 --- /dev/null +++ b/tests/onnxengine/test_engine_stubbed.py @@ -0,0 +1,339 @@ +from __future__ import annotations + +import importlib +import sys +import types +from enum import Enum +from typing import Any + +import numpy as np +import pytest + + +@pytest.fixture() +def onnx_engine_module(monkeypatch): + class GraphOptimizationLevel(Enum): + ORT_DISABLE_ALL = "disable" + ORT_ENABLE_BASIC = "basic" + ORT_ENABLE_EXTENDED = "extended" + ORT_ENABLE_ALL = "all" + + class ExecutionMode(Enum): + ORT_SEQUENTIAL = "seq" + ORT_PARALLEL = "par" + + modelmeta_holder: dict[str, dict[str, object] | None] = {"map": None} + + class SessionOptions: + def __init__(self): + self.enable_profiling = False + self.graph_optimization_level = None + self.execution_mode = None + self.intra_op_num_threads = None + self.inter_op_num_threads = None + self.log_severity_level = None + self.config_entries: dict[str, str] = {} + + def add_session_config_entry(self, key: str, value: str): + self.config_entries[key] = value + + class RunOptions: + def __init__(self): + self.entries: dict[str, str] = {} + + def add_run_config_entry(self, key: str, value: str): + self.entries[key] = value + + class FakeIOBinding: + def __init__(self, session): + self.session = session + self.output_storage: dict[str, np.ndarray] = {} + self.bound_outputs: list[str] = [] + self.bound_inputs: dict[str, np.ndarray] = {} + + def clear_binding_inputs(self): + self.bound_inputs.clear() + + def clear_binding_outputs(self): + self.bound_outputs.clear() + self.output_storage = {} + + def bind_cpu_input(self, name: str, array: np.ndarray): + self.bound_inputs[name] = array + + def bind_output(self, name: str): + self.bound_outputs.append(name) + + def copy_outputs_to_cpu(self) -> list[np.ndarray]: + return [self.output_storage[name] for name in self.bound_outputs] + + class FakeSession: + def __init__( + self, model_path, sess_options, providers, provider_options + ): + self.model_path = model_path + self.sess_options = sess_options + self._providers = providers + self._provider_options = provider_options + self._inputs = [ + types.SimpleNamespace( + name="input", type="tensor(float)", shape=[1, "dyn", 4] + ) + ] + self._outputs = [ + types.SimpleNamespace( + name="output", type="tensor(float)", shape=[1, 4] + ) + ] + self._binding = FakeIOBinding(self) + self.last_run_options = None + + def get_providers(self): + return self._providers + + def get_provider_options(self): + return self._provider_options + + def get_modelmeta(self): + meta_map = modelmeta_holder["map"] + if meta_map is None: + raise RuntimeError("no metadata configured") + return types.SimpleNamespace(custom_metadata_map=meta_map) + + def get_outputs(self): + return self._outputs + + def get_inputs(self): + return self._inputs + + def io_binding(self): + return self._binding + + def run(self, output_names, feed, run_options=None): + self.last_run_options = run_options + return [np.ones((1, 4), dtype=np.float32) for _ in output_names] + + def run_with_iobinding(self, binding, run_options=None): + self.last_run_options = run_options + for idx, name in enumerate(binding.bound_outputs): + binding.output_storage[name] = np.full( + (1, 4), float(idx), dtype=np.float32 + ) + + fake_ort: Any = types.ModuleType("onnxruntime") + fake_ort.GraphOptimizationLevel = GraphOptimizationLevel + fake_ort.ExecutionMode = ExecutionMode + fake_ort.SessionOptions = SessionOptions + fake_ort.RunOptions = RunOptions + fake_ort.InferenceSession = FakeSession + fake_ort.get_available_providers = lambda: ["CPUExecutionProvider"] + + monkeypatch.setitem(sys.modules, "onnxruntime", fake_ort) + module = importlib.reload( + importlib.import_module("capybara.onnxengine.engine") + ) + module_any: Any = module + module_any._test_modelmeta_holder = modelmeta_holder + yield module + + +def test_onnx_engine_with_stubbed_runtime(onnx_engine_module): + engine_config_cls = onnx_engine_module.EngineConfig + onnx_engine_cls = onnx_engine_module.ONNXEngine + + config = engine_config_cls( + graph_optimization="basic", + execution_mode="parallel", + intra_op_num_threads=2, + inter_op_num_threads=1, + log_severity_level=0, + session_config_entries={"session.use": "1"}, + provider_options={ + "CUDAExecutionProvider": {"arena_extend_strategy": "manual"} + }, + fallback_to_cpu=True, + enable_io_binding=True, + run_config_entries={"run.tag": "demo"}, + enable_profiling=True, + ) + + engine = onnx_engine_cls( + model_path="model.onnx", + backend="cuda", + session_option={"custom": "value"}, + provider_option={"do_copy_in_default_stream": False}, + config=config, + ) + + feed = {"input": np.ones((1, 4), dtype=np.float64)} + outputs = engine(**feed) + assert set(outputs) == {"output"} + assert outputs["output"].dtype == np.float32 + + summary = engine.summary() + assert summary["model"] == "model.onnx" + assert summary["inputs"][0]["shape"][1] is None + + stats = engine.benchmark(feed, repeat=3, warmup=1) + assert stats["repeat"] == 3 + assert "mean" in stats["latency_ms"] + + +def test_onnx_engine_benchmark_validates_repeat_and_warmup(onnx_engine_module): + onnx_engine_cls = onnx_engine_module.ONNXEngine + engine = onnx_engine_cls(model_path="model.onnx", backend="cpu") + feed = {"input": np.ones((1, 4), dtype=np.float32)} + + with pytest.raises(ValueError, match="repeat must be >= 1"): + engine.benchmark(feed, repeat=0) + + with pytest.raises(ValueError, match="warmup must be >= 0"): + engine.benchmark(feed, warmup=-1) + + +def test_onnx_engine_accepts_wrapped_feed_dict(onnx_engine_module): + """__call__ supports passing a single mapping payload as a kwarg.""" + onnx_engine_cls = onnx_engine_module.ONNXEngine + engine = onnx_engine_cls(model_path="model.onnx", backend="cpu") + + feed = {"input": np.ones((1, 4), dtype=np.float32)} + outputs = engine(payload=feed) + assert outputs["output"].shape == (1, 4) + + +def test_onnx_engine_run_method_uses_mapping(onnx_engine_module): + """run(feed) is a stable public API for inference.""" + onnx_engine_cls = onnx_engine_module.ONNXEngine + engine = onnx_engine_cls(model_path="model.onnx", backend="cpu") + + outputs = engine.run({"input": np.ones((1, 4), dtype=np.float32)}) + assert outputs["output"].dtype == np.float32 + + +def test_onnx_engine_iobinding_without_run_options(onnx_engine_module): + """When run_config_entries is empty, IO binding runs without RunOptions.""" + onnx_engine_cls = onnx_engine_module.ONNXEngine + engine_config_cls = onnx_engine_module.EngineConfig + + config = engine_config_cls(enable_io_binding=True, run_config_entries=None) + engine = onnx_engine_cls( + model_path="model.onnx", backend="cpu", config=config + ) + + out = engine.run({"input": np.ones((1, 4), dtype=np.float32)}) + assert engine._session.last_run_options is None + assert np.all(out["output"] == 0.0) + + +def test_onnx_engine_session_option_overrides_attribute(onnx_engine_module): + """Known SessionOptions attributes are set directly instead of config entries.""" + onnx_engine_cls = onnx_engine_module.ONNXEngine + + engine = onnx_engine_cls( + model_path="model.onnx", + backend="cpu", + session_option={"enable_profiling": True}, + ) + + assert engine._session.sess_options.enable_profiling is True + + +def test_onnx_engine_accepts_enum_config_values(onnx_engine_module): + """Enum values pass through without string normalization.""" + onnx_engine_cls = onnx_engine_module.ONNXEngine + engine_config_cls = onnx_engine_module.EngineConfig + + config = engine_config_cls( + graph_optimization=onnx_engine_module.ort.GraphOptimizationLevel.ORT_ENABLE_BASIC, + execution_mode=onnx_engine_module.ort.ExecutionMode.ORT_PARALLEL, + ) + + engine = onnx_engine_cls( + model_path="model.onnx", backend="cpu", config=config + ) + assert ( + engine._session.sess_options.graph_optimization_level + == onnx_engine_module.ort.GraphOptimizationLevel.ORT_ENABLE_BASIC + ) + assert ( + engine._session.sess_options.execution_mode + == onnx_engine_module.ort.ExecutionMode.ORT_PARALLEL + ) + + +def test_onnx_engine_extract_metadata_parses_json(onnx_engine_module): + """Custom metadata JSON payloads are parsed for easier downstream use.""" + onnx_engine_cls = onnx_engine_module.ONNXEngine + onnx_engine_module._test_modelmeta_holder["map"] = { + "author": '{"name": "bob"}', + "note": "plain", + "count": 3, + } + + engine = onnx_engine_cls(model_path="model.onnx", backend="cpu") + + assert engine.metadata == { + "author": {"name": "bob"}, + "note": "plain", + "count": 3, + } + + +def test_onnx_engine_run_uses_run_options_without_iobinding(onnx_engine_module): + """run_config_entries should build RunOptions even when IO binding is disabled.""" + onnx_engine_cls = onnx_engine_module.ONNXEngine + engine_config_cls = onnx_engine_module.EngineConfig + + config = engine_config_cls( + enable_io_binding=False, run_config_entries={"k": "v"} + ) + engine = onnx_engine_cls( + model_path="model.onnx", + backend="cpu", + config=config, + ) + + out = engine.run({"input": np.ones((1, 4), dtype=np.float32)}) + assert out["output"].shape == (1, 4) + assert engine._session.last_run_options is not None + + +def test_onnx_engine_converts_outputs_via_toarray(onnx_engine_module): + """Some onnxruntime outputs expose a toarray() helper instead of ndarray.""" + onnx_engine_cls = onnx_engine_module.ONNXEngine + engine_config_cls = onnx_engine_module.EngineConfig + + config = engine_config_cls(enable_io_binding=False) + engine = onnx_engine_cls( + model_path="model.onnx", backend="cpu", config=config + ) + + class _FakeOrtValue: + def __init__(self, value: float) -> None: + self.value = float(value) + + def toarray(self) -> np.ndarray: + return np.full((1, 4), self.value, dtype=np.float32) + + engine._session.run = lambda *_args, **_kwargs: [_FakeOrtValue(7.0)] # type: ignore[method-assign] + + out = engine.run({"input": np.ones((1, 4), dtype=np.float32)}) + assert np.all(out["output"] == 7.0) + + +def test_onnx_engine_missing_required_input_raises_keyerror(onnx_engine_module): + onnx_engine_cls = onnx_engine_module.ONNXEngine + engine = onnx_engine_cls(model_path="model.onnx", backend="cpu") + + with pytest.raises(KeyError, match="Missing required input"): + engine.run({}) + + +def test_onnx_engine_extract_metadata_returns_none_for_non_mapping( + onnx_engine_module, +): + onnx_engine_cls = onnx_engine_module.ONNXEngine + onnx_engine_module._test_modelmeta_holder["map"] = 123 + + engine = onnx_engine_cls(model_path="model.onnx", backend="cpu") + assert engine.metadata is None diff --git a/tests/onnxengine/test_utils_and_metadata.py b/tests/onnxengine/test_utils_and_metadata.py new file mode 100644 index 0000000..f684d1a --- /dev/null +++ b/tests/onnxengine/test_utils_and_metadata.py @@ -0,0 +1,246 @@ +from __future__ import annotations + +import types + +import onnx +import pytest +from onnx import TensorProto, helper + +from capybara.onnxengine import metadata, utils + + +@pytest.fixture() +def simple_onnx(tmp_path): + input_info = helper.make_tensor_value_info( + "input", TensorProto.FLOAT, [1, 3] + ) + output_info = helper.make_tensor_value_info( + "output", TensorProto.FLOAT, [1, 3] + ) + node = helper.make_node("Identity", ["input"], ["output"]) + graph = helper.make_graph([node], "test-graph", [input_info], [output_info]) + model = helper.make_model(graph) + path = tmp_path / "simple.onnx" + onnx.save(model, path) + return path + + +def test_get_input_and_output_infos(simple_onnx): + inputs = utils.get_onnx_input_infos(simple_onnx) + outputs = utils.get_onnx_output_infos(simple_onnx) + assert inputs["input"]["shape"] == [1, 3] + assert outputs["output"]["shape"] == [1, 3] + assert str(inputs["input"]["dtype"]) == "float32" + + +def test_get_input_and_output_infos_accept_model_proto(simple_onnx): + model = onnx.load(simple_onnx) + inputs = utils.get_onnx_input_infos(model) + outputs = utils.get_onnx_output_infos(model) + assert inputs["input"]["shape"] == [1, 3] + assert outputs["output"]["shape"] == [1, 3] + + +def test_make_onnx_dynamic_axes_overrides_dims( + simple_onnx, monkeypatch, tmp_path +): + monkeypatch.setattr( + utils, + "onnxslim", + types.SimpleNamespace(simplify=lambda model: (model, True)), + ) + out_path = tmp_path / "dynamic.onnx" + utils.make_onnx_dynamic_axes( + model_fpath=simple_onnx, + output_fpath=out_path, + input_dims={"input": {0: "batch"}}, + output_dims={"output": {0: "batch"}}, + opset_version=18, + ) + model = onnx.load(out_path) + assert ( + model.graph.input[0].type.tensor_type.shape.dim[0].dim_param == "batch" + ) + assert ( + model.graph.output[0].type.tensor_type.shape.dim[0].dim_param == "batch" + ) + + +def test_make_onnx_dynamic_axes_adds_default_opset_when_missing( + monkeypatch, tmp_path +): + input_info = helper.make_tensor_value_info( + "input", TensorProto.FLOAT, [1, 3] + ) + output_info = helper.make_tensor_value_info( + "output", TensorProto.FLOAT, [1, 3] + ) + node = helper.make_node("Identity", ["input"], ["output"]) + graph = helper.make_graph( + [node], "no-default-opset", [input_info], [output_info] + ) + model = helper.make_model( + graph, + opset_imports=[helper.make_opsetid(domain="ai.onnx.ml", version=1)], + ) + + src = tmp_path / "src.onnx" + out_path = tmp_path / "dynamic.onnx" + onnx.save(model, src) + + monkeypatch.setattr( + utils, + "onnxslim", + types.SimpleNamespace(simplify=lambda model: (model, True)), + ) + utils.make_onnx_dynamic_axes( + model_fpath=src, + output_fpath=out_path, + input_dims={"input": {0: "batch"}}, + output_dims={"output": {0: "batch"}}, + opset_version=18, + ) + updated = onnx.load(out_path) + assert any( + opset.domain == "" and opset.version == 18 + for opset in updated.opset_import + ) + + +def test_make_onnx_dynamic_axes_uses_current_opset_when_version_is_none( + monkeypatch, tmp_path +): + input_info = helper.make_tensor_value_info( + "input", TensorProto.FLOAT, [1, 3] + ) + output_info = helper.make_tensor_value_info( + "output", TensorProto.FLOAT, [1, 3] + ) + node = helper.make_node("Identity", ["input"], ["output"]) + graph = helper.make_graph( + [node], "no-default-opset", [input_info], [output_info] + ) + model = helper.make_model( + graph, + opset_imports=[helper.make_opsetid(domain="ai.onnx.ml", version=1)], + ) + + src = tmp_path / "src.onnx" + out_path = tmp_path / "dynamic.onnx" + onnx.save(model, src) + + monkeypatch.setattr(onnx.defs, "onnx_opset_version", lambda: 17) + monkeypatch.setattr( + utils, + "onnxslim", + types.SimpleNamespace(simplify=lambda model: (model, True)), + ) + utils.make_onnx_dynamic_axes( + model_fpath=src, + output_fpath=out_path, + input_dims={"input": {0: "batch"}}, + output_dims={"output": {0: "batch"}}, + opset_version=None, + ) + updated = onnx.load(out_path) + assert any( + opset.domain == "" and opset.version == 17 + for opset in updated.opset_import + ) + + +def test_make_onnx_dynamic_axes_accepts_simplify_returning_model( + simple_onnx, monkeypatch, tmp_path +): + monkeypatch.setattr( + utils, + "onnxslim", + types.SimpleNamespace(simplify=lambda model: model), + ) + out_path = tmp_path / "dynamic.onnx" + utils.make_onnx_dynamic_axes( + model_fpath=simple_onnx, + output_fpath=out_path, + input_dims={"input": {0: "batch"}}, + output_dims={"output": {0: "batch"}}, + opset_version=18, + ) + assert (tmp_path / "dynamic.onnx").exists() + + +def test_make_onnx_dynamic_axes_rejects_reshape_nodes(tmp_path): + input_info = helper.make_tensor_value_info( + "input", TensorProto.FLOAT, [1, 3] + ) + shape_init = helper.make_tensor( + "shape", TensorProto.INT64, dims=[2], vals=[1, 3] + ) + output_info = helper.make_tensor_value_info( + "output", TensorProto.FLOAT, [1, 3] + ) + node = helper.make_node("Reshape", ["input", "shape"], ["output"]) + graph = helper.make_graph( + [node], + "reshape-graph", + [input_info], + [output_info], + initializer=[shape_init], + ) + model = helper.make_model(graph) + src = tmp_path / "reshape.onnx" + onnx.save(model, src) + + with pytest.raises(ValueError, match="Reshape cannot be trasformed"): + utils.make_onnx_dynamic_axes( + model_fpath=src, + output_fpath=tmp_path / "out.onnx", + input_dims={"input": {0: "batch"}}, + output_dims={"output": {0: "batch"}}, + opset_version=18, + ) + + +@pytest.fixture(autouse=True) +def fake_ort(monkeypatch): + class FakeSession: + def __init__(self, path, providers=None): + self._path = path + + def get_modelmeta(self): + model = onnx.load(self._path) + mapping = {p.key: p.value for p in model.metadata_props} + return types.SimpleNamespace(custom_metadata_map=mapping) + + monkeypatch.setattr( + metadata, "ort", types.SimpleNamespace(InferenceSession=FakeSession) + ) + + +def test_metadata_roundtrip(simple_onnx, tmp_path): + out_path = tmp_path / "meta.onnx" + metadata.write_metadata_into_onnx( + simple_onnx, out_path, author={"name": "angizero"} + ) + parsed = metadata.parse_metadata_from_onnx(out_path) + assert parsed["author"]["name"] == "angizero" + + raw = metadata.get_onnx_metadata(out_path) + assert "author" in raw + + +def test_parse_metadata_preserves_non_string_values(monkeypatch, simple_onnx): + class FakeSession: + def __init__(self, path, providers=None): + self._path = path + + def get_modelmeta(self): + return types.SimpleNamespace( + custom_metadata_map={"raw": 123, "json": "1"} + ) + + monkeypatch.setattr( + metadata, "ort", types.SimpleNamespace(InferenceSession=FakeSession) + ) + parsed = metadata.parse_metadata_from_onnx(simple_onnx) + assert parsed["raw"] == 123 + assert parsed["json"] == 1 diff --git a/tests/onnxruntime/test_engine.py b/tests/onnxruntime/test_engine.py deleted file mode 100644 index 67e114a..0000000 --- a/tests/onnxruntime/test_engine.py +++ /dev/null @@ -1,39 +0,0 @@ -import numpy as np -import pytest - -from capybara import Backend, ONNXEngine, get_curdir, get_recommended_backend - - -def test_ONNXEngine_CPU(): - model_path = get_curdir(__file__).parent / "resources/model_dynamic-axes.onnx" - engine = ONNXEngine(model_path, backend="cpu") - for i in range(5): - xs = {"input": np.random.randn(32, 3, 224, 224).astype("float32")} - outs = engine(**xs) - if i: - assert not np.allclose(outs["output"], prev_outs["output"]) - prev_outs = outs - - -@pytest.mark.skipif(get_recommended_backend() != Backend.cuda, reason="Linux with GPU only") -def test_ONNXEngine_CUDA(): - model_path = get_curdir(__file__).parent / "resources/model_dynamic-axes.onnx" - engine = ONNXEngine(model_path, backend=get_recommended_backend()) - for i in range(5): - xs = {"input": np.random.randn(32, 3, 224, 224).astype("float32")} - outs = engine(**xs) - if i: - assert not np.allclose(outs["output"], prev_outs["output"]) - prev_outs = outs - - -@pytest.mark.skipif(get_recommended_backend() != "Darwin", reason="Mac only") -def test_ONNXEngine_COREML(): - model_path = get_curdir(__file__).parent / "resources/model_dynamic-axes.onnx" - engine = ONNXEngine(model_path, backend="coreml") - for i in range(5): - xs = {"input": np.random.randn(32, 3, 224, 224).astype("float32")} - outs = engine(**xs) - if i: - assert not np.allclose(outs["output"], prev_outs["output"]) - prev_outs = outs diff --git a/tests/onnxruntime/test_engine_io_binding.py b/tests/onnxruntime/test_engine_io_binding.py deleted file mode 100644 index 796d8a3..0000000 --- a/tests/onnxruntime/test_engine_io_binding.py +++ /dev/null @@ -1,17 +0,0 @@ -import numpy as np -import pytest - -from capybara import Backend, ONNXEngineIOBinding, get_curdir, get_recommended_backend - - -@pytest.mark.skipif(get_recommended_backend() != Backend.cuda, reason="Linux with GPU only") -def test_ONNXEngineIOBinding_CUDAonly(): - model_path = get_curdir(__file__).parent / "resources/model_dynamic-axes.onnx" - input_initializer = {"input": np.random.randn(32, 3, 448, 448).astype("float32")} - engine = ONNXEngineIOBinding(model_path, input_initializer) - for i in range(5): - xs = {"input": np.random.randn(32, 3, 448, 448).astype("float32")} - outs = engine(**xs) - if i: - assert not np.allclose(outs["output"], prev_outs["output"]) - prev_outs = outs diff --git a/tests/onnxruntime/test_tools.py b/tests/onnxruntime/test_tools.py deleted file mode 100644 index 23f6016..0000000 --- a/tests/onnxruntime/test_tools.py +++ /dev/null @@ -1,39 +0,0 @@ -import numpy as np - -from capybara import ( - ONNXEngine, - get_onnx_input_infos, - get_onnx_output_infos, - make_onnx_dynamic_axes, -) - - -def test_get_onnx_input_infos(): - model_path = "tests/resources/model_shape=224x224.onnx" - input_infos = get_onnx_input_infos(model_path) - assert input_infos == {"input": {"shape": [1, 3, 224, 224], "dtype": "float32"}} - - -def test_get_onnx_output_infos(): - model_path = "tests/resources/model_shape=224x224.onnx" - output_infos = get_onnx_output_infos(model_path) - assert output_infos == {"output": {"shape": [1, 64, 56, 56], "dtype": "float32"}} - - -def test_make_onnx_dynamic_axes(): - model_path = "tests/resources/model_shape=224x224.onnx" - input_infos = get_onnx_input_infos(model_path) - output_infos = get_onnx_output_infos(model_path) - input_dims = {k: {0: "b", 2: "h", 3: "w"} for k in input_infos.keys()} - output_dims = {k: {0: "b", 2: "h", 3: "w"} for k in output_infos.keys()} - new_model_path = "/tmp/model_dynamic-axes.onnx" - make_onnx_dynamic_axes( - model_path, - new_model_path, - input_dims=input_dims, - output_dims=output_dims, - ) - xs = {"input": np.random.randn(32, 3, 320, 320).astype("float32")} - engine = ONNXEngine(new_model_path, session_option={"log_severity_level": 1}, backend="cpu") - outs = engine(**xs) - assert outs["output"].shape == (32, 64, 80, 80) diff --git a/tests/openvinoengine/test_openvino_engine_stubbed.py b/tests/openvinoengine/test_openvino_engine_stubbed.py new file mode 100644 index 0000000..2ed1d5b --- /dev/null +++ b/tests/openvinoengine/test_openvino_engine_stubbed.py @@ -0,0 +1,601 @@ +from __future__ import annotations + +import importlib +import queue +import sys +import threading +import types +import warnings +from typing import Any + +import numpy as np +import pytest + + +@pytest.fixture() +def openvino_engine_module(monkeypatch): + class FakeType: + f32 = object() + i64 = object() + boolean = object() + + class FakeDim: + def __init__(self, value): + self.value = value + self.is_static = value is not None + + def get_length(self): + if self.value is None: + raise ValueError("dynamic dim") + return self.value + + class FakePort: + def __init__(self, name, element_type, shape): + self._name = name + self._type = element_type + self._shape = shape + + def get_any_name(self): + return self._name + + def get_element_type(self): + return self._type + + def get_partial_shape(self): + dims = [] + for value in self._shape: + if isinstance(value, int): + dims.append(FakeDim(value)) + else: + dims.append(FakeDim(None)) + return dims + + class FakeTensor: + def __init__(self, value): + self.data = np.full((1, 2), value, dtype=np.float32) + + class FakeInferRequest: + def __init__(self, outputs): + self._outputs = outputs + + def infer(self, feed): + self._last_feed = feed + + def get_tensor(self, port): + idx = self._outputs.index(port) + return FakeTensor(idx + 1) + + class FakeCompiledModel: + def __init__(self, input_shapes=None): + shape = (1, None, 3) + if input_shapes: + shape = input_shapes.get("input", shape) + self.inputs = [ + FakePort("input", FakeType.f32, shape), + ] + self.outputs = [ + FakePort("output", FakeType.f32, (1, 2)), + ] + + def create_infer_request(self): + return FakeInferRequest(self.outputs) + + class FakeAsyncInferQueue: + def __init__(self, compiled_model, jobs): + self._compiled_model = compiled_model + self._jobs = int(jobs) + self._callback = None + + def set_callback(self, callback): + self._callback = callback + + def start_async(self, feed, userdata): + req = self._compiled_model.create_infer_request() + req.infer(feed) + if self._callback is not None: + self._callback(req, userdata) + + def wait_all(self): + return None + + class FakeCore: + def __init__(self): + self._properties = {} + self._last_reshape = None + + def set_property(self, props): + self._properties.update(props) + + def read_model(self, path): + class FakeModel: + def __init__(self, p): + self.path = p + self.reshape_map = {} + + def reshape(self, mapping): + self.reshape_map = mapping + + return FakeModel(path) + + def compile_model(self, model, device, properties): + self._properties["device"] = device + self._properties.update(properties) + self._last_reshape = getattr(model, "reshape_map", {}) + return FakeCompiledModel(self._last_reshape) + + fake_runtime: Any = types.ModuleType("openvino.runtime") + fake_runtime.Type = FakeType + fake_runtime.Core = FakeCore + fake_runtime.AsyncInferQueue = FakeAsyncInferQueue + fake_pkg: Any = types.ModuleType("openvino") + fake_pkg.runtime = fake_runtime + + monkeypatch.setitem(sys.modules, "openvino", fake_pkg) + monkeypatch.setitem(sys.modules, "openvino.runtime", fake_runtime) + module = importlib.reload( + importlib.import_module("capybara.openvinoengine.engine") + ) + yield module + + +def test_openvino_engine_runs_with_stub(openvino_engine_module, tmp_path): + config_cls = openvino_engine_module.OpenVINOConfig + engine_cls = openvino_engine_module.OpenVINOEngine + device_enum = openvino_engine_module.OpenVINODevice + + cfg = config_cls( + compile_properties={"PERF_HINT": "THROUGHPUT"}, + core_properties={"LOG_LEVEL": "ERROR"}, + cache_dir=tmp_path / "cache", + num_streams=2, + num_threads=4, + ) + + engine = engine_cls( + model_path="model.xml", + device=device_enum.cpu, + config=cfg, + ) + + feed = {"input": np.ones((1, 3), dtype=np.float32)} + outputs = engine(**feed) + assert outputs["output"].shape == (1, 2) + + summary = engine.summary() + assert summary["device"] == "CPU" + assert summary["inputs"][0]["shape"][1] is None + + +def test_openvino_engine_accepts_input_shapes(openvino_engine_module): + engine_cls = openvino_engine_module.OpenVINOEngine + device_enum = openvino_engine_module.OpenVINODevice + + engine = engine_cls( + model_path="model.xml", + device=device_enum.npu, + input_shapes={"input": (2, 3, 5)}, + ) + + summary = engine.summary() + assert summary["inputs"][0]["shape"] == [2, 3, 5] + + +def test_openvino_engine_async_queue(openvino_engine_module): + engine_cls = openvino_engine_module.OpenVINOEngine + + engine = engine_cls(model_path="model.xml", device="CPU") + with engine.create_async_queue(num_requests=2) as q: + fut = q.submit({"input": np.ones((1, 3), dtype=np.float32)}) + outputs = fut.result(timeout=1) + assert outputs["output"].shape == (1, 2) + + +def test_openvino_engine_async_queue_auto_requests(openvino_engine_module): + engine_cls = openvino_engine_module.OpenVINOEngine + + engine = engine_cls(model_path="model.xml", device="CPU") + with engine.create_async_queue(num_requests=0) as q: + fut = q.submit({"input": np.ones((1, 3), dtype=np.float32)}) + outputs = fut.result(timeout=1) + assert outputs["output"].shape == (1, 2) + + +def test_openvino_engine_benchmark_async(openvino_engine_module): + engine_cls = openvino_engine_module.OpenVINOEngine + + engine = engine_cls(model_path="model.xml", device="CPU") + stats = engine.benchmark_async( + {"input": np.ones((1, 3), dtype=np.float32)}, + repeat=10, + warmup=1, + num_requests=2, + ) + assert stats["num_requests"] == 2 + + +def test_openvino_engine_async_queue_auto_request_id(openvino_engine_module): + engine_cls = openvino_engine_module.OpenVINOEngine + + engine = engine_cls(model_path="model.xml", device="CPU") + completion = queue.Queue() + with engine.create_async_queue(num_requests=2) as q: + fut = q.submit( + {"input": np.ones((1, 3), dtype=np.float32)}, + completion_queue=completion, + ) + outputs = fut.result(timeout=1) + req_id, event_outputs = completion.get(timeout=1) + + assert getattr(fut, "request_id", None) is not None + assert req_id == fut.request_id + assert event_outputs is not outputs + + +def test_openvino_engine_async_queue_preserves_request_id( + openvino_engine_module, +): + engine_cls = openvino_engine_module.OpenVINOEngine + + engine = engine_cls(model_path="model.xml", device="CPU") + completion = queue.Queue() + with engine.create_async_queue(num_requests=2) as q: + fut = q.submit( + {"input": np.ones((1, 3), dtype=np.float32)}, + request_id="req-123", + completion_queue=completion, + ) + outputs = fut.result(timeout=1) + req_id, event_outputs = completion.get(timeout=1) + + assert fut.request_id == "req-123" + assert req_id == "req-123" + assert event_outputs is not outputs + + +def test_openvino_engine_async_queue_completion_queue(openvino_engine_module): + engine_cls = openvino_engine_module.OpenVINOEngine + + engine = engine_cls(model_path="model.xml", device="CPU") + completion = queue.Queue() + with engine.create_async_queue(num_requests=2) as q: + fut = q.submit( + {"input": np.ones((1, 3), dtype=np.float32)}, + request_id="req-1", + completion_queue=completion, + ) + outputs = fut.result(timeout=1) + req_id, event_outputs = completion.get(timeout=1) + + assert req_id == "req-1" + assert event_outputs is not outputs + outputs["mutated"] = np.zeros((1,), dtype=np.float32) + assert "mutated" not in event_outputs + assert outputs["output"].shape == (1, 2) + assert event_outputs["output"].shape == (1, 2) + + +def test_openvino_engine_async_queue_completion_queue_full_does_not_block( + openvino_engine_module, +): + engine_cls = openvino_engine_module.OpenVINOEngine + + engine = engine_cls(model_path="model.xml", device="CPU") + completion = queue.Queue(maxsize=1) + completion.put(("sentinel", {})) + + result: dict[str, object] = {} + + def submit_request(): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RuntimeWarning) + with engine.create_async_queue(num_requests=2) as q: + fut = q.submit( + {"input": np.ones((1, 3), dtype=np.float32)}, + request_id="req-1", + completion_queue=completion, + ) + result["outputs"] = fut.result(timeout=1) + + thread = threading.Thread(target=submit_request, daemon=True) + thread.start() + thread.join(timeout=0.5) + + if thread.is_alive(): + completion.get_nowait() + thread.join(timeout=0.5) + pytest.fail("q.submit() blocked when completion_queue was full") + + assert completion.qsize() == 1 + assert completion.get_nowait()[0] == "sentinel" + assert isinstance(result.get("outputs"), dict) + + +def test_openvino_engine_async_queue_completion_queue_full_emits_warning( + openvino_engine_module, +): + engine_cls = openvino_engine_module.OpenVINOEngine + + engine = engine_cls(model_path="model.xml", device="CPU") + completion = queue.Queue(maxsize=1) + completion.put(("sentinel", {})) + + with engine.create_async_queue(num_requests=2) as q: + with pytest.warns(RuntimeWarning, match="completion_queue is full"): + fut = q.submit( + {"input": np.ones((1, 3), dtype=np.float32)}, + request_id="req-1", + completion_queue=completion, + ) + outputs = fut.result(timeout=1) + + assert outputs["output"].shape == (1, 2) + assert completion.qsize() == 1 + assert completion.get_nowait()[0] == "sentinel" + + +def test_openvino_device_from_any_accepts_strings_and_rejects_unknown( + openvino_engine_module, +): + device_enum = openvino_engine_module.OpenVINODevice + + assert device_enum.from_any(device_enum.cpu) is device_enum.cpu + assert device_enum.from_any("cpu") is device_enum.cpu + + with pytest.raises(ValueError, match="Unsupported OpenVINO device"): + device_enum.from_any("tpu") + + +def test_openvino_engine_accepts_wrapped_feed_dict(openvino_engine_module): + """__call__ supports passing a single mapping payload as a kwarg.""" + engine_cls = openvino_engine_module.OpenVINOEngine + engine = engine_cls(model_path="model.xml", device="CPU") + + feed = {"input": np.ones((1, 3), dtype=np.float32)} + outputs = engine(payload=feed) + assert outputs["output"].shape == (1, 2) + + +def test_openvino_engine_run_replaces_failed_request(openvino_engine_module): + """Infer failures should not leave a broken request in the pool.""" + engine_cls = openvino_engine_module.OpenVINOEngine + + engine = engine_cls(model_path="model.xml", device="CPU") + feed = {"input": np.ones((1, 3), dtype=np.float32)} + + failing_request = engine._request_pool.get_nowait() + + def boom(_prepared): + raise RuntimeError("infer boom") + + failing_request.infer = boom # type: ignore[assignment] + engine._request_pool.put(failing_request) + + with pytest.raises(RuntimeError, match="infer boom"): + engine.run(feed) + + replacement = engine._request_pool.get_nowait() + assert replacement is not failing_request + engine._request_pool.put(replacement) + + outputs = engine.run(feed) + assert outputs["output"].shape == (1, 2) + + +def test_openvino_engine_benchmark_sync(openvino_engine_module): + engine_cls = openvino_engine_module.OpenVINOEngine + + engine = engine_cls(model_path="model.xml", device="CPU") + stats = engine.benchmark( + {"input": np.ones((1, 3), dtype=np.float32)}, + repeat=3, + warmup=1, + ) + + assert stats["repeat"] == 3 + assert "latency_ms" in stats + + +def test_openvino_engine_benchmark_validates_repeat_and_warmup( + openvino_engine_module, +): + engine_cls = openvino_engine_module.OpenVINOEngine + engine = engine_cls(model_path="model.xml", device="CPU") + feed = {"input": np.ones((1, 3), dtype=np.float32)} + + with pytest.raises(ValueError, match="repeat must be >= 1"): + engine.benchmark(feed, repeat=0) + + with pytest.raises(ValueError, match="warmup must be >= 0"): + engine.benchmark(feed, warmup=-1) + + with pytest.raises(ValueError, match="repeat must be >= 1"): + engine.benchmark_async(feed, repeat=0) + + with pytest.raises(ValueError, match="warmup must be >= 0"): + engine.benchmark_async(feed, warmup=-1) + + +def test_openvino_engine_num_requests_validation(openvino_engine_module): + engine_cls = openvino_engine_module.OpenVINOEngine + config_cls = openvino_engine_module.OpenVINOConfig + + engine = engine_cls( + model_path="model.xml", + device="CPU", + config=config_cls(num_requests=0), + ) + assert engine._request_pool.maxsize == 1 + + with pytest.raises(ValueError, match="num_requests must be >= 0"): + engine_cls( + model_path="model.xml", + device="CPU", + config=config_cls(num_requests=-1), + ) + + with pytest.raises(ValueError, match="num_requests must be >= 0"): + engine.create_async_queue(num_requests=-1) + + +def test_openvino_engine_input_shapes_reject_none_dims(openvino_engine_module): + engine_cls = openvino_engine_module.OpenVINOEngine + + with pytest.raises(ValueError, match="must use concrete dimensions"): + engine_cls( + model_path="model.xml", + device="CPU", + input_shapes={"input": (1, None, 3)}, + ) + + +def test_openvino_engine_requires_model_reshape_when_input_shapes( + openvino_engine_module, +): + engine_cls = openvino_engine_module.OpenVINOEngine + + class NoReshapeCore: + def read_model(self, path): + return object() + + def compile_model(self, model, device, properties): # pragma: no cover + raise AssertionError( + "compile_model should not run when reshape fails" + ) + + with pytest.raises(RuntimeError, match="does not support reshape"): + engine_cls( + model_path="model.xml", + device="CPU", + core=NoReshapeCore(), + input_shapes={"input": (1, 3, 5)}, + ) + + +def test_openvino_engine_prepare_feed_validates_and_casts( + openvino_engine_module, +): + engine_cls = openvino_engine_module.OpenVINOEngine + engine = engine_cls(model_path="model.xml", device="CPU") + + with pytest.raises(KeyError, match="Missing required input"): + engine.run({"missing": np.ones((1, 3), dtype=np.float32)}) + + feed = {"input": np.ones((1, 3), dtype=np.float64)} + engine.run(feed) + req = engine._request_pool.get_nowait() + assert req._last_feed["input"].dtype == np.float32 # type: ignore[attr-defined] + engine._request_pool.put(req) + + +def test_openvino_partial_shape_to_tuple_accepts_int_dims( + openvino_engine_module, +): + engine_cls = openvino_engine_module.OpenVINOEngine + engine = engine_cls(model_path="model.xml", device="CPU") + + assert engine._partial_shape_to_tuple([1, 2, "dyn"]) == (1, 2, None) + + +def test_openvino_build_type_map_handles_missing_type(openvino_engine_module): + engine_cls = openvino_engine_module.OpenVINOEngine + engine = engine_cls(model_path="model.xml", device="CPU") + + assert engine._build_type_map(types.SimpleNamespace()) == {} + + +def test_openvino_async_queue_requires_asyncinferqueue(openvino_engine_module): + engine_cls = openvino_engine_module.OpenVINOEngine + engine = engine_cls(model_path="model.xml", device="CPU") + engine._ov = types.SimpleNamespace() # no AsyncInferQueue + + with pytest.raises(RuntimeError, match="AsyncInferQueue"): + engine.create_async_queue(num_requests=2) + + +def test_openvino_async_queue_submit_after_close_raises(openvino_engine_module): + engine_cls = openvino_engine_module.OpenVINOEngine + engine = engine_cls(model_path="model.xml", device="CPU") + + q = engine.create_async_queue(num_requests=1) + q.close() + q.close() # idempotent + with pytest.raises(RuntimeError, match="Async queue is closed"): + q.submit({"input": np.ones((1, 3), dtype=np.float32)}) + + +def test_openvino_async_queue_completion_queue_put_fallback( + openvino_engine_module, +): + """Fallback to completion_queue.put(block=False) when put_nowait is absent.""" + engine_cls = openvino_engine_module.OpenVINOEngine + engine = engine_cls(model_path="model.xml", device="CPU") + + class PutOnlyQueue: + def __init__(self): + self.items = [] + + def put(self, item, *, block=False): + self.items.append((item, block)) + + completion = PutOnlyQueue() + with engine.create_async_queue(num_requests=1) as q: + fut = q.submit( + {"input": np.ones((1, 3), dtype=np.float32)}, + request_id="req-1", + completion_queue=completion, + ) + fut.result(timeout=1) + + assert completion.items[0][0][0] == "req-1" + assert completion.items[0][1] is False + + +def test_openvino_async_queue_completion_queue_signature_mismatch_warns( + openvino_engine_module, +): + """Completion queues without non-blocking put semantics should warn once.""" + engine_cls = openvino_engine_module.OpenVINOEngine + engine = engine_cls(model_path="model.xml", device="CPU") + + class BadQueue: + def put(self, item): + self.item = item + + completion = BadQueue() + with engine.create_async_queue(num_requests=1) as q: + with pytest.warns( + RuntimeWarning, + match="does not support non-blocking put", + ): + fut = q.submit( + {"input": np.ones((1, 3), dtype=np.float32)}, + request_id="req-1", + completion_queue=completion, + ) + fut.result(timeout=1) + + +def test_openvino_engine_request_pool_respects_num_requests( + openvino_engine_module, +): + config_cls = openvino_engine_module.OpenVINOConfig + engine_cls = openvino_engine_module.OpenVINOEngine + + cfg = config_cls(num_requests=3) + engine = engine_cls(model_path="model.xml", device="CPU", config=cfg) + + assert engine._request_pool.maxsize == 3 + + +def test_openvino_engine_input_shapes_accepts_non_sequence_values( + openvino_engine_module, +): + engine_cls = openvino_engine_module.OpenVINOEngine + + engine = engine_cls( + model_path="model.xml", + device="CPU", + input_shapes={"input": {"dim": 3}}, + ) + + assert engine._core._last_reshape["input"] == {"dim": 3} diff --git a/tests/resources/make_onnx.py b/tests/resources/make_onnx.py index 52e6dac..cb4a5a4 100644 --- a/tests/resources/make_onnx.py +++ b/tests/resources/make_onnx.py @@ -4,6 +4,7 @@ try: import torch + model = torch.nn.Sequential( torch.nn.Conv2d(3, 32, 3, 2, 1), torch.nn.BatchNorm2d(32), @@ -12,20 +13,24 @@ torch.nn.BatchNorm2d(64), torch.nn.ReLU(), ) + dummy_input = torch.randn(1, 3, 224, 224) torch.onnx.export( model, - torch.randn(1, 3, 224, 224), - cur_folder/'model_shape=224x224.onnx', - input_names=['input'], - output_names=['output'], + (dummy_input,), + cur_folder / "model_shape=224x224.onnx", + input_names=["input"], + output_names=["output"], ) torch.onnx.export( model, - torch.randn(1, 3, 224, 224), - cur_folder/'model_dynamic-axes.onnx', - input_names=['input'], - output_names=['output'], - dynamic_axes={'input': {0: 'b', 2: 'h', 3: 'w'}, 'output': {0: 'b', 2: 'h', 3: 'w'}}, + (dummy_input,), + cur_folder / "model_dynamic-axes.onnx", + input_names=["input"], + output_names=["output"], + dynamic_axes={ + "input": {0: "b", 2: "h", 3: "w"}, + "output": {0: "b", 2: "h", 3: "w"}, + }, ) except ImportError: print("PyTorch is not installed.") diff --git a/tests/structures/test_boxes.py b/tests/structures/test_boxes.py index 438cc2a..6c188af 100644 --- a/tests/structures/test_boxes.py +++ b/tests/structures/test_boxes.py @@ -1,35 +1,39 @@ +from typing import Any + import numpy as np import pytest from capybara import Box, Boxes, BoxMode -def test_invalid_input_type(): +def test_box_invalid_input_type(): with pytest.raises(TypeError): - Box("invalid_input") + invalid_input: Any = "invalid_input" + Box(invalid_input) -def test_invalid_input_shape(): +def test_box_invalid_input_shape(): with pytest.raises(TypeError): - Box([1, 2, 3, 4, 5]) # 長度為5而非4,不符合預期的box格式 + Box([1, 2, 3, 4, 5]) # 長度為5而非4, 不符合預期的box格式 -def test_normalized_array(): +def test_box_accepts_is_normalized_flag(): array = np.array([0.1, 0.2, 0.3, 0.4]) box = Box(array, is_normalized=True) assert box.is_normalized is True -def test_invalid_box_mode(): +def test_box_invalid_box_mode(): with pytest.raises(KeyError): array = np.array([1, 2, 3, 4]) Box(array, box_mode="invalid_mode") -def test_array_conversion(): +def test_box_array_conversion(): array = [1, 2, 3, 4] box = Box(array) - assert np.allclose(box._array, np.array(array, dtype='float32')) + assert np.allclose(box._array, np.array(array, dtype="float32")) + # Test Box initialization @@ -39,13 +43,17 @@ def test_box_init(): box = Box((50, 50, 100, 100), box_mode=BoxMode.XYXY) assert isinstance(box, Box), "Initialization of Box failed." + # Test conversion of Box format def test_box_convert(): box = Box((50, 50, 100, 100), box_mode=BoxMode.XYXY) converted_box = box.convert(BoxMode.XYWH) - assert np.allclose(converted_box.numpy(), np.array([50, 50, 50, 50])), "Box conversion failed." + assert np.allclose(converted_box.numpy(), np.array([50, 50, 50, 50])), ( + "Box conversion failed." + ) + # Test calculation of area of Box @@ -54,13 +62,17 @@ def test_box_area(): box = Box((50, 50, 100, 100), box_mode=BoxMode.XYXY) assert box.area == 2500, "Box area calculation failed." + # Test Box.copy() method def test_box_copy(): box = Box((50, 50, 100, 100), box_mode=BoxMode.XYXY) copied_box = box.copy() - assert copied_box is not box and (copied_box._array == box._array).all(), "Box copy failed." + assert copied_box is not box and (copied_box._array == box._array).all(), ( + "Box copy failed." + ) + # Test Box conversion to numpy array @@ -69,7 +81,9 @@ def test_box_numpy(): box = Box((50, 50, 100, 100), box_mode=BoxMode.XYXY) arr = box.numpy() assert isinstance(arr, np.ndarray) and np.allclose( - arr, np.array([50, 50, 100, 100])), "Box to numpy conversion failed." + arr, np.array([50, 50, 100, 100]) + ), "Box to numpy conversion failed." + # Test Box normalization @@ -77,7 +91,10 @@ def test_box_numpy(): def test_box_normalize(): box = Box((50, 50, 100, 100), box_mode=BoxMode.XYXY) normalized_box = box.normalize(200, 200) - assert np.allclose(normalized_box.numpy(), np.array([0.25, 0.25, 0.5, 0.5])), "Box normalization failed." + assert np.allclose( + normalized_box.numpy(), np.array([0.25, 0.25, 0.5, 0.5]) + ), "Box normalization failed." + # Test Box denormalization @@ -85,7 +102,10 @@ def test_box_normalize(): def test_box_denormalize(): box = Box((0.25, 0.25, 0.5, 0.5), box_mode=BoxMode.XYXY, is_normalized=True) denormalized_box = box.denormalize(200, 200) - assert np.allclose(denormalized_box.numpy(), np.array([50, 50, 100, 100])), "Box denormalization failed." + assert np.allclose( + denormalized_box.numpy(), np.array([50, 50, 100, 100]) + ), "Box denormalization failed." + # Test Box clipping @@ -93,7 +113,23 @@ def test_box_denormalize(): def test_box_clip(): box = Box((50, 50, 100, 100), box_mode=BoxMode.XYXY) clipped_box = box.clip(60, 60, 90, 90) - assert np.allclose(clipped_box.numpy(), np.array([60, 60, 90, 90])), "Box clipping failed." + assert np.allclose(clipped_box.numpy(), np.array([60, 60, 90, 90])), ( + "Box clipping failed." + ) + + +def test_box_clip_preserves_box_mode_for_xywh_inputs(): + box = Box((10, 20, 5, 5), box_mode=BoxMode.XYWH) + clipped_box = box.clip(0, 0, 12, 30) + assert clipped_box.box_mode == BoxMode.XYWH + assert np.allclose(clipped_box.numpy(), np.array([10, 20, 2, 5])) + + +def test_box_convert_preserves_is_normalized_flag(): + box = Box((0.1, 0.2, 0.3, 0.4), box_mode=BoxMode.XYXY, is_normalized=True) + converted = box.convert(BoxMode.XYWH) + assert converted.is_normalized is True + # Test Box shifting @@ -101,7 +137,10 @@ def test_box_clip(): def test_box_shift(): box = Box((50, 50, 100, 100), box_mode=BoxMode.XYXY) shifted_box = box.shift(10, -10) - assert np.allclose(shifted_box.numpy(), np.array([60, 40, 110, 90])), "Box shifting failed." + assert np.allclose(shifted_box.numpy(), np.array([60, 40, 110, 90])), ( + "Box shifting failed." + ) + # Test Box scaling @@ -109,7 +148,10 @@ def test_box_shift(): def test_box_scale(): box = Box((50, 50, 100, 100), box_mode=BoxMode.XYXY) scaled_box = box.scale(dsize=(20, 0)) - assert np.allclose(scaled_box.numpy(), np.array([40, 50, 110, 100])), "Box scaling failed." + assert np.allclose(scaled_box.numpy(), np.array([40, 50, 110, 100])), ( + "Box scaling failed." + ) + # Test Box to_list @@ -118,127 +160,187 @@ def test_box_to_list(): box = Box((50, 50, 50, 50), box_mode=BoxMode.XYXY) assert box.to_list() == [50, 50, 50, 50], "Boxes tolist failed." + # Test Box to_polygon def test_box_to_polygon(): box = Box((50, 50, 100, 100), box_mode=BoxMode.XYXY) polygon = box.to_polygon() - assert np.allclose(polygon.numpy(), np.array( - [[50, 50], [100, 50], [100, 100], [50, 100]])), "Box convert_to_polygon failed." + assert np.allclose( + polygon.numpy(), np.array([[50, 50], [100, 50], [100, 100], [50, 100]]) + ), "Box convert_to_polygon failed." + # Test Boxes initialization -def test_invalid_input_type(): +def test_boxes_invalid_input_type(): with pytest.raises(TypeError): - Boxes("invalid_input") + invalid_input: Any = "invalid_input" + Boxes(invalid_input) -def test_invalid_input_shape(): +def test_boxes_invalid_input_shape(): with pytest.raises(TypeError): Boxes([[1, 2, 3, 4, 5]]) -def test_normalized_array(): +def test_boxes_accepts_is_normalized_flag(): array = np.array([0.1, 0.2, 0.3, 0.4]) box = Boxes([array], is_normalized=True) assert box.is_normalized is True -def test_invalid_box_mode(): +def test_boxes_invalid_box_mode(): with pytest.raises(KeyError): array = np.array([1, 2, 3, 4]) - Box(array, box_mode="invalid_mode") + Boxes([array], box_mode="invalid_mode") -def test_array_conversion(): +def test_boxes_array_conversion(): array = [[1, 2, 3, 4]] box = Boxes(array) - assert np.allclose(box._array, np.array(array, dtype='float32')) + assert np.allclose(box._array, np.array(array, dtype="float32")) def test_boxes_init(): # Create boxes in XYXY format - boxes = Boxes([(50, 50, 100, 100), [60, 60, 120, 120]], box_mode=BoxMode.XYXY) + boxes = Boxes( + [(50, 50, 100, 100), [60, 60, 120, 120]], box_mode=BoxMode.XYXY + ) assert isinstance(boxes, Boxes), "Initialization of Boxes failed." + # Test conversion of Boxes format def test_boxes_convert(): - boxes = Boxes([(50, 50, 100, 100), (60, 60, 120, 120)], box_mode=BoxMode.XYXY) + boxes = Boxes( + [(50, 50, 100, 100), (60, 60, 120, 120)], box_mode=BoxMode.XYXY + ) converted_boxes = boxes.convert(BoxMode.XYWH) - assert np.allclose(converted_boxes.numpy(), np.array( - [[50, 50, 50, 50], [60, 60, 60, 60]])), "Boxes conversion failed." + assert np.allclose( + converted_boxes.numpy(), np.array([[50, 50, 50, 50], [60, 60, 60, 60]]) + ), "Boxes conversion failed." + # Test calculation of area of Boxes def test_boxes_area(): - boxes = Boxes([(50, 50, 100, 100), (60, 60, 120, 120)], box_mode=BoxMode.XYXY) - assert np.allclose(boxes.area, np.array([2500, 3600])), "Boxes area calculation failed." + boxes = Boxes( + [(50, 50, 100, 100), (60, 60, 120, 120)], box_mode=BoxMode.XYXY + ) + assert np.allclose(boxes.area, np.array([2500, 3600])), ( + "Boxes area calculation failed." + ) + # Test Boxes.copy() method def test_boxes_copy(): - boxes = Boxes([(50, 50, 100, 100), (60, 60, 120, 120)], box_mode=BoxMode.XYXY) + boxes = Boxes( + [(50, 50, 100, 100), (60, 60, 120, 120)], box_mode=BoxMode.XYXY + ) copied_boxes = boxes.copy() - assert copied_boxes is not boxes and (copied_boxes._array == boxes._array).all(), "Boxes copy failed." + assert ( + copied_boxes is not boxes + and (copied_boxes._array == boxes._array).all() + ), "Boxes copy failed." + # Test Boxes conversion to numpy array def test_boxes_numpy(): - boxes = Boxes([(50, 50, 100, 100), (60, 60, 120, 120)], box_mode=BoxMode.XYXY) + boxes = Boxes( + [(50, 50, 100, 100), (60, 60, 120, 120)], box_mode=BoxMode.XYXY + ) arr = boxes.numpy() - assert isinstance(arr, np.ndarray) and np.allclose(arr, np.array( - [(50, 50, 100, 100), (60, 60, 120, 120)])), "Boxes to numpy conversion failed." + assert isinstance(arr, np.ndarray) and np.allclose( + arr, np.array([(50, 50, 100, 100), (60, 60, 120, 120)]) + ), "Boxes to numpy conversion failed." + # Test Boxes normalization def test_boxes_normalize(): - boxes = Boxes([(50, 50, 100, 100), (60, 60, 120, 120)], box_mode=BoxMode.XYXY) + boxes = Boxes( + [(50, 50, 100, 100), (60, 60, 120, 120)], box_mode=BoxMode.XYXY + ) normalized_boxes = boxes.normalize(200, 200) - assert np.allclose(normalized_boxes.numpy(), np.array( - [[0.25, 0.25, 0.5, 0.5], [0.3, 0.3, 0.6, 0.6]])), "Boxes normalization failed." + assert np.allclose( + normalized_boxes.numpy(), + np.array([[0.25, 0.25, 0.5, 0.5], [0.3, 0.3, 0.6, 0.6]]), + ), "Boxes normalization failed." + # Test Boxes denormalization def test_boxes_denormalize(): - boxes = Boxes([(0.25, 0.25, 0.5, 0.5), (0.3, 0.3, 0.6, 0.6)], box_mode=BoxMode.XYXY, is_normalized=True) + boxes = Boxes( + [(0.25, 0.25, 0.5, 0.5), (0.3, 0.3, 0.6, 0.6)], + box_mode=BoxMode.XYXY, + is_normalized=True, + ) denormalized_boxes = boxes.denormalize(200, 200) - assert np.allclose(denormalized_boxes.numpy(), np.array( - [(50, 50, 100, 100), (60, 60, 120, 120)])), "Boxes denormalization failed." + assert np.allclose( + denormalized_boxes.numpy(), + np.array([(50, 50, 100, 100), (60, 60, 120, 120)]), + ), "Boxes denormalization failed." + # Test Boxes clipping def test_boxes_clip(): - boxes = Boxes([(50, 50, 100, 100), (60, 60, 120, 120)], box_mode=BoxMode.XYXY) + boxes = Boxes( + [(50, 50, 100, 100), (60, 60, 120, 120)], box_mode=BoxMode.XYXY + ) clipped_boxes = boxes.clip(60, 60, 90, 90) - assert np.allclose(clipped_boxes.numpy(), np.array([(60, 60, 90, 90), (60, 60, 90, 90)])), "Boxes clipping failed." + assert np.allclose( + clipped_boxes.numpy(), np.array([(60, 60, 90, 90), (60, 60, 90, 90)]) + ), "Boxes clipping failed." + # Test Boxes shifting def test_boxes_shift(): - boxes = Boxes([(50, 50, 100, 100), (60, 60, 120, 120)], box_mode=BoxMode.XYXY) + boxes = Boxes( + [(50, 50, 100, 100), (60, 60, 120, 120)], box_mode=BoxMode.XYXY + ) shifted_boxes = boxes.shift(10, -10) - assert np.allclose(shifted_boxes.numpy(), np.array( - [(60, 40, 110, 90), (70, 50, 130, 110)])), "Boxes shifting failed." + assert np.allclose( + shifted_boxes.numpy(), np.array([(60, 40, 110, 90), (70, 50, 130, 110)]) + ), "Boxes shifting failed." + # Test Boxes scaling def test_boxes_scale(): - boxes = Boxes([(50, 50, 100, 100), (60, 60, 120, 120)], box_mode=BoxMode.XYXY) + boxes = Boxes( + [(50, 50, 100, 100), (60, 60, 120, 120)], box_mode=BoxMode.XYXY + ) scaled_boxes = boxes.scale(dsize=(20, 0)) - assert np.allclose(scaled_boxes.numpy(), np.array( - [(40, 50, 110, 100), (50, 60, 130, 120)])), "Boxes scaling failed." + assert np.allclose( + scaled_boxes.numpy(), np.array([(40, 50, 110, 100), (50, 60, 130, 120)]) + ), "Boxes scaling failed." + + +def test_boxes_scale_supports_fy_with_single_box(): + """Regression: fy scaling should not index the 4th row for small N.""" + boxes = Boxes([(0, 0, 10, 10)], box_mode=BoxMode.XYWH) + scaled = boxes.scale(fy=2.0).convert(BoxMode.XYWH).numpy() + np.testing.assert_allclose( + scaled, np.array([[0, -5, 10, 20]], dtype=np.float32) + ) + # Test Boxes get_empty_index @@ -247,26 +349,183 @@ def test_boxes_get_empty_index(): boxes = Boxes([(50, 50, 50, 50), (60, 60, 120, 120)], box_mode=BoxMode.XYXY) assert boxes.get_empty_index() == 0, "Boxes get_empty_index failed." + # Test Boxes drop_empty def test_boxes_drop_empty(): boxes = Boxes([(50, 50, 50, 50), (60, 60, 120, 120)], box_mode=BoxMode.XYXY) boxes = boxes.drop_empty() - assert np.allclose(boxes.numpy(), np.array([(60, 60, 120, 120)])), "Boxes drop_empty failed." + assert np.allclose(boxes.numpy(), np.array([(60, 60, 120, 120)])), ( + "Boxes drop_empty failed." + ) + # Test Boxes tolist def test_boxes_tolist(): boxes = Boxes([(50, 50, 50, 50), (60, 60, 120, 120)], box_mode=BoxMode.XYXY) - assert boxes.tolist() == [[50, 50, 50, 50], [60, 60, 120, 120]], "Boxes tolist failed." + assert boxes.tolist() == [[50, 50, 50, 50], [60, 60, 120, 120]], ( + "Boxes tolist failed." + ) + # Test Boxes to_polygons def test_boxes_to_polygons(): - boxes = Boxes([(50, 50, 100, 100), (60, 60, 120, 120)], box_mode=BoxMode.XYXY) + boxes = Boxes( + [(50, 50, 100, 100), (60, 60, 120, 120)], box_mode=BoxMode.XYXY + ) polygons = boxes.to_polygons() - assert np.allclose(polygons.numpy(), np.array([[[50, 50], [100, 50], [100, 100], [50, 100]], [ - [60, 60], [120, 60], [120, 120], [60, 120]]])), "Boxes convert_to_polygons failed." + assert np.allclose( + polygons.numpy(), + np.array( + [ + [[50, 50], [100, 50], [100, 100], [50, 100]], + [[60, 60], [120, 60], [120, 120], [60, 120]], + ] + ), + ), "Boxes convert_to_polygons failed." + + +def test_boxmode_convert_supports_cxcywh_and_align_code_int_and_invalid(): + xywh = np.array([10, 20, 30, 40], dtype=np.float32) + + cxcywh = BoxMode.convert(xywh, 1, "cxcywh") + np.testing.assert_allclose( + cxcywh, np.array([25, 40, 30, 40], dtype=np.float32) + ) + + xyxy = BoxMode.convert(cxcywh, BoxMode.CXCYWH, BoxMode.XYXY) + np.testing.assert_allclose( + xyxy, np.array([10, 20, 40, 60], dtype=np.float32) + ) + + back_xywh = BoxMode.convert(xyxy, BoxMode.XYXY, BoxMode.XYWH) + np.testing.assert_allclose(back_xywh, xywh) + + with pytest.raises(TypeError, match="not int, str, or BoxMode"): + invalid_box_mode: Any = object() + BoxMode.align_code(invalid_box_mode) + + +def test_box_repr_getitem_slice_and_eq_non_box(): + box = Box((1, 2, 3, 4), box_mode=BoxMode.XYXY) + assert "Box(" in repr(box) + + np.testing.assert_allclose(box[:2], np.array([1, 2], dtype=np.float32)) + assert (box == object()) is False + + +def test_box_invalid_numpy_array_shape_raises(): + with pytest.raises(TypeError, match="got shape"): + Box(np.zeros((2, 2), dtype=np.float32)) + + +def test_box_square_warns_on_normalize_denormalize_clip_nan_and_scale_fx_fy(): + box_xywh = Box((0, 0, 10, 5), box_mode=BoxMode.XYWH) + square = box_xywh.square().convert(BoxMode.XYWH).numpy() + np.testing.assert_allclose(square, np.array([2.5, 0.0, 5.0, 5.0])) + + with pytest.warns(UserWarning, match="forced to do normalization"): + _ = Box((0.1, 0.1, 0.2, 0.2), is_normalized=True).normalize(10, 10) + + with pytest.warns(UserWarning, match="forced to do denormalization"): + _ = Box((1, 2, 3, 4), is_normalized=False).denormalize(10, 10) + + with pytest.raises(ValueError, match="infinite or NaN"): + Box((np.nan, 0, 1, 1)).clip(0, 0, 10, 10) + + scaled = Box((10, 10, 10, 20), box_mode=BoxMode.XYWH).scale(fx=2.0, fy=0.5) + np.testing.assert_allclose( + scaled.convert(BoxMode.XYWH).numpy(), np.array([5, 15, 20, 10]) + ) + + +def test_box_to_polygon_rejects_non_positive_size_and_properties_work(): + with pytest.raises(ValueError, match="invaild value"): + Box((0, 0, -1, 2), box_mode=BoxMode.XYWH).to_polygon() + + box = Box((10, 20, 30, 40), box_mode=BoxMode.XYWH) + assert box.width == 30 + assert box.height == 40 + np.testing.assert_allclose( + box.left_top, np.array([10, 20], dtype=np.float32) + ) + np.testing.assert_allclose( + box.right_bottom, np.array([40, 60], dtype=np.float32) + ) + np.testing.assert_allclose( + box.left_bottom, np.array([10, 60], dtype=np.float32) + ) + np.testing.assert_allclose( + box.right_top, np.array([40, 20], dtype=np.float32) + ) + assert box.aspect_ratio == 30 / 40 + np.testing.assert_allclose(box.center, np.array([25, 40], dtype=np.float32)) + + +def test_boxes_repr_indexing_eq_and_constructor_from_boxes(): + boxes = Boxes([[0, 0, 10, 10], [20, 20, 30, 30]], box_mode=BoxMode.XYXY) + assert "Boxes(" in repr(boxes) + + assert isinstance(boxes[[1]], Boxes) + assert len(boxes[[1]]) == 1 + assert isinstance(boxes[:1], Boxes) + assert len(boxes[:1]) == 1 + + mask = np.array([True, False]) + assert len(boxes[mask]) == 1 + + with pytest.raises(TypeError, match="Boxes indices"): + _ = boxes["0"] # type: ignore[index] + + assert (boxes == object()) is False + + converted = Boxes(boxes, box_mode=BoxMode.XYWH) + assert converted.box_mode == BoxMode.XYWH + assert len(converted) == len(boxes) + + +def test_boxes_square_warns_clip_nan_scale_fx_and_to_polygons_invalid(): + boxes_xywh = Boxes([[0, 0, 10, 5], [10, 10, 6, 8]], box_mode=BoxMode.XYWH) + squared = boxes_xywh.square().convert(BoxMode.XYWH).numpy() + assert np.allclose(squared[:, 2], squared[:, 3]) + + with pytest.warns(UserWarning, match="forced to do normalization"): + _ = Boxes([[0.1, 0.1, 0.2, 0.2]], is_normalized=True).normalize(10, 10) + + with pytest.warns(UserWarning, match="forced to do denormalization"): + _ = Boxes([[1, 2, 3, 4]], is_normalized=False).denormalize(10, 10) + + with pytest.raises(ValueError, match="infinite or NaN"): + Boxes([[np.nan, 0, 1, 1]], box_mode=BoxMode.XYXY).clip(0, 0, 10, 10) + + scaled = Boxes([[10, 10, 10, 20]], box_mode=BoxMode.XYWH).scale(fx=2.0) + np.testing.assert_allclose( + scaled.convert(BoxMode.XYWH).numpy(), np.array([[5, 10, 20, 20]]) + ) + + with pytest.raises(ValueError, match="invaild value"): + Boxes([[0, 0, -1, 2]], box_mode=BoxMode.XYWH).to_polygons() + + +def test_boxes_properties_work(): + boxes = Boxes([[10, 20, 30, 40], [0, 0, 2, 4]], box_mode=BoxMode.XYWH) + np.testing.assert_allclose(boxes.width, np.array([30, 2], dtype=np.float32)) + np.testing.assert_allclose( + boxes.height, np.array([40, 4], dtype=np.float32) + ) + np.testing.assert_allclose( + boxes.left_top, np.array([[10, 20], [0, 0]], dtype=np.float32) + ) + np.testing.assert_allclose( + boxes.right_bottom, + np.array([[40, 60], [2, 4]], dtype=np.float32), + ) + np.testing.assert_allclose(boxes.aspect_ratio, np.array([0.75, 0.5])) + np.testing.assert_allclose( + boxes.center, np.array([[25, 40], [1, 2]], dtype=np.float32) + ) diff --git a/tests/structures/test_functionals.py b/tests/structures/test_functionals.py index 4f8816f..7d5d51b 100644 --- a/tests/structures/test_functionals.py +++ b/tests/structures/test_functionals.py @@ -1,67 +1,86 @@ import numpy as np import pytest -from capybara import (Boxes, Polygon, jaccard_index, pairwise_ioa, - pairwise_iou, polygon_iou) +from capybara import ( + Box, + Boxes, + Keypoints, + Polygon, + jaccard_index, + pairwise_ioa, + pairwise_iou, + polygon_iou, +) +from capybara.structures.functionals import ( + calc_angle, + is_inside_box, + pairwise_intersection, + poly_angle, +) test_functionals_error_param = [ ( pairwise_iou, ([(1, 2, 3, 4)], [(1, 2, 3, 4)]), TypeError, - 'Input type of boxes1 and boxes2 must be Boxes' + "Input type of boxes1 and boxes2 must be Boxes", ), ( pairwise_iou, [Boxes([(1, 1, 0, 2)], "XYWH"), Boxes([(1, 1, 0, 2)], "XYWH")], ValueError, - 'Some boxes in Boxes has invaild value' + "Some boxes in Boxes has invaild value", ), ( pairwise_ioa, ([(1, 2, 3, 4)], [(1, 2, 3, 4)]), TypeError, - 'Input type of boxes1 and boxes2 must be Boxes' + "Input type of boxes1 and boxes2 must be Boxes", ), ( pairwise_ioa, [Boxes([(1, 1, 0, 2)], "XYWH"), Boxes([(1, 1, 0, 2)], "XYWH")], ValueError, - 'Some boxes in Boxes has invaild value' + "Some boxes in Boxes has invaild value", ), ] -@pytest.mark.parametrize('fn, test_input, error, match', test_functionals_error_param) +@pytest.mark.parametrize( + "fn, test_input, error, match", test_functionals_error_param +) def test_functionals_error(fn, test_input, error, match): with pytest.raises(error, match=match): fn(*test_input) -test_pairwise_iou_param = [( - Boxes(np.array([[10, 10, 20, 20], [15, 15, 25, 25]]), "XYXY"), - Boxes( - np.array([[10, 10, 20, 20], [15, 15, 25, 25], [25, 25, 10, 10]]), "XYWH"), - np.array([ - [1 / 4, 1 / 28, 0], - [1 / 4, 4 / 25, 0] - ], dtype='float32') -)] +test_pairwise_iou_param = [ + ( + Boxes(np.array([[10, 10, 20, 20], [15, 15, 25, 25]]), "XYXY"), + Boxes( + np.array([[10, 10, 20, 20], [15, 15, 25, 25], [25, 25, 10, 10]]), + "XYWH", + ), + np.array([[1 / 4, 1 / 28, 0], [1 / 4, 4 / 25, 0]], dtype="float32"), + ) +] -@pytest.mark.parametrize('boxes1, boxes2, expected', test_pairwise_iou_param) +@pytest.mark.parametrize("boxes1, boxes2, expected", test_pairwise_iou_param) def test_pairwise_iou(boxes1, boxes2, expected): assert (pairwise_iou(boxes1, boxes2) == expected).all() -test_pairwise_ioa_param = [( - Boxes(np.array([[10, 10, 20, 20]]), "XYXY"), - Boxes(np.array([[15, 15, 20, 20], [20, 20, 10, 10]]), "XYWH"), - np.array([[1 / 16, 0]], dtype='float32') -)] +test_pairwise_ioa_param = [ + ( + Boxes(np.array([[10, 10, 20, 20]]), "XYXY"), + Boxes(np.array([[15, 15, 20, 20], [20, 20, 10, 10]]), "XYWH"), + np.array([[1 / 16, 0]], dtype="float32"), + ) +] -@pytest.mark.parametrize('boxes1, boxes2, expected', test_pairwise_ioa_param) +@pytest.mark.parametrize("boxes1, boxes2, expected", test_pairwise_ioa_param) def test_pairwise_ioa(boxes1, boxes2, expected): assert (pairwise_ioa(boxes1, boxes2) == expected).all() @@ -70,12 +89,12 @@ def test_pairwise_ioa(boxes1, boxes2, expected): ( Polygon(np.array([[0, 0], [0, 10], [10, 10], [10, 0]])), Polygon(np.array([[5, 5], [5, 15], [15, 15], [15, 5]])), - 25 / 175 + 25 / 175, ) ] -@pytest.mark.parametrize('poly1, poly2, expected', test_polygon_iou_param) +@pytest.mark.parametrize("poly1, poly2, expected", test_polygon_iou_param) def test_polygon_iou(poly1, poly2, expected): assert polygon_iou(poly1, poly2) == expected @@ -85,12 +104,14 @@ def test_polygon_iou(poly1, poly2, expected): np.array([[0, 0], [0, 10], [10, 10], [10, 0]]), np.array([[5, 5], [5, 15], [15, 15], [15, 5]]), (100, 100), - 25 / 175 + 25 / 175, ) ] -@pytest.mark.parametrize('pred_poly, gt_poly, img_size, expected', test_jaccard_index_param) +@pytest.mark.parametrize( + "pred_poly, gt_poly, img_size, expected", test_jaccard_index_param +) def test_jaccard_index(pred_poly, gt_poly, img_size, expected): assert jaccard_index(pred_poly, gt_poly, img_size) == expected @@ -101,19 +122,154 @@ def test_jaccard_index(pred_poly, gt_poly, img_size, expected): np.array([[5, 5], [5, 15], [15, 15], [15, 5]]), (100, 100), ValueError, - 'Input polygon must be 4-point polygon.' + "Input polygon must be 4-point polygon.", ), ( np.array([[0, 0], [0, 10], [10, 10], [10, 0]]), np.array([[5, 5], [5, 15], [15, 15], [15, 5], [5, 5]]), (100, 100), ValueError, - 'Input polygon must be 4-point polygon.' + "Input polygon must be 4-point polygon.", ), ] -@pytest.mark.parametrize('pred_poly, gt_poly, img_size, error, match', test_jaccard_index_error_param) +@pytest.mark.parametrize( + "pred_poly, gt_poly, img_size, error, match", test_jaccard_index_error_param +) def test_jaccard_index_error(pred_poly, gt_poly, img_size, error, match): with pytest.raises(error, match=match): jaccard_index(pred_poly, gt_poly, img_size) + + +def test_pairwise_intersection_rejects_non_boxes(): + with pytest.raises(TypeError, match="must be Boxes"): + pairwise_intersection([(0, 0, 1, 1)], [(0, 0, 1, 1)]) # type: ignore[arg-type] + + +def test_jaccard_index_requires_image_size(): + pred = np.zeros((4, 2), dtype=np.float32) + gt = np.zeros((4, 2), dtype=np.float32) + with pytest.raises(ValueError, match="image size"): + jaccard_index(pred, gt, None) # type: ignore[arg-type] + + +def test_jaccard_index_returns_zero_when_shapely_raises(monkeypatch): + import capybara.structures.functionals as fn_mod + + monkeypatch.setattr( + fn_mod, + "ShapelyPolygon", + lambda *_args, **_kwargs: (_ for _ in ()).throw(ValueError("boom")), + ) + pred = np.zeros((4, 2), dtype=np.float32) + gt = np.zeros((4, 2), dtype=np.float32) + assert jaccard_index(pred, gt, (100, 100)) == 0 + + +def test_jaccard_index_clamps_intersection_area_close_to_min(monkeypatch): + import capybara.structures.functionals as fn_mod + + monkeypatch.setattr( + fn_mod.cv2, "getPerspectiveTransform", lambda *_: np.eye(3) + ) + monkeypatch.setattr(fn_mod.cv2, "perspectiveTransform", lambda pts, _m: pts) + + class _FakePoly: + def __init__( + self, *, area: float, intersection_area: float | None = None + ) -> None: + self._area = area + self._intersection_area = intersection_area + + @property + def area(self) -> float: + return self._area + + def __and__(self, _other): + assert self._intersection_area is not None + return _FakePoly(area=self._intersection_area) + + target = _FakePoly(area=1.0, intersection_area=1.00000000005) + pred = _FakePoly(area=1.0) + factory = iter([target, pred]) + monkeypatch.setattr( + fn_mod, "ShapelyPolygon", lambda *_args, **_kwargs: next(factory) + ) + + pred_poly = np.zeros((4, 2), dtype=np.float32) + gt_poly = np.zeros((4, 2), dtype=np.float32) + assert jaccard_index(pred_poly, gt_poly, (10, 10)) == pytest.approx(1.0) + + +def test_polygon_iou_rejects_non_polygon_and_returns_zero_on_errors( + monkeypatch, +): + with pytest.raises(TypeError, match="must be Polygon"): + polygon_iou("bad", Polygon(np.zeros((4, 2), dtype=np.float32))) # type: ignore[arg-type] + + import capybara.structures.functionals as fn_mod + + class _FakePoly: + def __init__( + self, + *, + area: float, + intersection_area: float | None = None, + raise_intersection: bool = False, + ) -> None: + self._area = area + self._intersection_area = intersection_area + self._raise_intersection = raise_intersection + + @property + def area(self) -> float: + return self._area + + def intersection(self, _other): + if self._raise_intersection: + raise ValueError("boom") + assert self._intersection_area is not None + return _FakePoly(area=self._intersection_area) + + poly1_shape = _FakePoly(area=2.0, intersection_area=1.00000000005) + poly2_shape = _FakePoly(area=1.0) + factory = iter([poly1_shape, poly2_shape]) + monkeypatch.setattr( + fn_mod, "ShapelyPolygon", lambda *_args, **_kwargs: next(factory) + ) + + poly1 = Polygon(np.zeros((4, 2), dtype=np.float32)) + poly2 = Polygon(np.zeros((4, 2), dtype=np.float32)) + assert polygon_iou(poly1, poly2) == pytest.approx(0.5) + + poly1_shape_err = _FakePoly(area=1.0, raise_intersection=True) + poly2_shape_err = _FakePoly(area=1.0) + factory_err = iter([poly1_shape_err, poly2_shape_err]) + monkeypatch.setattr( + fn_mod, "ShapelyPolygon", lambda *_args, **_kwargs: next(factory_err) + ) + assert polygon_iou(poly1, poly2) == 0 + + +def test_is_inside_box_calc_angle_and_poly_angle(): + box = Box((0, 0, 10, 10), box_mode="XYXY") + assert is_inside_box(Keypoints([(1, 1), (9, 9)]), box) + assert not is_inside_box(Keypoints([(1, 1), (11, 9)]), box) + + assert calc_angle(np.array([0, 1]), np.array([-1, 0])) == pytest.approx( + 90.0 + ) + assert calc_angle(np.array([0, 1]), np.array([1, 0])) == pytest.approx( + 270.0 + ) + + poly1 = Polygon( + np.array([[0, 0], [0, 1], [1, 1], [1, 0]], dtype=np.float32) + ) + assert poly_angle(poly1) == pytest.approx(90.0) + + poly2 = Polygon( + np.array([[0, 0], [1, 0], [0, -1], [1, -1]], dtype=np.float32) + ) + assert poly_angle(poly1, poly2) == pytest.approx(270.0) diff --git a/tests/structures/test_keypoints.py b/tests/structures/test_keypoints.py index 20509c9..d967ee0 100644 --- a/tests/structures/test_keypoints.py +++ b/tests/structures/test_keypoints.py @@ -1,23 +1,29 @@ +from typing import Any + import numpy as np import pytest -from capybara import Box, Boxes, BoxMode, Keypoints, KeypointsList +from capybara import Keypoints, KeypointsList def test_invalid_input_type(): with pytest.raises(TypeError): - Keypoints("invalid_input") + invalid_input: Any = "invalid_input" + Keypoints(invalid_input) def test_invalid_input_shape(): with pytest.raises(ValueError): - Keypoints([(1, 2, 3, 4), (1, 2, 3, 4)]) + invalid_input: Any = [(1, 2, 3, 4), (1, 2, 3, 4)] + Keypoints(invalid_input) def test_keypoints_eat_itself(): keypoints1 = Keypoints([(1, 2), (3, 4)]) keypoints2 = Keypoints(keypoints1) - assert np.allclose(keypoints1.numpy(), keypoints2.numpy()), "Keypoints eat itself failed." + assert np.allclose(keypoints1.numpy(), keypoints2.numpy()), ( + "Keypoints eat itself failed." + ) def test_normalized_array(): @@ -29,91 +35,233 @@ def test_normalized_array(): def test_keypoints_numpy(): array = np.array([[1, 2], [3, 4]]) keypoints = Keypoints(array) - assert np.allclose(keypoints.numpy(), array), "Keypoints numpy conversion failed." + assert np.allclose(keypoints.numpy(), array), ( + "Keypoints numpy conversion failed." + ) def test_keypoints_copy(): keypoints = Keypoints([(1, 2), (3, 4)]) copied_keypoints = keypoints.copy() - assert np.allclose(keypoints.numpy(), copied_keypoints.numpy()), "Keypoints copy failed." + assert np.allclose(keypoints.numpy(), copied_keypoints.numpy()), ( + "Keypoints copy failed." + ) def test_keypoints_shift(): keypoints = Keypoints([(1, 2), (3, 4)]) shifted_keypoints = keypoints.shift(10, 10) - assert np.allclose(shifted_keypoints.numpy(), np.array( - [[11, 12], [13, 14]])), "Keypoints shift failed." + assert np.allclose( + shifted_keypoints.numpy(), np.array([[11, 12], [13, 14]]) + ), "Keypoints shift failed." def test_keypoints_scale(): keypoints = Keypoints([(1, 2), (3, 4)]) scaled_keypoints = keypoints.scale(10, 10) - assert np.allclose(scaled_keypoints.numpy(), np.array( - [[10, 20], [30, 40]])), "Keypoints scale failed." + assert np.allclose( + scaled_keypoints.numpy(), np.array([[10, 20], [30, 40]]) + ), "Keypoints scale failed." def test_keypoints_normalize(): keypoints = Keypoints([(1, 2), (3, 4)]) normalized_keypoints = keypoints.normalize(100, 100) - assert np.allclose(normalized_keypoints.numpy(), np.array( - [[0.01, 0.02], [0.03, 0.04]])), "Keypoints normalization failed." + assert np.allclose( + normalized_keypoints.numpy(), np.array([[0.01, 0.02], [0.03, 0.04]]) + ), "Keypoints normalization failed." def test_keypoints_denormalize(): keypoints = Keypoints([(0.01, 0.02), (0.03, 0.04)], is_normalized=True) denormalized_keypoints = keypoints.denormalize(100, 100) - assert np.allclose(denormalized_keypoints.numpy(), np.array( - [[1, 2], [3, 4]])), "Keypoints denormalization failed." + assert np.allclose( + denormalized_keypoints.numpy(), np.array([[1, 2], [3, 4]]) + ), "Keypoints denormalization failed." def test_keypoints_list_empty_input(): keypoints_list = KeypointsList([]) - np.testing.assert_allclose(keypoints_list.numpy(), np.array([], dtype='float32')) + np.testing.assert_allclose( + keypoints_list.numpy(), np.array([], dtype="float32") + ) def test_keypoints_list_numpy(): array = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) keypoints_list = KeypointsList(array) - assert np.allclose(keypoints_list.numpy(), array), "KeypointsList numpy conversion failed." + assert np.allclose(keypoints_list.numpy(), array), ( + "KeypointsList numpy conversion failed." + ) def test_keypoints_list_copy(): keypoints_list = KeypointsList([[(1, 2), (3, 4)], [(5, 6), (7, 8)]]) copied_keypoints_list = keypoints_list.copy() - assert np.allclose(keypoints_list.numpy(), copied_keypoints_list.numpy()), "KeypointsList copy failed." + assert np.allclose(keypoints_list.numpy(), copied_keypoints_list.numpy()), ( + "KeypointsList copy failed." + ) def test_keypoints_list_shift(): keypoints_list = KeypointsList([[(1, 2), (3, 4)], [(5, 6), (7, 8)]]) shifted_keypoints_list = keypoints_list.shift(10, 10) - assert np.allclose(shifted_keypoints_list.numpy(), np.array( - [[[11, 12], [13, 14]], [[15, 16], [17, 18]]])), "KeypointsList shift failed." + assert np.allclose( + shifted_keypoints_list.numpy(), + np.array([[[11, 12], [13, 14]], [[15, 16], [17, 18]]]), + ), "KeypointsList shift failed." def test_keypoints_list_scale(): keypoints_list = KeypointsList([[(1, 2), (3, 4)], [(5, 6), (7, 8)]]) scaled_keypoints_list = keypoints_list.scale(10, 10) - assert np.allclose(scaled_keypoints_list.numpy(), np.array( - [[[10, 20], [30, 40]], [[50, 60], [70, 80]]])), "KeypointsList scale failed." + assert np.allclose( + scaled_keypoints_list.numpy(), + np.array([[[10, 20], [30, 40]], [[50, 60], [70, 80]]]), + ), "KeypointsList scale failed." def test_keypoints_list_normalize(): keypoints_list = KeypointsList([[(1, 2), (3, 4)], [(5, 6), (7, 8)]]) normalized_keypoints_list = keypoints_list.normalize(100, 100) - assert np.allclose(normalized_keypoints_list.numpy(), np.array( - [[[0.01, 0.02], [0.03, 0.04]], [[0.05, 0.06], [0.07, 0.08]]])), "KeypointsList normalization failed." + assert np.allclose( + normalized_keypoints_list.numpy(), + np.array([[[0.01, 0.02], [0.03, 0.04]], [[0.05, 0.06], [0.07, 0.08]]]), + ), "KeypointsList normalization failed." def test_keypoints_list_denormalize(): - keypoints_list = KeypointsList([[(0.01, 0.02), (0.03, 0.04)], [(0.05, 0.06), (0.07, 0.08)]], is_normalized=True) + keypoints_list = KeypointsList( + [[(0.01, 0.02), (0.03, 0.04)], [(0.05, 0.06), (0.07, 0.08)]], + is_normalized=True, + ) denormalized_keypoints_list = keypoints_list.denormalize(100, 100) - assert np.allclose(denormalized_keypoints_list.numpy(), np.array( - [[[1, 2], [3, 4]], [[5, 6], [7, 8]]])), "KeypointsList denormalization failed." + assert np.allclose( + denormalized_keypoints_list.numpy(), + np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]), + ), "KeypointsList denormalization failed." def test_keypoints_list_cat(): keypoints_list1 = KeypointsList([[(1, 2), (3, 4)]]) keypoints_list2 = KeypointsList([[(5, 6), (7, 8)]]) cat_keypoints_list = KeypointsList.cat([keypoints_list1, keypoints_list2]) - assert np.allclose(cat_keypoints_list.numpy(), np.array( - [[[1, 2], [3, 4]], [[5, 6], [7, 8]]])), "KeypointsList concatenation failed." + assert np.allclose( + cat_keypoints_list.numpy(), + np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]), + ), "KeypointsList concatenation failed." + + +def test_keypoints_repr_eq_and_validation_and_warnings(): + kpts = Keypoints([(1, 2), (3, 4)]) + assert "Keypoints(" in repr(kpts) + assert (kpts == object()) is False + + with pytest.raises(ValueError, match="ndim"): + Keypoints(np.zeros((2, 2, 2), dtype=np.float32)) + + with pytest.raises(ValueError, match="labels"): + Keypoints(np.array([[1, 2, 3], [4, 5, -1]], dtype=np.float32)) + + with pytest.warns(UserWarning, match="forced to do normalization"): + Keypoints([(0.1, 0.2)], is_normalized=True).normalize(10, 10) + + with pytest.warns(UserWarning, match="forced to do denormalization"): + Keypoints([(1, 2)], is_normalized=False).denormalize(10, 10) + + +def test_keypoints_point_colors_can_be_updated(): + kpts = Keypoints([(1, 2), (3, 4)]) + colors = kpts.point_colors + assert len(colors) == 2 + assert all(isinstance(c, tuple) and len(c) == 3 for c in colors) + + kpts.point_colors = "viridis" + colors2 = kpts.point_colors + assert len(colors2) == 2 + + +def test_keypoints_list_getitem_setitem_repr_eq_and_point_colors(): + kpts_list = KeypointsList([[(1, 2), (3, 4)], [(5, 6), (7, 8)]]) + assert isinstance(kpts_list[0], Keypoints) + assert isinstance(kpts_list[:1], KeypointsList) + + with pytest.raises(TypeError, match="not a keypoint"): + kpts_list[0] = "bad" # type: ignore[assignment] + + kpts_list[0] = Keypoints([(9, 10), (11, 12)]) + np.testing.assert_allclose( + kpts_list[0].numpy(), np.array([[9, 10], [11, 12]], dtype=np.float32) + ) + + assert "KeypointsList(" in repr(kpts_list) + assert (kpts_list == object()) is False + + assert len(kpts_list.point_colors) == 2 + kpts_list.point_colors = "viridis" + assert len(kpts_list.point_colors) == 2 + + +def test_keypoints_list_validation_errors_and_warnings_and_empty_point_colors(): + kpts_list = KeypointsList([[(1, 2), (3, 4)]]) + kpts_list2 = KeypointsList(kpts_list) + np.testing.assert_allclose(kpts_list2.numpy(), kpts_list.numpy()) + + with pytest.raises(ValueError, match="ndim"): + KeypointsList(np.zeros((2, 2), dtype=np.float32)) + + with pytest.raises(ValueError, match="shape\\[-1\\]"): + KeypointsList(np.zeros((1, 2, 4), dtype=np.float32)) + + with pytest.raises(ValueError, match="labels"): + KeypointsList(np.array([[[1, 2, 3], [4, 5, 9]]], dtype=np.float32)) + + with pytest.warns(UserWarning, match="keypoints_list"): + KeypointsList([[(0.1, 0.2), (0.3, 0.4)]], is_normalized=True).normalize( + 10, 10 + ) + + with pytest.warns(UserWarning, match="forced to do denormalization"): + KeypointsList([[(1, 2), (3, 4)]], is_normalized=False).denormalize( + 10, 10 + ) + + assert KeypointsList([]).point_colors == [] + + +def test_keypoints_list_cat_validation_errors(): + with pytest.raises(TypeError, match="should be a list"): + KeypointsList.cat("not-a-list") # type: ignore[arg-type] + + with pytest.raises(ValueError, match="is empty"): + KeypointsList.cat([]) + + with pytest.raises(TypeError, match="must be KeypointsList"): + KeypointsList.cat([KeypointsList([]), "bad"]) # type: ignore[list-item] + + +def test_keypoints_colormap_falls_back_without_matplotlib(monkeypatch): + import builtins + + real_import = builtins.__import__ + + def fake_import(name, *args, **kwargs): + if name == "matplotlib": + raise ModuleNotFoundError("boom") + return real_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", fake_import) + + kpts = Keypoints([(0, 0), (1, 1)]) + assert len(kpts.point_colors) == 2 + + +def test_keypoints_list_rejects_invalid_types(): + with pytest.raises(TypeError, match="Input array is not"): + KeypointsList(123) # type: ignore[arg-type] + + +def test_keypoints_list_set_point_colors_noops_when_empty(): + keypoints_list = KeypointsList([]) + keypoints_list.set_point_colors("rainbow") + assert keypoints_list.point_colors == [] diff --git a/tests/structures/test_polygons.py b/tests/structures/test_polygons.py index c4ddb89..0f9b325 100644 --- a/tests/structures/test_polygons.py +++ b/tests/structures/test_polygons.py @@ -1,3 +1,5 @@ +from typing import Any, cast + import numpy as np import pytest @@ -13,45 +15,49 @@ def assert_almost_equal(actual, expected, tolerance=1e-5): tl, bl, br, tr = (0, 0), (0, 1), (1, 1), (1, 0) -POINTS_SET = np.array([(tl, bl, br, tr), - (tl, bl, tr, br), - (tl, br, bl, tr), - (tl, br, tr, bl), - (tl, tr, bl, br), - (tl, tr, br, bl), - (bl, tl, br, tr), - (bl, tl, tr, br), - (bl, br, tl, tr), - (bl, br, tr, tl), - (bl, tr, tl, br), - (bl, tr, br, tl), - (br, tl, bl, tr), - (br, tl, tr, bl), - (br, bl, tl, tr), - (br, bl, tr, tl), - (br, tr, tl, bl), - (br, tr, bl, tl), - (tr, tl, bl, br), - (tr, tl, br, bl), - (tr, bl, tl, br), - (tr, bl, br, tl), - (tr, br, tl, bl), - (tr, br, bl, tl)]) +POINTS_SET = np.array( + [ + (tl, bl, br, tr), + (tl, bl, tr, br), + (tl, br, bl, tr), + (tl, br, tr, bl), + (tl, tr, bl, br), + (tl, tr, br, bl), + (bl, tl, br, tr), + (bl, tl, tr, br), + (bl, br, tl, tr), + (bl, br, tr, tl), + (bl, tr, tl, br), + (bl, tr, br, tl), + (br, tl, bl, tr), + (br, tl, tr, bl), + (br, bl, tl, tr), + (br, bl, tr, tl), + (br, tr, tl, bl), + (br, tr, bl, tl), + (tr, tl, bl, br), + (tr, tl, br, bl), + (tr, bl, tl, br), + (tr, bl, br, tl), + (tr, br, tl, bl), + (tr, br, bl, tl), + ] +) CLOCKWISE_PTS = np.array([tl, tr, br, bl]) COUNTER_CLOCKWISE_PTS = np.array([tl, bl, br, tr]) -@pytest.mark.parametrize("pts, expected", [ - (pts, CLOCKWISE_PTS) for pts in POINTS_SET -]) +@pytest.mark.parametrize( + "pts, expected", [(pts, CLOCKWISE_PTS) for pts in POINTS_SET] +) def test_order_points_clockwise(pts, expected): ordered_pts = order_points_clockwise(pts) np.testing.assert_allclose(ordered_pts, expected) -@pytest.mark.parametrize("pts, expected", [ - (pts, COUNTER_CLOCKWISE_PTS) for pts in POINTS_SET -]) +@pytest.mark.parametrize( + "pts, expected", [(pts, COUNTER_CLOCKWISE_PTS) for pts in POINTS_SET] +) def test_order_points_counter_clockwise(pts, expected): ordered_pts = order_points_clockwise(pts, inverse=True) np.testing.assert_allclose(ordered_pts, expected) @@ -80,7 +86,7 @@ def test_polygon_repr(): # Test __repr__ method array = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) poly = Polygon(array) - assert repr(poly) == f"Polygon({str(array)})" + assert repr(poly) == f"Polygon({array!s})" def test_polygon_len(): @@ -103,14 +109,14 @@ def test_polygon_normalized(): # Test initialization with is_normalized=True array = np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]) poly = Polygon(array, is_normalized=True) - assert_array_equal(poly._array, array.astype('float32')) + assert_array_equal(poly._array, array.astype("float32")) assert poly.is_normalized def test_polygon_invalid_array(): # Test initialization with invalid arrays with pytest.raises(TypeError): - invalid_array = "invalid" # Invalid type + invalid_array: Any = "invalid" # Invalid type Polygon(invalid_array) with pytest.raises(TypeError): @@ -118,15 +124,15 @@ def test_polygon_invalid_array(): Polygon(invalid_array) with pytest.raises(TypeError): - invalid_array = [1, 2, 3] # Invalid type + invalid_array: Any = [1, 2, 3] # Invalid type Polygon(invalid_array) with pytest.raises(TypeError): - invalid_array = [1, 2, 3, 4] # Invalid type + invalid_array: Any = [1, 2, 3, 4] # Invalid type Polygon(invalid_array) with pytest.raises(TypeError): - invalid_array = Polygon() # Invalid type (empty Polygon instance) + invalid_array = cast(Any, Polygon)() # Invalid type (missing args) Polygon(invalid_array) @@ -152,7 +158,8 @@ def test_polygon_normalize(): poly = Polygon(array) normalized_poly = poly.normalize(100.0, 200.0) expected_normalized_array = np.array( - [[0.1, 0.1], [0.3, 0.2], [0.5, 0.3]]).astype('float32') + [[0.1, 0.1], [0.3, 0.2], [0.5, 0.3]] + ).astype("float32") assert_array_equal(normalized_poly._array, expected_normalized_array) assert normalized_poly.is_normalized @@ -163,9 +170,11 @@ def test_polygon_denormalize(): poly = Polygon(normalized_array, is_normalized=True) denormalized_poly = poly.denormalize(100.0, 200.0) expected_denormalized_array = np.array( - [[10.0, 20.0], [30.0, 40.0], [50.0, 60.0]]).astype('float32') + [[10.0, 20.0], [30.0, 40.0], [50.0, 60.0]] + ).astype("float32") np.testing.assert_allclose( - denormalized_poly._array, expected_denormalized_array) + denormalized_poly._array, expected_denormalized_array + ) assert not denormalized_poly.is_normalized @@ -173,7 +182,8 @@ def test_polygon_denormalize_non_normalized(): # Test denormalize method for non-is_normalized Polygon array = np.array([[10.0, 20.0], [30.0, 40.0], [50.0, 60.0]]) poly = Polygon(array) - denormalized_poly = poly.denormalize(100.0, 200.0) + with pytest.warns(UserWarning, match="Non-normalized polygon"): + denormalized_poly = poly.denormalize(100.0, 200.0) assert not np.array_equal(denormalized_poly._array, array) assert not denormalized_poly.is_normalized @@ -191,7 +201,8 @@ def test_polygon_clip(): # Test clipping outside the range clipped_poly = poly.clip(10, 20, 30, 40) expected_clipped_array = np.array( - [[10.0, 20.0], [10.0, 20.0], [10.0, 20.0]]) + [[10.0, 20.0], [10.0, 20.0], [10.0, 20.0]] + ) assert_array_equal(clipped_poly._array, expected_clipped_array) @@ -237,7 +248,7 @@ def test_polygon_scale_empty(): with pytest.raises(ValueError): array = np.array([]) poly = Polygon(array) - scaled_poly = poly.scale(1) + poly.scale(1) def test_polygon_to_convexhull(): @@ -268,7 +279,7 @@ def test_polygon_to_box(): poly = Polygon(array) # Test bounding box of the polygon in "xyxy" format - box = poly.to_box(box_mode='xyxy') + box = poly.to_box(box_mode="xyxy") expected_box = [10, 7, 23, 23] assert box.tolist() == expected_box @@ -323,27 +334,22 @@ def test_polygon_is_empty_with_threshold(): assert non_empty_poly.is_empty(threshold=3) is True -test_props_input = np.array([ - [5, 0], - [10, 5], - [5, 10], - [0, 5] -]) +test_props_input = np.array([[5, 0], [10, 5], [5, 10], [0, 5]]) test_props_params = [ ( "moments", { - 'm00': 50.0, - 'm10': 250.0, - 'm01': 250.0, - 'm20': 1458.3333333333333, - 'm11': 1250.0, - 'm02': 1458.3333333333333, - 'm30': 9375.0, - 'm21': 7291.666666666667, - 'm12': 7291.666666666667, - 'm03': 9375.0, + "m00": 50.0, + "m10": 250.0, + "m01": 250.0, + "m20": 1458.3333333333333, + "m11": 1250.0, + "m02": 1458.3333333333333, + "m30": 9375.0, + "m21": 7291.666666666667, + "m12": 7291.666666666667, + "m03": 9375.0, }, ), ("area", 50), @@ -359,7 +365,7 @@ def test_polygon_is_empty_with_threshold(): ] -@ pytest.mark.parametrize('prop, expected', test_props_params) +@pytest.mark.parametrize("prop, expected", test_props_params) def test_polygon_property(prop, expected): value = getattr(Polygon(test_props_input), prop) if isinstance(value, (int, float)): @@ -368,7 +374,7 @@ def test_polygon_property(prop, expected): for k, v in expected.items(): np.testing.assert_allclose(value[k], v, rtol=1e-4) elif isinstance(value, (list, tuple)): - for v, e in zip(value, expected): + for v, e in zip(value, expected, strict=True): np.testing.assert_allclose(v, e, rtol=1e-4) @@ -396,7 +402,7 @@ def test_polygons_init(): # Test initialization with invalid input with pytest.raises(TypeError): - invalid_input = "invalid" + invalid_input: Any = "invalid" polygons = Polygons(invalid_input) @@ -441,7 +447,7 @@ def test_polygons_getitem(): # Test invalid input with pytest.raises(TypeError): - invalid_input = 1.5 + invalid_input: Any = 1.5 polygons[invalid_input] @@ -465,10 +471,14 @@ def test_polygons_to_min_boxpoints(): polygons = Polygons(polygons_list) min_boxpoints_polygons = polygons.to_min_boxpoints() assert len(min_boxpoints_polygons) == 2 - assert_array_equal(min_boxpoints_polygons[0]._array, np.array( - [[5, 0], [10, 5], [5, 10], [0, 5]])) - assert_array_equal(min_boxpoints_polygons[1]._array, np.array( - [[50, 0], [100, 50], [50, 100], [0, 50]])) + assert_array_equal( + min_boxpoints_polygons[0]._array, + np.array([[5, 0], [10, 5], [5, 10], [0, 5]]), + ) + assert_array_equal( + min_boxpoints_polygons[1]._array, + np.array([[50, 0], [100, 50], [50, 100], [0, 50]]), + ) def test_polygons_to_convexhull(): @@ -480,10 +490,14 @@ def test_polygons_to_convexhull(): convexhull_polygons = polygons.to_convexhull() assert len(convexhull_polygons) == 2 - assert_array_equal(convexhull_polygons[0]._array, np.array( - [[5, 0], [10, 5], [5, 10], [0, 5]])) - assert_array_equal(convexhull_polygons[1]._array, np.array( - [[50, 0], [100, 50], [50, 100], [0, 50]])) + assert_array_equal( + convexhull_polygons[0]._array, + np.array([[5, 0], [10, 5], [5, 10], [0, 5]]), + ) + assert_array_equal( + convexhull_polygons[1]._array, + np.array([[50, 0], [100, 50], [50, 100], [0, 50]]), + ) def test_polygons_to_boxes(): @@ -536,11 +550,17 @@ def test_polygons_normalize(): w, h = 10.0, 10.0 normalized_polygons = polygons.normalize(w, h) assert len(normalized_polygons) == 2 - assert_array_equal(normalized_polygons[0]._array, np.array( - [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]], dtype='float32')) - assert_array_equal(normalized_polygons[1]._array, np.array( - [[0.7, 0.8], [0.9, 1.0]], dtype='float32')) - assert normalized_polygons.is_normalized # Check if the is_normalized flag is True + assert_array_equal( + normalized_polygons[0]._array, + np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]], dtype="float32"), + ) + assert_array_equal( + normalized_polygons[1]._array, + np.array([[0.7, 0.8], [0.9, 1.0]], dtype="float32"), + ) + assert ( + normalized_polygons.is_normalized + ) # Check if the is_normalized flag is True def test_polygons_denormalize(): @@ -553,10 +573,13 @@ def test_polygons_denormalize(): w, h = 10.0, 10.0 denormalized_polygons = polygons.denormalize(w, h) assert len(denormalized_polygons) == 2 - assert_array_equal(denormalized_polygons[0]._array, np.array( - [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])) - assert_array_equal(denormalized_polygons[1]._array, np.array( - [[7.0, 8.0], [9.0, 10.0]])) + assert_array_equal( + denormalized_polygons[0]._array, + np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]), + ) + assert_array_equal( + denormalized_polygons[1]._array, np.array([[7.0, 8.0], [9.0, 10.0]]) + ) # Check if the is_normalized flag is False assert not denormalized_polygons.is_normalized @@ -571,10 +594,14 @@ def test_polygons_scale(): # Test scaling with distance=1 and default join_style (mitre) scaled_polygons = polygons.scale(1) assert len(scaled_polygons) == 2 - assert_array_equal(scaled_polygons[0]._array, np.array( - [[9, 9], [9, 21], [21, 21], [21, 9]])) - assert_array_equal(scaled_polygons[1]._array, np.array( - [[9, 9], [9, 21], [21, 21], [21, 9]])) + assert_array_equal( + scaled_polygons[0]._array, + np.array([[9, 9], [9, 21], [21, 21], [21, 9]]), + ) + assert_array_equal( + scaled_polygons[1]._array, + np.array([[9, 9], [9, 21], [21, 21], [21, 9]]), + ) # Check if no empty polygons after scaling assert not scaled_polygons.is_empty().any() @@ -587,7 +614,6 @@ def test_polygons_numpy(): polygons = Polygons(polygons_list) non_flattened_numpy_array = polygons.numpy(flatten=False) - expected_non_flattened_array = np.array([array1, array2], dtype=object) for i in range(len(polygons)): assert_array_equal(non_flattened_numpy_array[i], polygons[i]._array) @@ -601,7 +627,9 @@ def test_polygons_to_list(): flattened_list = polygons.to_list(flatten=True) expected_flattened_list = [ - [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [7.0, 8.0, 9.0, 10.0]] + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + [7.0, 8.0, 9.0, 10.0], + ] assert flattened_list == expected_flattened_list non_flattened_list = polygons.to_list(flatten=False) @@ -617,26 +645,32 @@ def test_polygons_to_list(): test_ploygons_props_params = [ ( - "moments", [ + "moments", + [ { - 'm00': 50.0, - 'm10': 250.0, - 'm01': 250.0, + "m00": 50.0, + "m10": 250.0, + "m01": 250.0, }, { - 'm00': 50.0, - 'm10': 250.0, - 'm01': 250.0, + "m00": 50.0, + "m10": 250.0, + "m01": 250.0, }, - ] + ], ), ("area", np.array([50, 50])), ("arclength", np.array([28.28427, 28.28427])), ("centroid", np.array([(5, 5), (5, 5)])), ("boundingbox", np.array([(0, 0, 10, 10), (0, 0, 10, 10)])), ("min_circle", [((5.0, 5.0), 5.0), ((5.0, 5.0), 5.0)]), - ("min_box", [((5.0, 5.0), (7.07106, 7.07106), 45.0), - ((5.0, 5.0), (7.07106, 7.07106), 45.0)]), + ( + "min_box", + [ + ((5.0, 5.0), (7.07106, 7.07106), 45.0), + ((5.0, 5.0), (7.07106, 7.07106), 45.0), + ], + ), ("orientation", np.array([45, 45])), ("min_box_wh", np.array([(7.07106, 7.07106), (7.07106, 7.07106)])), ("extent", np.array([0.5, 0.5])), @@ -644,19 +678,199 @@ def test_polygons_to_list(): ] -@ pytest.mark.parametrize('prop, expected', test_ploygons_props_params) +@pytest.mark.parametrize("prop, expected", test_ploygons_props_params) def test_polygons_property(prop, expected): value = getattr(Polygons(test_ploygons_props_input), prop) if isinstance(value, (int, float, np.ndarray)): np.testing.assert_allclose(value, expected, rtol=1e-4) elif isinstance(value, list): - for v, e in zip(value, expected): + for v, e in zip(value, expected, strict=True): if isinstance(v, (list, tuple)): - for vv, ee in zip(v, e): + for vv, ee in zip(v, e, strict=True): np.testing.assert_allclose( - np.array(vv), np.array(ee), rtol=1e-4) + np.array(vv), np.array(ee), rtol=1e-4 + ) elif isinstance(v, dict): for key, val in e.items(): assert v[key] == val else: np.testing.assert_allclose(np.array(v), np.array(e), rtol=1e-4) + + +def test_order_points_clockwise_rejects_invalid_shape(): + with pytest.raises(ValueError, match=r"shape \(4, 2\)"): + order_points_clockwise(np.zeros((3, 2), dtype=np.float32)) + + +def test_polygon_accepts_nx1x2_contour_array_and_eq_non_polygon(): + contour = np.array([[[1.0, 2.0]], [[3.0, 4.0]]], dtype=np.float32) + poly = Polygon(contour) + assert poly.numpy().shape == (2, 2) + assert (poly == object()) is False + + +def test_polygon_warns_on_double_normalize_and_clip_rejects_nan(): + poly = Polygon( + np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]), + is_normalized=True, + ) + with pytest.warns(UserWarning, match="forced to do normalization"): + _ = poly.normalize(10.0, 10.0) + + with pytest.raises(ValueError, match="infinite or NaN"): + Polygon(np.array([[np.nan, 0.0], [1.0, 1.0], [2.0, 2.0]])).clip( + 0, 0, 10, 10 + ) + + +def test_polygon_is_empty_validates_threshold_type(): + poly = Polygon(np.array([[0.0, 0.0], [1.0, 0.0], [1.0, 1.0]])) + with pytest.raises(TypeError, match='expected "int"'): + poly.is_empty(threshold="3") # type: ignore[arg-type] + + +def test_polygons_init_from_polygons_repr_and_eq_edge_cases(): + polygons = Polygons( + [ + np.array([[0.0, 0.0], [1.0, 0.0], [1.0, 1.0]]), + np.array([[0.0, 0.0], [0.0, 2.0], [2.0, 2.0]]), + ] + ) + polygons2 = Polygons(polygons) + assert polygons2 == polygons + assert "Polygons(" in repr(polygons2) + + assert (polygons == object()) is False + + polygons_small = Polygons([np.array([[0.0, 0.0], [1.0, 0.0], [1.0, 1.0]])]) + assert (polygons_small == polygons) is False + + +def test_polygons_warns_on_double_normalize_denormalize_and_supports_clip_shift_tolist(): + poly = Polygon( + np.array([[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]), + is_normalized=True, + ) + polygons = Polygons([poly], is_normalized=True) + with pytest.warns(UserWarning, match="forced to do normalization"): + _ = polygons.normalize(10.0, 10.0) + + with pytest.warns(UserWarning, match="forced to do denormalization"): + _ = Polygons([poly], is_normalized=False).denormalize(10.0, 10.0) + + clipped = polygons.clip(0, 0, 1, 1) + shifted = polygons.shift(1.0, -1.0) + assert isinstance(clipped, Polygons) + assert isinstance(shifted, Polygons) + assert clipped.tolist() == clipped.to_list() + + +def test_polygons_from_image_validates_type_and_filters_short_contours( + monkeypatch, +): + import capybara.structures.polygons as poly_mod + + with pytest.raises(TypeError, match=r"np\.ndarray"): + Polygons.from_image("not-an-array") # type: ignore[arg-type] + + def fake_find_contours(image, *, mode, method): + assert isinstance(image, np.ndarray) + assert isinstance(mode, int) + assert isinstance(method, int) + return ( + [ + np.zeros((1, 1, 2), dtype=np.int32), + np.array([[[1, 2]], [[3, 4]]], dtype=np.int32), + ], + None, + ) + + monkeypatch.setattr(poly_mod.cv2, "findContours", fake_find_contours) + polys = Polygons.from_image(np.zeros((10, 10), dtype=np.uint8)) + assert len(polys) == 1 + assert polys[0].numpy().shape == (2, 2) + + +def test_polygons_cat_validation_and_happy_path(): + with pytest.raises(TypeError, match="should be a list"): + Polygons.cat("bad") # type: ignore[arg-type] + + with pytest.raises(ValueError, match="is empty"): + Polygons.cat([]) + + with pytest.raises(TypeError, match="must be Polygon"): + Polygons.cat([Polygons([]), "bad"]) # type: ignore[list-item] + + polys1 = Polygons([np.array([[0.0, 0.0], [1.0, 0.0], [1.0, 1.0]])]) + polys2 = Polygons([np.array([[0.0, 0.0], [0.0, 2.0], [2.0, 2.0]])]) + cat = Polygons.cat([polys1, polys2]) + assert len(cat) == 2 + + +def test_polygon_scale_handles_multipolygon_and_empty_exterior(monkeypatch): + import capybara.structures.polygons as poly_mod + + class _FakeExterior: + def __init__(self, *, xy, is_empty: bool) -> None: + self.xy = xy + self.is_empty = is_empty + + class _FakeMultiPolygon: + def __init__(self, geoms) -> None: + self.geoms = geoms + + class _FakeShapelyPolygon: + def __init__( + self, + arr, + *, + area: float = 0.0, + exterior_empty: bool = False, + xy=None, + ) -> None: + self._area = area + self.exterior = _FakeExterior( + xy=xy or ([], []), is_empty=exterior_empty + ) + self._arr = np.array(arr, dtype=np.float32) + + @property + def area(self) -> float: + return self._area + + def buffer(self, *_args, **_kwargs): + p1 = _FakeShapelyPolygon( + self._arr, + area=1.0, + exterior_empty=False, + xy=([0.0, 0.0, 1.0, 1.0], [0.0, 1.0, 1.0, 0.0]), + ) + p2 = _FakeShapelyPolygon( + self._arr, + area=2.0, + exterior_empty=False, + xy=([0.0, 0.0, 2.0, 2.0], [0.0, 2.0, 2.0, 0.0]), + ) + return _FakeMultiPolygon([p1, p2]) + + monkeypatch.setattr(poly_mod, "_Polygon_shapely", _FakeShapelyPolygon) + monkeypatch.setattr(poly_mod, "MultiPolygon", _FakeMultiPolygon) + + poly = Polygon(np.array([[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0]])) + scaled = poly.scale(1) + assert scaled.numpy().shape == poly.numpy().shape + + class _FakeShapelyPolygonEmptyExterior(_FakeShapelyPolygon): + def buffer(self, *_args, **_kwargs): + return _FakeShapelyPolygon( + self._arr, + area=1.0, + exterior_empty=True, + xy=([], []), + ) + + monkeypatch.setattr( + poly_mod, "_Polygon_shapely", _FakeShapelyPolygonEmptyExterior + ) + empty_scaled = poly.scale(1) + assert empty_scaled.is_empty() diff --git a/tests/test_hidden_bugs_regressions.py b/tests/test_hidden_bugs_regressions.py new file mode 100644 index 0000000..0374b1a --- /dev/null +++ b/tests/test_hidden_bugs_regressions.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +import os +import subprocess +import sys +from pathlib import Path + +import numpy as np +import pytest + + +def test_pyproject_onnxruntime_gpu_extra_uses_hyphenated_package_name(): + text = (Path(__file__).resolve().parents[1] / "pyproject.toml").read_text( + encoding="utf-8" + ) + assert "onnxruntime-gpu>=1.22.0,<2" in text + assert "onnxruntime_gpu>=1.22.0,<2" not in text + + +def test_get_files_does_not_return_directories_when_suffix_none(tmp_path): + from capybara.utils.files_utils import get_files + + (tmp_path / "a.txt").write_text("a", encoding="utf-8") + (tmp_path / "sub").mkdir() + (tmp_path / "sub" / "b.txt").write_text("b", encoding="utf-8") + + paths = get_files(tmp_path, suffix=None, recursive=True, sort_path=False) + assert paths + assert all(Path(p).is_file() for p in paths) + + +def test_get_onnx_infos_preserve_symbolic_dim_params(tmp_path): + import onnx + from onnx import TensorProto, helper + + from capybara.onnxengine import utils + + input_info = helper.make_tensor_value_info( + "input", TensorProto.FLOAT, ["batch", 3] + ) + output_info = helper.make_tensor_value_info( + "output", TensorProto.FLOAT, ["batch", 3] + ) + node = helper.make_node("Identity", ["input"], ["output"]) + graph = helper.make_graph([node], "sym-graph", [input_info], [output_info]) + model = helper.make_model(graph) + path = tmp_path / "sym.onnx" + onnx.save(model, path) + + inputs = utils.get_onnx_input_infos(path) + outputs = utils.get_onnx_output_infos(path) + + assert inputs["input"]["shape"][0] == "batch" + assert outputs["output"]["shape"][0] == "batch" + + +def test_pad_constant_value_fills_all_channels_for_rgba(): + from capybara.vision.functionals import pad + + img = np.zeros((10, 10, 4), dtype=np.uint8) + out = pad(img, pad_size=1, pad_value=128) + assert out.shape == (12, 12, 4) + assert out.dtype == np.uint8 + assert out[0, 0].tolist() == [128, 128, 128, 128] + + +def test_imrotate_constant_border_fills_all_channels_for_rgba(): + from capybara import BORDER + from capybara.vision.geometric import imrotate + + img = np.zeros((20, 20, 4), dtype=np.uint8) + out = imrotate( + img, + angle=45, + expand=False, + bordertype=BORDER.CONSTANT, + bordervalue=128, + ) + assert out.shape == img.shape + assert out[0, 0].tolist() == [128, 128, 128, 128] + + +def test_imrotate_preserves_dtype_for_float32_inputs(): + from capybara import BORDER + from capybara.vision.geometric import imrotate + + img = np.zeros((20, 20, 3), dtype=np.float32) + out = imrotate( + img, + angle=10, + expand=False, + bordertype=BORDER.CONSTANT, + bordervalue=0, + ) + assert out.dtype == np.float32 + + +def test_visualization_package_import_is_lazy(): + repo_root = Path(__file__).resolve().parents[1] + code = """ +import sys +sys.modules.pop("capybara.vision.visualization.draw", None) +sys.modules.pop("capybara.vision.visualization.utils", None) +sys.modules.pop("capybara.vision.visualization", None) +import capybara.vision.visualization as vis +assert "capybara.vision.visualization.draw" not in sys.modules +assert "capybara.vision.visualization.utils" not in sys.modules +""" + env = {**os.environ, "PYTHONPATH": str(repo_root)} + subprocess.run([sys.executable, "-c", code], env=env, check=True) + + +def test_draw_text_falls_back_when_font_files_missing(monkeypatch, tmp_path): + import capybara.vision.visualization.draw as draw_mod + + monkeypatch.setattr(draw_mod, "DEFAULT_FONT_PATH", tmp_path / "missing.ttf") + + img = np.full((40, 160, 3), 255, dtype=np.uint8) + out = draw_mod.draw_text( + img.copy(), + "hello", + location=(5, 5), + color=(0, 0, 255), + text_size=18, + font_path=tmp_path / "also_missing.ttf", + ) + assert out.shape == img.shape + assert not np.array_equal(out, img) + + +def test_draw_line_handles_zero_length_and_validates_gap(): + import capybara.vision.visualization.draw as draw_mod + + img = np.zeros((40, 40, 3), dtype=np.uint8) + out = draw_mod.draw_line( + img.copy(), + pt1=(10, 10), + pt2=(10, 10), + color=(0, 255, 0), + thickness=2, + style="line", + gap=8, + inplace=False, + ) + assert out.shape == img.shape + assert out.sum() > 0 + + with pytest.raises(ValueError, match="gap must be > 0"): + draw_mod.draw_line(img.copy(), (0, 0), (10, 10), gap=0) + + +def test_draw_mask_minmax_normalize_constant_mask_is_safe(): + import capybara.vision.visualization.draw as draw_mod + + img = np.zeros((20, 30, 3), dtype=np.uint8) + mask = np.full((20, 30), 7, dtype=np.uint8) + + with np.errstate(divide="raise", invalid="raise"): + out = draw_mod.draw_mask(img, mask, min_max_normalize=True) + assert out.shape == img.shape diff --git a/tests/test_mixins.py b/tests/test_mixins.py index 062055c..a89d1e1 100644 --- a/tests/test_mixins.py +++ b/tests/test_mixins.py @@ -1,82 +1,75 @@ from copy import deepcopy -from dataclasses import dataclass, field +from dataclasses import dataclass from enum import Enum -from typing import Any, Callable, Dict, List +from typing import Any, ClassVar import numpy as np import pytest import capybara as cb -from capybara import DataclassCopyMixin, DataclassToJsonMixin, EnumCheckMixin, dict_to_jsonable - -MockImage = np.zeros((5, 5, 3), dtype="uint8") -base64png_Image = "iVBORw0KGgoAAAANSUhEUgAAAAUAAAAFCAIAAAACDbGyAAAADElEQVQIHWNgoC4AAABQAAFhFZyBAAAAAElFTkSuQmCC" -base64npy_Image = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" +from capybara import ( + DataclassCopyMixin, + DataclassToJsonMixin, + EnumCheckMixin, + dict_to_jsonable, +) + +MOCK_IMAGE = np.zeros((5, 5, 3), dtype="uint8") +BASE64_PNG_IMAGE = "iVBORw0KGgoAAAANSUhEUgAAAAUAAAAFCAIAAAACDbGyAAAADElEQVQIHWNgoC4AAABQAAFhFZyBAAAAAElFTkSuQmCC" +BASE64_NPY_IMAGE = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" data = [ ( - dict( - box=cb.Box((0, 0, 1, 1)), - boxes=cb.Boxes([(0, 0, 1, 1)]), - keypoints=cb.Keypoints([(0, 1), (1, 0)]), - keypoints_list=cb.KeypointsList([[(0, 1), (1, 0)], [(0, 1), (2, 0)]]), - polygon=cb.Polygon([(0, 0), (1, 0), (1, 1)]), - polygons=cb.Polygons([[(0, 0), (1, 0), (1, 1)]]), - np_bool=np.bool_(True), - np_float=np.float64(1), - np_number=np.array(1), - np_array=np.array([1, 2]), - image=MockImage, - dict=dict(box=cb.Box((0, 0, 1, 1))), - str="test", - int=1, - float=0.6, - tuple=(1, 1), - pow=1e10, - ), - dict( - image=lambda x: cb.img_to_b64str(x, cb.IMGTYP.PNG), - ), - dict( - box=[0, 0, 1, 1], - boxes=[[0, 0, 1, 1]], - keypoints=[[0, 1], [1, 0]], - keypoints_list=[[[0, 1], [1, 0]], [[0, 1], [2, 0]]], - polygon=[[0, 0], [1, 0], [1, 1]], - polygons=[[[0, 0], [1, 0], [1, 1]]], - np_bool=True, - np_float=1.0, - np_number=1, - np_array=[1, 2], - image=base64png_Image, - dict=dict(box=[0, 0, 1, 1]), - str="test", - int=1, - float=0.6, - tuple=[1, 1], - pow=1e10, - ), + { + "box": cb.Box((0, 0, 1, 1)), + "boxes": cb.Boxes([(0, 0, 1, 1)]), + "keypoints": cb.Keypoints([(0, 1), (1, 0)]), + "keypoints_list": cb.KeypointsList( + [[(0, 1), (1, 0)], [(0, 1), (2, 0)]] + ), + "polygon": cb.Polygon([(0, 0), (1, 0), (1, 1)]), + "polygons": cb.Polygons([[(0, 0), (1, 0), (1, 1)]]), + "np_bool": np.bool_(True), + "np_float": np.float64(1), + "np_number": np.array(1), + "np_array": np.array([1, 2]), + "image": MOCK_IMAGE, + "dict": {"box": cb.Box((0, 0, 1, 1))}, + "str": "test", + "int": 1, + "float": 0.6, + "tuple": (1, 1), + "pow": 1e10, + }, + {"image": lambda x: cb.img_to_b64str(x, cb.IMGTYP.PNG)}, + { + "box": [0, 0, 1, 1], + "boxes": [[0, 0, 1, 1]], + "keypoints": [[0, 1], [1, 0]], + "keypoints_list": [[[0, 1], [1, 0]], [[0, 1], [2, 0]]], + "polygon": [[0, 0], [1, 0], [1, 1]], + "polygons": [[[0, 0], [1, 0], [1, 1]]], + "np_bool": True, + "np_float": 1.0, + "np_number": 1, + "np_array": [1, 2], + "image": BASE64_PNG_IMAGE, + "dict": {"box": [0, 0, 1, 1]}, + "str": "test", + "int": 1, + "float": 0.6, + "tuple": [1, 1], + "pow": 1e10, + }, ), ( - dict( - image=MockImage, - ), - dict( - image=lambda x: cb.npy_to_b64str(x), - ), - dict( - image=base64npy_Image, - ), + {"image": MOCK_IMAGE}, + {"image": lambda x: cb.npy_to_b64str(x)}, + {"image": BASE64_NPY_IMAGE}, ), ( - dict( - images=[dict(image=MockImage)], - ), - dict( - image=lambda x: cb.npy_to_b64str(x), - ), - dict( - images=[dict(image=base64npy_Image)], - ), + {"images": [{"image": MOCK_IMAGE}]}, + {"image": lambda x: cb.npy_to_b64str(x)}, + {"images": [{"image": BASE64_NPY_IMAGE}]}, ), ] @@ -87,6 +80,8 @@ def test_dict_to_jsonable(x, jsonable_func, expected): class TestEnum(EnumCheckMixin, Enum): + __test__ = False + FIRST = 1 SECOND = "two" @@ -112,8 +107,10 @@ def test_obj_to_enum_with_invalid_int(self): @dataclass class TestDataclass(DataclassCopyMixin): + __test__ = False + int_field: int - list_field: List[Any] + list_field: list[Any] class TestDataclassCopyMixin: @@ -131,12 +128,19 @@ def test_deep_copy(self, test_dataclass_instance): deepcopy_instance = deepcopy(test_dataclass_instance) assert deepcopy_instance is not test_dataclass_instance assert deepcopy_instance.int_field == test_dataclass_instance.int_field - assert deepcopy_instance.list_field is not test_dataclass_instance.list_field - assert deepcopy_instance.list_field == test_dataclass_instance.list_field + assert ( + deepcopy_instance.list_field + is not test_dataclass_instance.list_field + ) + assert ( + deepcopy_instance.list_field == test_dataclass_instance.list_field + ) @dataclass class TestDataclass2(DataclassToJsonMixin): + __test__ = False + box: cb.Box boxes: cb.Boxes keypoints: cb.Keypoints @@ -157,7 +161,7 @@ class TestDataclass2(DataclassToJsonMixin): py_pow: float # based on mixin - jsonable_func = { + jsonable_func: ClassVar[dict[str, Any]] = { "image": lambda x: cb.img_to_b64str(x, cb.IMGTYP.PNG), "np_array_to_b64str": lambda x: cb.npy_to_b64str(x), } @@ -174,16 +178,18 @@ def test_dataclass_instance(self): box=cb.Box((0, 0, 1, 1)), boxes=cb.Boxes([(0, 0, 1, 1)]), keypoints=cb.Keypoints([(0, 1), (1, 0)]), - keypoints_list=cb.KeypointsList([[(0, 1), (1, 0)], [(0, 1), (2, 0)]]), + keypoints_list=cb.KeypointsList( + [[(0, 1), (1, 0)], [(0, 1), (2, 0)]] + ), polygon=cb.Polygon([(0, 0), (1, 0), (1, 1)]), polygons=cb.Polygons([[(0, 0), (1, 0), (1, 1)]]), np_bool=np.bool_(True), np_float=np.float64(1), np_number=np.array(1), np_array=np.array([1, 2]), - image=MockImage, + image=MOCK_IMAGE, np_array_to_b64str=np_array_to_b64str, - py_dict=dict(box=cb.Box((0, 0, 1, 1))), + py_dict={"box": cb.Box((0, 0, 1, 1))}, py_str="test", py_int=1, py_float=0.6, @@ -194,26 +200,49 @@ def test_dataclass_instance(self): @pytest.fixture def test_expected(self): - return dict( - box=[0, 0, 1, 1], - boxes=[[0, 0, 1, 1]], - keypoints=[[0, 1], [1, 0]], - keypoints_list=[[[0, 1], [1, 0]], [[0, 1], [2, 0]]], - polygon=[[0, 0], [1, 0], [1, 1]], - polygons=[[[0, 0], [1, 0], [1, 1]]], - np_bool=True, - np_float=1.0, - np_number=1, - np_array=[1, 2], - image=base64png_Image, - np_array_to_b64str=np_array_to_b64str_b64, - py_dict=dict(box=[0, 0, 1, 1]), - py_str="test", - py_int=1, - py_float=0.6, - py_tuple=[1, 1], - py_pow=1e10, - ) + return { + "box": [0, 0, 1, 1], + "boxes": [[0, 0, 1, 1]], + "keypoints": [[0, 1], [1, 0]], + "keypoints_list": [[[0, 1], [1, 0]], [[0, 1], [2, 0]]], + "polygon": [[0, 0], [1, 0], [1, 1]], + "polygons": [[[0, 0], [1, 0], [1, 1]]], + "np_bool": True, + "np_float": 1.0, + "np_number": 1, + "np_array": [1, 2], + "image": BASE64_PNG_IMAGE, + "np_array_to_b64str": np_array_to_b64str_b64, + "py_dict": {"box": [0, 0, 1, 1]}, + "py_str": "test", + "py_int": 1, + "py_float": 0.6, + "py_tuple": [1, 1], + "py_pow": 1e10, + } def test_be_jsonable(self, test_dataclass_instance, test_expected): assert test_dataclass_instance.be_jsonable() == test_expected + + +def test_dict_to_jsonable_converts_enum_to_name(): + class LocalEnum(Enum): + A = 1 + + out = dict_to_jsonable({"e": LocalEnum.A}) + assert out["e"] == "A" + + +def test_dict_to_jsonable_warns_when_output_is_not_jsonable(): + with pytest.warns(UserWarning, match="not JSON serializable"): + out = dict_to_jsonable({"bad": {1, 2}}) + assert out["bad"] == {1, 2} + + +def test_dataclass_copy_mixin_requires_dataclass_instance(): + class NotADataclass(DataclassCopyMixin): + def __init__(self) -> None: + self.x = 1 + + with pytest.raises(TypeError, match="not a dataclass"): + NotADataclass().__copy__() diff --git a/tests/test_runtime.py b/tests/test_runtime.py new file mode 100644 index 0000000..0bd5637 --- /dev/null +++ b/tests/test_runtime.py @@ -0,0 +1,306 @@ +import builtins +import sys +import types +from typing import Any + +import pytest + +from capybara import runtime as runtime_module +from capybara.runtime import Backend, Runtime + + +def test_backend_from_any_respects_runtime_boundaries(): + cuda = Backend.from_any("cuda", runtime="onnx") + assert cuda.name == "cuda" + + with pytest.raises(ValueError): + Backend.from_any("cuda", runtime="openvino") + + +def test_runtime_from_any_accepts_runtime_instances(): + rt = Runtime.onnx + assert Runtime.from_any(rt) is rt + + +def test_runtime_normalize_backend_defaults(): + rt = Runtime.from_any("onnx") + backend = rt.normalize_backend(None) + + assert backend.name == rt.default_backend_name + assert [b.name for b in rt.available_backends()] == list(rt.backend_names) + + +def test_runtime_accepts_backend_instances(): + rt = Runtime.from_any("openvino") + backend = rt.normalize_backend(Backend.ov_gpu) + + assert backend.device == "GPU" + assert backend.runtime == rt.name + + +def test_auto_backend_prefers_tensorrt(monkeypatch): + rt = Runtime.from_any("onnx") + monkeypatch.setattr( + "capybara.runtime._get_available_onnx_providers", + lambda: { + "TensorrtExecutionProvider", + "CUDAExecutionProvider", + "CPUExecutionProvider", + }, + ) + + # When both TensorRT and CUDA providers are visible, prefer CUDA by default. + assert rt.auto_backend_name() == "cuda" + + +def test_auto_backend_prefers_rtx(monkeypatch): + rt = Runtime.from_any("onnx") + monkeypatch.setattr( + "capybara.runtime._get_available_onnx_providers", + lambda: {"NvTensorRTRTXExecutionProvider"}, + ) + + assert rt.auto_backend_name() == "tensorrt_rtx" + + +def test_auto_backend_falls_back_to_cpu(monkeypatch): + rt = Runtime.from_any("onnx") + monkeypatch.setattr( + "capybara.runtime._get_available_onnx_providers", + lambda: {"CPUExecutionProvider"}, + ) + + assert rt.auto_backend_name() == "cpu" + + +def test_auto_backend_pt_prefers_cuda(monkeypatch): + rt = Runtime.from_any("pt") + monkeypatch.setattr( + runtime_module, + "_get_torch_capabilities", + lambda: (True, True), + ) + assert rt.auto_backend_name() == "cuda" + + +def test_auto_backend_pt_defaults_to_cpu(monkeypatch): + rt = Runtime.from_any("pt") + monkeypatch.setattr( + runtime_module, + "_get_torch_capabilities", + lambda: (False, False), + ) + assert rt.auto_backend_name() == rt.default_backend_name + + +def test_auto_backend_openvino_uses_default_when_no_devices(monkeypatch): + rt = Runtime.from_any("openvino") + + monkeypatch.setattr( + runtime_module, + "_get_openvino_devices", + lambda: set(), + ) + + assert rt.auto_backend_name() == rt.default_backend_name + + +def test_auto_backend_openvino_prefers_gpu(monkeypatch): + rt = Runtime.from_any("openvino") + + monkeypatch.setattr( + runtime_module, + "_get_openvino_devices", + lambda: {"GPU.0", "CPU"}, + ) + + assert rt.auto_backend_name() == "gpu" + + +def test_auto_backend_openvino_prefers_npu(monkeypatch): + rt = Runtime.from_any("openvino") + + monkeypatch.setattr( + runtime_module, + "_get_openvino_devices", + lambda: {"NPU", "CPU"}, + ) + + assert rt.auto_backend_name() == "npu" + + +def test_auto_backend_returns_default_for_unknown_runtime(): + runtime_key = "custom_auto" + backend_name = "alpha" + Backend(name=backend_name, runtime_key=runtime_key) + try: + rt = Runtime( + name=runtime_key, + backend_names=(backend_name,), + default_backend_name=backend_name, + ) + assert rt.auto_backend_name() == rt.default_backend_name + finally: + Runtime._REGISTRY.pop(runtime_key, None) + namespace = Backend._REGISTRY.get(runtime_key, {}) + namespace.pop(backend_name, None) + if not namespace: + Backend._REGISTRY.pop(runtime_key, None) + + +def test_backend_registration_rejects_duplicates(): + runtime_key = "temp_runtime" + name = "temp_backend" + Backend(name=name, runtime_key=runtime_key) + try: + with pytest.raises(ValueError): + Backend(name=name, runtime_key=runtime_key) + finally: + namespace = Backend._REGISTRY.get(runtime_key, {}) + namespace.pop(name, None) + if not namespace: + Backend._REGISTRY.pop(runtime_key, None) + + +def test_backend_from_any_requires_runtime_when_many_registered(): + with pytest.raises(ValueError): + Backend.from_any("cpu") + + +def test_backend_instance_must_match_runtime(): + cuda = Backend.from_any("cuda", runtime="onnx") + with pytest.raises(ValueError): + Backend.from_any(cuda, runtime="openvino") + + +def test_runtime_duplicate_registration_rejected(): + with pytest.raises(ValueError): + Runtime( + name="onnx", + backend_names=("cpu",), + default_backend_name="cpu", + ) + + +def test_runtime_unknown_backend_reference(): + with pytest.raises(ValueError): + Runtime( + name="ghost", + backend_names=("missing",), + default_backend_name="missing", + ) + + +def test_runtime_default_backend_must_be_known(): + runtime_key = "custom_runtime" + backend_name = "alpha" + Backend(name=backend_name, runtime_key=runtime_key) + try: + with pytest.raises(ValueError): + Runtime( + name=runtime_key, + backend_names=(backend_name,), + default_backend_name="beta", + ) + finally: + namespace = Backend._REGISTRY.get(runtime_key, {}) + namespace.pop(backend_name, None) + if not namespace: + Backend._REGISTRY.pop(runtime_key, None) + + +def test_get_available_onnx_providers_handles_import_failure(monkeypatch): + real_import = builtins.__import__ + + def fake_import(name, *args, **kwargs): + if name == "onnxruntime": + raise ModuleNotFoundError("boom") + return real_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", fake_import) + + assert runtime_module._get_available_onnx_providers() == set() + + +def test_get_available_onnx_providers_reads_module(monkeypatch): + module = types.SimpleNamespace( + get_available_providers=lambda: [ + "CUDAExecutionProvider", + "CPUExecutionProvider", + ] + ) + monkeypatch.setitem(sys.modules, "onnxruntime", module) + try: + providers = runtime_module._get_available_onnx_providers() + assert providers == {"CUDAExecutionProvider", "CPUExecutionProvider"} + finally: + monkeypatch.delitem(sys.modules, "onnxruntime", raising=False) + + +def test_get_torch_capabilities_handles_import_failure(monkeypatch): + real_import = builtins.__import__ + + def fake_import(name, *args, **kwargs): + if name == "torch": + raise ModuleNotFoundError("boom") + return real_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", fake_import) + + assert runtime_module._get_torch_capabilities() == (False, False) + + +def test_get_torch_capabilities_reads_cuda_available(monkeypatch): + fake_torch: Any = types.ModuleType("torch") + + class FakeCuda: + @staticmethod + def is_available() -> bool: + return True + + fake_torch.cuda = FakeCuda + monkeypatch.setitem(sys.modules, "torch", fake_torch) + try: + assert runtime_module._get_torch_capabilities() == (True, True) + finally: + monkeypatch.delitem(sys.modules, "torch", raising=False) + + +def test_get_openvino_devices_handles_import_failure(monkeypatch): + real_import = builtins.__import__ + + def fake_import(name, *args, **kwargs): + if name.startswith("openvino"): + raise ModuleNotFoundError("boom") + return real_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", fake_import) + + assert runtime_module._get_openvino_devices() == set() + + +def test_get_openvino_devices_reads_core_devices(monkeypatch): + fake_ov_runtime: Any = types.ModuleType("openvino.runtime") + + class FakeCore: + available_devices = ("GPU.0", "CPU") + + fake_ov_runtime.Core = FakeCore + + fake_ov: Any = types.ModuleType("openvino") + fake_ov.runtime = fake_ov_runtime + + monkeypatch.setitem(sys.modules, "openvino", fake_ov) + monkeypatch.setitem(sys.modules, "openvino.runtime", fake_ov_runtime) + try: + assert runtime_module._get_openvino_devices() == {"GPU.0", "CPU"} + finally: + monkeypatch.delitem(sys.modules, "openvino.runtime", raising=False) + monkeypatch.delitem(sys.modules, "openvino", raising=False) + + +def test_backend_from_any_infers_runtime_when_single(monkeypatch): + monkeypatch.setattr(Backend, "_REGISTRY", {}) + Backend(name="alpha", runtime_key="solo") + backend = Backend.from_any("alpha") + assert backend.runtime == "solo" diff --git a/tests/test_torchengine.py b/tests/test_torchengine.py new file mode 100644 index 0000000..8472684 --- /dev/null +++ b/tests/test_torchengine.py @@ -0,0 +1,444 @@ +from __future__ import annotations + +import sys +import types +from types import SimpleNamespace +from typing import cast + +import numpy as np +import pytest + +from capybara.torchengine import TorchEngine, TorchEngineConfig +from capybara.torchengine import engine as engine_module + + +class _DummyContext: + def __enter__(self): + return None + + def __exit__(self, exc_type, exc, tb): + return False + + +class _FakeDevice: + def __init__(self, spec: str | _FakeDevice = "cpu"): + if isinstance(spec, _FakeDevice): + self.type = spec.type + self._spec = spec._spec + else: + spec_str = str(spec) + self._spec = spec_str + self.type = "cuda" if spec_str.startswith("cuda") else "cpu" + + def __str__(self) -> str: + return self._spec + + +class _FakeTensor: + def __init__(self, array: np.ndarray, dtype, device: _FakeDevice): + self._array = np.asarray(array, dtype=np.float32) + self.dtype = dtype + self.device = device + + def to(self, target): + if isinstance(target, _FakeDevice): + self.device = target + elif isinstance(target, str): + self.device = _FakeDevice(target) + else: + self.dtype = target + return self + + def detach(self): + return self + + def contiguous(self): + return self + + def numpy(self) -> np.ndarray: + return np.array(self._array, copy=True) + + +class _FakeTorchModel: + def __init__(self, torch_ref): + self._torch = torch_ref + self.dtype = torch_ref.float32 + self.device = _FakeDevice("cpu") + self.eval_called = False + self.calls = 0 + + def eval(self): + self.eval_called = True + return self + + def to(self, device=None, dtype=None): + if isinstance(device, _FakeDevice): + self.device = device + if dtype in (self._torch.float16, self._torch.float32): + self.dtype = dtype + return self + + def half(self): + self.dtype = self._torch.float16 + return self + + def float(self): + self.dtype = self._torch.float32 + return self + + def __call__(self, *tensors): + self.calls += 1 + outputs = [] + for idx in range(2): + arr = np.full( + (1, 2, 2, 2), fill_value=self.calls + idx, dtype=np.float32 + ) + outputs.append(_FakeTensor(arr, self.dtype, self.device)) + return tuple(outputs) + + +class _FakeTorchModule: + def __init__(self): + class _FakeDType: + pass + + self.dtype = _FakeDType + self.float16 = _FakeDType() + self.float32 = _FakeDType() + self.cuda = SimpleNamespace( + synchronize=lambda device: self._sync_calls.append(str(device)), + is_available=lambda: True, + ) + self._sync_calls: list[str] = [] + self._loaded_paths: list[str] = [] + self.jit = SimpleNamespace(load=self._load) + + def _load(self, path, map_location=None): + self._loaded_paths.append(str(path)) + model = _FakeTorchModel(self) + model.to(map_location) + return model + + def device(self, spec): + return _FakeDevice(spec) + + def no_grad(self): + return _DummyContext() + + def inference_mode(self): + return _DummyContext() + + def from_numpy(self, array): + return _FakeTensor( + np.asarray(array, dtype=np.float32), + self.float32, + _FakeDevice("cpu"), + ) + + def is_tensor(self, obj): + return isinstance(obj, _FakeTensor) + + +@pytest.fixture +def fake_torch(monkeypatch): + stub = _FakeTorchModule() + monkeypatch.setattr(engine_module, "_lazy_import_torch", lambda: stub) + return stub + + +def test_lazy_import_torch_prefers_sysmodules(monkeypatch): + fake = types.ModuleType("torch") + monkeypatch.setitem(sys.modules, "torch", fake) + assert engine_module._lazy_import_torch() is fake + + +def test_torch_engine_formats_outputs(fake_torch, tmp_path): + model_path = tmp_path / "fake.pt" + model_path.write_bytes(b"torchscript") + engine = TorchEngine( + model_path, + device="cuda:1", + output_names=("feat_s8", "feat_s16"), + ) + inputs = { + "image": np.zeros((1, 3, 4, 4), dtype=np.float32), + } + outputs = engine.run(inputs) + + assert set(outputs.keys()) == {"feat_s8", "feat_s16"} + for value in outputs.values(): + assert value.dtype == np.float32 + assert value.shape == (1, 2, 2, 2) + + # Ensure TorchEngine selected the CUDA device and casted dtype. + assert engine.device.type == "cuda" + assert engine.dtype == fake_torch.float32 + assert fake_torch._loaded_paths == [str(model_path)] + + +def test_torch_engine_call_accepts_kwargs_feed(fake_torch, tmp_path): + model_path = tmp_path / "demo.pt" + model_path.write_bytes(b"torchscript") + + engine = TorchEngine( + model_path, + device="cpu", + output_names=("feat_s8", "feat_s16"), + ) + outputs = engine(image=np.zeros((1, 3, 4, 4), dtype=np.float32)) + + assert set(outputs.keys()) == {"feat_s8", "feat_s16"} + + +def test_torch_engine_auto_dtype_selects_fp16_when_name_and_cuda( + fake_torch, tmp_path +): + """Auto dtype picks fp16 for '*fp16*' model names when running on CUDA.""" + model_path = tmp_path / "demo_fp16.pt" + model_path.write_bytes(b"torchscript") + + engine = TorchEngine(model_path, device="cuda:0") + + assert engine.device.type == "cuda" + assert engine.dtype == fake_torch.float16 + assert engine._model.dtype == fake_torch.float16 + + +def test_torch_engine_explicit_fp32_dtype(fake_torch, tmp_path): + model_path = tmp_path / "demo.pt" + model_path.write_bytes(b"torchscript") + + config = TorchEngineConfig(dtype="fp32") + engine = TorchEngine(model_path, device="cpu", config=config) + + assert engine.dtype == fake_torch.float32 + assert engine._model.dtype == fake_torch.float32 + + +def test_torch_engine_prepare_feed_requires_mapping(fake_torch, tmp_path): + """Run rejects non-mapping feeds to avoid accidental positional ordering bugs.""" + model_path = tmp_path / "demo.pt" + model_path.write_bytes(b"torchscript") + engine = TorchEngine(model_path, device="cpu") + + with pytest.raises(TypeError, match="feed must be a mapping"): + engine.run(["not", "a", "mapping"]) # type: ignore[arg-type] + + +def test_torch_engine_output_names_mismatch_raises(fake_torch, tmp_path): + """Output key schema must match model outputs to prevent silent mislabeling.""" + model_path = tmp_path / "demo.pt" + model_path.write_bytes(b"torchscript") + engine = TorchEngine(model_path, device="cpu", output_names=("only_one",)) + + with pytest.raises(ValueError, match="model produced 2 outputs"): + engine.run({"image": np.zeros((1, 3, 4, 4), dtype=np.float32)}) + + +def test_torch_engine_formats_mapping_and_tensor_outputs(fake_torch, tmp_path): + """TorchEngine supports dict outputs and single tensor outputs.""" + model_path = tmp_path / "demo.pt" + model_path.write_bytes(b"torchscript") + engine = TorchEngine(model_path, device="cuda:0") + + def _mapping_model(*_tensors): + return { + "feat": _FakeTensor( + np.ones((1, 2, 2, 2)), + fake_torch.float16, + cast(_FakeDevice, engine.device), + ) + } + + engine._model = _mapping_model # type: ignore[assignment] + out = engine.run({"image": np.zeros((1, 3, 4, 4), dtype=np.float32)}) + assert set(out) == {"feat"} + assert out["feat"].dtype == np.float32 + + def _tensor_model(*_tensors): + return _FakeTensor( + np.ones((1, 2, 2, 2)), + fake_torch.float16, + cast(_FakeDevice, engine.device), + ) + + engine._model = _tensor_model # type: ignore[assignment] + out = engine.run({"image": np.zeros((1, 3, 4, 4), dtype=np.float32)}) + assert set(out) == {"output"} + + +def test_torch_engine_benchmark_honors_cuda_sync_override(fake_torch, tmp_path): + """Benchmark uses synchronize() only when CUDA + cuda_sync is enabled.""" + model_path = tmp_path / "demo.pt" + model_path.write_bytes(b"torchscript") + engine = TorchEngine(model_path, device="cuda:1") + + stats = engine.benchmark( + {"image": np.zeros((1, 3, 4, 4), dtype=np.float32)}, + repeat=2, + warmup=1, + cuda_sync=True, + ) + + assert stats["repeat"] == 2 + assert stats["warmup"] == 1 + assert "latency_ms" in stats + assert len(fake_torch._sync_calls) == 5 + + +def test_torch_engine_benchmark_validates_repeat_and_warmup( + fake_torch, tmp_path +): + model_path = tmp_path / "demo.pt" + model_path.write_bytes(b"torchscript") + engine = TorchEngine(model_path, device="cpu") + feed = {"image": np.zeros((1, 3, 4, 4), dtype=np.float32)} + + with pytest.raises(ValueError, match="repeat must be >= 1"): + engine.benchmark(feed, repeat=0) + + with pytest.raises(ValueError, match="warmup must be >= 0"): + engine.benchmark(feed, warmup=-1) + + +def test_torch_engine_call_accepts_wrapped_mapping_and_generates_names( + fake_torch, tmp_path +): + """__call__ supports passing a mapping payload and auto-generating output names.""" + model_path = tmp_path / "demo.pt" + model_path.write_bytes(b"torchscript") + engine = TorchEngine(model_path, device="cpu") + + outputs = engine( + payload={"image": np.zeros((1, 3, 4, 4), dtype=np.float32)} + ) + assert set(outputs) == {"output_0", "output_1"} + + +def test_torch_engine_multiple_inputs_use_positional_forward( + fake_torch, tmp_path +): + """Multiple inputs are forwarded positionally into the TorchScript model.""" + model_path = tmp_path / "demo.pt" + model_path.write_bytes(b"torchscript") + engine = TorchEngine(model_path, device="cpu") + + outputs = engine.run( + { + "a": np.zeros((1, 3, 4, 4), dtype=np.float32), + "b": np.zeros((1, 3, 4, 4), dtype=np.float32), + } + ) + assert set(outputs) == {"output_0", "output_1"} + + +def test_torch_engine_summary_reports_core_fields(fake_torch, tmp_path): + model_path = tmp_path / "demo.pt" + model_path.write_bytes(b"torchscript") + engine = TorchEngine(model_path, device="cpu") + + summary = engine.summary() + assert summary["model"] == str(model_path) + assert summary["device"] == "cpu" + + +def test_torch_engine_benchmark_respects_config_cuda_sync_default( + fake_torch, tmp_path +): + """cuda_sync defaults to config when override is omitted.""" + model_path = tmp_path / "demo.pt" + model_path.write_bytes(b"torchscript") + engine = TorchEngine( + model_path, + device="cuda:0", + config=TorchEngineConfig(cuda_sync=False), + ) + + engine.benchmark( + {"image": np.zeros((1, 3, 4, 4), dtype=np.float32)}, + repeat=1, + warmup=0, + ) + assert fake_torch._sync_calls == [] + + +def test_torch_engine_accepts_preconstructed_tensors( + fake_torch, monkeypatch, tmp_path +): + """Existing torch tensors should pass through without from_numpy conversion.""" + model_path = tmp_path / "demo.pt" + model_path.write_bytes(b"torchscript") + engine = TorchEngine(model_path, device="cpu") + + tensor = fake_torch.from_numpy(np.zeros((1, 3, 4, 4), dtype=np.float32)) + monkeypatch.setattr( + fake_torch, + "from_numpy", + lambda *_args, **_kwargs: pytest.fail( + "from_numpy should not be called" + ), + ) + + outputs = engine.run({"image": tensor}) + assert set(outputs) == {"output_0", "output_1"} + + +def test_torch_engine_device_instance_short_circuits_normalization( + fake_torch, monkeypatch, tmp_path +): + """Passing an existing torch.device object should be preserved.""" + model_path = tmp_path / "demo.pt" + model_path.write_bytes(b"torchscript") + monkeypatch.setattr(fake_torch, "device", _FakeDevice) + + device = _FakeDevice("cuda:0") + engine = TorchEngine(model_path, device=device) + assert engine.device is device + + +def test_torch_engine_dtype_string_and_custom_dtype_handling( + fake_torch, monkeypatch, tmp_path +): + """dtype supports explicit strings, custom torch.dtype instances, and errors.""" + model_path = tmp_path / "demo.pt" + model_path.write_bytes(b"torchscript") + + fake_dtype_type = type("dtype", (), {}) + monkeypatch.setattr(fake_torch, "dtype", fake_dtype_type, raising=False) + + engine = TorchEngine( + model_path, + device="cpu", + config=TorchEngineConfig(dtype="fp16"), + ) + assert engine.dtype == fake_torch.float16 + + custom_dtype = fake_dtype_type() + engine = TorchEngine( + model_path, + device="cpu", + config=TorchEngineConfig(dtype=custom_dtype), + ) + assert engine.dtype is custom_dtype + + with pytest.raises(ValueError, match="Unsupported dtype specification"): + TorchEngine( + model_path, + device="cpu", + config=TorchEngineConfig(dtype="weird"), + ) + + +def test_torch_engine_rejects_unsupported_outputs(fake_torch, tmp_path): + """Unexpected model outputs should raise a clear error.""" + model_path = tmp_path / "demo.pt" + model_path.write_bytes(b"torchscript") + engine = TorchEngine(model_path, device="cpu") + + engine._model = lambda *_args: 123 # type: ignore[assignment] + with pytest.raises(TypeError, match="Unsupported TorchScript output"): + engine.run({"image": np.zeros((1, 3, 4, 4), dtype=np.float32)}) + + engine._model = lambda *_args: {"feat": "bad"} # type: ignore[assignment] + with pytest.raises(TypeError, match=r"Model outputs must be torch\.Tensor"): + engine.run({"image": np.zeros((1, 3, 4, 4), dtype=np.float32)}) diff --git a/tests/utils/test_custom_path.py b/tests/utils/test_custom_path.py new file mode 100644 index 0000000..1ed8212 --- /dev/null +++ b/tests/utils/test_custom_path.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from capybara.utils.custom_path import rm_path + + +def test_rm_path_removes_non_empty_directories(tmp_path): + folder = tmp_path / "nested" + folder.mkdir() + (folder / "file.txt").write_text("hello", encoding="utf-8") + + rm_path(folder) + assert not folder.exists() + + +def test_rm_path_removes_files(tmp_path): + file_path = tmp_path / "file.txt" + file_path.write_text("hello", encoding="utf-8") + + rm_path(file_path) + assert not file_path.exists() diff --git a/tests/utils/test_custom_tqdm.py b/tests/utils/test_custom_tqdm.py new file mode 100644 index 0000000..ea30b9c --- /dev/null +++ b/tests/utils/test_custom_tqdm.py @@ -0,0 +1,15 @@ +from capybara.utils.custom_tqdm import Tqdm + + +def test_custom_tqdm_respects_explicit_total_and_infers_total(): + bar = Tqdm(range(3), total=10, disable=True) + try: + assert bar.total == 10 + finally: + bar.close() + + inferred = Tqdm([1, 2, 3], disable=True) + try: + assert inferred.total == 3 + finally: + inferred.close() diff --git a/tests/utils/test_download_from_google.py b/tests/utils/test_download_from_google.py new file mode 100644 index 0000000..7b424c8 --- /dev/null +++ b/tests/utils/test_download_from_google.py @@ -0,0 +1,183 @@ +from __future__ import annotations + +from collections.abc import Iterator +from pathlib import Path +from typing import Any + +import pytest + +import capybara.utils.utils as utils_mod + + +class _FakeResponse: + def __init__( + self, + *, + headers: dict[str, str] | None = None, + cookies: dict[str, str] | None = None, + text: str = "", + chunks: list[bytes] | None = None, + iter_raises: Exception | None = None, + ) -> None: + self.headers = headers or {} + self.cookies = cookies or {} + self.text = text + self._chunks = chunks or [] + self._iter_raises = iter_raises + + def iter_content(self, *, chunk_size: int) -> Iterator[bytes]: + if self._iter_raises is not None: + raise self._iter_raises + yield from self._chunks + + +class _FakeSession: + def __init__(self, responses: list[_FakeResponse]) -> None: + self._responses = list(responses) + self.calls: list[tuple[str, dict[str, Any] | None]] = [] + + def get(self, url: str, params: dict[str, Any] | None = None, stream=True): + self.calls.append((url, params)) + if not self._responses: + raise AssertionError("No more fake responses configured") + return self._responses.pop(0) + + +class _DummyTqdm: + def __init__(self, *args: Any, **kwargs: Any) -> None: + self.total = kwargs.get("total", 0) + self.updated = 0 + + def update(self, n: int) -> None: + self.updated += n + + def __enter__(self) -> _DummyTqdm: + return self + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + +def _install_fakes(monkeypatch, session: _FakeSession) -> None: + monkeypatch.setattr(utils_mod.requests, "Session", lambda: session) + monkeypatch.setattr(utils_mod, "tqdm", _DummyTqdm) + + +def test_download_from_google_direct_content_disposition(tmp_path, monkeypatch): + session = _FakeSession( + [ + _FakeResponse( + headers={ + "content-disposition": "attachment; filename=x.bin", + "content-length": "3", + }, + chunks=[b"abc"], + ) + ] + ) + _install_fakes(monkeypatch, session) + + out = utils_mod.download_from_google( + file_id="id", + file_name="x.bin", + target=tmp_path, + ) + assert out == Path(tmp_path) / "x.bin" + assert out.read_bytes() == b"abc" + assert session.calls[0][0] == "https://docs.google.com/uc" + + +def test_download_from_google_uses_cookie_confirm_token(tmp_path, monkeypatch): + session = _FakeSession( + [ + _FakeResponse(cookies={"download_warning_foo": "TOKEN"}, text=""), + _FakeResponse( + headers={ + "content-disposition": "attachment", + "content-length": "1", + }, + chunks=[b"x"], + ), + ] + ) + _install_fakes(monkeypatch, session) + + out = utils_mod.download_from_google("id", "y.bin", target=tmp_path) + assert out.read_bytes() == b"x" + assert session.calls[1][1] is not None + assert session.calls[1][1]["confirm"] == "TOKEN" + + +def test_download_from_google_parses_html_download_form(tmp_path, monkeypatch): + html = """ + +

+ + +
+ + """ + session = _FakeSession( + [ + _FakeResponse(text=html), + _FakeResponse( + headers={ + "content-disposition": "attachment", + "content-length": "2", + }, + chunks=[b"hi"], + ), + ] + ) + _install_fakes(monkeypatch, session) + + out = utils_mod.download_from_google("id", "z.bin", target=tmp_path) + assert out.read_bytes() == b"hi" + assert session.calls[1][0] == "https://example.com/download" + + +def test_download_from_google_parses_confirm_param(tmp_path, monkeypatch): + session = _FakeSession( + [ + _FakeResponse(text="... confirm=ABC123 ..."), + _FakeResponse( + headers={ + "content-disposition": "attachment", + "content-length": "1", + }, + chunks=[b"1"], + ), + ] + ) + _install_fakes(monkeypatch, session) + + out = utils_mod.download_from_google("id", "c.bin", target=tmp_path) + assert out.read_bytes() == b"1" + assert session.calls[1][1] is not None + assert session.calls[1][1]["confirm"] == "ABC123" + + +def test_download_from_google_raises_when_no_link_found(tmp_path, monkeypatch): + session = _FakeSession([_FakeResponse(text="")]) + _install_fakes(monkeypatch, session) + + with pytest.raises(Exception, match="無法在回應中找到下載連結"): + utils_mod.download_from_google("id", "x.bin", target=tmp_path) + + +def test_download_from_google_wraps_streaming_errors(tmp_path, monkeypatch): + session = _FakeSession( + [ + _FakeResponse( + headers={ + "content-disposition": "attachment", + "content-length": "1", + }, + iter_raises=ValueError("boom"), + ) + ] + ) + _install_fakes(monkeypatch, session) + + with pytest.raises(RuntimeError, match="File download failed"): + utils_mod.download_from_google("id", "x.bin", target=tmp_path) diff --git a/tests/utils/test_files_utils.py b/tests/utils/test_files_utils.py new file mode 100644 index 0000000..dfaf0c6 --- /dev/null +++ b/tests/utils/test_files_utils.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import hashlib +from pathlib import Path + +import numpy as np +import pytest + +from capybara.utils import custom_path +from capybara.utils import files_utils as futils + + +def test_rm_path_removes_file_and_directory(tmp_path: Path): + f = tmp_path / "a.txt" + f.write_text("x", encoding="utf-8") + assert f.exists() + custom_path.rm_path(f) + assert not f.exists() + + d = tmp_path / "empty_dir" + d.mkdir() + assert d.exists() + custom_path.rm_path(d) + assert not d.exists() + + +def test_copy_path_copies_and_validates_source(tmp_path: Path): + src = tmp_path / "src.txt" + src.write_text("hello", encoding="utf-8") + dst = tmp_path / "dst.txt" + custom_path.copy_path(src, dst) + assert dst.read_text(encoding="utf-8") == "hello" + + with pytest.raises(ValueError, match="invaild"): + custom_path.copy_path(tmp_path / "missing.txt", dst) + + +def test_gen_md5_and_img_to_md5(tmp_path: Path): + p = tmp_path / "blob.bin" + payload = b"capybara" + p.write_bytes(payload) + + assert futils.gen_md5(p) == hashlib.md5(payload).hexdigest() + + img = np.arange(12, dtype=np.uint8).reshape(3, 4) + assert futils.img_to_md5(img) == hashlib.md5(img.tobytes()).hexdigest() + + with pytest.raises(TypeError, match="numpy array"): + futils.img_to_md5("not-an-array") # type: ignore[arg-type] + + +def test_dump_and_load_json_yaml_and_pickle(tmp_path: Path, monkeypatch): + obj = {"a": 1, "b": [1, 2, 3]} + + json_path = tmp_path / "x.json" + futils.dump_json(obj, json_path) + assert futils.load_json(json_path) == obj + + yaml_path = tmp_path / "x.yaml" + futils.dump_yaml(obj, yaml_path) + assert futils.load_yaml(yaml_path) == obj + + pkl_path = tmp_path / "x.pkl" + futils.dump_pickle(obj, pkl_path) + assert futils.load_pickle(pkl_path) == obj + + # Default-path behavior should not pollute repo root. + monkeypatch.chdir(tmp_path) + futils.dump_json(obj, path=None) + assert (tmp_path / "tmp.json").exists() + futils.dump_yaml(obj, path=None) + assert (tmp_path / "tmp.yaml").exists() + + +def test_get_files_filters_suffix_and_options(tmp_path: Path): + # Create a few files with varying cases. + (tmp_path / "a.txt").write_text("a", encoding="utf-8") + (tmp_path / "b.TXT").write_text("b", encoding="utf-8") + (tmp_path / "c.jpg").write_text("c", encoding="utf-8") + sub = tmp_path / "sub" + sub.mkdir() + (sub / "d.txt").write_text("d", encoding="utf-8") + + with pytest.raises(FileNotFoundError): + futils.get_files(tmp_path / "missing") + + with pytest.raises(TypeError, match="suffix must be"): + futils.get_files(tmp_path, suffix=123) # type: ignore[arg-type] + + files = futils.get_files(tmp_path, suffix=".txt", recursive=True) + assert all(Path(p).suffix.lower() == ".txt" for p in files) + assert len(files) == 3 + + files_no_case = futils.get_files( + tmp_path, + suffix=[".txt"], + recursive=False, + ignore_letter_case=False, + sort_path=False, + return_pathlib=False, + ) + assert all(isinstance(p, str) for p in files_no_case) + assert len(files_no_case) == 1 diff --git a/tests/utils/test_powerdict.py b/tests/utils/test_powerdict.py index ae3679d..b27f55a 100644 --- a/tests/utils/test_powerdict.py +++ b/tests/utils/test_powerdict.py @@ -1,161 +1,172 @@ +from typing import Any, cast + import pytest from capybara import PowerDict -power_dict_attr_data = [ - ( - { - 'key': 'value', - 'name': 'mock_name' - }, - ['key', 'name'] - ) -] +def test_powerdict_init_accepts_none_and_kwargs(): + assert PowerDict() == {} + assert PowerDict(None) == {} + pd = PowerDict(None, new="value") + assert pd == {"new": "value"} + assert pd.new == "value" -@pytest.mark.parametrize('x, match', power_dict_attr_data) -def test_power_dict_attr(x, match): - test_dict = PowerDict(x) - for attr in match: - assert hasattr(test_dict, attr) + pd = PowerDict({"a": 1}, b=2) + assert pd == {"a": 1, "b": 2} + assert pd.a == 1 + assert pd.b == 2 -power_dict_freeze_melt_data = [ - ( - { - 'key': 'value', - }, - ['key'] - ), - ( - { - 'PowerDict': PowerDict({'A': 1}), - }, - ['PowerDict'] - ), - ( - { - 'list': [1, 2, 3], - 'tuple': (1, 2, 3) - }, - ['list', 'tuple'] - ), - ( +def test_powerdict_attribute_and_item_access_are_kept_in_sync(): + pd = PowerDict() + pd.alpha = 1 + assert pd["alpha"] == 1 + + pd["beta"] = 2 + assert pd.beta == 2 + + del pd["alpha"] + assert "alpha" not in pd + assert not hasattr(pd, "alpha") + + pd.gamma = 3 + del pd.gamma + assert "gamma" not in pd + + +def test_powerdict_recursively_wraps_nested_mappings_and_sequences(): + pd = PowerDict( { - 'PowerDict_in_list': [PowerDict({'A': 1}), PowerDict({'B': 2})], - 'PowerDict_in_tuple': (PowerDict({'A': 1}), PowerDict({'B': 2})) - }, - ['PowerDict_in_list', 'PowerDict_in_tuple'] + "cfg": {"x": 1}, + "items": [{"y": 2}, {"z": 3}], + "numbers_tuple": (1, 2, 3), + } ) -] + assert isinstance(pd.cfg, PowerDict) + assert pd.cfg.x == 1 + + assert isinstance(pd.items, list) + assert [type(x) for x in pd.items] == [PowerDict, PowerDict] + assert pd.items[0].y == 2 + assert pd.items[1].z == 3 + + # Tuples are normalized to lists to keep internal mutation rules simple. + assert pd.numbers_tuple == [1, 2, 3] + + +def test_powerdict_freeze_blocks_mutation_and_melt_restores(): + pd = PowerDict({"a": 1, "nested": {"b": 2}}) + pd.freeze() + + with pytest.raises(ValueError, match="PowerDict is frozen"): + pd.a = 10 + with pytest.raises(ValueError, match="PowerDict is frozen"): + pd["a"] = 10 + with pytest.raises(ValueError, match="PowerDict is frozen"): + del pd.a + with pytest.raises(ValueError, match="PowerDict is frozen"): + del pd["a"] + with pytest.raises(ValueError, match="PowerDict is frozen"): + pd.update({"c": 3}) + with pytest.raises(ValueError, match="PowerDict is frozen"): + pd.pop("a") + + # Nested PowerDict is also frozen. + with pytest.raises(ValueError, match="PowerDict is frozen"): + pd.nested.b = 20 + + pd.melt() + pd.a = 10 + pd.update({"c": 3}) + assert pd == {"a": 10, "nested": {"b": 2}, "c": 3} + + +def test_powerdict_pop_behaves_like_dict_pop(): + pd = PowerDict({"a": 1}) + assert pd.pop("a") == 1 + assert "a" not in pd -@pytest.mark.parametrize('x, match', power_dict_freeze_melt_data) -def test_power_dict_freeze_melt(x, match): - test_dict = PowerDict(x) - test_dict.freeze() - for attr in match: - try: - test_dict[attr] = None - assert False - except ValueError: - pass + assert pd.pop("missing", None) is None - test_dict.melt() - for attr in match: - test_dict[attr] = None + with pytest.raises(KeyError): + pd.pop("missing") -power_dict_init_data = [ - { - 'd': None, - 'kwargs': None, - 'match': {} - }, - { - 'd': None, - 'kwargs': {'new': 'update'}, - 'match': {'kwargs': {'new': 'update'}} - } -] +def test_powerdict_reserved_frozen_key_behaviors(): + pd = PowerDict() + with pytest.raises(KeyError, match="_frozen"): + pd["_frozen"] = True + with pytest.raises(KeyError, match="_frozen"): + del pd["_frozen"] + with pytest.raises(KeyError, match="_frozen"): + del pd._frozen -@pytest.mark.parametrize('test_data', power_dict_init_data) -def test_dict_init(test_data: dict): - if test_data['kwargs'] is not None: - assert PowerDict(d=test_data['d'], kwargs=test_data['kwargs']) == test_data['match'] - else: - assert PowerDict(d=test_data['d']) == test_data['match'] +def test_powerdict_missing_attr_raises_attribute_error(): + pd = PowerDict() + with pytest.raises(AttributeError): + _ = pd.missing -power_dict_set_data = [ - ({'int': 1}, 'int', 2, {'int': 2}), - ({'list': [1, 2, 3]}, 'list', [4, 5, 6], {'list': [4, 5, 6]}), - ({'tuple': (1, 2, 3)}, 'tuple', (7, 8, 9), {'tuple': [7, 8, 9]}) -] +def test_powerdict_serialization_helpers(tmp_path): + pd = PowerDict({"a": 1, "nested": {"b": 2}}) + json_path = tmp_path / "x.json" + yaml_path = tmp_path / "x.yaml" + pkl_path = tmp_path / "x.pkl" + txt_path = tmp_path / "x.txt" -@pytest.mark.parametrize('x, new_key, new_value, match', power_dict_set_data) -def test_power_dict_set(x, new_key, new_value, match): - test_dict = PowerDict(x) - test_dict[new_key] = new_value - assert test_dict == match + assert pd.to_json(json_path) is None + assert json_path.exists() + assert PowerDict.load_json(json_path) == pd + assert pd.to_yaml(yaml_path) is None + assert yaml_path.exists() + assert PowerDict.load_yaml(yaml_path) == pd -power_dict_set_raises_data = [ - ({'int': 1}, 'int', 2, ValueError, "PowerDict is frozen. 'int' cannot be set."), - ({'list': [1, 2, 3]}, 'list', [3, 4], ValueError, "PowerDict is frozen. 'list' cannot be set."), - ({'tuple': (1, 2, 3)}, 'tuple', (3, 4), ValueError, "PowerDict is frozen. 'tuple' cannot be set.") -] + assert pd.to_pickle(pkl_path) is None + assert pkl_path.exists() + assert PowerDict.load_pickle(pkl_path) == pd + pd.to_txt(txt_path) + assert "nested" in txt_path.read_text(encoding="utf-8") -@pytest.mark.parametrize('x, new_key, new_value, error, match', power_dict_set_raises_data) -def test_power_dict_set_raises(x, new_key, new_value, error, match): - test_dict = PowerDict(x) - test_dict.freeze() - with pytest.raises(error, match=match): - test_dict[new_key] = new_value +def test_powerdict_deepcopy_is_blocked_when_frozen(): + from copy import deepcopy -power_dict_update_data = [ - ({'a': 1, 'b': 2}, {'b': 4, 'c': 3}, {'a': 1, 'b': 4, 'c': 3}), - ({'a': 1, 'b': 2}, {'c': [1, 2]}, {'a': 1, 'b': 2, 'c': [1, 2]}), - ({'a': 1, 'b': 2}, {'c': (1, 2)}, {'a': 1, 'b': 2, 'c': [1, 2]}) -] + pd = PowerDict({"a": 1}) + pd.freeze() + with pytest.raises(Warning, match="cannot be copy"): + _ = deepcopy(pd) -@pytest.mark.parametrize('x, e, match', power_dict_update_data) -def test_power_dict_update(x, e, match): - test_dict = PowerDict(x) - test_dict.update(e) - assert test_dict == match +def test_powerdict_update_without_args_and_deepcopy_when_unfrozen(): + from copy import deepcopy + pd = PowerDict({"a": 1}) + pd.update() + assert pd == {"a": 1} -power_dict_pop_data = [ - ({'a': 1, 'b': 2, 'c': 3}, 'b', {'a': 1, 'c': 3}), - ({'a': 1, 'b': 2, 'c': [1, 2]}, 'b', {'a': 1, 'c': [1, 2]}), - ({'a': 1, 'b': 2, 'c': (1, 2)}, 'b', {'a': 1, 'c': [1, 2]}), -] + clone = deepcopy(pd) + assert clone == pd + assert clone is not pd -@pytest.mark.parametrize('x, key, match', power_dict_pop_data) -def test_power_dict_pop(x, key, match): - test_dict = PowerDict(x) - test_dict.pop(key) - assert test_dict == match +def test_powerdict_freeze_and_to_dict_handle_powerdict_inside_lists(): + pd = PowerDict({"items": [{"x": 1}, {"y": 2}]}) + pd.freeze() + with pytest.raises(ValueError, match="PowerDict is frozen"): + items = cast(list[Any], pd["items"]) + cast(Any, items[0]).x = 10 + pd.melt() + items = cast(list[Any], pd["items"]) + cast(Any, items[0]).x = 10 + assert cast(Any, items[0]).x == 10 -test_power_dict_del_raises = [ - ({'a': 1, 'b': 2, 'c': 3}, 'b', ValueError, "PowerDict is frozen. 'b' cannot be del."), - ({'a': 1, 'b': 2, 'c': [1, 2]}, 'b', ValueError, "PowerDict is frozen. 'b' cannot be del."), - ({'a': 1, 'b': 2, 'c': (1, 2)}, 'b', ValueError, "PowerDict is frozen. 'b' cannot be del."), -] - - -@pytest.mark.parametrize('x, key, error, match', test_power_dict_del_raises) -def test_power_dict_del_raises(x, key, error, match): - test_dict = PowerDict(x) - test_dict.freeze() - with pytest.raises(error, match=match): - del test_dict[key] + out = pd.to_dict() + assert out == {"items": [{"x": 10}, {"y": 2}]} diff --git a/tests/utils/test_time.py b/tests/utils/test_time.py new file mode 100644 index 0000000..71a5047 --- /dev/null +++ b/tests/utils/test_time.py @@ -0,0 +1,134 @@ +from __future__ import annotations + +import time as py_time +from datetime import datetime, timezone + +import numpy as np +import pytest + +import capybara.utils.time as time_mod + + +def test_timer_requires_tic_before_toc(): + timer = time_mod.Timer() + with pytest.raises(ValueError, match="has not been started"): + timer.toc() + + +def test_timer_records_and_stats(monkeypatch, capsys): + perf = iter([10.0, 10.5, 20.0, 20.25]) + + monkeypatch.setattr(time_mod.time, "perf_counter", lambda: next(perf)) + timer = time_mod.Timer(precision=2, desc="bench", verbose=True) + + timer.tic() + dt1 = timer.toc(verbose=True) + assert dt1 == 0.5 + + timer.tic() + dt2 = timer.toc() + assert dt2 == 0.25 + + assert timer.mean == np.array([0.5, 0.25]).mean().round(2) + assert timer.max == 0.5 + assert timer.min == 0.25 + assert timer.std == np.array([0.5, 0.25]).std().round(2) + + out = capsys.readouterr().out + assert "bench" in out + assert "Cost:" in out + + timer.clear_record() + assert timer.mean is None + + +def test_timer_as_context_manager(monkeypatch): + perf = iter([1.0, 1.2]) + monkeypatch.setattr(time_mod.time, "perf_counter", lambda: next(perf)) + timer = time_mod.Timer(precision=3) + with timer: + pass + + assert timer.dt == 0.2 + + +def test_timer_context_manager_returns_self(monkeypatch): + perf = iter([1.0, 1.1]) + monkeypatch.setattr(time_mod.time, "perf_counter", lambda: next(perf)) + + with time_mod.Timer() as timer: + assert isinstance(timer, time_mod.Timer) + + +def test_timer_as_decorator(monkeypatch): + perf = iter([5.0, 5.1]) + monkeypatch.setattr(time_mod.time, "perf_counter", lambda: next(perf)) + + timer = time_mod.Timer() + + @timer + def add1(x: int) -> int: + return x + 1 + + assert add1(2) == 3 + assert timer.mean == 0.1 + + +def test_now_supports_timestamp_datetime_time_and_fmt(monkeypatch): + monkeypatch.setattr(time_mod.time, "time", lambda: 123.0) + assert time_mod.now("timestamp") == 123.0 + + assert isinstance(time_mod.now("datetime"), datetime) + assert isinstance(time_mod.now("time"), py_time.struct_time) + + # fmt takes precedence and returns a string. + monkeypatch.setattr(time_mod.time, "localtime", py_time.gmtime) + assert time_mod.now(fmt="%Y-%m-%d") == "1970-01-01" + + with pytest.raises(ValueError, match="Unsupported input"): + time_mod.now("invalid") + + +def test_time_and_datetime_converters_validate_types(monkeypatch): + fixed = py_time.gmtime(0) + + assert time_mod.time2datetime(fixed).year == 1970 + assert isinstance(time_mod.timestamp2datetime(0), datetime) + + with pytest.raises(TypeError): + time_mod.time2datetime("not-a-time") # type: ignore[arg-type] + + monkeypatch.setattr(time_mod.time, "mktime", lambda _: 42.0) + assert time_mod.time2timestamp(fixed) == 42.0 + with pytest.raises(TypeError): + time_mod.time2timestamp("not-a-time") # type: ignore[arg-type] + + assert time_mod.time2str(fixed, "%Y") == "1970" + with pytest.raises(TypeError): + time_mod.time2str("not-a-time", "%Y") # type: ignore[arg-type] + + dt = datetime(2000, 1, 1, tzinfo=timezone.utc) + assert isinstance(time_mod.datetime2time(dt), py_time.struct_time) + assert time_mod.datetime2timestamp(dt) == 946684800.0 + assert time_mod.datetime2str(dt, "%Y") == "2000" + + with pytest.raises(TypeError): + time_mod.datetime2time("not-a-dt") # type: ignore[arg-type] + with pytest.raises(TypeError): + time_mod.datetime2timestamp("not-a-dt") # type: ignore[arg-type] + with pytest.raises(TypeError): + time_mod.datetime2str("not-a-dt", "%Y") # type: ignore[arg-type] + + +def test_str_converters_validate_types(monkeypatch): + monkeypatch.setattr(time_mod.time, "mktime", lambda _: 99.0) + assert time_mod.str2time("1970-01-01", "%Y-%m-%d").tm_year == 1970 + assert time_mod.str2datetime("1970-01-01", "%Y-%m-%d").year == 1970 + assert time_mod.str2timestamp("1970-01-01", "%Y-%m-%d") == 99.0 + + with pytest.raises(TypeError): + time_mod.str2time(123, "%Y") # type: ignore[arg-type] + with pytest.raises(TypeError): + time_mod.str2datetime(123, "%Y") # type: ignore[arg-type] + with pytest.raises(TypeError): + time_mod.str2timestamp(123, "%Y") # type: ignore[arg-type] diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index 88039ac..c13783d 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -1,34 +1,20 @@ from capybara import COLORSTR, FORMATSTR, colorstr, make_batch -def number_generator(start, end): - for i in range(start, end + 1): - yield i +def number_generator(start: int, end: int): + yield from range(start, end + 1) def test_make_batch(): - # Create a number generator from 1 to 10 data_generator = number_generator(1, 10) - - # Batch size of 3 batch_size = 3 - - # Generate batched data batched_data_generator = make_batch(data_generator, batch_size) - - # Check the first batch batch = next(batched_data_generator) assert batch == [1, 2, 3] - - # Check the second batch batch = next(batched_data_generator) assert batch == [4, 5, 6] - - # Check the third batch batch = next(batched_data_generator) assert batch == [7, 8, 9] - - # Check the fourth batch (last batch with remaining data) batch = next(batched_data_generator) assert batch == [10] @@ -36,19 +22,22 @@ def test_make_batch(): def test_colorstr_blue_bold(): obj = "Hello, colorful world!" expected_output = "\033[1;34mHello, colorful world!\033[0m" - assert colorstr(obj, color=COLORSTR.BLUE, fmt=FORMATSTR.BOLD) == expected_output + assert ( + colorstr(obj, color=COLORSTR.BLUE, fmt=FORMATSTR.BOLD) + == expected_output + ) def test_colorstr_red(): obj = "Error: Something went wrong!" expected_output = "\033[1;31mError: Something went wrong!\033[0m" - assert colorstr(obj, color='red') == expected_output + assert colorstr(obj, color="red") == expected_output def test_colorstr_underline_green(): obj = "Important message" expected_output = "\033[4;32mImportant message\033[0m" - assert colorstr(obj, color=32, fmt='underline') == expected_output + assert colorstr(obj, color=32, fmt="underline") == expected_output def test_colorstr_default(): diff --git a/tests/vision/ipcam/test_camera.py b/tests/vision/ipcam/test_camera.py new file mode 100644 index 0000000..b6c834e --- /dev/null +++ b/tests/vision/ipcam/test_camera.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +import itertools + +import numpy as np + + +def test_ipcam_capture_is_an_iterator_and_yields_frames(monkeypatch): + import capybara.vision.ipcam.camera as cam_mod + + class _FakeCapture: + def get(self, prop): + # 4: height, 3: width (as used by the implementation) + if prop == 4: + return 10 + if prop == 3: + return 20 + return 0 + + def read(self): + return False, None + + class _FakeThread: + def __init__(self, target, daemon=True): + self._target = target + + def start(self): + self._target() + + monkeypatch.setattr( + cam_mod.cv2, "VideoCapture", lambda *_args, **_kwargs: _FakeCapture() + ) + monkeypatch.setattr(cam_mod, "Thread", _FakeThread) + + cap = cam_mod.IpcamCapture(url="fake", color_base="BGR") + assert iter(cap) is cap + + frames = list(itertools.islice(cap, 3)) + assert len(frames) == 3 + assert all(isinstance(frame, np.ndarray) for frame in frames) + + +def test_ipcam_capture_rejects_unsupported_image_size(monkeypatch): + import capybara.vision.ipcam.camera as cam_mod + + class _FakeCapture: + def get(self, prop): + if prop == 4: + return 0 + if prop == 3: + return 0 + return 0 + + def read(self): + return False, None + + class _FakeThread: + def __init__(self, target, daemon=True): + self._target = target + + def start(self): + self._target() + + monkeypatch.setattr( + cam_mod.cv2, "VideoCapture", lambda *_args, **_kwargs: _FakeCapture() + ) + monkeypatch.setattr(cam_mod, "Thread", _FakeThread) + + with np.testing.assert_raises(ValueError): + cam_mod.IpcamCapture(url="fake", color_base="BGR") + + +def test_ipcam_capture_converts_color_and_returns_frame_copy(monkeypatch): + import capybara.vision.ipcam.camera as cam_mod + + calls: list[str] = [] + + def fake_imcvtcolor(frame: np.ndarray, *, cvt_mode: str) -> np.ndarray: + calls.append(cvt_mode) + return frame + 1 + + class _FakeCapture: + def __init__(self): + self._calls = 0 + + def get(self, prop): + if prop == 4: + return 10 + if prop == 3: + return 20 + return 0 + + def read(self): + if self._calls == 0: + self._calls += 1 + return True, np.zeros((10, 20, 3), dtype=np.uint8) + return False, None + + class _FakeThread: + def __init__(self, target, daemon=True): + self._target = target + + def start(self): + self._target() + + monkeypatch.setattr( + cam_mod.cv2, "VideoCapture", lambda *_args, **_kwargs: _FakeCapture() + ) + monkeypatch.setattr(cam_mod, "Thread", _FakeThread) + monkeypatch.setattr(cam_mod, "imcvtcolor", fake_imcvtcolor) + + cap = cam_mod.IpcamCapture(url="fake", color_base="RGB") + frame1 = cap.get_frame() + assert calls == ["BGR2RGB"] + assert frame1.sum() > 0 + + frame1[0, 0, 0] = 123 + frame2 = cap.get_frame() + assert frame2[0, 0, 0] != 123 diff --git a/tests/vision/test_functional.py b/tests/vision/test_functional.py index 9077514..c2107c0 100644 --- a/tests/vision/test_functional.py +++ b/tests/vision/test_functional.py @@ -1,10 +1,24 @@ +from typing import Any + import cv2 import numpy as np import pytest -from capybara import (BORDER, Box, Boxes, gaussianblur, imbinarize, imcropbox, - imcropboxes, imcvtcolor, imresize_and_pad_if_need, - meanblur, medianblur, pad) +from capybara import ( + BORDER, + Box, + Boxes, + gaussianblur, + imbinarize, + imcropbox, + imcropboxes, + imcvtcolor, + imresize_and_pad_if_need, + meanblur, + medianblur, + pad, +) +from capybara.vision.functionals import centercrop, imadjust def test_meanblur(): @@ -31,8 +45,8 @@ def test_gaussianblur(): # 測試指定ksize和sigmaX ksize = (7, 7) - sigmaX = 1 - blurred_img_custom = gaussianblur(img, ksize=ksize, sigmaX=sigmaX) + sigma_x = 1 + blurred_img_custom = gaussianblur(img, ksize=ksize, sigma_x=sigma_x) assert blurred_img_custom.shape == img.shape @@ -55,16 +69,24 @@ def test_imcvtcolor(): img = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8) # 測試RGB轉灰階 - gray_img = imcvtcolor(img, 'RGB2GRAY') + gray_img = imcvtcolor(img, "RGB2GRAY") assert gray_img.shape == (100, 100) + # 支援 OpenCV 的 COLOR_ 前綴寫法 + gray_img_prefixed = imcvtcolor(img, "COLOR_RGB2GRAY") + assert gray_img_prefixed.shape == (100, 100) + + # 支援直接傳入 OpenCV 的 conversion code + gray_img2 = imcvtcolor(img, cv2.COLOR_RGB2GRAY) + assert gray_img2.shape == (100, 100) + # 測試RGB轉BGR - bgr_img = imcvtcolor(img, 'RGB2BGR') + bgr_img = imcvtcolor(img, "RGB2BGR") assert bgr_img.shape == img.shape # 測試轉換為不支援的色彩空間 with pytest.raises(ValueError): - imcvtcolor(img, 'RGB2WWW') # XYZ為不支援的色彩空間 + imcvtcolor(img, "RGB2WWW") # XYZ為不支援的色彩空間 def test_pad_constant_gray(): @@ -76,13 +98,22 @@ def test_pad_constant_gray(): pad_value = 128 padded_img = pad(img, pad_size=pad_size, pad_value=pad_value) assert padded_img.shape == ( - img.shape[0] + 2 * pad_size, img.shape[1] + 2 * pad_size) + img.shape[0] + 2 * pad_size, + img.shape[1] + 2 * pad_size, + ) assert np.all(padded_img[:pad_size, :] == pad_value) assert np.all(padded_img[-pad_size:, :] == pad_value) assert np.all(padded_img[:, :pad_size] == pad_value) assert np.all(padded_img[:, -pad_size:] == pad_value) +def test_pad_constant_gray_accepts_singleton_tuple_pad_value(): + img = np.random.randint(0, 256, size=(10, 10), dtype=np.uint8) + padded_img = pad(img, pad_size=1, pad_value=(123,)) + assert padded_img.shape == (12, 12) + assert int(padded_img[0, 0]) == 123 + + def test_pad_constant_color(): # 測試用的彩色圖片 img = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8) @@ -92,13 +123,26 @@ def test_pad_constant_color(): pad_value = (255, 0, 0) # 紅色 padded_img = pad(img, pad_size=pad_size, pad_value=pad_value) assert padded_img.shape == ( - img.shape[0] + 2 * pad_size, img.shape[1] + 2 * pad_size, img.shape[2]) + img.shape[0] + 2 * pad_size, + img.shape[1] + 2 * pad_size, + img.shape[2], + ) assert np.all(padded_img[:pad_size, :, :] == pad_value) assert np.all(padded_img[-pad_size:, :, :] == pad_value) assert np.all(padded_img[:, :pad_size, :] == pad_value) assert np.all(padded_img[:, -pad_size:, :] == pad_value) +def test_pad_rejects_invalid_pad_value_type_and_invalid_image_ndim(): + img = np.random.randint(0, 256, size=(10, 10, 3), dtype=np.uint8) + with pytest.raises(ValueError, match="pad_value must be"): + pad(img, pad_size=1, pad_value=[1, 2, 3]) # type: ignore[arg-type] + + bad = np.zeros((1, 2, 3, 4), dtype=np.uint8) + with pytest.raises(ValueError, match="img must be a 2D or 3D"): + pad(bad, pad_size=1, pad_value=0) + + def test_pad_replicate(): # 測試用的圖片 img = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8) @@ -107,7 +151,10 @@ def test_pad_replicate(): pad_size = (5, 10) padded_img = pad(img, pad_size=pad_size, pad_mode=BORDER.REPLICATE) assert padded_img.shape == ( - img.shape[0] + 2 * pad_size[0], img.shape[1] + 2 * pad_size[1], img.shape[2]) + img.shape[0] + 2 * pad_size[0], + img.shape[1] + 2 * pad_size[1], + img.shape[2], + ) def test_pad_reflect(): @@ -120,7 +167,7 @@ def test_pad_reflect(): assert padded_img.shape == ( img.shape[0] + pad_size[0] + pad_size[1], img.shape[1] + pad_size[2] + pad_size[3], - img.shape[2] + img.shape[2], ) @@ -128,12 +175,13 @@ def test_pad_invalid_input(): # 測試不支援的填充模式 img = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8) with pytest.raises(ValueError): - pad(img, pad_size=5, pad_mode='invalid_mode') + pad(img, pad_size=5, pad_mode="invalid_mode") # 測試不合法的填充大小 img = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8) with pytest.raises(ValueError): - pad(img, pad_size=(10, 20, 30)) + pad_size: Any = (10, 20, 30) + pad(img, pad_size=pad_size) def test_imcropbox_custom_box(): @@ -170,7 +218,8 @@ def test_imcropbox_invalid_input(): # 測試不支援的裁剪區域 img = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8) with pytest.raises(TypeError): - imcropbox(img, "invalid_box") + invalid_box: Any = "invalid_box" + imcropbox(img, invalid_box) # 測試不合法的裁剪區域 img = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8) @@ -240,7 +289,8 @@ def test_imcropboxes_invalid_input(): # 測試不支援的裁剪區域 img = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8) with pytest.raises(TypeError): - imcropboxes(img, "invalid_boxes") + invalid_boxes: Any = "invalid_boxes" + imcropboxes(img, invalid_boxes) # 測試不合法的裁剪區域 img = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8) @@ -286,18 +336,84 @@ def test_imbinarize_invalid_input(): def test_imresize_and_pad_if_need(): - img = np.ones((150, 120, 3), dtype='uint8') + img = np.ones((150, 120, 3), dtype="uint8") processed = imresize_and_pad_if_need(img, 150, 150) - np.testing.assert_allclose(processed[:, 120:], np.zeros((150, 30, 3), dtype='uint8')) - - img = np.ones((151, 119, 3), dtype='uint8') + np.testing.assert_allclose( + processed[:, 120:], np.zeros((150, 30, 3), dtype="uint8") + ) + + img = np.ones((151, 119, 3), dtype="uint8") processed = imresize_and_pad_if_need(img, 150, 150) - np.testing.assert_allclose(processed[:, 120:], np.zeros((150, 30, 3), dtype='uint8')) + np.testing.assert_allclose( + processed[:, 120:], np.zeros((150, 30, 3), dtype="uint8") + ) - img = np.ones((200, 100, 3), dtype='uint8') + img = np.ones((200, 100, 3), dtype="uint8") processed = imresize_and_pad_if_need(img, 100, 100) - np.testing.assert_allclose(processed[:, 50:], np.zeros((100, 50, 3), dtype='uint8')) + np.testing.assert_allclose( + processed[:, 50:], np.zeros((100, 50, 3), dtype="uint8") + ) - img = np.ones((20, 20, 3), dtype='uint8') + img = np.ones((20, 20, 3), dtype="uint8") processed = imresize_and_pad_if_need(img, 100, 100) - np.testing.assert_allclose(processed, np.ones((100, 100, 3), dtype='uint8')) + np.testing.assert_allclose(processed, np.ones((100, 100, 3), dtype="uint8")) + + +def test_blur_accepts_numpy_scalar_ksize_and_rejects_invalid_ksize(): + img = np.random.randint(0, 256, size=(20, 20, 3), dtype=np.uint8) + out = meanblur(img, ksize=np.array(3)) + assert out.shape == img.shape + out2 = gaussianblur(img, ksize=np.array(5)) + assert out2.shape == img.shape + + with pytest.raises(TypeError, match="ksize"): + meanblur(img, ksize=(1, 2, 3)) # type: ignore[arg-type] + + +def test_pad_accepts_none_pad_value_and_validates_pad_value_types(): + img = np.zeros((3, 3, 3), dtype=np.uint8) + padded = pad(img, pad_size=1, pad_value=None) + assert padded.shape == (5, 5, 3) + assert np.all(padded[0] == 0) + + with pytest.raises(ValueError, match="pad_value"): + invalid_pad_value: Any = (1, 2) + pad(img, pad_size=1, pad_value=invalid_pad_value) + + gray = np.zeros((3, 3), dtype=np.uint8) + with pytest.raises(ValueError, match="must be an int"): + pad(gray, pad_size=1, pad_value=(1, 2, 3)) + + +def test_centercrop_produces_square_crop(): + img = np.zeros((10, 20, 3), dtype=np.uint8) + cropped = centercrop(img) + assert cropped.shape == (10, 10, 3) + + +def test_imadjust_returns_input_when_bounds_degenerate(): + img = np.zeros((20, 20), dtype=np.uint8) + out = imadjust(img) + assert np.array_equal(out, img) + + +def test_imadjust_stretches_grayscale_and_color_images(): + gray = np.tile(np.arange(256, dtype=np.uint8), (4, 1)) + out = imadjust(gray) + assert out.shape == gray.shape + assert out.dtype == np.uint8 + assert out.min() == 0 + assert out.max() == 255 + assert not np.array_equal(out, gray) + + bgr = np.stack([gray] * 3, axis=-1) + out2 = imadjust(bgr) + assert out2.shape == bgr.shape + assert out2.dtype == np.uint8 + + +def test_imresize_and_pad_if_need_can_return_scale(): + img = np.ones((200, 100, 3), dtype=np.uint8) + out, scale = imresize_and_pad_if_need(img, 100, 100, return_scale=True) + assert out.shape == (100, 100, 3) + assert scale == pytest.approx(0.5) diff --git a/tests/vision/test_geometric.py b/tests/vision/test_geometric.py index 9131110..e8fad30 100644 --- a/tests/vision/test_geometric.py +++ b/tests/vision/test_geometric.py @@ -1,9 +1,20 @@ +from typing import Any + import numpy as np import pytest -from capybara import (BORDER, INTER, ROTATE, Polygon, Polygons, imresize, - imrotate, imrotate90, imwarp_quadrangle, - imwarp_quadrangles) +from capybara import ( + BORDER, + INTER, + ROTATE, + Polygon, + Polygons, + imresize, + imrotate, + imrotate90, + imwarp_quadrangle, + imwarp_quadrangles, +) @pytest.fixture @@ -15,14 +26,10 @@ def random_img(): @pytest.fixture def small_gray_img(): """ - 建立 3x3 的灰階小影像(數值小,方便檢查旋轉結果)。 - 為了測試方便,這裡只有一個通道。 + 建立 3x3 的灰階小影像 (數值小, 方便檢查旋轉結果)。 + 為了測試方便, 這裡只有一個通道。 """ - return np.array([ - [1, 2, 3], - [4, 5, 6], - [7, 8, 9] - ], dtype=np.uint8) + return np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.uint8) def test_imresize_both_dims(random_img): @@ -32,7 +39,7 @@ def test_imresize_both_dims(random_img): def test_imresize_single_dim(random_img): - """只給定單一維度時,檢查是否能等比例縮放。""" + """只給定單一維度時, 檢查是否能等比例縮放。""" orig_h, orig_w = random_img.shape[:2] # 測試只給定寬度 @@ -48,6 +55,22 @@ def test_imresize_single_dim(random_img): assert resized_img.shape[1] == expected_w +def test_imresize_rejects_missing_dimensions(random_img): + with pytest.raises(ValueError, match="at least one dimension"): + imresize(random_img, (None, None), INTER.BILINEAR) + + +def test_imresize_return_scale_when_only_one_dim_provided(random_img): + _orig_h, orig_w = random_img.shape[:2] + resized_img, w_scale, h_scale = imresize( + random_img, (None, 50), INTER.BILINEAR, return_scale=True + ) + + assert resized_img.shape[1] == 50 + assert w_scale == pytest.approx(50 / orig_w) + assert h_scale == pytest.approx(50 / orig_w) + + def test_imresize_return_scale(random_img): """測試回傳縮放比例。""" orig_h, orig_w = random_img.shape[:2] @@ -78,29 +101,17 @@ def test_imresize_different_interpolation(random_img): [ ( ROTATE.ROTATE_90, - np.array([ - [7, 4, 1], - [8, 5, 2], - [9, 6, 3] - ], dtype=np.uint8) + np.array([[7, 4, 1], [8, 5, 2], [9, 6, 3]], dtype=np.uint8), ), ( ROTATE.ROTATE_180, - np.array([ - [9, 8, 7], - [6, 5, 4], - [3, 2, 1] - ], dtype=np.uint8) + np.array([[9, 8, 7], [6, 5, 4], [3, 2, 1]], dtype=np.uint8), ), ( ROTATE.ROTATE_270, - np.array([ - [3, 6, 9], - [2, 5, 8], - [1, 4, 7] - ], dtype=np.uint8) - ) - ] + np.array([[3, 6, 9], [2, 5, 8], [1, 4, 7]], dtype=np.uint8), + ), + ], ) def test_imrotate90(small_gray_img, rotate_code, expected): """測試 90 度基底旋轉。""" @@ -110,14 +121,14 @@ def test_imrotate90(small_gray_img, rotate_code, expected): @pytest.mark.parametrize("angle, expand", [(90, False), (45, False)]) def test_imrotate_no_expand(random_img, angle, expand): - """測試不擴展的旋轉,輸出大小應與原圖相同。""" + """測試不擴展的旋轉, 輸出大小應與原圖相同。""" rotated_img = imrotate(random_img, angle=angle, expand=expand) assert rotated_img.shape == random_img.shape @pytest.mark.parametrize("angle, expand", [(90, True), (45, True)]) def test_imrotate_expand(random_img, angle, expand): - """測試擴展旋轉,輸出大小應大於或等於原圖。""" + """測試擴展旋轉, 輸出大小應大於或等於原圖。""" h, w = random_img.shape[:2] rotated_img = imrotate(random_img, angle=angle, expand=expand) assert rotated_img.shape[0] >= h @@ -125,7 +136,7 @@ def test_imrotate_expand(random_img, angle, expand): def test_imrotate_with_center(random_img): - """指定旋轉中心,檢查旋轉結果維度是否符合預期。""" + """指定旋轉中心, 檢查旋轉結果維度是否符合預期。""" h, w = random_img.shape[:2] center = (w // 4, h // 4) angle = 30 @@ -140,8 +151,13 @@ def test_imrotate_scale_border(random_img): 測試帶有 scale 與 bordervalue 的旋轉。 例如提供 (255, 0, 0) 以便檢查邊界是否填上紅色。 """ - scaled_img = imrotate(random_img, angle=45, scale=1.5, - bordertype=BORDER.REFLECT, bordervalue=(255, 0, 0)) + scaled_img = imrotate( + random_img, + angle=45, + scale=1.5, + bordertype=BORDER.REFLECT, + bordervalue=(255, 0, 0), + ) # 只要確定沒有拋出錯誤並且輸出維度放大即可 assert scaled_img.shape[0] > random_img.shape[0] assert scaled_img.shape[1] > random_img.shape[1] @@ -156,6 +172,31 @@ def test_imrotate_invalid_input(random_img): imrotate(random_img, angle=90, interpolation="invalid_interpolation") +def test_imrotate_validates_image_ndim_and_border_value_tuple_sizes( + random_img, small_gray_img +): + rotated_gray = imrotate( + small_gray_img, + angle=15, + expand=False, + bordertype=BORDER.CONSTANT, + bordervalue=(7,), + ) + assert rotated_gray.shape == small_gray_img.shape + + with pytest.raises(ValueError, match="2D or 3D"): + imrotate(np.zeros((1, 2, 3, 4), dtype=np.uint8), angle=0) + + with pytest.raises(ValueError, match="bordervalue"): + imrotate( + random_img, + angle=10, + expand=False, + bordertype=BORDER.CONSTANT, + bordervalue=(1, 2), + ) + + @pytest.fixture def default_polygon(): """產生含有 4 個點的基本 Polygon。""" @@ -164,7 +205,7 @@ def default_polygon(): def test_imwarp_quadrangle_default(random_img, default_polygon): - """不指定 dst_size 時,自動使用 min_area_rectangle 所得長寬。""" + """不指定 dst_size 時, 自動使用 min_area_rectangle 所得長寬。""" warped = imwarp_quadrangle(random_img, default_polygon) # 大致上應該會有 80x80 以上的圖 assert warped.shape[0] >= 80 @@ -172,18 +213,19 @@ def test_imwarp_quadrangle_default(random_img, default_polygon): def test_imwarp_quadrangle_with_dstsize(random_img, default_polygon): - """指定 dst_size,檢查變換後的輸出是否符合設定大小。""" + """指定 dst_size, 檢查變換後的輸出是否符合設定大小。""" warped = imwarp_quadrangle(random_img, default_polygon, dst_size=(100, 50)) assert warped.shape == (50, 100, 3) def test_imwarp_quadrangle_no_order_points(random_img, default_polygon): """ - 當 do_order_points = False 時,檢查程式能否正常執行。 - 有些情況下,使用者已確保點位順序正確,就可以省略排序。 + 當 do_order_points = False 時, 檢查程式能否正常執行。 + 有些情況下, 使用者已確保點位順序正確, 就可以省略排序。 """ warped = imwarp_quadrangle( - random_img, default_polygon, do_order_points=False) + random_img, default_polygon, do_order_points=False + ) assert warped.shape[0] >= 80 assert warped.shape[1] >= 80 @@ -192,7 +234,8 @@ def test_imwarp_quadrangle_invalid_polygon(random_img): """檢查多種不合法 polygon 之行為。""" # 傳入不支援的 polygon 類型 with pytest.raises(TypeError): - imwarp_quadrangle(random_img, "invalid_polygon") + invalid_polygon: Any = "invalid_polygon" + imwarp_quadrangle(random_img, invalid_polygon) # 傳入長度不是 4 的 polygon with pytest.raises(ValueError): @@ -200,17 +243,23 @@ def test_imwarp_quadrangle_invalid_polygon(random_img): imwarp_quadrangle(random_img, bad_pts) +def test_imwarp_quadrangle_swaps_width_height_when_needed(random_img): + pts = np.array([[10, 10], [30, 10], [30, 90], [10, 90]], dtype=np.float32) + polygon = Polygon(pts) + warped = imwarp_quadrangle(random_img, polygon) + assert warped.shape[0] < warped.shape[1] + + @pytest.fixture def polygons_list(): - """產生 2 個四邊形,合併成 Polygons。""" + """產生 2 個四邊形, 合併成 Polygons。""" src_pts_1 = np.array( - [[10, 10], [90, 10], [90, 90], [10, 90]], dtype=np.float32) + [[10, 10], [90, 10], [90, 90], [10, 90]], dtype=np.float32 + ) src_pts_2 = np.array( - [[20, 20], [80, 20], [80, 80], [20, 80]], dtype=np.float32) - return Polygons([ - Polygon(src_pts_1), - Polygon(src_pts_2) - ]) + [[20, 20], [80, 20], [80, 80], [20, 80]], dtype=np.float32 + ) + return Polygons([Polygon(src_pts_1), Polygon(src_pts_2)]) def test_imwarp_quadrangles(random_img, polygons_list): @@ -227,12 +276,16 @@ def test_imwarp_quadrangles(random_img, polygons_list): def test_imwarp_quadrangles_invalid_type(random_img): """檢查不合法的 polygons 輸入。""" with pytest.raises(TypeError): - imwarp_quadrangles(random_img, "invalid_polygons") + invalid_polygons: Any = "invalid_polygons" + imwarp_quadrangles(random_img, invalid_polygons) with pytest.raises(TypeError): - invalid_polygons = [ + invalid_polygons: Any = [ Polygon( - np.array([[10, 10], [90, 10], [90, 90], [10, 90]], dtype=np.float32)), - "invalid_polygon" + np.array( + [[10, 10], [90, 10], [90, 90], [10, 90]], dtype=np.float32 + ) + ), + "invalid_polygon", ] imwarp_quadrangles(random_img, invalid_polygons) diff --git a/tests/vision/test_improc.py b/tests/vision/test_improc.py index 4ed3581..ca98eef 100644 --- a/tests/vision/test_improc.py +++ b/tests/vision/test_improc.py @@ -22,6 +22,11 @@ def test_imread(): assert isinstance(img_gray, np.ndarray) assert len(img_gray.shape) == 2 # 灰階圖片的channel數為1 + # color_base should be case-insensitive + img_gray2 = imread(image_path, color_base="gray") + assert isinstance(img_gray2, np.ndarray) + assert len(img_gray2.shape) == 2 + # 測試heif格式的圖片讀取 img_heif = imread(DIR.parent / "resources" / "lena.heic", color_base="BGR") assert isinstance(img_heif, np.ndarray) @@ -32,14 +37,24 @@ def test_imread(): imread("non_existent_image.jpg") -def test_imwrite(): +def test_imwrite(tmp_path): # 測試用的圖片 img = np.zeros((100, 100, 3), dtype=np.uint8) # 建立一個全黑的BGR圖片 # 測試BGR格式的圖片寫入 - temp_file_path = DIR / "temp_image.jpg" + temp_file_path = tmp_path / "temp_image.jpg" assert imwrite(img, path=temp_file_path, color_base="BGR") assert Path(temp_file_path).exists() - # 測試不指定路徑時的圖片寫入 - assert imwrite(img, color_base="BGR") # 將會寫入一個暫時的檔案 + # 測試不指定路徑時的圖片寫入 (不應污染 repo working directory) + assert imwrite(img, color_base="BGR") + + +def test_imwrite_without_path_does_not_create_tmp_file_in_cwd( + tmp_path, monkeypatch +): + """Regression: historical implementation wrote `tmp{suffix}` into CWD.""" + monkeypatch.chdir(tmp_path) + img = np.zeros((10, 10, 3), dtype=np.uint8) + assert imwrite(img, color_base="BGR", suffix=".jpg") is True + assert not (tmp_path / "tmp.jpg").exists() diff --git a/tests/vision/test_improc_extra.py b/tests/vision/test_improc_extra.py new file mode 100644 index 0000000..d0346d8 --- /dev/null +++ b/tests/vision/test_improc_extra.py @@ -0,0 +1,294 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import numpy as np +import pytest + +import capybara.vision.improc as improc +from capybara import IMGTYP, ROTATE + + +def test_is_numpy_img_accepts_2d_and_3d_and_rejects_other_shapes(): + assert improc.is_numpy_img(np.zeros((10, 10), dtype=np.uint8)) + assert improc.is_numpy_img(np.zeros((10, 10, 1), dtype=np.uint8)) + assert improc.is_numpy_img(np.zeros((10, 10, 3), dtype=np.uint8)) + assert not improc.is_numpy_img(np.zeros((10, 10, 4), dtype=np.uint8)) + assert not improc.is_numpy_img("not-an-array") # type: ignore[arg-type] + + +def test_get_orientation_code_maps_known_values(monkeypatch): + def fake_load(_: Any): + return {"0th": {improc.piexif.ImageIFD.Orientation: 6}} + + monkeypatch.setattr(improc.piexif, "load", fake_load) + assert improc.get_orientation_code(b"") == ROTATE.ROTATE_90 + + def fake_load_3(_: Any): + return {"0th": {improc.piexif.ImageIFD.Orientation: 3}} + + monkeypatch.setattr(improc.piexif, "load", fake_load_3) + assert improc.get_orientation_code(b"") == ROTATE.ROTATE_180 + + def fake_load_8(_: Any): + return {"0th": {improc.piexif.ImageIFD.Orientation: 8}} + + monkeypatch.setattr(improc.piexif, "load", fake_load_8) + assert improc.get_orientation_code(b"") == ROTATE.ROTATE_270 + + def fake_load_other(_: Any): + return {"0th": {improc.piexif.ImageIFD.Orientation: 1}} + + monkeypatch.setattr(improc.piexif, "load", fake_load_other) + assert improc.get_orientation_code(b"") is None + + monkeypatch.setattr( + improc.piexif, + "load", + lambda _: (_ for _ in ()).throw(Exception("boom")), + ) + assert improc.get_orientation_code(b"") is None + + +def test_jpgencode_handles_tuple_return_and_failures(monkeypatch): + class FakeJPEG: + def __init__(self, *, raise_encode: bool = False) -> None: + self.raise_encode = raise_encode + + def encode(self, img: np.ndarray, quality: int = 90): + if self.raise_encode: + raise RuntimeError("boom") + return (b"xx", b"ignored") + + monkeypatch.setattr(improc, "jpeg", FakeJPEG()) + img = np.zeros((8, 8, 3), dtype=np.uint8) + assert improc.jpgencode(img) == b"xx" + + monkeypatch.setattr(improc, "jpeg", FakeJPEG(raise_encode=True)) + assert improc.jpgencode(img) is None + + assert improc.jpgencode(np.zeros((8, 8, 4), dtype=np.uint8)) is None + + +def test_jpgdecode_rotates_based_on_orientation(monkeypatch): + img = np.zeros((5, 5, 3), dtype=np.uint8) + + class FakeJPEG: + def decode(self, _: bytes) -> np.ndarray: + return img + + monkeypatch.setattr(improc, "jpeg", FakeJPEG()) + monkeypatch.setattr( + improc, "get_orientation_code", lambda _: ROTATE.ROTATE_90 + ) + monkeypatch.setattr(improc, "imrotate90", lambda arr, code: arr + 1) + out = improc.jpgdecode(b"bytes") + assert isinstance(out, np.ndarray) + assert out.sum() > 0 + + class BadJPEG: + def decode(self, _: bytes) -> np.ndarray: + raise RuntimeError("bad") + + monkeypatch.setattr(improc, "jpeg", BadJPEG()) + assert improc.jpgdecode(b"bytes") is None + + +def test_pngencode_pngdecode_roundtrip_and_invalid_inputs(): + img = np.zeros((10, 10, 3), dtype=np.uint8) + img[2:4, 2:4] = 255 + + enc = improc.pngencode(img) + assert isinstance(enc, bytes) + dec = improc.pngdecode(enc) + assert isinstance(dec, np.ndarray) + assert dec.shape == img.shape + + assert improc.pngencode(np.zeros((10, 10, 4), dtype=np.uint8)) is None + assert improc.pngdecode(b"not-a-real-image") is None + + +def test_imencode_selects_format_and_validates_kwargs(monkeypatch): + monkeypatch.setattr(improc, "jpgencode", lambda _: b"jpg") + monkeypatch.setattr(improc, "pngencode", lambda _: b"png") + + dummy = np.zeros((2, 2, 3), dtype=np.uint8) + assert improc.imencode(dummy, IMGTYP.JPEG) == b"jpg" + assert improc.imencode(dummy, IMGTYP.PNG) == b"png" + assert improc.imencode(dummy, IMGTYP="PNG") == b"png" + + with pytest.raises(TypeError, match="both provided"): + improc.imencode(dummy, "png", IMGTYP="jpeg") + + with pytest.raises(TypeError, match="Unexpected keyword"): + improc.imencode(dummy, unexpected=1) # type: ignore[arg-type] + + +def test_imdecode_falls_back_to_png_when_jpeg_decode_fails(monkeypatch): + monkeypatch.setattr(improc, "jpgdecode", lambda _: None) + monkeypatch.setattr( + improc, "pngdecode", lambda _: np.zeros((1, 1, 3), dtype=np.uint8) + ) + out = improc.imdecode(b"blob") + assert isinstance(out, np.ndarray) + + monkeypatch.setattr( + improc, "jpgdecode", lambda _: (_ for _ in ()).throw(RuntimeError("x")) + ) + assert improc.imdecode(b"blob") is None + + +def test_img_to_b64_and_back(monkeypatch): + dummy = np.zeros((2, 2, 3), dtype=np.uint8) + + monkeypatch.setattr(improc, "jpgencode", lambda _: b"\x00\x01") + b64 = improc.img_to_b64(dummy, IMGTYP.JPEG) + assert isinstance(b64, bytes) + + b64str = improc.img_to_b64str(dummy, IMGTYP.JPEG) + assert isinstance(b64str, str) + + monkeypatch.setattr( + improc, "imdecode", lambda _: np.ones((1, 1, 3), dtype=np.uint8) + ) + out = improc.b64_to_img(b64) + assert isinstance(out, np.ndarray) + + +def test_img_to_b64_accepts_imgtyp_via_imgtyp_kwarg(monkeypatch): + dummy = np.zeros((2, 2, 3), dtype=np.uint8) + + monkeypatch.setattr(improc, "pngencode", lambda _: b"\x00\x01") + b64 = improc.img_to_b64(dummy, IMGTYP="PNG") + assert isinstance(b64, bytes) + + +def test_b64str_to_img_validates_and_warns(monkeypatch): + monkeypatch.setattr( + improc, "b64_to_img", lambda _: np.zeros((1, 1, 3), dtype=np.uint8) + ) + assert improc.b64str_to_img("AA==") is not None + + with pytest.warns(UserWarning, match="b64str is None"): + assert improc.b64str_to_img(None) is None + + with pytest.raises(ValueError, match="not a string"): + improc.b64str_to_img(123) # type: ignore[arg-type] + + +def test_npy_b64_roundtrip_and_npyread(tmp_path: Path): + arr = np.array([1.0, 2.0], dtype=np.float32) + b64 = improc.npy_to_b64(arr) + out = improc.b64_to_npy(b64) + np.testing.assert_allclose(out, arr) + + b64s = improc.npy_to_b64str(arr) + out2 = improc.b64str_to_npy(b64s) + np.testing.assert_allclose(out2, arr) + + file = tmp_path / "x.npy" + np.save(file, arr) + loaded = improc.npyread(file) + assert loaded is not None + np.testing.assert_allclose(loaded, arr) + + assert improc.npyread(tmp_path / "missing.npy") is None + + +def test_pdf2imgs_handles_bytes_and_paths_and_failures( + monkeypatch, tmp_path: Path +): + from PIL import Image + + pil = Image.fromarray(np.zeros((2, 3, 3), dtype=np.uint8)) + + monkeypatch.setattr(improc, "convert_from_bytes", lambda _: [pil]) + monkeypatch.setattr(improc, "convert_from_path", lambda _: [pil]) + monkeypatch.setattr(improc, "imcvtcolor", lambda arr, cvt_mode: arr) + + out = improc.pdf2imgs(b"%PDF-1.0") + assert out is not None + assert len(out) == 1 + assert isinstance(out[0], np.ndarray) + + pdf_path = tmp_path / "a.pdf" + pdf_path.write_bytes(b"%PDF-1.0") + out2 = improc.pdf2imgs(pdf_path) + assert out2 is not None + assert len(out2) == 1 + + monkeypatch.setattr( + improc, + "convert_from_path", + lambda _: (_ for _ in ()).throw(RuntimeError("boom")), + ) + assert improc.pdf2imgs(pdf_path) is None + + +def test_pngdecode_returns_none_when_cv2_imdecode_raises(monkeypatch): + monkeypatch.setattr( + improc.cv2, + "imdecode", + lambda *_args, **_kwargs: (_ for _ in ()).throw(ValueError("x")), + ) + assert improc.pngdecode(b"bytes") is None + + +def test_img_to_b64_validates_kwargs_and_handles_encode_none_and_exceptions( + monkeypatch, +): + dummy = np.zeros((2, 2, 3), dtype=np.uint8) + + with pytest.raises(TypeError, match="both provided"): + improc.img_to_b64(dummy, "png", IMGTYP="jpeg") + + with pytest.raises(TypeError, match="Unexpected keyword"): + improc.img_to_b64(dummy, unexpected=1) # type: ignore[arg-type] + + monkeypatch.setattr(improc, "jpgencode", lambda *_args, **_kwargs: None) + assert improc.img_to_b64(dummy, IMGTYP.JPEG) is None + + monkeypatch.setattr(improc, "jpgencode", lambda *_args, **_kwargs: b"x") + monkeypatch.setattr( + improc.pybase64, + "b64encode", + lambda *_args, **_kwargs: (_ for _ in ()).throw(ValueError("boom")), + ) + assert improc.img_to_b64(dummy, IMGTYP.JPEG) is None + + +def test_b64_to_img_returns_none_when_b64decode_raises(monkeypatch): + monkeypatch.setattr( + improc.pybase64, + "b64decode", + lambda *_args, **_kwargs: (_ for _ in ()).throw(ValueError("boom")), + ) + assert improc.b64_to_img(b"@@@") is None + + +def test_imread_warns_when_image_is_none(monkeypatch, tmp_path: Path): + path = tmp_path / "x.jpg" + path.write_bytes(b"not-a-real-jpg") + + monkeypatch.setattr(improc, "jpgread", lambda *_args, **_kwargs: None) + monkeypatch.setattr(improc.cv2, "imread", lambda *_args, **_kwargs: None) + + with pytest.warns(UserWarning, match="None type image"): + assert improc.imread(path, verbose=True) is None + + +def test_imwrite_converts_color_base(monkeypatch, tmp_path: Path): + calls: list[str] = [] + + def fake_imcvtcolor(img: np.ndarray, *, cvt_mode: str) -> np.ndarray: + calls.append(cvt_mode) + return img + + monkeypatch.setattr(improc, "imcvtcolor", fake_imcvtcolor) + monkeypatch.setattr(improc.cv2, "imwrite", lambda *_args, **_kwargs: True) + + img = np.zeros((2, 2, 3), dtype=np.uint8) + out_path = tmp_path / "out.jpg" + assert improc.imwrite(img, path=out_path, color_base="RGB") + assert calls == ["RGB2BGR"] diff --git a/tests/vision/test_morphology.py b/tests/vision/test_morphology.py new file mode 100644 index 0000000..9f17452 --- /dev/null +++ b/tests/vision/test_morphology.py @@ -0,0 +1,56 @@ +import numpy as np +import pytest + +from capybara.enums import MORPH +from capybara.vision import morphology as morph + + +@pytest.mark.parametrize( + "fn", + [ + morph.imerode, + morph.imdilate, + morph.imopen, + morph.imclose, + morph.imgradient, + morph.imtophat, + morph.imblackhat, + ], +) +def test_morphology_ops_preserve_shape_and_dtype(fn): + img = np.zeros((20, 20), dtype=np.uint8) + img[8:12, 8:12] = 255 + + out = fn(img, ksize=3, kstruct=MORPH.RECT) + assert out.shape == img.shape + assert out.dtype == img.dtype + + out2 = fn(img, ksize=(5, 3), kstruct="CROSS") + assert out2.shape == img.shape + + out3 = fn(img, ksize=3, kstruct=MORPH.RECT.value) + assert out3.shape == img.shape + + +@pytest.mark.parametrize( + "fn", + [ + morph.imerode, + morph.imdilate, + morph.imopen, + morph.imclose, + morph.imgradient, + morph.imtophat, + morph.imblackhat, + ], +) +def test_morphology_invalid_ksize_raises(fn): + img = np.zeros((5, 5), dtype=np.uint8) + with pytest.raises(TypeError): + fn(img, ksize=(1, 2, 3)) # type: ignore[arg-type] + + +def test_morphology_invalid_kstruct_raises(): + img = np.zeros((5, 5), dtype=np.uint8) + with pytest.raises(ValueError): + morph.imerode(img, kstruct="cross") diff --git a/tests/vision/videotools/test_video2frames.py b/tests/vision/videotools/test_video2frames.py index 6e84414..6ae3651 100644 --- a/tests/vision/videotools/test_video2frames.py +++ b/tests/vision/videotools/test_video2frames.py @@ -21,6 +21,18 @@ def test_video2frames_with_fps(): assert len(frames) == 18 +def test_video2frames_with_fps_greater_than_video_fps(): + # When requested FPS exceeds the video FPS, fall back to extracting all frames. + frames_all = video2frames(video_path) + frames = video2frames(video_path, frame_per_sec=60) + assert len(frames) == len(frames_all) + + +def test_video2frames_rejects_non_positive_fps(): + with pytest.raises(ValueError, match="frame_per_sec must be > 0"): + video2frames(video_path, frame_per_sec=0) + + def test_video2frames_invalid_input(): # 測試不支援的影片類型 with pytest.raises(TypeError): @@ -29,3 +41,49 @@ def test_video2frames_invalid_input(): # 測試不存在的影片路徑 with pytest.raises(TypeError): video2frames("non_existent_video.mp4") + + +def test_video2frames_returns_empty_list_when_capture_cannot_open(tmp_path): + # A file with a valid suffix but invalid contents should fail to open. + bad = tmp_path / "bad.mp4" + bad.write_bytes(b"") + assert video2frames(bad) == [] + + +def test_video2frames_falls_back_to_all_frames_when_fps_is_invalid( + monkeypatch, tmp_path +): + import importlib + + v_mod = importlib.import_module("capybara.vision.videotools.video2frames") + + dummy = tmp_path / "x.mp4" + dummy.write_bytes(b"0") + + class _FakeCapture: + def __init__(self) -> None: + self._idx = 0 + + def isOpened(self) -> bool: # noqa: N802 + return True + + def get(self, prop): + if prop == v_mod.cv2.CAP_PROP_FPS: + return 0 + return 0 + + def read(self): + if self._idx >= 3: + return False, None + self._idx += 1 + return True, np.zeros((2, 2, 3), dtype=np.uint8) + + def release(self) -> None: + return None + + monkeypatch.setattr( + v_mod.cv2, "VideoCapture", lambda *_args, **_kwargs: _FakeCapture() + ) + + frames = v_mod.video2frames(dummy, frame_per_sec=2) + assert len(frames) == 3 diff --git a/tests/vision/videotools/test_video2frames_v2.py b/tests/vision/videotools/test_video2frames_v2.py index cbbfc58..bc91951 100644 --- a/tests/vision/videotools/test_video2frames_v2.py +++ b/tests/vision/videotools/test_video2frames_v2.py @@ -17,7 +17,9 @@ def test_video2frames_v2(): def test_video2frames_v2_with_fps(): # 測試指定提取幀的速度 - frames = video2frames_v2(video_path, frame_per_sec=2, start_sec=0, end_sec=2, n_threads=2) + frames = video2frames_v2( + video_path, frame_per_sec=2, start_sec=0, end_sec=2, n_threads=2 + ) assert len(frames) == 4 @@ -29,3 +31,262 @@ def test_video2frames_v2_invalid_input(): # 測試不存在的影片路徑 with pytest.raises(TypeError): video2frames_v2("non_existent_video.mp4") + + +def test_video2frames_v2_helpers_and_internal_branches(monkeypatch, tmp_path): + import importlib + + v2_mod = importlib.import_module( + "capybara.vision.videotools.video2frames_v2" + ) + + assert v2_mod.is_numpy_img(np.zeros((10, 10), dtype=np.uint8)) + assert v2_mod.flatten_list([[1], [2, 3]]) == [1, 2, 3] + assert v2_mod.flatten_list([[[1], [2]]]) == [1, 2] + + with pytest.raises(ValueError, match="larger than"): + v2_mod.get_step_inds(0, 1, 5) + + with pytest.raises(TypeError, match="inappropriate"): + v2_mod._extract_frames([0, 1], "missing.mp4") + + dummy_video = tmp_path / "x.mp4" + dummy_video.write_bytes(b"") + + resized_calls: list[dict[str, object]] = [] + cvt_calls: list[str] = [] + + def fake_imresize(frame: np.ndarray, dsize, **kwargs): + resized_calls.append({"dsize": dsize, **kwargs}) + h, w = dsize + return np.zeros((h, w, frame.shape[-1]), dtype=frame.dtype) + + def fake_imcvtcolor(frame: np.ndarray, *, cvt_mode: str) -> np.ndarray: + cvt_calls.append(cvt_mode) + return frame + + monkeypatch.setattr(v2_mod, "imresize", fake_imresize) + monkeypatch.setattr(v2_mod, "imcvtcolor", fake_imcvtcolor) + + class _FakeCapture: + def __init__(self, frames) -> None: + self._frames = list(frames) + self._idx = 0 + + def set(self, *_args, **_kwargs) -> None: + return None + + def read(self): + if self._idx >= len(self._frames): + return False, None + frame = self._frames[self._idx] + self._idx += 1 + return True, frame + + def release(self) -> None: + return None + + # scale < 1 branch + color conversion + skip None frame + monkeypatch.setattr( + v2_mod.cv2, + "VideoCapture", + lambda *_args, **_kwargs: _FakeCapture( + [None, np.zeros((20, 20, 3), dtype=np.uint8)] + ), + ) + frames, _ = v2_mod._extract_frames( + inds=[0, 1], + video_path=dummy_video, + max_size=10, + color_base="RGB", + ) + assert len(frames) == 1 + assert resized_calls + assert cvt_calls == ["BGR2RGB"] + + resized_calls.clear() + cvt_calls.clear() + + # scale > 1 branch + monkeypatch.setattr( + v2_mod.cv2, + "VideoCapture", + lambda *_args, **_kwargs: _FakeCapture( + [np.zeros((5, 5, 3), dtype=np.uint8)] + ), + ) + frames2, _ = v2_mod._extract_frames( + inds=[0], + video_path=dummy_video, + max_size=10, + color_base="BGR", + ) + assert len(frames2) == 1 + assert any(call.get("interpolation") is not None for call in resized_calls) + + +def test_video2frames_v2_returns_empty_for_zero_frames_or_fps( + monkeypatch, tmp_path +): + import importlib + + v2_mod = importlib.import_module( + "capybara.vision.videotools.video2frames_v2" + ) + + dummy_video = tmp_path / "x.mp4" + dummy_video.write_bytes(b"") + + class _FakeCapture: + def get(self, *_args, **_kwargs): + return 0 + + def release(self) -> None: + return None + + monkeypatch.setattr( + v2_mod.cv2, "VideoCapture", lambda *_args, **_kwargs: _FakeCapture() + ) + assert video2frames_v2(dummy_video) == [] + + +def test_video2frames_v2_validates_start_end_and_handles_worker_exceptions( + monkeypatch, tmp_path +): + import importlib + + v2_mod = importlib.import_module( + "capybara.vision.videotools.video2frames_v2" + ) + + dummy_video = tmp_path / "x.mp4" + dummy_video.write_bytes(b"") + + class _FakeCapture: + def __init__(self) -> None: + self.calls = 0 + + def get(self, prop): + if prop == v2_mod.cv2.CAP_PROP_FRAME_COUNT: + return 10 + if prop == v2_mod.cv2.CAP_PROP_FPS: + return 10 + return 0 + + def release(self) -> None: + return None + + monkeypatch.setattr( + v2_mod.cv2, "VideoCapture", lambda *_args, **_kwargs: _FakeCapture() + ) + + with pytest.raises(ValueError, match="start_sec"): + video2frames_v2(dummy_video, start_sec=2.0, end_sec=1.0) + + monkeypatch.setattr( + v2_mod, + "_extract_frames", + lambda *_args, **_kwargs: (_ for _ in ()).throw(RuntimeError("boom")), + ) + assert ( + video2frames_v2(dummy_video, start_sec=0.0, end_sec=1.0, n_threads=1) + == [] + ) + + +def test_video2frames_v2_skips_empty_worker_chunks(monkeypatch, tmp_path): + import importlib + + v2_mod = importlib.import_module( + "capybara.vision.videotools.video2frames_v2" + ) + + dummy_video = tmp_path / "x.mp4" + dummy_video.write_bytes(b"") + + class _FakeCapture: + def get(self, prop): + if prop == v2_mod.cv2.CAP_PROP_FRAME_COUNT: + return 10 + if prop == v2_mod.cv2.CAP_PROP_FPS: + return 10 + return 0 + + def release(self) -> None: + return None + + monkeypatch.setattr( + v2_mod.cv2, "VideoCapture", lambda *_args, **_kwargs: _FakeCapture() + ) + + calls: list[list[int]] = [] + + def fake_extract_frames( + inds, + video_path, + max_size=1920, + color_base="BGR", + global_ind=0, + ): + assert inds, "Should never schedule empty index chunks." + calls.append(list(inds)) + return [], global_ind + + monkeypatch.setattr(v2_mod, "_extract_frames", fake_extract_frames) + + assert ( + v2_mod.video2frames_v2( + dummy_video, + frame_per_sec=10, + start_sec=0.0, + end_sec=0.1, + n_threads=8, + ) + == [] + ) + assert calls == [[0]] + + +def test_video2frames_v2_validates_threads_fps_and_zero_duration( + monkeypatch, tmp_path +): + import importlib + + v2_mod = importlib.import_module( + "capybara.vision.videotools.video2frames_v2" + ) + + dummy_video = tmp_path / "x.mp4" + dummy_video.write_bytes(b"") + + class _FakeCapture: + def get(self, prop): + if prop == v2_mod.cv2.CAP_PROP_FRAME_COUNT: + return 10 + if prop == v2_mod.cv2.CAP_PROP_FPS: + return 10 + return 0 + + def release(self) -> None: + return None + + monkeypatch.setattr( + v2_mod.cv2, "VideoCapture", lambda *_args, **_kwargs: _FakeCapture() + ) + + with pytest.raises(ValueError, match="n_threads must be >= 1"): + v2_mod.video2frames_v2(dummy_video, n_threads=0) + + with pytest.raises(ValueError, match="frame_per_sec must be > 0"): + v2_mod.video2frames_v2(dummy_video, frame_per_sec=0, n_threads=1) + + assert ( + v2_mod.video2frames_v2( + dummy_video, + frame_per_sec=10, + start_sec=1.0, + end_sec=1.0, + n_threads=1, + ) + == [] + ) diff --git a/tests/vision/visualization/test_draw.py b/tests/vision/visualization/test_draw.py new file mode 100644 index 0000000..4a0b71d --- /dev/null +++ b/tests/vision/visualization/test_draw.py @@ -0,0 +1,247 @@ +import numpy as np +import pytest + +from capybara import Box, Boxes, Keypoints, KeypointsList, Polygon, Polygons +from capybara.vision.visualization import draw as draw_mod + + +def test_draw_box_denormalizes_normalized_box(): + img = np.zeros((60, 80, 3), dtype=np.uint8) + box = Box([0.1, 0.2, 0.9, 0.8], box_mode="XYXY", is_normalized=True) + out = draw_mod.draw_box(img.copy(), box, color=(0, 255, 0), thickness=1) + assert out.shape == img.shape + assert out.dtype == img.dtype + # (x1, y1) should be denormalized to a non-zero coordinate. + assert out[12, 8].sum() > 0 + + +def test_draw_boxes_supports_per_box_colors_and_thicknesses(): + img = np.zeros((40, 40, 3), dtype=np.uint8) + boxes = Boxes([[0, 0, 10, 10], [20, 20, 35, 35]], box_mode="XYXY") + out = draw_mod.draw_boxes( + img.copy(), + boxes, + colors=[(255, 0, 0), (0, 255, 0)], + thicknesses=[1, 2], + ) + assert out.shape == img.shape + assert out.sum() > 0 + + +def test_draw_polygon_and_polygons_support_fillup_and_normalized_coords(): + img = np.zeros((50, 50, 3), dtype=np.uint8) + poly = Polygon( + [(0.1, 0.1), (0.9, 0.1), (0.9, 0.9), (0.1, 0.9)], + is_normalized=True, + ) + out_edges = draw_mod.draw_polygon(img.copy(), poly, fillup=False) + out_fill = draw_mod.draw_polygon(img.copy(), poly, fillup=True) + assert out_edges.shape == img.shape + assert out_fill.shape == img.shape + assert out_edges.sum() > 0 + assert out_fill.sum() > 0 + + polys = Polygons([poly, poly.shift(0.0, -0.1)], is_normalized=True) + out_multi = draw_mod.draw_polygons( + img.copy(), + polys, + colors=[(0, 0, 255), (255, 255, 0)], + thicknesses=[1, 1], + fillup=False, + ) + assert out_multi.shape == img.shape + assert out_multi.sum() > 0 + + +def test_draw_text_draws_pixels(): + img = np.full((60, 200, 3), 255, dtype=np.uint8) + out = draw_mod.draw_text( + img.copy(), + "hello", + location=(5, 5), + color=(0, 0, 255), + text_size=18, + ) + assert out.shape == img.shape + assert out.dtype == img.dtype + assert not np.array_equal(out, img) + + +def test_draw_text_handles_fonts_without_getbbox(monkeypatch): + from PIL import ImageFont + + class _NoBBoxFont: + def __init__(self): + self._font = ImageFont.load_default() + + def getbbox(self, _text): + raise RuntimeError("boom") + + def __getattr__(self, name: str): + return getattr(self._font, name) + + monkeypatch.setattr( + draw_mod, "_load_font", lambda *_args, **_kwargs: _NoBBoxFont() + ) + + img = np.full((40, 160, 3), 255, dtype=np.uint8) + out = draw_mod.draw_text(img.copy(), "hello", location=(5, 5)) + assert out.shape == img.shape + + +@pytest.mark.parametrize("style", ["dotted", "line"]) +def test_draw_line_supports_styles(style: str): + img = np.zeros((40, 40, 3), dtype=np.uint8) + out = draw_mod.draw_line( + img, + pt1=(0, 0), + pt2=(39, 39), + color=(0, 255, 0), + thickness=2, + style=style, + gap=8, + inplace=False, + ) + assert out.shape == img.shape + assert not np.array_equal(out, img) + + +def test_draw_line_inplace_and_invalid_style(): + img = np.zeros((20, 20, 3), dtype=np.uint8) + out = draw_mod.draw_line( + img, + pt1=(0, 0), + pt2=(19, 0), + color=(255, 0, 0), + thickness=1, + style="line", + gap=5, + inplace=True, + ) + assert out is img + assert out.sum() > 0 + + with pytest.raises(ValueError, match="Unknown style"): + draw_mod.draw_line(img, (0, 0), (10, 10), style="invalid") + + +def test_draw_point_and_draw_points_preserve_grayscale_shape(): + gray = np.zeros((30, 30), dtype=np.uint8) + out = draw_mod.draw_point( + gray.copy(), + (15, 15), + scale=1.0, + color=(255, 0, 0), + thickness=-1, + ) + assert out.shape == gray.shape + assert out.ndim == 2 + assert out.sum() > 0 + + out2 = draw_mod.draw_points( + gray.copy(), + points=[(5, 5), (25, 25)], + scales=[1.0, 2.0], + colors=[(0, 255, 0), (0, 0, 255)], + thicknesses=[-1, -1], + ) + assert out2.shape == (*gray.shape, 3) + assert out2.ndim == 3 + assert out2.sum() > 0 + + +def test_draw_keypoints_and_list_support_normalized_inputs(): + img = np.zeros((60, 60, 3), dtype=np.uint8) + kpts = Keypoints([(0.25, 0.25), (0.75, 0.75)], is_normalized=True) + out = draw_mod.draw_keypoints(img.copy(), kpts, scale=1.0, thickness=-1) + assert out.shape == img.shape + assert out.sum() > 0 + + kpts_list = KeypointsList([kpts, kpts.shift(0.0, -0.1)], is_normalized=True) + out2 = draw_mod.draw_keypoints_list( + img.copy(), kpts_list, scales=[1.0, 1.5], thicknesses=[-1, -1] + ) + assert out2.shape == img.shape + assert out2.sum() > 0 + + +def test_generate_colors_and_distinct_color(): + np.random.seed(0) + tri = draw_mod.generate_colors(3, scheme="triadic") + assert len(tri) == 3 + assert all(isinstance(c, tuple) and len(c) == 3 for c in tri) + + hsv = draw_mod.generate_colors(3, scheme="hsv") + assert len(hsv) == 3 + + analogous = draw_mod.generate_colors(3, scheme="analogous") + assert len(analogous) == 3 + + square = draw_mod.generate_colors(3, scheme="square") + assert len(square) == 3 + + unknown = draw_mod.generate_colors(2, scheme="not-a-scheme") + assert unknown == [] + + assert 0 <= draw_mod._vdc(1, base=2) < 1 + c0 = draw_mod.distinct_color(0) + c1 = draw_mod.distinct_color(1) + assert isinstance(c0, tuple) and len(c0) == 3 + assert c0 != c1 + + assert draw_mod._label_to_index("123") == 123 + assert draw_mod._label_to_index("cat") == draw_mod._label_to_index("cat") + + +def test_draw_mask_normalization_and_shape_checks(): + img = np.zeros((20, 30, 3), dtype=np.uint8) + mask = np.arange(20 * 30, dtype=np.uint8).reshape(20, 30) + out = draw_mod.draw_mask(img, mask, min_max_normalize=True) + assert out.shape == img.shape + + mask_bgr = np.stack([mask] * 3, axis=-1) + out2 = draw_mod.draw_mask(img, mask_bgr, min_max_normalize=False) + assert out2.shape == img.shape + + bad_mask = np.zeros((20, 30, 2), dtype=np.uint8) + with pytest.raises(ValueError, match="Mask must be either 2D"): + draw_mod.draw_mask(img, bad_mask) + + +def test_draw_detection_and_draw_detections_end_to_end(): + img = np.zeros((80, 120, 3), dtype=np.uint8) + box = Box([0.05, 0.0, 0.5, 0.2], box_mode="XYXY", is_normalized=True) + + out = draw_mod.draw_detection( + img.copy(), + box, + label="cat", + score=0.9, + color=None, + thickness=None, + box_alpha=0.5, + text_bg_alpha=0.5, + ) + assert out.shape == img.shape + assert out.sum() > 0 + + boxes = Boxes([[0, 0, 20, 20], [30, 30, 60, 60]], box_mode="XYXY") + with pytest.raises(ValueError, match="Number of boxes must match"): + draw_mod.draw_detections(img, boxes, labels=["only-one"]) + with pytest.raises(ValueError, match="Number of scores must match"): + draw_mod.draw_detections(img, boxes, labels=["a", "b"], scores=[0.1]) + + out2 = draw_mod.draw_detections( + img.copy(), + boxes, + labels=["a", "b"], + scores=[0.1, 0.2], + colors=[(0, 255, 0), (255, 0, 0)], + thicknesses=[1, 2], + text_colors=[(255, 255, 255), (0, 0, 0)], + text_sizes=[12, 13], + box_alpha=1.0, + text_bg_alpha=0.5, + ) + assert out2.shape == img.shape + assert out2.sum() > 0 diff --git a/tests/vision/visualization/test_vis_utils.py b/tests/vision/visualization/test_vis_utils.py index 81b9af4..8524737 100644 --- a/tests/vision/visualization/test_vis_utils.py +++ b/tests/vision/visualization/test_vis_utils.py @@ -1,23 +1,47 @@ +from typing import Any + import numpy as np import pytest -from capybara.structures import (Box, Boxes, Keypoints, KeypointsList, Polygon, - Polygons) -from capybara.vision.visualization.utils import * +from capybara.structures import ( + Box, + Boxes, + Keypoints, + KeypointsList, + Polygon, + Polygons, +) +from capybara.vision.visualization.utils import ( + is_numpy_img, + prepare_box, + prepare_boxes, + prepare_color, + prepare_colors, + prepare_img, + prepare_keypoints, + prepare_keypoints_list, + prepare_point, + prepare_polygon, + prepare_polygons, + prepare_scale, + prepare_scales, + prepare_thickness, + prepare_thicknesses, +) def test_is_numpy_img(): img = np.random.random((100, 100, 3)) - assert is_numpy_img(img) == True + assert is_numpy_img(img) is True img = np.random.random((100, 100)) - assert is_numpy_img(img) == True + assert is_numpy_img(img) is True - img = np.random.random((100, )) - assert is_numpy_img(img) == False + img = np.random.random((100,)) + assert is_numpy_img(img) is False img = "not an image" - assert is_numpy_img(img) == False + assert is_numpy_img(img) is False def test_prepare_color(): @@ -33,12 +57,14 @@ def test_prepare_color(): color = 0 assert prepare_color(color) == (0, 0, 0) - color = 'black' + color: Any = "black" with pytest.raises(TypeError): prepare_color(color) - color = (0.1, 0.1, 0.1) - with pytest.raises(TypeError, match=r'[0-9a-zA-Z=,. ]+colors\[2\][0-9a-zA-Z=,. ()[\]]+'): + color: Any = (0.1, 0.1, 0.1) + with pytest.raises( + TypeError, match=r"[0-9a-zA-Z=,. ]+colors\[2\][0-9a-zA-Z=,. ()[\]]+" + ): prepare_color(color, 2) @@ -59,29 +85,38 @@ def test_prepare_colors(): length = 3 assert prepare_colors(colors, length) == [(0, 0, 0), (0, 0, 0), (0, 0, 0)] - colors = 'black' + colors: Any = "black" length = 3 with pytest.raises(TypeError): prepare_colors(colors, length) colors = [(0, 0, 0), (1, 1, 1), (2, 2, 2)] length = 2 - with pytest.raises(ValueError, match=r'The length of colors = 3 is not equal to the length = 2.'): + with pytest.raises( + ValueError, + match=r"The length of colors = 3 is not equal to the length = 2.", + ): prepare_colors(colors, length) colors = [(0, 0, 0), (1.1, 1.1, 1.1), (2, 2, 2)] length = 3 - with pytest.raises(TypeError, match=r'[0-9a-zA-Z = , . ]+colors\[1\][0-9a-zA-Z = , . ()[\]]+'): + with pytest.raises( + TypeError, + match=r"[0-9a-zA-Z = , . ]+colors\[1\][0-9a-zA-Z = , . ()[\]]+", + ): prepare_colors(colors, length) - colors = 0.1 + colors: Any = 0.1 length = 3 - with pytest.raises(TypeError, match=r'[0-9a-zA-Z = , . ]+colors\[0\][0-9a-zA-Z = , . ()[\]]+'): + with pytest.raises( + TypeError, + match=r"[0-9a-zA-Z = , . ]+colors\[0\][0-9a-zA-Z = , . ()[\]]+", + ): prepare_colors(colors, length) def test_prepare_img(): - tgt_img1 = np.random.randint(0, 255, (100, 100, 3), dtype='uint8') + tgt_img1 = np.random.randint(0, 255, (100, 100, 3), dtype="uint8") img = tgt_img1.copy() np.testing.assert_allclose(prepare_img(img), tgt_img1) @@ -93,7 +128,12 @@ def test_prepare_img(): with pytest.raises(ValueError): prepare_img(img) - img = "not an image" + img = np.random.randint(0, 255, (10, 20, 1), dtype="uint8") + out = prepare_img(img) + assert out.shape == (10, 20, 3) + np.testing.assert_allclose(out[..., 0], img[..., 0]) + + img: Any = "not an image" with pytest.raises(ValueError): prepare_img(img) @@ -111,16 +151,23 @@ def test_prepare_box(): box = np.array([0, 0, 100, 100]) assert prepare_box(box) == tgt_box - box = 0 - with pytest.raises(ValueError, match=r"[0-9a-zA-Z=,. ]+0[0-9a-zA-Z=,. ()[\]]+"): + box: Any = 0 + with pytest.raises( + ValueError, match=r"[0-9a-zA-Z=,. ]+0[0-9a-zA-Z=,. ()[\]]+" + ): prepare_box(box) box = (0, 0, 100) - with pytest.raises(ValueError, match=r"[0-9a-zA-Z=,. ]+\(0, 0, 100\)[0-9a-zA-Z=,. ()[\]']+"): + with pytest.raises( + ValueError, match=r"[0-9a-zA-Z=,. ]+\(0, 0, 100\)[0-9a-zA-Z=,. ()[\]']+" + ): prepare_box(box) box = (0, 0, 100, 100, 100) - with pytest.raises(ValueError, match=r"[0-9a-zA-Z=,. ]+\(0, 0, 100, 100, 100\)[0-9a-zA-Z=,. ()[\]']+"): + with pytest.raises( + ValueError, + match=r"[0-9a-zA-Z=,. ]+\(0, 0, 100, 100, 100\)[0-9a-zA-Z=,. ()[\]']+", + ): prepare_box(box) @@ -134,8 +181,13 @@ def test_prepare_boxes(): np_boxes = np.array(boxes_list) assert prepare_boxes(np_boxes) == boxes - boxes = [(0, 1), ] - with pytest.raises(ValueError, match=r"[0-9a-zA-Z=,. ]+boxes\[0\][0-9a-zA-Z=,. ]+\(0, 1\)[0-9a-zA-Z=,. ()[\]']+"): + boxes = [ + (0, 1), + ] + with pytest.raises( + ValueError, + match=r"[0-9a-zA-Z=,. ]+boxes\[0\][0-9a-zA-Z=,. ]+\(0, 1\)[0-9a-zA-Z=,. ()[\]']+", + ): prepare_boxes(boxes) @@ -150,8 +202,11 @@ def test_prepare_keypoints(): assert prepare_keypoints(tgt_keypoints) == tgt_keypoints - keypoints = [[0, 1], [2, 1], [3, 1]] - with pytest.raises(TypeError, match=r"[0-9a-zA-Z=,. ]+\[\[0, 1\], \[2, 1\], \[3, 1\]\][0-9a-zA-Z=,. ()[\]]+"): + keypoints: Any = [[0, 1], [2, 1], [3, 1]] + with pytest.raises( + TypeError, + match=r"[0-9a-zA-Z=,. ]+\[\[0, 1\], \[2, 1\], \[3, 1\]\][0-9a-zA-Z=,. ()[\]]+", + ): prepare_keypoints(keypoints) @@ -177,7 +232,10 @@ def test_prepare_keypoints_list(): [[0, 1], [1, 2], [3, 4]], [(0, 1), (1, 2), (3, 3)], ] - with pytest.raises(TypeError, match=r"[0-9a-zA-Z=,. ]+keypoints_list\[0\][0-9a-zA-Z=,. ()[\]]+"): + with pytest.raises( + TypeError, + match=r"[0-9a-zA-Z=,. ]+keypoints_list\[0\][0-9a-zA-Z=,. ()[\]]+", + ): prepare_keypoints_list(keypoints_list) @@ -191,7 +249,10 @@ def test_prepare_polygon(): assert prepare_polygon(np_polygon) == tgt_polygon polygon = [[0, 1], [1, 2, 3]] - with pytest.raises(TypeError, match=r"[0-9a-zA-Z=,. ]+\[\[0, 1\], \[1, 2, 3\]\][0-9a-zA-Z=,. ()[\]]+"): + with pytest.raises( + TypeError, + match=r"[0-9a-zA-Z=,. ]+\[\[0, 1\], \[1, 2, 3\]\][0-9a-zA-Z=,. ()[\]]+", + ): prepare_polygon(polygon) @@ -199,7 +260,14 @@ def test_prepare_polygons(): tgt_polygons = Polygons( [ [[0, 1], [1, 2], [3, 4]], - [[1, 2,], [3, 3], [5, 5]], + [ + [ + 1, + 2, + ], + [3, 3], + [5, 5], + ], ] ) polygons = [ @@ -215,7 +283,9 @@ def test_prepare_polygons(): [[0, 1], [1, 2], [3, 4]], [[1, 2], [3, 3], [5, 5, 3]], ] - with pytest.raises(TypeError, match=r"[0-9a-zA-Z=,. ]+polygons\[1\][0-9a-zA-Z=,. ()[\]]+"): + with pytest.raises( + TypeError, match=r"[0-9a-zA-Z=,. ]+polygons\[1\][0-9a-zA-Z=,. ()[\]]+" + ): prepare_polygons(polygons) @@ -263,3 +333,15 @@ def test_prepare_scales(): scales = 1.2 assert prepare_scales(scales, 2) == [1.2, 1.2] + + +def test_prepare_point_and_numeric_validations_cover_error_paths(): + point: Any = (1,) + with pytest.raises(TypeError, match=r"points\[0\]"): + prepare_point(point, ind=0) + + with pytest.raises(ValueError, match=r"thickness\[s\[1\]\]"): + prepare_thickness(-2, ind=1) + + with pytest.raises(ValueError, match=r"scale\[s\[1\]\]"): + prepare_scale(-2, ind=1)