Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/_test-integrations.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ jobs:
MINDEE_V2_API_KEY: ${{ secrets.MINDEE_V2_SE_TESTS_API_KEY }}
MINDEE_V2_FINDOC_MODEL_ID: ${{ secrets.MINDEE_V2_SE_TESTS_FINDOC_MODEL_ID }}
MINDEE_V2_SE_TESTS_BLANK_PDF_URL: ${{ secrets.MINDEE_V2_SE_TESTS_BLANK_PDF_URL }}
MINDEE_V2_SE_TESTS_SPLIT_MODEL_ID: ${{ secrets.MINDEE_V2_SE_TESTS_SPLIT_MODEL_ID }}
run: |
pytest --cov mindee -m integration

Expand Down
2 changes: 1 addition & 1 deletion mindee/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def enqueue_and_parse( # pylint: disable=too-many-locals
if poll_results.job.status == "failed":
raise MindeeError("Parsing failed for job {poll_results.job.id}")
logger.debug(
"Polling server for parsing result with job id: %s", queue_result.job.id
"Polling server for product result with job id: %s", queue_result.job.id
)
retry_counter += 1
sleep(delay_sec)
Expand Down
117 changes: 101 additions & 16 deletions mindee/client_v2.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import warnings
from time import sleep
from typing import Optional, Union
from typing import Optional, Union, Type, TypeVar

from mindee.client_mixin import ClientMixin
from mindee.error.mindee_error import MindeeError
from mindee.error.mindee_http_error_v2 import handle_error_v2
from mindee.input import UrlInputSource
from mindee.input import UrlInputSource, BaseParameters
from mindee.input.inference_parameters import InferenceParameters
from mindee.input.polling_options import PollingOptions
from mindee.input.sources.local_input_source import LocalInputSource
Expand All @@ -15,9 +16,12 @@
is_valid_post_response,
)
from mindee.parsing.v2.common_response import CommonStatus
from mindee.v2.parsing.inference.base_response import BaseResponse
from mindee.parsing.v2.inference_response import InferenceResponse
from mindee.parsing.v2.job_response import JobResponse

TypeBaseInferenceResponse = TypeVar("TypeBaseInferenceResponse", bound=BaseResponse)


class ClientV2(ClientMixin):
"""
Expand All @@ -41,20 +45,35 @@ def __init__(self, api_key: Optional[str] = None) -> None:
def enqueue_inference(
self,
input_source: Union[LocalInputSource, UrlInputSource],
params: InferenceParameters,
params: BaseParameters,
disable_redundant_warnings: bool = False,
) -> JobResponse:
"""[Deprecated] Use `enqueue` instead."""
if not disable_redundant_warnings:
warnings.warn(
"enqueue_inference is deprecated; use enqueue instead",
DeprecationWarning,
stacklevel=2,
)
return self.enqueue(input_source, params)

def enqueue(
self,
input_source: Union[LocalInputSource, UrlInputSource],
params: BaseParameters,
) -> JobResponse:
"""
Enqueues a document to a given model.

:param input_source: The document/source file to use. Can be local or remote.

:param params: Parameters to set when sending a file.
:param slug: Slug for the endpoint.

:return: A valid inference response.
"""
logger.debug("Enqueuing inference using model: %s", params.model_id)

response = self.mindee_api.req_post_inference_enqueue(
input_source=input_source, params=params
input_source=input_source, params=params, slug=params.get_enqueue_slug()
)
dict_response = response.json()

Expand All @@ -79,34 +98,57 @@ def get_job(self, job_id: str) -> JobResponse:
dict_response = response.json()
return JobResponse(dict_response)

def get_inference(self, inference_id: str) -> InferenceResponse:
def get_inference(
self,
inference_id: str,
response_type: Type[BaseResponse] = InferenceResponse,
disable_redundant_warnings: bool = False,
) -> BaseResponse:
"""[Deprecated] Use `get_result` instead."""
if not disable_redundant_warnings:
warnings.warn(
"get_inference is deprecated; use get_result instead",
DeprecationWarning,
stacklevel=2,
)
return self.get_result(inference_id, response_type)

def get_result(
self,
inference_id: str,
response_type: Type[BaseResponse] = InferenceResponse,
) -> BaseResponse:
"""
Get the result of an inference that was previously enqueued.

The inference will only be available after it has finished processing.

