Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions examples/extras/nli_e2e_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import sys
import threading
import time

import requests
import uvicorn

from memos.extras.nli_model.client import NLIClient
from memos.extras.nli_model.server.serve import app


# Config
PORT = 32534


def run_server():
print(f"Starting server on port {PORT}...")
# Using a separate thread for the server
uvicorn.run(app, host="127.0.0.1", port=PORT, log_level="info")


def main():
print("Initializing E2E Test...")

# Start server thread
server_thread = threading.Thread(target=run_server, daemon=True)
server_thread.start()

# Wait for server to be up
print("Waiting for server to initialize (this may take time if downloading model)...")
client = NLIClient(base_url=f"http://127.0.0.1:{PORT}")

# Poll until server is ready
start_time = time.time()
ready = False

# Wait up to 5 minutes for model download and initialization
timeout = 300

while time.time() - start_time < timeout:
try:
# Check if docs endpoint is accessible
resp = requests.get(f"http://127.0.0.1:{PORT}/docs", timeout=1)
if resp.status_code == 200:
ready = True
break
except requests.ConnectionError:
pass
except Exception:
# Ignore other errors during startup
pass

time.sleep(2)
print(".", end="", flush=True)

print("\n")
if not ready:
print("Server failed to start in time.")
sys.exit(1)

print("Server is up! Sending request...")

# Test Data
source = "I like apples"
targets = ["I like apples", "I hate apples", "Paris is a city"]

try:
results = client.compare_one_to_many(source, targets)
print("-" * 30)
print(f"Source: {source}")
print("Targets & Results:")
for t, r in zip(targets, results, strict=False):
print(f" - '{t}': {r.value}")
print("-" * 30)

# Basic Validation
passed = True
if results[0].value != "Duplicate":
print(f"FAILURE: Expected Duplicate for '{targets[0]}', got {results[0].value}")
passed = False

if results[1].value != "Contradiction":
print(f"FAILURE: Expected Contradiction for '{targets[1]}', got {results[1].value}")
passed = False

if results[2].value != "Unrelated":
print(f"FAILURE: Expected Unrelated for '{targets[2]}', got {results[2].value}")
passed = False

if passed:
print("\nSUCCESS: Logic verification passed!")
else:
print("\nFAILURE: Unexpected results!")

except Exception as e:
print(f"Error during request: {e}")
sys.exit(1)


if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\nTest interrupted.")
7 changes: 7 additions & 0 deletions src/memos/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,13 @@ def get_internet_config() -> dict[str, Any]:
},
}

@staticmethod
def get_nli_config() -> dict[str, Any]:
"""Get NLI model configuration."""
return {
"base_url": os.getenv("NLI_MODEL_BASE_URL", "http://localhost:32532"),
}

@staticmethod
def get_neo4j_community_config(user_id: str | None = None) -> dict[str, Any]:
"""Get Neo4j community configuration."""
Expand Down
5 changes: 5 additions & 0 deletions src/memos/api/handlers/component_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
build_internet_retriever_config,
build_llm_config,
build_mem_reader_config,
build_nli_client_config,
build_pref_adder_config,
build_pref_extractor_config,
build_pref_retriever_config,
Expand Down Expand Up @@ -48,6 +49,7 @@

if TYPE_CHECKING:
from memos.memories.textual.tree import TreeTextMemory
from memos.extras.nli_model.client import NLIClient
from memos.mem_agent.deepsearch_agent import DeepSearchMemAgent
from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import (
InternetRetrieverFactory,
Expand Down Expand Up @@ -161,6 +163,7 @@ def init_server() -> dict[str, Any]:
llm_config = build_llm_config()
chat_llm_config = build_chat_llm_config()
embedder_config = build_embedder_config()
nli_client_config = build_nli_client_config()
mem_reader_config = build_mem_reader_config()
reranker_config = build_reranker_config()
feedback_reranker_config = build_feedback_reranker_config()
Expand All @@ -186,6 +189,7 @@ def init_server() -> dict[str, Any]:
else None
)
embedder = EmbedderFactory.from_config(embedder_config)
nli_client = NLIClient(base_url=nli_client_config["base_url"])
# Pass graph_db to mem_reader for recall operations (deduplication, conflict detection)
mem_reader = MemReaderFactory.from_config(mem_reader_config, graph_db=graph_db)
reranker = RerankerFactory.from_config(reranker_config)
Expand Down Expand Up @@ -388,4 +392,5 @@ def init_server() -> dict[str, Any]:
"feedback_server": feedback_server,
"redis_client": redis_client,
"deepsearch_agent": deepsearch_agent,
"nli_client": nli_client,
}
10 changes: 10 additions & 0 deletions src/memos/api/handlers/config_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,13 @@ def build_pref_retriever_config() -> dict[str, Any]:
Validated retriever configuration dictionary
"""
return RetrieverConfigFactory.model_validate({"backend": "naive", "config": {}})


def build_nli_client_config() -> dict[str, Any]:
"""
Build NLI client configuration.
Returns:
NLI client configuration dictionary
"""
return APIConfig.get_nli_config()
Empty file added src/memos/extras/__init__.py
Empty file.
Empty file.
61 changes: 61 additions & 0 deletions src/memos/extras/nli_model/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import logging

import requests

from memos.extras.nli_model.types import NLIResult


logger = logging.getLogger(__name__)


class NLIClient:
"""
Client for interacting with the deployed NLI model service.
"""

def __init__(self, base_url: str = "http://localhost:32532"):
self.base_url = base_url.rstrip("/")
self.session = requests.Session()

def compare_one_to_many(self, source: str, targets: list[str]) -> list[NLIResult]:
"""
Compare one source text against multiple target memories using the NLI service.

