diff --git a/halligan/halligan/stages/stage1.py b/halligan/halligan/stages/stage1.py index 5b3415c..cc91c40 100644 --- a/halligan/halligan/stages/stage1.py +++ b/halligan/halligan/stages/stage1.py @@ -1,5 +1,8 @@ import re import inspect +import time +import logging +from typing import Callable import halligan.prompts as Prompts from halligan.agents import Agent @@ -7,6 +10,56 @@ from halligan.utils.constants import Stage from halligan.utils.logger import Trace +logger = logging.getLogger(__name__) + + +# Stage-specific exception types +class StageError(Exception): + pass + +class NetworkError(StageError): + pass + +class ToolInvokeError(StageError): + pass + +class ScriptSyntaxError(StageError): + pass + +class FatalStageError(StageError): + pass + + +def _is_network_error(exc: Exception) -> bool: + # heuristic: ConnectionError, TimeoutError or message + return isinstance(exc, (ConnectionError, TimeoutError)) or "network" in str(exc).lower() + + +def _run_with_retry(func: Callable, retries: int = 3, backoff: float = 0.5, retry_on=None): + retry_on = retry_on or (Exception,) + last_exc = None + for attempt in range(1, retries + 1): + try: + return func() + except Exception as e: + last_exc = e + # classify + if _is_network_error(e): + logger.warning(f"Network error on attempt {attempt}/{retries}: {e}") + else: + logger.warning(f"Error on attempt {attempt}/{retries}: {e}") + + if attempt == retries: + break + time.sleep(backoff * attempt) + + # determine proper exception type + if last_exc and _is_network_error(last_exc): + raise NetworkError(str(last_exc)) from last_exc + if isinstance(last_exc, SyntaxError): + raise ScriptSyntaxError(str(last_exc)) from last_exc + raise ToolInvokeError(str(last_exc)) from last_exc + stage = Stage.OBJECTIVE_IDENTIFICATION @@ -55,14 +108,44 @@ def objective(description: str): ) print(prompt) - # Request script from agent + # Request script from agent (with retries on network/tool errors) images = [frame.image for frame in frames] image_captions = [f"Frame {i}" for i in range(len(frames))] - response, _ = agent(prompt, images, image_captions) - script = get_script(response) + + def request_and_extract(): + resp = agent(prompt, images, image_captions) + # agent may return (response, metadata) + if isinstance(resp, tuple) or isinstance(resp, list): + resp_text = resp[0] + else: + resp_text = resp + script_text = get_script(resp_text) + if not script_text: + # ask agent to reformat or raise to trigger retry + raise ToolInvokeError("Agent returned no python script block") + return script_text + + script = _run_with_retry(request_and_extract, retries=3, backoff=0.7) print(script) - # Execute response script - exec(script, tools, {}) + # Execute response script (parse first to give clear syntax errors) + def exec_script(): + try: + # validate syntax first + import ast + ast.parse(script) + except SyntaxError as e: + # don't retry syntax errors + raise ScriptSyntaxError(str(e)) from e + + try: + exec_globals = dict(tools) + exec(script, exec_globals, {}) + except Exception as e: + # tool invocation error inside the executed script + raise ToolInvokeError(str(e)) from e + return True + + _run_with_retry(exec_script, retries=2, backoff=0.5) agent.reset() return task_objective \ No newline at end of file diff --git a/halligan/halligan/stages/stage2.py b/halligan/halligan/stages/stage2.py index 29b3aec..a7c8246 100644 --- a/halligan/halligan/stages/stage2.py +++ b/halligan/halligan/stages/stage2.py @@ -1,6 +1,8 @@ import re import ast -from typing import List +import time +import logging +from typing import List, Callable from textwrap import indent import halligan.prompts as Prompts @@ -10,6 +12,47 @@ from halligan.utils.layout import Frame, Element, get_observation from halligan.utils.logger import Trace +logger = logging.getLogger(__name__) + + +# Stage-specific exceptions (kept local for stage files) +class StageError(Exception): + pass + +class NetworkError(StageError): + pass + +class ToolInvokeError(StageError): + pass + +class ScriptSyntaxError(StageError): + pass + +def _is_network_error(exc: Exception) -> bool: + return isinstance(exc, (ConnectionError, TimeoutError)) or "network" in str(exc).lower() + + +def _run_with_retry(func: Callable, retries: int = 3, backoff: float = 0.5): + last_exc = None + for attempt in range(1, retries + 1): + try: + return func() + except Exception as e: + last_exc = e + if _is_network_error(e): + logger.warning(f"Network error attempt {attempt}/{retries}: {e}") + else: + logger.warning(f"Error attempt {attempt}/{retries}: {e}") + if attempt == retries: + break + time.sleep(backoff * attempt) + + if last_exc and _is_network_error(last_exc): + raise NetworkError(str(last_exc)) from last_exc + if isinstance(last_exc, SyntaxError): + raise ScriptSyntaxError(str(last_exc)) from last_exc + raise ToolInvokeError(str(last_exc)) from last_exc + stage = Stage.STRUCTURE_ABSTRACTION @@ -66,13 +109,39 @@ def get_script(response: str) -> list[str]: ) print(prompt) - # Request script from agent - response, _ = agent(prompt, images, image_captions) - script = get_script(response) + # Request script from agent with retries + def request_script(): + resp = agent(prompt, images, image_captions) + if isinstance(resp, (tuple, list)): + resp_text = resp[0] + else: + resp_text = resp + code = get_script(resp_text) + if not code: + raise ToolInvokeError("Agent returned no python code block for structure_abstraction") + return code + + script = _run_with_retry(request_script, retries=3, backoff=0.6) print(script) - # Execute response script - env = {} - exec(script, toolkit.dependencies, env) - env["process"](frames) + # Execute response script safely + def exec_script(): + try: + ast.parse(script) + except SyntaxError as e: + raise ScriptSyntaxError(str(e)) from e + + env = {} + try: + exec(script, toolkit.dependencies, env) + except Exception as e: + raise ToolInvokeError(str(e)) from e + + if "process" not in env or not callable(env["process"]): + raise ToolInvokeError("Agent script did not define a callable 'process(frames)'") + + env["process"](frames) + return True + + _run_with_retry(exec_script, retries=2, backoff=0.4) agent.reset() \ No newline at end of file diff --git a/halligan/halligan/stages/stage3.py b/halligan/halligan/stages/stage3.py index 441badc..04749bc 100644 --- a/halligan/halligan/stages/stage3.py +++ b/halligan/halligan/stages/stage3.py @@ -1,6 +1,9 @@ import re import ast +import time +import logging from textwrap import indent +from typing import Callable import halligan.prompts as Prompts import halligan.utils.examples as Examples @@ -12,6 +15,48 @@ from halligan.utils.vision_tools import vision_toolkits from halligan.utils.layout import Frame, get_observation +logger = logging.getLogger(__name__) + + +# local stage exceptions +class StageError(Exception): + pass + +class NetworkError(StageError): + pass + +class ToolInvokeError(StageError): + pass + +class ScriptSyntaxError(StageError): + pass + + +def _is_network_error(exc: Exception) -> bool: + return isinstance(exc, (ConnectionError, TimeoutError)) or "network" in str(exc).lower() + + +def _run_with_retry(func: Callable, retries: int = 3, backoff: float = 0.5): + last_exc = None + for attempt in range(1, retries + 1): + try: + return func() + except Exception as e: + last_exc = e + if _is_network_error(e): + logger.warning(f"Network error attempt {attempt}/{retries}: {e}") + else: + logger.warning(f"Error attempt {attempt}/{retries}: {e}") + if attempt == retries: + break + time.sleep(backoff * attempt) + + if last_exc and _is_network_error(last_exc): + raise NetworkError(str(last_exc)) from last_exc + if isinstance(last_exc, SyntaxError): + raise ScriptSyntaxError(str(last_exc)) from last_exc + raise ToolInvokeError(str(last_exc)) from last_exc + stage = Stage.SOLUTION_COMPOSITION @@ -78,26 +123,67 @@ def execute_script(script: str, dependencies: dict): ) print(prompt) - # Request script from agent + # Request script from agent (retries on network errors) + def request_script(): + resp = agent(prompt, images, image_captions) + if isinstance(resp, (tuple, list)): + resp_text = resp[0] + else: + resp_text = resp + script = get_script(resp_text) + if not script: + raise ToolInvokeError("Agent returned no python script block") + return script + try: - response, _ = agent(prompt, images, image_captions) - script = get_script(response) + script = _run_with_retry(request_script, retries=3, backoff=0.6) print(script) - execute_script(script, dependencies) - - except Exception as e: - feedback = e - - for _ in range(3): - try: - print(feedback) - response, _ = agent(f"Your code has errors, please fix it.\n{feedback}") - script = get_script(response) - print(script) - execute_script(script, dependencies) - break - except Exception as e: - feedback = e + # try executing; if execution fails due to code errors, ask agent to fix + try: + # validate syntax + ast.parse(script) + except SyntaxError as e: + raise ScriptSyntaxError(str(e)) from e + + try: + execute_script(script, dependencies) + except Exception as e: + feedback = e + # ask agent to fix up to 3 times + for attempt in range(1, 4): + logger.info(f"Asking agent to fix code (attempt {attempt}/3): {feedback}") + def request_fix(): + resp = agent(f"Your code has errors, please fix it.\n{feedback}", images, image_captions) + if isinstance(resp, (tuple, list)): + return resp[0] + return resp + + fixed_resp = _run_with_retry(request_fix, retries=2, backoff=0.5) + fixed_script = get_script(fixed_resp) + if not fixed_script: + feedback = ToolInvokeError("Agent did not return a python script when asked to fix code") + continue + + try: + ast.parse(fixed_script) + except SyntaxError as e: + feedback = ScriptSyntaxError(str(e)) + continue + + try: + execute_script(fixed_script, dependencies) + script = fixed_script + break + except Exception as e: + feedback = e + continue + else: + # attempted fixes exhausted + raise ToolInvokeError(f"Agent failed to produce a working script: {feedback}") from feedback - agent.reset() \ No newline at end of file + finally: + try: + agent.reset() + except Exception: + logger.debug("Failed to reset agent after solution_composition") \ No newline at end of file diff --git a/halligan/halligan/utils/action_tools.py b/halligan/halligan/utils/action_tools.py index 7df1b87..93e395b 100644 --- a/halligan/halligan/utils/action_tools.py +++ b/halligan/halligan/utils/action_tools.py @@ -19,56 +19,156 @@ load_dotenv() -page: Page | None = None + +# Replace raw global `page` with a proxy singleton that forwards to an +# underlying Playwright `Page` instance. This preserves the module-level +# name `page` so existing imports/calls continue to work (e.g. `page.mouse`). +class _ActionPageState: + def __init__(self) -> None: + self._page: Page | None = None + + +class _PageProxy: + """A transparent proxy to the underlying Playwright Page. + + Behaves like the real `Page` when initialized. If not initialized, + attribute access raises a clear RuntimeError. Implements __bool__ so + checks like `if page:` still work. + """ + def __getattr__(self, name: str): + page = _ACTION_PAGE_STATE._page + if page is None: + raise RuntimeError("Playwright Page has not been initialized. Call set_page(page) first.") + return getattr(page, name) + + def __bool__(self) -> bool: + return _ACTION_PAGE_STATE._page is not None + + +_ACTION_PAGE_STATE = _ActionPageState() +page = _PageProxy() def set_page(p: Page): - global page - page = p + """Set the underlying Playwright `Page` instance used by the actions. + + Keeps the original function name and signature so callers do not need + to change their code. + """ + _ACTION_PAGE_STATE._page = p + + +def _to_int(value, default=0) -> int: + try: + return int(value) + except Exception: + return int(default) + +def _clamp_coord(x: int) -> int: + try: + return max(0, int(x)) + except Exception: + return 0 -def screenshot(region: list[float] = None) -> PIL.Image.Image: - if region: - region = { - "x": region[0], "y": region[1], - "width": region[2], "height": region[3] - } - image_bytes = page.screenshot(clip=region) - return PIL.Image.open(io.BytesIO(image_bytes)).convert("RGB") + +def _normalize_region(region: list[float] | None) -> dict | None: + """Normalize a region list [x,y,w,h] into a Playwright clip dict. + + Recoverable cases: + - If region is None -> return None (full page screenshot) + - If width/height <= 0 -> return None (fallback to full page) + - Negative x/y -> clamp to 0 + - Non-numeric values -> coerce to int where possible or fallback + """ + if not region: + return None + if not isinstance(region, (list, tuple)) or len(region) < 4: + return None + + x = max(0, _to_int(region[0], 0)) + y = max(0, _to_int(region[1], 0)) + w = _to_int(region[2], 0) + h = _to_int(region[3], 0) + + if w <= 0 or h <= 0: + return None + + return {"x": x, "y": y, "width": w, "height": h} + + +def screenshot(region: list[float] = None, page: Page | None = None) -> PIL.Image.Image: + """Take a screenshot via Playwright page with defensive region handling. + + Args: + region: [x, y, w, h] or None + page: explicit Playwright Page, fallback to module page proxy + + Returns: + PIL.Image.Image + """ + page = page or _ACTION_PAGE_STATE._page + if page is None: + raise RuntimeError("No Playwright Page available for screenshot. Call set_page(page) or pass page param.") + + clip = _normalize_region(region) + try: + if clip is None: + image_bytes = page.screenshot() + else: + image_bytes = page.screenshot(clip=clip) + except Exception as e: + # Provide a clear error for unrecoverable Playwright errors + raise RuntimeError(f"Failed to take screenshot: {e}") from e + + try: + return PIL.Image.open(io.BytesIO(image_bytes)).convert("RGB") + except Exception as e: + raise RuntimeError(f"Failed to open screenshot image: {e}") from e class Choice: - def __init__(self, image: PIL.Image.Image) -> None: + def __init__(self, image: PIL.Image.Image, page: Page | None = None) -> None: self._image = image + # prefer explicit page, fallback to module proxy + self._page = page or _ACTION_PAGE_STATE._page @property def image(self) -> PIL.Image.Image: return self._image - def release(self) -> None: + def release(self, page: Page | None = None) -> None: """ (For click_and_hold) Release from holding. """ + page = page or self._page + if page is None: + raise RuntimeError("No Playwright Page available to release mouse. Call set_page(page) or pass page param.") page.mouse.up() class SelectChoice: - def __init__(self, index: int, image: PIL.Image.Image, next: Element) -> None: + def __init__(self, index: int, image: PIL.Image.Image, next: Element, page: Page | None = None) -> None: self._image = image self._index = index self._next = next + self._page = page or _ACTION_PAGE_STATE._page @property def image(self) -> PIL.Image.Image: return self._image - def select(self) -> None: + def select(self, page: Page | None = None) -> None: """ (For get_all_choices) Select this choice. """ - x = self._next.x + self._next.w // 2 - y = self._next.y + self._next.h // 2 - for _ in range(self._index): + page = page or self._page + if page is None: + raise RuntimeError("No Playwright Page available to select choice. Call set_page(page) or pass page param.") + + x = int(self._next.x + self._next.w // 2) + y = int(self._next.y + self._next.h // 2) + for _ in range(max(0, int(self._index))): page.mouse.click(x, y) @@ -80,7 +180,8 @@ def __init__( current_x: int, current_y: int, observe: Frame, - track_bounds: tuple[int, int] + track_bounds: tuple[int, int], + page: Page | None = None ) -> None: self._axis = axis self._image = image @@ -88,6 +189,7 @@ def __init__( self._current_y = current_y self._observe = observe self._track_bounds = track_bounds + self._page = page or _ACTION_PAGE_STATE._page @property def image(self) -> PIL.Image.Image: @@ -110,13 +212,19 @@ def refine(self) -> list[SlideChoice]: choices = [] for pos in range(min_bound, max_bound, step): if self._axis == "x": + page = self._page or _ACTION_PAGE_STATE._page + if page is None: + raise RuntimeError("No Playwright Page available for sliding. Pass page when creating SlideChoice.") page.mouse.move(x=pos, y=self._current_y) - image = screenshot(self._observe.region) - choice = SlideChoice(self._axis, image, pos, self._current_y, self._observe, (min_bound, max_bound)) + image = screenshot(self._observe.region, page=page) + choice = SlideChoice(self._axis, image, pos, self._current_y, self._observe, (min_bound, max_bound), page=page) else: + page = self._page or _ACTION_PAGE_STATE._page + if page is None: + raise RuntimeError("No Playwright Page available for sliding. Pass page when creating SlideChoice.") page.mouse.move(x=self._current_x, y=pos) - image = screenshot(self._observe.region) - choice = SlideChoice(self._axis, image, self._current_x, pos, self._observe, (min_bound, max_bound)) + image = screenshot(self._observe.region, page=page) + choice = SlideChoice(self._axis, image, self._current_x, pos, self._observe, (min_bound, max_bound), page=page) choices.append(choice) @@ -126,6 +234,9 @@ def release(self) -> None: """ Confirm this as the final choice and release slider. """ + page = self._page or _ACTION_PAGE_STATE._page + if page is None: + raise RuntimeError("No Playwright Page available to release slider. Pass page when creating SlideChoice.") page.mouse.move(self._current_x, self._current_y) page.mouse.up() @@ -142,6 +253,7 @@ def __init__( self._image = image self._start = start self._end = end + self._page = None @property def preview(self) -> PIL.Image.Image: @@ -163,25 +275,31 @@ def swap(self) -> None: """ Executes the swap previewed in this choice. """ - x1, y1 = self._start - x2, y2 = self._end + page = self._page or _ACTION_PAGE_STATE._page + if page is None: + raise RuntimeError("No Playwright Page available to execute swap. Assign ._page or pass page via set_page.") - # Attempt 1: click start and end - page.mouse.click(x1, y1) - page.mouse.click(x2, y2) + x1, y1 = int(self._start[0]), int(self._start[1]) + x2, y2 = int(self._end[0]), int(self._end[1]) - # Attempt 2: drag start to end - page.mouse.move(x1, y1) - page.mouse.down() - page.mouse.move(x2, y2) - page.mouse.up() + # Attempt 1: click start and end + try: + page.mouse.click(x1, y1) + page.mouse.click(x2, y2) + except Exception: + # Fallback to drag + page.mouse.move(x1, y1) + page.mouse.down() + page.mouse.move(x2, y2) + page.mouse.up() class DragChoice: - def __init__(self, image: PIL.Image.Image, start: tuple, end: tuple) -> None: + def __init__(self, image: PIL.Image.Image, start: tuple, end: tuple, page: Page | None = None) -> None: self._image = image self._start = start self._end = end + self._page = page or _ACTION_PAGE_STATE._page @property def preview(self) -> PIL.Image.Image: @@ -194,23 +312,31 @@ def drop(self) -> None: """ Confirm this as the final choice and drop here. """ - x1, y1 = self._start - x2, y2 = self._end + page = self._page or _ACTION_PAGE_STATE._page + if page is None: + raise RuntimeError("No Playwright Page available to drop. Pass page or call set_page().") + + x1, y1 = int(self._start[0]), int(self._start[1]) + x2, y2 = int(self._end[0]), int(self._end[1]) page.mouse.move(x1, y1) page.mouse.down() page.mouse.move(x2, y2) page.mouse.up() -def click(target: Union[Frame, Element]) -> None: +def click(target: Union[Frame, Element], page: Page | None = None) -> None: """ Click a UI button. """ - x, y = target.center - page.mouse.click(x, y) + page = page or _ACTION_PAGE_STATE._page + if page is None: + raise RuntimeError("No Playwright Page available to click. Call set_page(page) or pass page param.") + + x, y = target.center + page.mouse.click(_clamp_coord(x), _clamp_coord(y)) -def click_and_hold(target: Union[Frame, Element], observe: Frame): +def click_and_hold(target: Union[Frame, Element], observe: Frame, page: Page | None = None): """ Hold until release, returns observed state while holding. This action happens in real-time, do not batch process. @@ -220,27 +346,31 @@ def click_and_hold(target: Union[Frame, Element], observe: Frame): for choice in click_and_hold(...): if ask([choice.image], "ready to release?"): break """ + page = page or _ACTION_PAGE_STATE._page + if page is None: + raise RuntimeError("No Playwright Page available to click and hold. Call set_page(page) or pass page param.") + x, y = target.center region = observe.region - page.mouse.down(x, y) + page.mouse.down(_clamp_coord(x), _clamp_coord(y)) start_time = time.time() timeout = 10 while True: elapsed_time = time.time() - start_time if elapsed_time > timeout: break + image = screenshot(region, page=page) + yield Choice(image, page=page) - image = screenshot(region) - yield Choice(image) - -def get_all_choices(prev_arrow: Element, next_arrow: Element, observe: Frame) -> list[SelectChoice]: +def get_all_choices(prev_arrow: Element, next_arrow: Element, observe: Frame, page: Page | None = None) -> list[SelectChoice]: """ Cycle through all choices by clicking arrow buttons. Returns all cycled choices from frame. """ def same_as(diff: PIL.Image.Image) -> bool: - if not diff_with_first.getbbox(): return True + if not diff or not getattr(diff, "getbbox", lambda: None)(): + return True diff = diff.convert("L") total_diff = sum(diff.getdata()) max_diff = diff.size[0] * diff.size[1] * 255 @@ -249,32 +379,37 @@ def same_as(diff: PIL.Image.Image) -> bool: index = 0 region = observe.region - choices = [SelectChoice(index, screenshot(region), next_arrow)] + page = page or _ACTION_PAGE_STATE._page + if page is None: + raise RuntimeError("No Playwright Page available to get choices. Call set_page(page) or pass page param.") + + choices = [SelectChoice(index, screenshot(region, page=page), next_arrow, page=page)] next_x, next_y = next_arrow.center prev_x, prev_y = prev_arrow.center while True: - page.mouse.click(next_x, next_y) - image = screenshot(region) + page.mouse.click(_clamp_coord(next_x), _clamp_coord(next_y)) + image = screenshot(region, page=page) diff_with_first = ImageChops.difference(image, choices[0].image) diff_with_prev = ImageChops.difference(image, choices[-1].image) # Same as first means we have gone through a full cycle. - if same_as(diff_with_first): break + if same_as(diff_with_first): + break # Same as prev means it has reached the end but can't cycle back, manually do so. if same_as(diff_with_prev): - for _ in range(len(choices) - 1): - page.mouse.click(prev_x, prev_y) + for _ in range(len(choices) - 1): + page.mouse.click(_clamp_coord(prev_x), _clamp_coord(prev_y)) break index += 1 - choices.append(SelectChoice(index, image, next_arrow)) + choices.append(SelectChoice(index, image, next_arrow, page=page)) return choices -def drag(start: Element, end: Point) -> list[DragChoice]: +def drag(start: Element, end: Point, page: Page | None = None) -> list[DragChoice]: """ Drag element from start to end point. @@ -291,6 +426,10 @@ def get_mask(image: PIL.Image.Image) -> np.ndarray: mask = PIL.Image.fromarray(mask).convert('L') return mask + page = page or _ACTION_PAGE_STATE._page + if page is None: + raise RuntimeError("No Playwright Page available for drag helper. Call set_page(page) or pass page param.") + x2, y2 = end.center choices = [] step = 10 @@ -307,63 +446,83 @@ def get_mask(image: PIL.Image.Image) -> np.ndarray: start.h + margin * 2 ] mask = get_mask(start.image) - image = screenshot(region) + image = screenshot(region, page=page) image.paste(start.image, box=(margin, margin), mask=mask) - choices.append(DragChoice(image, start.center, end=(cx, cy))) + choices.append(DragChoice(image, start.center, end=(cx, cy), page=page)) return choices -def draw(path: list[Point]) -> None: +def draw(path: list[Point], page: Page | None = None) -> None: """ Draw a path following a list of points. """ if not path: return + page = page or _ACTION_PAGE_STATE._page + if page is None: + raise RuntimeError("No Playwright Page available to draw. Call set_page(page) or pass page param.") + x, y = path[0] - page.mouse.move(x, y) + page.mouse.move(int(x), int(y)) page.mouse.down() for point in path: - page.mouse.move(point.x, point.y) + page.mouse.move(int(point.x), int(point.y)) x, y = path[-1] - page.mouse.move(x, y) + page.mouse.move(int(x), int(y)) page.mouse.up() -def enter(field: Union[Frame, Element], text: str) -> None: +def enter(field: Union[Frame, Element], text: str, page: Page | None = None) -> None: """ Click on an input field and enter text. """ + page = page or _ACTION_PAGE_STATE._page + if page is None: + raise RuntimeError("No Playwright Page available to enter text. Call set_page(page) or pass page param.") + x, y = field.center - page.mouse.click(x, y) - page.keyboard.type(text) + page.mouse.click(int(x), int(y)) + page.keyboard.type(str(text)) -def point(to: Point) -> None: +def point(to: Point, page: Page | None = None) -> None: """ Click on a point on a frame. """ + page = page or _ACTION_PAGE_STATE._page + if page is None: + raise RuntimeError("No Playwright Page available to point. Call set_page(page) or pass page param.") + x, y = to.center - page.mouse.click(x, y) + page.mouse.click(int(x), int(y)) -def select(choice: Union[Frame, Element]) -> None: +def select(choice: Union[Frame, Element], page: Page | None = None) -> None: """ Select a choice. """ + page = page or _ACTION_PAGE_STATE._page + if page is None: + raise RuntimeError("No Playwright Page available to select. Call set_page(page) or pass page param.") + x, y = choice.center - page.mouse.click(x, y) + page.mouse.click(int(x), int(y)) -def slide_x(handle: Element, direction: Literal['left', 'right'], observe_frame: Frame) -> list[SlideChoice]: +def slide_x(handle: Element, direction: Literal['left', 'right'], observe_frame: Frame, page: Page | None = None) -> list[SlideChoice]: """ Drag and move slider handle left/right while observing changes in a frame. Returns: observation (list[Choice]): observation over frame while sliding. """ + page = page or _ACTION_PAGE_STATE._page + if page is None: + raise RuntimeError("No Playwright Page available to slide. Call set_page(page) or pass page param.") + track_bounds = (handle.parent.x, handle.parent.x + handle.parent.w) step_size = handle.w // 2 step = -step_size if direction == "left" else step_size @@ -371,27 +530,31 @@ def slide_x(handle: Element, direction: Literal['left', 'right'], observe_frame: choices = [] current_x = handle.x + handle.w // 2 current_y = handle.y + handle.h // 2 - page.mouse.move(current_x, current_y) + page.mouse.move(int(current_x), int(current_y)) page.mouse.down() while track_bounds[0] < current_x + handle.w < track_bounds[1]: current_x += step - page.mouse.move(current_x, current_y) - image = screenshot(observe_frame.region) + page.mouse.move(int(current_x), int(current_y)) + image = screenshot(observe_frame.region, page=page) refine_range = [track_bounds[0] + step, track_bounds[1] - step] - choice = SlideChoice("x", image, current_x, current_y, observe_frame, refine_range) + choice = SlideChoice("x", image, current_x, current_y, observe_frame, refine_range, page=page) choices.append(choice) return choices -def slide_y(handle: Element, direction: Literal['up', 'down'], observe_frame: Frame) -> list[SlideChoice]: +def slide_y(handle: Element, direction: Literal['up', 'down'], observe_frame: Frame, page: Page | None = None) -> list[SlideChoice]: """ Drag and move slider handle up/down while observing changes in a frame. Returns: observation (list[Choice]): observation over frame while sliding. """ + page = page or _ACTION_PAGE_STATE._page + if page is None: + raise RuntimeError("No Playwright Page available to slide. Call set_page(page) or pass page param.") + track_bounds = (handle.parent.y, handle.parent.y + handle.parent.h) step_size = handle.h // 2 step = -step_size if direction.lower() == "down" else step_size @@ -399,13 +562,13 @@ def slide_y(handle: Element, direction: Literal['up', 'down'], observe_frame: Fr choices = [] current_x = handle.x + handle.w // 2 current_y = handle.y + handle.h // 2 - page.mouse.move(current_x, current_y) + page.mouse.move(int(current_x), int(current_y)) page.mouse.down() while track_bounds[0] < current_y + step < track_bounds[1]: - page.mouse.move(current_x, current_y) - image = screenshot(observe_frame.region) - choice = SlideChoice("y", image, current_x, current_y, observe_frame, track_bounds) + page.mouse.move(int(current_x), int(current_y)) + image = screenshot(observe_frame.region, page=page) + choice = SlideChoice("y", image, current_x, current_y, observe_frame, track_bounds, page=page) choices.append(choice) current_y += step @@ -460,7 +623,8 @@ def explore(grid: Frame) -> list[SwapChoice]: return [choice for choices in choices_by_distance.values() for choice in choices] -dependencies = {**globals(), "__builtins__": __builtins__, "List": List} +_local_globals = {k: v for k, v in globals().items() if k not in ("_ACTION_PAGE_STATE", "page")} +dependencies = {**_local_globals, "__builtins__": __builtins__, "List": List} action_toolkits: dict[str, Toolkit] = { "DRAGGABLE": [drag, DragChoice.preview, DragChoice.drop], @@ -476,4 +640,19 @@ def explore(grid: Frame) -> list[SwapChoice]: } for action, tools in action_toolkits.items(): - action_toolkits[action] = Toolkit(tools=tools, dependencies=dependencies) \ No newline at end of file + action_toolkits[action] = Toolkit(tools=tools, dependencies=dependencies) + + +class ActionToolkitRegistry: + """Singleton registry exporting initialized action toolkits.""" + _instance: dict[str, Toolkit] | None = None + + @classmethod + def get(cls) -> dict[str, Toolkit]: + if cls._instance is None: + cls._instance = action_toolkits + return cls._instance + + +# Export default singleton instance for external callers +DEFAULT_ACTION_TOOLKITS = ActionToolkitRegistry.get() \ No newline at end of file diff --git a/halligan/halligan/utils/vision_tools.py b/halligan/halligan/utils/vision_tools.py index b0c7295..f077068 100644 --- a/halligan/halligan/utils/vision_tools.py +++ b/halligan/halligan/utils/vision_tools.py @@ -22,7 +22,39 @@ load_dotenv() -agent = GPTAgent(api_key=os.getenv("OPENAI_API_KEY")) +_AGENT_API_KEY = os.getenv("OPENAI_API_KEY") + + +# Replace raw global `agent` with a proxy singleton so callers that use +# `agent(prompt, images, captions)` or call `agent.reset()` continue to +# work without changing call sites. Create the GPTAgent instance unconditionally +# (matching original behavior where an agent object existed regardless of API key). +class _AgentState: + def __init__(self) -> None: + # Initialize GPTAgent even if API key is None to preserve prior behavior + try: + self._agent = GPTAgent(api_key=_AGENT_API_KEY) + except Exception: + # If GPTAgent construction fails, keep None and let callers handle + self._agent = None + + +class _AgentProxy: + def __call__(self, *args, **kwargs): + agent = _AGENT_STATE._agent + if agent is None: + raise RuntimeError("GPTAgent not initialized. Ensure OPENAI_API_KEY is set or initialize agent before use.") + return agent(*args, **kwargs) + + def __getattr__(self, name: str): + agent = _AGENT_STATE._agent + if agent is None: + raise RuntimeError("GPTAgent not initialized. Ensure OPENAI_API_KEY is set or initialize agent before use.") + return getattr(agent, name) + + +_AGENT_STATE = _AgentState() +agent = _AgentProxy() def mark(images: list[PIL.Image.Image], object: str) -> list[PIL.Image.Image]: @@ -60,7 +92,7 @@ def focus(image: PIL.Image.Image, description: str) -> list[PIL.Image.Image]: return zoomed_regions -def ask(images: list[PIL.Image.Image], question: str, answer_type: str) -> list[Any]: +def ask(images: list[PIL.Image.Image], question: str, answer_type: str, agent_instance=None) -> list[Any]: """ Ask a question about the visual state of a batch of images. `answer_type` can be `bool`, `int`, `str`. @@ -102,9 +134,19 @@ def ask(images: list[PIL.Image.Image], question: str, answer_type: str) -> list[ f"You should follow the format `{answers_format}` to answer the question.\n" f"{hint}" ) + if not isinstance(images, list) or not all(hasattr(img, "size") for img in images): + raise ValueError("`images` must be a list of PIL.Image.Image instances") + image_captions = [f"Image {i}" for i in range(len(images))] - response, _ = agent(prompt, images, image_captions) - agent.reset() + # prefer explicit agent instance, fallback to module agent if present + _agent = agent_instance or _AGENT_STATE._agent + if _agent is None: + raise RuntimeError("GPTAgent not initialized. Provide agent_instance or set OPENAI_API_KEY.") + + response, _ = _agent(prompt, images, image_captions) + # reset if available + if hasattr(_agent, "reset"): + _agent.reset() match = re.search(answer_pattern, response) if match: matches = eval(match.group(2)) @@ -116,7 +158,7 @@ def ask(images: list[PIL.Image.Image], question: str, answer_type: str) -> list[ return matches -def rank(images: list[PIL.Image.Image], task_objective: str) -> list[str]: +def rank(images: list[PIL.Image.Image], task_objective: str, agent_instance=None) -> list[str]: """ Ranks each image in the `images` list based on the specified criteria in `task_objective`. Returns image_ids (list[int]), a list of image IDs ordered by descending rank. @@ -142,9 +184,12 @@ def traverse(node: Node): return result def get_top_rank(prompt, batch_image, batch_captions): + _agent = agent_instance or _AGENT_STATE._agent + if _agent is None: + raise RuntimeError("GPTAgent not initialized. Provide agent_instance or set OPENAI_API_KEY.") + # Get ranking - response, _ = agent(prompt, batch_image, batch_captions) - #agent.reset() + response, _ = _agent(prompt, batch_image, batch_captions) match = re.search(r'rank\((ids=)?(\[[\d, ]+\])\)', response) print(response) @@ -156,11 +201,15 @@ def get_top_rank(prompt, batch_image, batch_captions): best_node = Node(best_id) best_node.children = [batch[i] for i in ranking] - agent.reset() + if hasattr(_agent, "reset"): + _agent.reset() return best_node # To prevent agent from being overwhelmed, batch the input images + if not isinstance(images, list) or not images: + raise ValueError("`images` must be a non-empty list of PIL.Image.Image instances") + print("all images", len(images)) batch_size = 10 @@ -217,7 +266,7 @@ def get_top_rank(prompt, batch_image, batch_captions): return preorder(root) -def compare(images: list[PIL.Image.Image], task_objective: str, reference: PIL.Image.Image = None) -> list[bool]: +def compare(images: list[PIL.Image.Image], task_objective: str, reference: PIL.Image.Image = None, agent_instance=None) -> list[bool]: """ Compare each image with the `reference` image and check if it satisfies `task_objective`. Returns comparison (list[bool]), a list of True/False for each image in `images`. @@ -264,11 +313,21 @@ def compare(images: list[PIL.Image.Image], task_objective: str, reference: PIL.I f"{hint}" ) + if reference is not None and not hasattr(reference, "size"): + raise ValueError("`reference` must be a PIL.Image.Image or None") + + if not isinstance(images, list): + raise ValueError("`images` must be a list of PIL.Image.Image instances") + images = [reference] + images image_captions = ["Reference"] + [f"Item {i}" for i in range(len(images))] - response, _ = agent(prompt, images, image_captions) + _agent = agent_instance or _AGENT_STATE._agent + if _agent is None: + raise RuntimeError("GPTAgent not initialized. Provide agent_instance or set OPENAI_API_KEY.") - agent.reset() + response, _ = _agent(prompt, images, image_captions) + if hasattr(_agent, "reset"): + _agent.reset() match = re.search(answer_pattern, response) matches = eval(match.group(2)) if match else [False] * (len(images) - 1) return matches @@ -349,7 +408,18 @@ def _color_dist(c1, c2): return _moment_match() and _color_match() -dependencies = {**globals(), "__builtins__": __builtins__, "List": List} +dependencies = { + "ask": ask, + "rank": rank, + "compare": compare, + "match": match, + "Frame": Frame, + "Element": Element, + "Point": Point, + "PIL": PIL, + "__builtins__": __builtins__, + "List": List, +} vision_toolkits: dict[str, Toolkit] = { "DRAGGABLE": [ask, rank, Frame.show_keypoints, Frame.get_keypoint, Point.show_neighbours, Point.get_neighbour], @@ -363,4 +433,19 @@ def _color_dist(c1, c2): } for action, tools in vision_toolkits.items(): - vision_toolkits[action] = Toolkit(tools=tools, dependencies=dependencies) \ No newline at end of file + vision_toolkits[action] = Toolkit(tools=tools, dependencies=dependencies) + + +class VisionToolkitRegistry: + """Singleton registry exporting the initialized vision toolkits.""" + _instance: dict[str, Toolkit] | None = None + + @classmethod + def get(cls) -> dict[str, Toolkit]: + if cls._instance is None: + cls._instance = vision_toolkits + return cls._instance + + +# Export default singleton instance for external callers +DEFAULT_VISION_TOOLKITS = VisionToolkitRegistry.get() \ No newline at end of file diff --git a/halligan/tests/test_action_tools.py b/halligan/tests/test_action_tools.py new file mode 100644 index 0000000..080d03d --- /dev/null +++ b/halligan/tests/test_action_tools.py @@ -0,0 +1,246 @@ +import sys +import io +import types +import pytest +from PIL import Image + +# Prepare fake external modules before importing action_tools +# 1) dotenv +dotenv = types.ModuleType("dotenv") +dotenv.load_dotenv = lambda: None +sys.modules['dotenv'] = dotenv + +# 2) playwright.sync_api with a Page class +class FakeMouse: + def __init__(self, recorder): + self._rec = recorder + def click(self, x, y): + self._rec.append(('click', x, y)) + def move(self, x, y): + self._rec.append(('move', x, y)) + def down(self): + self._rec.append(('down',)) + def up(self): + self._rec.append(('up',)) + +class FakePage: + def __init__(self): + self._rec = [] + self.mouse = FakeMouse(self._rec) + self.keyboard = types.SimpleNamespace(type=lambda t: self._rec.append(('type', t))) + self._screenshot_sequence = None + self._screenshot_calls = 0 + + def screenshot(self, clip=None): + # Return PNG bytes of a small image; if sequence provided, return next + if self._screenshot_sequence is not None: + idx = min(self._screenshot_calls, len(self._screenshot_sequence)-1) + data = self._screenshot_sequence[idx] + self._screenshot_calls += 1 + return data + buf = io.BytesIO() + Image.new('RGB', (8,8), 'white').save(buf, format='PNG') + return buf.getvalue() + +FakeSyncApi = types.ModuleType('playwright.sync_api') +FakeSyncApi.Page = FakePage +sys.modules['playwright'] = types.ModuleType('playwright') +sys.modules['playwright.sync_api'] = FakeSyncApi + +# 3) halligan.utils.layout stub +layout_mod = types.ModuleType('halligan.utils.layout') + +class FrameStub: + def __init__(self, x=0, y=0, image=None): + self.x = x + self.y = y + self._image = image or Image.new('RGB', (100, 100), 'white') + self.w = self._image.size[0] + self.h = self._image.size[1] + self.interactables = [] + self.subframes = [] + self.interactable = None + self.description = '' + + @property + def image(self): + return self._image + + @property + def region(self): + return [self.x, self.y, self.w, self.h] + + @property + def center(self): + return (self.x + self.w // 2, self.y + self.h // 2) + + def get_interactable(self, id): + return self.interactables[id] + +class ElementStub(FrameStub): + def __init__(self, x=0, y=0, image=None, parent=None): + super().__init__(x, y, image) + self.parent = parent or FrameStub() + self.retrieved = False + +class PointStub(FrameStub): + pass + +layout_mod.Frame = FrameStub +layout_mod.Element = ElementStub +layout_mod.Point = PointStub +sys.modules['halligan.utils.layout'] = layout_mod + +# 4) halligan.utils.toolkit stub +toolkit_mod = types.ModuleType('halligan.utils.toolkit') +class ToolkitStub: + def __init__(self, tools, dependencies): + self.tools = tools + self.dependencies = dependencies +toolkit_mod.Toolkit = ToolkitStub +sys.modules['halligan.utils.toolkit'] = toolkit_mod + +# 5) halligan.utils.vision_tools minimal stub for match +vision_tools_stub = types.ModuleType('halligan.utils.vision_tools') +vision_tools_stub.match = lambda e1,e2: False +sys.modules['halligan.utils.vision_tools'] = vision_tools_stub + +# Now import the module under test +from halligan.utils import action_tools as at + +# Helper to create PNG bytes +def png_bytes(color='white', size=(8,8)): + buf = io.BytesIO() + Image.new('RGB', size, color).save(buf, format='PNG') + return buf.getvalue() + +# Tests + +def test_screenshot_fullpage_calls_page_screenshot(): + page = FakePage() + at.set_page(page) + img = at.screenshot(page=page) + assert isinstance(img, Image.Image) + + +def test_screenshot_with_valid_region_calls_with_clip(): + page = FakePage() + at.set_page(page) + img = at.screenshot(region=[10,10,5,5], page=page) + assert isinstance(img, Image.Image) + + +def test_screenshot_with_invalid_region_fallbacks_fullpage(): + page = FakePage() + at.set_page(page) + # width negative -> fallback + img = at.screenshot(region=[0,0,-5,10], page=page) + assert isinstance(img, Image.Image) + + +def test_screenshot_no_page_raises(): + # ensure no page set + at.set_page(None) + with pytest.raises(RuntimeError): + at.screenshot() + + +def test_click_uses_clamped_coords(): + page = FakePage() + at.set_page(page) + el = ElementStub(x=-20, y=5, image=Image.new('RGB',(10,10),'white')) + # override center to negative + el.center = (-5, 10) + # monkeypatch center attribute by setting property-like + el.center = (-5, 10) + # call click + at.click(el) + assert page._rec[0][0] == 'click' + assert page._rec[0][1] >= 0 and page._rec[0][2] >= 0 + + +def test_click_and_hold_yields_choice_images(): + page = FakePage() + # set sequence images to simulate frames + page._screenshot_sequence = [png_bytes('white'), png_bytes('white')] + at.set_page(page) + target = ElementStub(x=10,y=10,image=Image.new('RGB',(10,10),'white')) + obs = FrameStub(x=0,y=0,image=Image.new('RGB',(20,20),'white')) + gen = at.click_and_hold(target, obs, page=page) + choice = next(gen) + assert hasattr(choice, 'image') and isinstance(choice.image, Image.Image) + + +def test_get_all_choices_cycles_until_same(): + page = FakePage() + # create two different images then repeat the first + a = png_bytes('white') + b = png_bytes('black') + page._screenshot_sequence = [a, b, a] + at.set_page(page) + prev_arrow = ElementStub(x=0,y=0,image=Image.new('RGB',(10,10),'white')) + next_arrow = ElementStub(x=50,y=0,image=Image.new('RGB',(10,10),'white')) + obs = FrameStub(x=0,y=0,image=Image.new('RGB',(20,20),'white')) + choices = at.get_all_choices(prev_arrow, next_arrow, obs, page=page) + assert isinstance(choices, list) + assert len(choices) >= 1 + + +def test_select_choice_select_calls_mouse_click_times_index(): + page = FakePage() + at.set_page(page) + next_arrow = ElementStub(x=10,y=10,image=Image.new('RGB',(10,10),'white')) + sc = at.SelectChoice(index=3, image=Image.new('RGB',(5,5),'white'), next=next_arrow, page=page) + sc.select() + # 3 clicks recorded + clicks = [r for r in page._rec if r[0]=='click'] + assert len(clicks) == 3 + + +def test_drag_returns_choices_count9(): + page = FakePage() + at.set_page(page) + start = ElementStub(x=0,y=0,image=Image.new('RGB',(10,10),'white')) + end = PointStub(x=50,y=50,image=Image.new('RGB',(10,10),'white')) + choices = at.drag(start, end, page=page) + assert len(choices) == 9 + + +def test_drag_drop_calls_mouse_methods(): + page = FakePage() + at.set_page(page) + dc = at.DragChoice(image=Image.new('RGB',(10,10),'white'), start=(0,0), end=(20,20), page=page) + dc.drop() + # expect sequence move, down, move, up + ops = [r[0] for r in page._rec] + assert 'move' in ops and 'down' in ops and 'up' in ops + + +def test_slide_choice_refine_creates_choices(): + page = FakePage() + at.set_page(page) + obs = FrameStub(x=0,y=0,image=Image.new('RGB',(100,20),'white')) + sc = at.SlideChoice('x', Image.new('RGB',(10,10),'white'), 30, 10, obs, (0,100), page=page) + choices = sc.refine() + assert isinstance(choices, list) + assert len(choices) > 0 + + +def test_swap_choice_swap_fallbacks_to_drag_on_click_exception(): + class BadPage(FakePage): + def __init__(self): + super().__init__() + def _bad_click(self, x, y): + raise Exception('click failed') + bad = FakePage() + # monkeypatch click to raise + def raising_click(x,y): + raise Exception('click failed') + bad.mouse.click = raising_click + at.set_page(bad) + grid = [[ElementStub(x=0,y=0,image=Image.new('RGB',(10,10),'white'))]] + sw = at.SwapChoice(grid, Image.new('RGB',(10,10),'white'), (5,5), (10,10)) + sw._page = bad + # should not raise + sw.swap() + diff --git a/halligan/tests/test_vision_tools.py b/halligan/tests/test_vision_tools.py new file mode 100644 index 0000000..a32a23b --- /dev/null +++ b/halligan/tests/test_vision_tools.py @@ -0,0 +1,150 @@ +import sys +import types +import io +from PIL import Image +import pytest + +# Prepare fake dotenv +dotenv = types.ModuleType('dotenv') +dotenv.load_dotenv = lambda: None +sys.modules['dotenv'] = dotenv + +# Minimal fake halligan.agents with GPTAgent +agents_mod = types.ModuleType('halligan.agents') +class FakeAgent: + def __init__(self, api_key=None): + self._calls = [] + def __call__(self, prompt, images, captions): + self._calls.append((prompt, len(images), captions)) + return ("response", {"meta": 1}) + def reset(self): + self._calls.append(('reset',)) +agents_mod.GPTAgent = FakeAgent +sys.modules['halligan.agents'] = agents_mod + +# Minimal halligan.models.Detector +models_mod = types.ModuleType('halligan.models') +class FakeDetector: + @staticmethod + def detect(images, obj=None): + # return empty boxes for each image + return [[] for _ in images] +models_mod.Detector = FakeDetector +sys.modules['halligan.models'] = models_mod + +# Create minimal cv2 and skimage modules to allow import +cv2 = types.ModuleType('cv2') +cv2.cvtColor = lambda img, code: img +cv2.threshold = lambda gray, a, b, c: (None, gray) +cv2.findContours = lambda thresh, mode, method: ([], None) +cv2.RETR_CCOMP = 0 +cv2.CHAIN_APPROX_SIMPLE = 0 +sys.modules['cv2'] = cv2 + +skimage_color = types.ModuleType('skimage.color') +skimage_color.rgb2lab = lambda arr: arr +skimage_color.deltaE_cie76 = lambda a,b: 0 +sys.modules['skimage.color'] = skimage_color + +# halligan.utils.toolkit and layout minimal stubs +toolkit_mod = types.ModuleType('halligan.utils.toolkit') +class ToolkitStub: + def __init__(self, tools, dependencies): + self.tools = tools + self.dependencies = dependencies +toolkit_mod.Toolkit = ToolkitStub +sys.modules['halligan.utils.toolkit'] = toolkit_mod + +layout_mod = types.ModuleType('halligan.utils.layout') +class FrameStub: + def __init__(self, x=0,y=0,image=None): + self.x = x; self.y = y; self._image = image or Image.new('RGB',(10,10),'white') + self.w = self._image.size[0]; self.h = self._image.size[1] + self.interactables = [] + self.keypoints = [] + self.subframes = [] + self.interactable = None + self.description = '' + @property + def image(self): + return self._image + def get_interactable(self, id): + return None + def show_keypoints(self): + return self._image +class ElementStub(FrameStub): + pass +class PointStub(FrameStub): + pass +layout_mod.Frame = FrameStub +layout_mod.Element = ElementStub +layout_mod.Point = PointStub +sys.modules['halligan.utils.layout'] = layout_mod + +# Now import the module under test +from halligan.utils import vision_tools as vt + +# Tests + +def test_ask_with_bad_images_raises(): + with pytest.raises(ValueError): + vt.ask("notalist", "q", "bool") + + +def test_ask_with_missing_agent_raises(monkeypatch): + # Force internal agent to None + vt._AGENT_STATE._agent = None + with pytest.raises(RuntimeError): + vt.ask([Image.new('RGB',(5,5))], "q", "bool") + + +def test_ask_with_agent_calls_agent_and_reset(): + fake = agents_mod.GPTAgent() + res = vt.ask([Image.new('RGB',(5,5))], "q", "bool", agent_instance=fake) + # default behavior returns a list + assert isinstance(res, list) + + +def test_rank_with_empty_images_raises(): + with pytest.raises(ValueError): + vt.rank([], "obj") + + +def test_rank_uses_agent_and_batches(): + fake = agents_mod.GPTAgent() + # provide 3 small images + images = [Image.new('RGB',(5,5)) for _ in range(3)] + order = vt.rank(images, "rank this", agent_instance=fake) + assert isinstance(order, list) + + +def test_compare_with_invalid_reference_raises(): + with pytest.raises(ValueError): + vt.compare([Image.new('RGB',(5,5))], "obj", reference='notimage') + + +def test_compare_with_missing_agent_raises(monkeypatch): + vt._AGENT_STATE._agent = None + with pytest.raises(RuntimeError): + vt.compare([Image.new('RGB',(5,5))], "obj", reference=Image.new('RGB',(5,5))) + + +def test_compare_calls_agent_and_reset(): + fake = agents_mod.GPTAgent() + res = vt.compare([Image.new('RGB',(5,5))], "obj", reference=Image.new('RGB',(5,5)), agent_instance=fake) + assert isinstance(res, list) + + +def test_match_non_element_returns_false(): + assert vt.match(1, 2) is False + + +def test_compare_default_on_agent_response_mismatch(monkeypatch): + class DumbAgent: + def __call__(self, prompt, images, captions): + return ("no_match", {}) + def reset(self): + pass + res = vt.compare([Image.new('RGB',(5,5))], "obj", reference=Image.new('RGB',(5,5)), agent_instance=DumbAgent()) + assert isinstance(res, list) +