:param inference_id: UUID of the inference to retrieve.
:param response_type: Class of the product to instantiate.
:return: An inference response.
"""
logger.debug("Fetching inference: %s", inference_id)

response = self.mindee_api.req_get_inference(inference_id)
response = self.mindee_api.req_get_inference(
inference_id, response_type.get_result_slug()
)
if not is_valid_get_response(response):
handle_error_v2(response.json())
dict_response = response.json()
return InferenceResponse(dict_response)
return response_type(dict_response)

def enqueue_and_get_inference(
def _enqueue_and_get(
self,
input_source: Union[LocalInputSource, UrlInputSource],
params: InferenceParameters,
) -> InferenceResponse:
params: BaseParameters,
response_type: Optional[Type[BaseResponse]] = InferenceResponse,
) -> BaseResponse:
"""
Enqueues to an asynchronous endpoint and automatically polls for a response.

:param input_source: The document/source file to use. Can be local or remote.

:param params: Parameters to set when sending a file.
:param response_type: The product class to use for the response object.

:return: A valid inference response.
"""
Expand All @@ -117,9 +159,9 @@ def enqueue_and_get_inference(
params.polling_options.delay_sec,
params.polling_options.max_retries,
)
enqueue_response = self.enqueue_inference(input_source, params)
enqueue_response = self.enqueue_inference(input_source, params, True)
logger.debug(
"Successfully enqueued inference with job id: %s", enqueue_response.job.id
"Successfully enqueued document with job id: %s", enqueue_response.job.id
)
sleep(params.polling_options.initial_delay_sec)
try_counter = 0
Expand All @@ -134,8 +176,51 @@ def enqueue_and_get_inference(
f"Parsing failed for job {job_response.job.id}: {detail}"
)
if job_response.job.status == CommonStatus.PROCESSED.value:
return self.get_inference(job_response.job.id)
result = self.get_inference(
job_response.job.id, response_type or InferenceResponse, True
)
return result
try_counter += 1
sleep(params.polling_options.delay_sec)

raise MindeeError(f"Couldn't retrieve document after {try_counter + 1} tries.")

def enqueue_and_get_inference(
self,
input_source: Union[LocalInputSource, UrlInputSource],
params: InferenceParameters,
) -> InferenceResponse:
"""[Deprecated] Use `enqueue_and_get_result` instead."""
warnings.warn(
"enqueue_and_get_inference is deprecated; use enqueue_and_get_result",
DeprecationWarning,
stacklevel=2,
)
response = self._enqueue_and_get(input_source, params)
assert isinstance(response, InferenceResponse), (
f'Invalid response type "{type(response)}"'
)
return response

def enqueue_and_get_result(
self,
response_type: Type[TypeBaseInferenceResponse],
input_source: Union[LocalInputSource, UrlInputSource],
params: BaseParameters,
) -> TypeBaseInferenceResponse:
"""
Enqueues to an asynchronous endpoint and automatically polls for a response.

:param input_source: The document/source file to use. Can be local or remote.

:param params: Parameters to set when sending a file.

:param response_type: The product class to use for the response object.

:return: A valid inference response.
"""
response = self._enqueue_and_get(input_source, params, response_type)
assert isinstance(response, response_type), (
f'Invalid response type "{type(response)}"'
)
return response
4 changes: 2 additions & 2 deletions mindee/error/mindee_http_error_v2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Optional
from typing import List, Optional

from mindee.parsing.common.string_dict import StringDict
from mindee.parsing.v2 import ErrorItem, ErrorResponse
Expand All @@ -18,7 +18,7 @@ def __init__(self, response: ErrorResponse) -> None:
self.title = response.title
self.code = response.code
self.detail = response.detail
self.errors: list[ErrorItem] = response.errors
self.errors: List[ErrorItem] = response.errors
super().__init__(
f"HTTP {self.status} - {self.title} :: {self.code} - {self.detail}"
)
Expand Down
20 changes: 13 additions & 7 deletions mindee/input/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from mindee.input.local_response import LocalResponse
from mindee.input.base_parameters import BaseParameters
from mindee.input.inference_parameters import InferenceParameters
from mindee.v2.product.split.split_parameters import SplitParameters
from mindee.input.page_options import PageOptions
from mindee.input.polling_options import PollingOptions
from mindee.input.sources.base_64_input import Base64Input
Expand All @@ -11,15 +14,18 @@
from mindee.input.workflow_options import WorkflowOptions

__all__ = [
"Base64Input",
"BaseParameters",
"BytesInput",
"FileInput",
"InputType",
"InferenceParameters",
"LocalInputSource",
"UrlInputSource",
"LocalResponse",
"PageOptions",
"PathInput",
"FileInput",
"Base64Input",
"BytesInput",
"WorkflowOptions",
"PollingOptions",
"PageOptions",
"LocalResponse",
"UrlInputSource",
"SplitParameters",
"WorkflowOptions",
]
44 changes: 44 additions & 0 deletions mindee/input/base_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from abc import ABC
from dataclasses import dataclass, field
from typing import Dict, Optional, List, Union

from mindee.input.polling_options import PollingOptions


@dataclass
class BaseParameters(ABC):
"""Base class for parameters accepted by all V2 endpoints."""

_slug: str = field(init=False)
"""Slug of the endpoint."""

model_id: str
"""ID of the model, required."""
alias: Optional[str] = None
"""Use an alias to link the file to your own DB. If empty, no alias will be used."""
webhook_ids: Optional[List[str]] = None
"""IDs of webhooks to propagate the API response to."""
polling_options: Optional[PollingOptions] = None
"""Options for polling. Set only if having timeout issues."""
close_file: bool = True
"""Whether to close the file after product."""

def get_config(self) -> Dict[str, Union[str, List[str]]]:
"""
Return the parameters as a config dictionary.