Args:
source: The new memory content.
targets: List of existing memory contents to compare against.

Returns:
List of NLIResult corresponding to each target.
"""
if not targets:
return []

url = f"{self.base_url}/compare_one_to_many"
# Match schemas.CompareRequest
payload = {"source": source, "targets": targets}

try:
response = self.session.post(url, json=payload, timeout=30)
response.raise_for_status()
data = response.json()

# Match schemas.CompareResponse
results_str = data.get("results", [])

results = []
for res_str in results_str:
try:
results.append(NLIResult(res_str))
except ValueError:
logger.warning(
f"[NLIClient] Unknown result: {res_str}, defaulting to UNRELATED"
)
results.append(NLIResult.UNRELATED)

return results

except requests.RequestException as e:
logger.error(f"[NLIClient] Request failed: {e}")
# Fallback: if NLI fails, assume all are Unrelated to avoid blocking the flow.
return [NLIResult.UNRELATED] * len(targets)
69 changes: 69 additions & 0 deletions src/memos/extras/nli_model/server/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# NLI Model Server

This directory contains the standalone server for the Natural Language Inference (NLI) model used by MemOS.

## Prerequisites

- Python 3.10+
- CUDA-capable GPU (Recommended for performance)
- `torch` and `transformers` libraries (required for the server)

## Running the Server

You can run the server using the module syntax from the project root to ensure imports work correctly.

### 1. Basic Start
```bash
python -m memos.extras.nli_model.server.serve
```

### 2. Configuration
You can configure the server by editing config.py:

- `HOST`: The host to bind to (default: `0.0.0.0`)
- `PORT`: The port to bind to (default: `32532`)
- `NLI_DEVICE`: The device to run the model on.
- `cuda` (Default, uses cuda:0 if available, else fallback to mps/cpu)
- `cuda:0` (Specific GPU)
- `mps` (Apple Silicon)
- `cpu` (CPU)

## API Usage

### Compare One to Many
**POST** `/compare_one_to_many`

**Request Body:**
```json
{
"source": "I just ate an apple.",
"targets": [
"I ate a fruit.",
"I hate apples.",
"The sky is blue."
]
}
```

## Testing

An end-to-end example script is provided to verify the server's functionality. This script starts the server locally and runs a client request to verify the NLI logic.

### End-to-End Test

Run the example script from the project root:

```bash
python examples/extras/nli_e2e_example.py
```

**Response:**
```json
{
"results": [
"Duplicate", // Entailment
"Contradiction", // Contradiction
"Unrelated" // Neutral
]
}
```
Empty file.
23 changes: 23 additions & 0 deletions src/memos/extras/nli_model/server/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import logging


NLI_MODEL_NAME = "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli"

# Configuration
# You can set the device directly here.
# Examples:
# - "cuda" : Use default GPU (cuda:0) if available, else auto-fallback
# - "cuda:0" : Use specific GPU
# - "mps" : Use Apple Silicon GPU (if available)
# - "cpu" : Use CPU
NLI_DEVICE = "cuda"
NLI_MODEL_HOST = "0.0.0.0"
NLI_MODEL_PORT = 32532

# Configure logging for NLI Server
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(name)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler(), logging.FileHandler("nli_server.log")],
)
logger = logging.getLogger("nli_server")
Loading