From a327519a2aed07ab918f10419479806751d3634a Mon Sep 17 00:00:00 2001 From: bittergreen Date: Mon, 26 Jan 2026 21:09:28 +0800 Subject: [PATCH 1/4] feat: Add NLI model server/client code for fast detection of duplication/conflict. --- src/memos/extras/__init__.py | 0 src/memos/extras/nli_model/__init__.py | 0 src/memos/extras/nli_model/client.py | 61 +++++++ src/memos/extras/nli_model/server/README.md | 68 +++++++ src/memos/extras/nli_model/server/__init__.py | 0 src/memos/extras/nli_model/server/config.py | 23 +++ src/memos/extras/nli_model/server/handler.py | 168 ++++++++++++++++++ src/memos/extras/nli_model/server/serve.py | 44 +++++ src/memos/extras/nli_model/types.py | 18 ++ tests/extras/__init__.py | 0 tests/extras/nli_model/__init__.py | 0 11 files changed, 382 insertions(+) create mode 100644 src/memos/extras/__init__.py create mode 100644 src/memos/extras/nli_model/__init__.py create mode 100644 src/memos/extras/nli_model/client.py create mode 100644 src/memos/extras/nli_model/server/README.md create mode 100644 src/memos/extras/nli_model/server/__init__.py create mode 100644 src/memos/extras/nli_model/server/config.py create mode 100644 src/memos/extras/nli_model/server/handler.py create mode 100644 src/memos/extras/nli_model/server/serve.py create mode 100644 src/memos/extras/nli_model/types.py create mode 100644 tests/extras/__init__.py create mode 100644 tests/extras/nli_model/__init__.py diff --git a/src/memos/extras/__init__.py b/src/memos/extras/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/memos/extras/nli_model/__init__.py b/src/memos/extras/nli_model/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/memos/extras/nli_model/client.py b/src/memos/extras/nli_model/client.py new file mode 100644 index 000000000..a02dae9f6 --- /dev/null +++ b/src/memos/extras/nli_model/client.py @@ -0,0 +1,61 @@ +import logging + +import requests + +from memos.extras.nli_model.types import NLIResult + + +logger = logging.getLogger(__name__) + + +class NLIClient: + """ + Client for interacting with the deployed NLI model service. + """ + + def __init__(self, base_url: str = "http://localhost:32532"): + self.base_url = base_url.rstrip("/") + self.session = requests.Session() + + def compare_one_to_many(self, source: str, targets: list[str]) -> list[NLIResult]: + """ + Compare one source text against multiple target memories using the NLI service. + + Args: + source: The new memory content. + targets: List of existing memory contents to compare against. + + Returns: + List of NLIResult corresponding to each target. + """ + if not targets: + return [] + + url = f"{self.base_url}/compare_one_to_many" + # Match schemas.CompareRequest + payload = {"source": source, "targets": targets} + + try: + response = self.session.post(url, json=payload, timeout=30) + response.raise_for_status() + data = response.json() + + # Match schemas.CompareResponse + results_str = data.get("results", []) + + results = [] + for res_str in results_str: + try: + results.append(NLIResult(res_str)) + except ValueError: + logger.warning( + f"[NLIClient] Unknown result: {res_str}, defaulting to UNRELATED" + ) + results.append(NLIResult.UNRELATED) + + return results + + except requests.RequestException as e: + logger.error(f"[NLIClient] Request failed: {e}") + # Fallback: if NLI fails, assume all are Unrelated to avoid blocking the flow. + return [NLIResult.UNRELATED] * len(targets) diff --git a/src/memos/extras/nli_model/server/README.md b/src/memos/extras/nli_model/server/README.md new file mode 100644 index 000000000..1dbe6142d --- /dev/null +++ b/src/memos/extras/nli_model/server/README.md @@ -0,0 +1,68 @@ +# NLI Model Server + +This directory contains the standalone server for the Natural Language Inference (NLI) model used by MemOS. + +## Prerequisites + +- Python 3.10+ +- CUDA-capable GPU (Recommended for performance) + +## Running the Server + +You can run the server using the module syntax from the project root to ensure imports work correctly. + +### 1. Basic Start +```bash +python -m memos.extras.nli_model.server.serve +``` + +### 2. Configuration +You can configure the server by editing config.py: + +- `HOST`: The host to bind to (default: `0.0.0.0`) +- `PORT`: The port to bind to (default: `32532`) +- `NLI_DEVICE`: The device to run the model on. + - `cuda` (Default, uses cuda:0 if available, else fallback to mps/cpu) + - `cuda:0` (Specific GPU) + - `mps` (Apple Silicon) + - `cpu` (CPU) + +## API Usage + +### Compare One to Many +**POST** `/compare_one_to_many` + +**Request Body:** +```json +{ + "source": "I just ate an apple.", + "targets": [ + "I ate a fruit.", + "I hate apples.", + "The sky is blue." + ] +} +``` + +## Testing + +An end-to-end example script is provided to verify the server's functionality. This script starts the server locally and runs a client request to verify the NLI logic. + +### End-to-End Test + +Run the example script from the project root: + +```bash +python examples/extras/nli_e2e_example.py +``` + +**Response:** +```json +{ + "results": [ + "Duplicate", // Entailment + "Contradiction", // Contradiction + "Unrelated" // Neutral + ] +} +``` diff --git a/src/memos/extras/nli_model/server/__init__.py b/src/memos/extras/nli_model/server/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/memos/extras/nli_model/server/config.py b/src/memos/extras/nli_model/server/config.py new file mode 100644 index 000000000..d2e12175d --- /dev/null +++ b/src/memos/extras/nli_model/server/config.py @@ -0,0 +1,23 @@ +import logging + + +NLI_MODEL_NAME = "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli" + +# Configuration +# You can set the device directly here. +# Examples: +# - "cuda" : Use default GPU (cuda:0) if available, else auto-fallback +# - "cuda:0" : Use specific GPU +# - "mps" : Use Apple Silicon GPU (if available) +# - "cpu" : Use CPU +NLI_DEVICE = "cuda" +NLI_MODEL_HOST = "0.0.0.0" +NLI_MODEL_PORT = 32532 + +# Configure logging for NLI Server +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s | %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler(), logging.FileHandler("nli_server.log")], +) +logger = logging.getLogger("nli_server") diff --git a/src/memos/extras/nli_model/server/handler.py b/src/memos/extras/nli_model/server/handler.py new file mode 100644 index 000000000..eb82fa57b --- /dev/null +++ b/src/memos/extras/nli_model/server/handler.py @@ -0,0 +1,168 @@ +import re + +import torch + +from transformers import AutoModelForSequenceClassification, AutoTokenizer + +from memos.extras.nli_model.server.config import NLI_MODEL_NAME, logger +from memos.extras.nli_model.types import NLIResult + + +def _map_label_to_result(raw: str) -> NLIResult: + t = raw.lower() + if "entail" in t: + return NLIResult.DUPLICATE + if "contrad" in t or "refut" in t: + return NLIResult.CONTRADICTION + # Neutral or unknown + return NLIResult.UNRELATED + + +def _clean_temporal_markers(s: str) -> str: + # Remove temporal/aspect markers that might cause contradiction + # Chinese markers (simple replace is usually okay as they are characters) + zh_markers = ["刚刚", "曾经", "正在", "目前", "现在"] + for m in zh_markers: + s = s.replace(m, "") + + # English markers (need word boundaries to avoid "snow" -> "s") + en_markers = ["just", "once", "currently", "now"] + pattern = r"\b(" + "|".join(en_markers) + r")\b" + s = re.sub(pattern, "", s, flags=re.IGNORECASE) + + # Cleanup extra spaces + s = re.sub(r"\s+", " ", s).strip() + return s + + +class NLIHandler: + def __init__(self, device: str = "cpu", use_fp16: bool = True, use_compile: bool = True): + self.device = self._resolve_device(device) + logger.info(f"Final resolved device: {self.device}") + + # Set defaults based on device if not explicitly provided + is_cuda = "cuda" in self.device + if not is_cuda: + use_fp16 = False + use_compile = False + + self.tokenizer = AutoTokenizer.from_pretrained(NLI_MODEL_NAME) + + model_kwargs = {} + if use_fp16 and is_cuda: + model_kwargs["torch_dtype"] = torch.float16 + + self.model = AutoModelForSequenceClassification.from_pretrained( + NLI_MODEL_NAME, **model_kwargs + ).to(self.device) + self.model.eval() + + self.id2label = {int(k): v for k, v in self.model.config.id2label.items()} + self.softmax = torch.nn.Softmax(dim=-1).to(self.device) + + if use_compile and hasattr(torch, "compile"): + logger.info("Compiling model with torch.compile...") + self.model = torch.compile(self.model) + + def _resolve_device(self, device: str) -> str: + d = device.strip().lower() + + has_cuda = torch.cuda.is_available() + has_mps = torch.backends.mps.is_available() if hasattr(torch.backends, "mps") else False + + if d == "cpu": + return "cpu" + + if d.startswith("cuda"): + if has_cuda: + if d == "cuda": + return "cuda:0" + return d + + # Fallback if CUDA not available + if has_mps: + logger.warning( + f"Device '{device}' requested but CUDA not available. Fallback to MPS." + ) + return "mps" + + logger.warning( + f"Device '{device}' requested but CUDA/MPS not available. Fallback to CPU." + ) + return "cpu" + + if d == "mps": + if has_mps: + return "mps" + + logger.warning(f"Device '{device}' requested but MPS not available. Fallback to CPU.") + return "cpu" + + # Fallback / Auto-detect for other cases (e.g. "gpu" or unknown) + if has_cuda: + return "cuda:0" + if has_mps: + return "mps" + + return "cpu" + + def predict_batch(self, premises: list[str], hypotheses: list[str]) -> list[NLIResult]: + # Clean inputs + premises = [_clean_temporal_markers(p) for p in premises] + hypotheses = [_clean_temporal_markers(h) for h in hypotheses] + + # Batch tokenize with padding + inputs = self.tokenizer( + premises, hypotheses, return_tensors="pt", truncation=True, max_length=512, padding=True + ).to(self.device) + with torch.no_grad(): + out = self.model(**inputs) + probs = self.softmax(out.logits) + + results = [] + for p in probs: + idx = int(torch.argmax(p).item()) + res = self.id2label.get(idx, str(idx)) + results.append(_map_label_to_result(res)) + return results + + def compare_one_to_many(self, source: str, targets: list[str]) -> list[NLIResult]: + """ + Compare one source text against multiple target memories efficiently using batch processing. + Performs bidirectional checks (Source <-> Target) for each pair. + """ + if not targets: + return [] + + n = len(targets) + # Construct batch: + # First n pairs: Source -> Target_i + # Next n pairs: Target_i -> Source + premises = [source] * n + targets + hypotheses = targets + [source] * n + + # Run single large batch inference + raw_results = self.predict_batch(premises, hypotheses) + + # Split results back + results_ab = raw_results[:n] + results_ba = raw_results[n:] + + final_results = [] + for i in range(n): + res_ab = results_ab[i] + res_ba = results_ba[i] + + # 1. Any Contradiction -> Contradiction (Sensitive detection, filtered by LLM later) + if res_ab == NLIResult.CONTRADICTION or res_ba == NLIResult.CONTRADICTION: + final_results.append(NLIResult.CONTRADICTION) + + # 2. Any Entailment -> Duplicate (as per user requirement) + elif res_ab == NLIResult.DUPLICATE or res_ba == NLIResult.DUPLICATE: + final_results.append(NLIResult.DUPLICATE) + + # 3. Otherwise (Both Neutral) -> Unrelated + else: + final_results.append(NLIResult.UNRELATED) + + return final_results diff --git a/src/memos/extras/nli_model/server/serve.py b/src/memos/extras/nli_model/server/serve.py new file mode 100644 index 000000000..0ed9eae65 --- /dev/null +++ b/src/memos/extras/nli_model/server/serve.py @@ -0,0 +1,44 @@ +from contextlib import asynccontextmanager + +import uvicorn + +from fastapi import FastAPI, HTTPException + +from memos.extras.nli_model.server.config import NLI_DEVICE, NLI_MODEL_HOST, NLI_MODEL_PORT +from memos.extras.nli_model.server.handler import NLIHandler +from memos.extras.nli_model.types import CompareRequest, CompareResponse + + +# Global handler instance +nli_handler: NLIHandler | None = None + + +@asynccontextmanager +async def lifespan(app: FastAPI): + global nli_handler + nli_handler = NLIHandler(device=NLI_DEVICE) + yield + # Clean up if needed + nli_handler = None + + +app = FastAPI(lifespan=lifespan) + + +@app.post("/compare_one_to_many", response_model=CompareResponse) +async def compare_one_to_many(request: CompareRequest): + if nli_handler is None: + raise HTTPException(status_code=503, detail="Model not loaded") + try: + results = nli_handler.compare_one_to_many(request.source, request.targets) + return CompareResponse(results=results) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +def start_server(host: str = "0.0.0.0", port: int = 32532): + uvicorn.run(app, host=host, port=port) + + +if __name__ == "__main__": + start_server(host=NLI_MODEL_HOST, port=NLI_MODEL_PORT) diff --git a/src/memos/extras/nli_model/types.py b/src/memos/extras/nli_model/types.py new file mode 100644 index 000000000..619f8f508 --- /dev/null +++ b/src/memos/extras/nli_model/types.py @@ -0,0 +1,18 @@ +from enum import Enum + +from pydantic import BaseModel + + +class NLIResult(Enum): + DUPLICATE = "Duplicate" + CONTRADICTION = "Contradiction" + UNRELATED = "Unrelated" + + +class CompareRequest(BaseModel): + source: str + targets: list[str] + + +class CompareResponse(BaseModel): + results: list[NLIResult] diff --git a/tests/extras/__init__.py b/tests/extras/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/extras/nli_model/__init__.py b/tests/extras/nli_model/__init__.py new file mode 100644 index 000000000..e69de29bb From 2bcb3443e8020680c517a5853bd9046ae757629c Mon Sep 17 00:00:00 2001 From: bittergreen Date: Wed, 28 Jan 2026 17:10:15 +0800 Subject: [PATCH 2/4] test: Add test and example for NLI model. --- examples/extras/nli_e2e_example.py | 104 ++++++++++++++ .../nli_model/test_client_integration.py | 129 ++++++++++++++++++ 2 files changed, 233 insertions(+) create mode 100644 examples/extras/nli_e2e_example.py create mode 100644 tests/extras/nli_model/test_client_integration.py diff --git a/examples/extras/nli_e2e_example.py b/examples/extras/nli_e2e_example.py new file mode 100644 index 000000000..087cceec7 --- /dev/null +++ b/examples/extras/nli_e2e_example.py @@ -0,0 +1,104 @@ +import sys +import threading +import time + +import requests +import uvicorn + +from memos.extras.nli_model.client import NLIClient +from memos.extras.nli_model.server.serve import app + + +# Config +PORT = 32534 + + +def run_server(): + print(f"Starting server on port {PORT}...") + # Using a separate thread for the server + uvicorn.run(app, host="127.0.0.1", port=PORT, log_level="info") + + +def main(): + print("Initializing E2E Test...") + + # Start server thread + server_thread = threading.Thread(target=run_server, daemon=True) + server_thread.start() + + # Wait for server to be up + print("Waiting for server to initialize (this may take time if downloading model)...") + client = NLIClient(base_url=f"http://127.0.0.1:{PORT}") + + # Poll until server is ready + start_time = time.time() + ready = False + + # Wait up to 5 minutes for model download and initialization + timeout = 300 + + while time.time() - start_time < timeout: + try: + # Check if docs endpoint is accessible + resp = requests.get(f"http://127.0.0.1:{PORT}/docs", timeout=1) + if resp.status_code == 200: + ready = True + break + except requests.ConnectionError: + pass + except Exception: + # Ignore other errors during startup + pass + + time.sleep(2) + print(".", end="", flush=True) + + print("\n") + if not ready: + print("Server failed to start in time.") + sys.exit(1) + + print("Server is up! Sending request...") + + # Test Data + source = "I like apples" + targets = ["I like apples", "I hate apples", "Paris is a city"] + + try: + results = client.compare_one_to_many(source, targets) + print("-" * 30) + print(f"Source: {source}") + print("Targets & Results:") + for t, r in zip(targets, results, strict=False): + print(f" - '{t}': {r.value}") + print("-" * 30) + + # Basic Validation + passed = True + if results[0].value != "Duplicate": + print(f"FAILURE: Expected Duplicate for '{targets[0]}', got {results[0].value}") + passed = False + + if results[1].value != "Contradiction": + print(f"FAILURE: Expected Contradiction for '{targets[1]}', got {results[1].value}") + passed = False + + if results[2].value != "Unrelated": + print(f"FAILURE: Expected Unrelated for '{targets[2]}', got {results[2].value}") + passed = False + + if passed: + print("\nSUCCESS: Logic verification passed!") + else: + print("\nFAILURE: Unexpected results!") + + except Exception as e: + print(f"Error during request: {e}") + sys.exit(1) + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("\nTest interrupted.") diff --git a/tests/extras/nli_model/test_client_integration.py b/tests/extras/nli_model/test_client_integration.py new file mode 100644 index 000000000..5beff14a0 --- /dev/null +++ b/tests/extras/nli_model/test_client_integration.py @@ -0,0 +1,129 @@ +import threading +import time +import unittest + +from unittest.mock import MagicMock, patch + +import requests +import uvicorn + +from memos.extras.nli_model.client import NLIClient +from memos.extras.nli_model.server.serve import app +from memos.extras.nli_model.types import NLIResult + + +# We need to mock the NLIHandler to avoid loading the heavy model +# but we want to run the real FastAPI server. +class TestNLIClientIntegration(unittest.TestCase): + server_thread = None + stop_server = False + port = 32533 # Use a different port for testing + + @classmethod + def setUpClass(cls): + # Patch the lifespan to inject a mock handler instead of real NLIHandler + cls.mock_handler = MagicMock() + cls.mock_handler.compare_one_to_many.return_value = [ + NLIResult.DUPLICATE, + NLIResult.CONTRADICTION, + ] + + # We need to patch the module where lifespan is defined/used or modify the global variable + # Since 'app' is already imported, we can patch the global nli_handler in serve.py + # But lifespan sets it on startup. + + # Let's patch NLIHandler class in serve.py so when lifespan instantiates it, it gets our mock + cls.handler_patcher = patch("memos.extras.nli_model.server.serve.NLIHandler") + cls.MockHandlerClass = cls.handler_patcher.start() + cls.MockHandlerClass.return_value = cls.mock_handler + + # Start server in a thread + def run_server(): + # Disable logs for uvicorn to keep test output clean + config = uvicorn.Config(app, host="127.0.0.1", port=cls.port, log_level="error") + cls.server = uvicorn.Server(config) + cls.server.run() + + cls.server_thread = threading.Thread(target=run_server, daemon=True) + cls.server_thread.start() + + # Wait for server to be ready + cls._wait_for_server() + + @classmethod + def tearDownClass(cls): + # Stop the server + if hasattr(cls, "server"): + cls.server.should_exit = True + if cls.server_thread: + cls.server_thread.join(timeout=5) + + cls.handler_patcher.stop() + + @classmethod + def _wait_for_server(cls): + url = f"http://127.0.0.1:{cls.port}/docs" + retries = 20 + for _ in range(retries): + try: + response = requests.get(url) + if response.status_code == 200: + return + except requests.ConnectionError: + pass + time.sleep(0.1) + raise RuntimeError("Server failed to start") + + def setUp(self): + self.client = NLIClient(base_url=f"http://127.0.0.1:{self.port}") + # Reset mock calls before each test + self.mock_handler.reset_mock() + # Ensure default behavior + self.mock_handler.compare_one_to_many.return_value = [ + NLIResult.DUPLICATE, + NLIResult.CONTRADICTION, + ] + + def test_real_server_compare_one_to_many(self): + source = "I like apples." + targets = ["I love fruit.", "I hate apples."] + + results = self.client.compare_one_to_many(source, targets) + + # Verify result + self.assertEqual(len(results), 2) + self.assertEqual(results[0], NLIResult.DUPLICATE) + self.assertEqual(results[1], NLIResult.CONTRADICTION) + + # Verify server received the request + self.mock_handler.compare_one_to_many.assert_called_once() + args, _ = self.mock_handler.compare_one_to_many.call_args + self.assertEqual(args[0], source) + self.assertEqual(args[1], targets) + + def test_real_server_empty_targets(self): + source = "I like apples." + targets = [] + + results = self.client.compare_one_to_many(source, targets) + + self.assertEqual(results, []) + # Should not call handler because client handles empty list + self.mock_handler.compare_one_to_many.assert_not_called() + + def test_real_server_handler_error(self): + # Simulate handler error + self.mock_handler.compare_one_to_many.side_effect = ValueError("Something went wrong") + + source = "I like apples." + targets = ["I love fruit."] + + # Client should catch 500 and return UNRELATED + results = self.client.compare_one_to_many(source, targets) + + self.assertEqual(len(results), 1) + self.assertEqual(results[0], NLIResult.UNRELATED) + + +if __name__ == "__main__": + unittest.main() From 1d06424902110a8131fc7ca8422677d0a7f4f69c Mon Sep 17 00:00:00 2001 From: bittergreen Date: Wed, 28 Jan 2026 17:21:32 +0800 Subject: [PATCH 3/4] feat: Init the NLI client when the server initializes. --- src/memos/api/config.py | 7 +++++++ src/memos/api/handlers/component_init.py | 5 +++++ src/memos/api/handlers/config_builders.py | 10 ++++++++++ 3 files changed, 22 insertions(+) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index fb6e5e35e..7da17d23f 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -538,6 +538,13 @@ def get_internet_config() -> dict[str, Any]: }, } + @staticmethod + def get_nli_config() -> dict[str, Any]: + """Get NLI model configuration.""" + return { + "base_url": os.getenv("NLI_MODEL_BASE_URL", "http://localhost:32532"), + } + @staticmethod def get_neo4j_community_config(user_id: str | None = None) -> dict[str, Any]: """Get Neo4j community configuration.""" diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index 008957bad..13dd92189 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -18,6 +18,7 @@ build_internet_retriever_config, build_llm_config, build_mem_reader_config, + build_nli_client_config, build_pref_adder_config, build_pref_extractor_config, build_pref_retriever_config, @@ -48,6 +49,7 @@ if TYPE_CHECKING: from memos.memories.textual.tree import TreeTextMemory +from memos.extras.nli_model.client import NLIClient from memos.mem_agent.deepsearch_agent import DeepSearchMemAgent from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( InternetRetrieverFactory, @@ -161,6 +163,7 @@ def init_server() -> dict[str, Any]: llm_config = build_llm_config() chat_llm_config = build_chat_llm_config() embedder_config = build_embedder_config() + nli_client_config = build_nli_client_config() mem_reader_config = build_mem_reader_config() reranker_config = build_reranker_config() feedback_reranker_config = build_feedback_reranker_config() @@ -186,6 +189,7 @@ def init_server() -> dict[str, Any]: else None ) embedder = EmbedderFactory.from_config(embedder_config) + nli_client = NLIClient(base_url=nli_client_config["base_url"]) # Pass graph_db to mem_reader for recall operations (deduplication, conflict detection) mem_reader = MemReaderFactory.from_config(mem_reader_config, graph_db=graph_db) reranker = RerankerFactory.from_config(reranker_config) @@ -388,4 +392,5 @@ def init_server() -> dict[str, Any]: "feedback_server": feedback_server, "redis_client": redis_client, "deepsearch_agent": deepsearch_agent, + "nli_client": nli_client, } diff --git a/src/memos/api/handlers/config_builders.py b/src/memos/api/handlers/config_builders.py index fce789e2a..ed673977a 100644 --- a/src/memos/api/handlers/config_builders.py +++ b/src/memos/api/handlers/config_builders.py @@ -188,3 +188,13 @@ def build_pref_retriever_config() -> dict[str, Any]: Validated retriever configuration dictionary """ return RetrieverConfigFactory.model_validate({"backend": "naive", "config": {}}) + + +def build_nli_client_config() -> dict[str, Any]: + """ + Build NLI client configuration. + + Returns: + NLI client configuration dictionary + """ + return APIConfig.get_nli_config() From 1b49395c8d3fc586981e9515f86cf5d631f2e43b Mon Sep 17 00:00:00 2001 From: bittergreen Date: Wed, 28 Jan 2026 17:38:14 +0800 Subject: [PATCH 4/4] fix: Avoid direct import of torch and transformers for the external NLI server. --- src/memos/extras/nli_model/server/README.md | 1 + src/memos/extras/nli_model/server/handler.py | 26 +++++++++++++++++--- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/src/memos/extras/nli_model/server/README.md b/src/memos/extras/nli_model/server/README.md index 1dbe6142d..f6886e0e4 100644 --- a/src/memos/extras/nli_model/server/README.md +++ b/src/memos/extras/nli_model/server/README.md @@ -6,6 +6,7 @@ This directory contains the standalone server for the Natural Language Inference - Python 3.10+ - CUDA-capable GPU (Recommended for performance) +- `torch` and `transformers` libraries (required for the server) ## Running the Server diff --git a/src/memos/extras/nli_model/server/handler.py b/src/memos/extras/nli_model/server/handler.py index eb82fa57b..3e98ddeb0 100644 --- a/src/memos/extras/nli_model/server/handler.py +++ b/src/memos/extras/nli_model/server/handler.py @@ -1,13 +1,15 @@ import re -import torch - -from transformers import AutoModelForSequenceClassification, AutoTokenizer - from memos.extras.nli_model.server.config import NLI_MODEL_NAME, logger from memos.extras.nli_model.types import NLIResult +# Placeholder for lazy imports +torch = None +AutoModelForSequenceClassification = None +AutoTokenizer = None + + def _map_label_to_result(raw: str) -> NLIResult: t = raw.lower() if "entail" in t: @@ -36,7 +38,23 @@ def _clean_temporal_markers(s: str) -> str: class NLIHandler: + """ + NLI Model Handler for inference. + Requires `torch` and `transformers` to be installed. + """ + def __init__(self, device: str = "cpu", use_fp16: bool = True, use_compile: bool = True): + global torch, AutoModelForSequenceClassification, AutoTokenizer + try: + import torch + + from transformers import AutoModelForSequenceClassification, AutoTokenizer + except ImportError as e: + raise ImportError( + "NLIHandler requires 'torch' and 'transformers'. " + "Please install them via 'pip install torch transformers' or use the requirements.txt." + ) from e + self.device = self._resolve_device(device) logger.info(f"Final resolved device: {self.device}")