:return: A dict of parameters.
"""
data: Dict[str, Union[str, List[str]]] = {
"model_id": self.model_id,
}
if self.alias is not None:
data["alias"] = self.alias
if self.webhook_ids and len(self.webhook_ids) > 0:
data["webhook_ids"] = self.webhook_ids
return data

@classmethod
def get_enqueue_slug(cls) -> str:
"""Getter for the enqueue slug."""
return cls._slug
44 changes: 29 additions & 15 deletions mindee/input/inference_parameters.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import json
from dataclasses import dataclass, asdict
from typing import List, Optional, Union
from dataclasses import dataclass, asdict, field
from typing import Dict, List, Optional, Union

from mindee.input.polling_options import PollingOptions
from mindee.input.base_parameters import BaseParameters


@dataclass
Expand Down Expand Up @@ -44,7 +44,7 @@ class DataSchemaField(StringDataClass):
guidelines: Optional[str] = None
"""Optional extraction guidelines."""
nested_fields: Optional[dict] = None
"""Subfields when type is `nested_object`. Leave empty for other types"""
"""Subfields when type is `nested_object`. Leave empty for other types."""


@dataclass
Expand Down Expand Up @@ -78,11 +78,12 @@ def __post_init__(self) -> None:


@dataclass
class InferenceParameters:
class InferenceParameters(BaseParameters):
"""Inference parameters to set when sending a file."""

model_id: str
"""ID of the model, required."""
_slug: str = field(init=False, default="inferences")
"""Slug of the endpoint."""

rag: Optional[bool] = None
"""Enhance extraction accuracy with Retrieval-Augmented Generation."""
raw_text: Optional[bool] = None
Expand All @@ -94,14 +95,6 @@ class InferenceParameters:
Boost the precision and accuracy of all extractions.
Calculate confidence scores for all fields, and fill their ``confidence`` attribute.
"""
alias: Optional[str] = None
"""Use an alias to link the file to your own DB. If empty, no alias will be used."""
webhook_ids: Optional[List[str]] = None
"""IDs of webhooks to propagate the API response to."""
polling_options: Optional[PollingOptions] = None
"""Options for polling. Set only if having timeout issues."""
close_file: bool = True
"""Whether to close the file after parsing."""
text_context: Optional[str] = None
"""
Additional text context used by the model during inference.
Expand All @@ -118,3 +111,24 @@ def __post_init__(self):
self.data_schema = DataSchema(**json.loads(self.data_schema))
elif isinstance(self.data_schema, dict):
self.data_schema = DataSchema(**self.data_schema)

def get_config(self) -> Dict[str, Union[str, List[str]]]:
"""
Return the parameters as a config dictionary.

:return: A dict of parameters.
"""
data = super().get_config()
if self.data_schema is not None:
data["data_schema"] = str(self.data_schema)
if self.rag is not None:
data["rag"] = data["rag"] = str(self.rag).lower()
if self.raw_text is not None:
data["raw_text"] = data["raw_text"] = str(self.raw_text).lower()
if self.polygon is not None:
data["polygon"] = data["polygon"] = str(self.polygon).lower()
if self.confidence is not None:
data["confidence"] = data["confidence"] = str(self.confidence).lower()
if self.text_context is not None:
data["text_context"] = self.text_context
return data
Loading