diff --git a/SYNC_UPSTREAM.md b/SYNC_UPSTREAM.md new file mode 100644 index 000000000..abe5cd886 --- /dev/null +++ b/SYNC_UPSTREAM.md @@ -0,0 +1,160 @@ +# Синхронизация с Upstream MemOS + +## Архитектура + +``` +┌─────────────────────────────────────────────────────────────┐ +│ MemTensor/MemOS (upstream) │ +│ Оригинал │ +└─────────────────────────┬───────────────────────────────────┘ + │ git fetch upstream + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ anatolykoptev/MemOS (fork) │ +│ ┌────────────────────┐ ┌─────────────────────────────┐ │ +│ │ src/memos/ │ │ overlays/krolik/ │ │ +│ │ (base MemOS) │ │ (auth, rate-limit, admin) │ │ +│ │ │ │ │ │ +│ │ ← syncs with │ │ ← НАШИ кастомизации │ │ +│ │ upstream │ │ (никогда не конфликтуют) │ │ +│ └────────────────────┘ └─────────────────────────────┘ │ +└─────────────────────────┬───────────────────────────────────┘ + │ Dockerfile.krolik + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ krolik-server (production) │ +│ src/memos/ + overlays merged at build │ +└─────────────────────────────────────────────────────────────┘ +``` + +## Регулярная синхронизация (еженедельно) + +```bash +cd ~/CascadeProjects/piternow_project/MemOS + +# 1. Получить изменения upstream +git fetch upstream + +# 2. Посмотреть что нового +git log --oneline upstream/main..main # Наши коммиты +git log --oneline main..upstream/main # Новое в upstream + +# 3. Merge upstream (overlays/ не затрагивается) +git checkout main +git merge upstream/main + +# 4. Если конфликты (редко, только в src/): +# - Разрешить конфликты +# - git add . +# - git commit + +# 5. Push в наш fork +git push origin main +``` + +## Обновление production (krolik-server) + +После синхронизации форка: + +```bash +cd ~/krolik-server + +# Пересобрать с новым MemOS +docker compose build --no-cache memos-api + +# Перезапустить +docker compose up -d memos-api + +# Проверить логи +docker logs -f memos-api +``` + +## Добавление новых фич в overlay + +```bash +# 1. Создать файл в overlays/krolik/ +vim overlays/krolik/api/middleware/new_feature.py + +# 2. Импортировать в server_api_ext.py +vim overlays/krolik/api/server_api_ext.py + +# 3. Commit в наш fork +git add overlays/ +git commit -m "feat(krolik): add new_feature middleware" +git push origin main +``` + +## Важные правила + +### ✅ Делать: +- Все кастомизации в `overlays/krolik/` +- Багфиксы в `src/` которые полезны upstream — создавать PR +- Регулярно синхронизировать с upstream + +### ❌ НЕ делать: +- Модифицировать файлы в `src/memos/` напрямую +- Форкать API в overlay вместо расширения +- Игнорировать обновления upstream > 2 недель + +## Структура overlays + +``` +overlays/ +└── krolik/ + └── api/ + ├── middleware/ + │ ├── __init__.py + │ ├── auth.py # API Key auth (PostgreSQL) + │ └── rate_limit.py # Redis sliding window + ├── routers/ + │ ├── __init__.py + │ └── admin_router.py # /admin/keys CRUD + ├── utils/ + │ ├── __init__.py + │ └── api_keys.py # Key generation + └── server_api_ext.py # Entry point +``` + +## Environment Variables (Krolik) + +```bash +# Authentication +AUTH_ENABLED=true +MASTER_KEY_HASH= +INTERNAL_SERVICE_SECRET= + +# Rate Limiting +RATE_LIMIT_ENABLED=true +RATE_LIMIT=100 +RATE_WINDOW_SEC=60 +REDIS_URL=redis://redis:6379 + +# PostgreSQL (for API keys) +POSTGRES_HOST=postgres +POSTGRES_PORT=5432 +POSTGRES_USER=memos +POSTGRES_PASSWORD= +POSTGRES_DB=memos + +# CORS +CORS_ORIGINS=https://krolik.hully.one,https://memos.hully.one +``` + +## Миграция из текущего krolik-server + +Текущий `krolik-server/services/memos-core/` содержит смешанный код. +После перехода на overlay pattern: + +1. **krolik-server** будет использовать `Dockerfile.krolik` из форка +2. **Локальные изменения** удаляются из krolik-server +3. **Все кастомизации** живут в `MemOS/overlays/krolik/` + +```yaml +# docker-compose.yml (krolik-server) +services: + memos-api: + build: + context: ../MemOS # Используем форк напрямую + dockerfile: docker/Dockerfile.krolik + # ... остальная конфигурация +``` diff --git a/docker/Dockerfile.krolik b/docker/Dockerfile.krolik new file mode 100644 index 000000000..c475a6d30 --- /dev/null +++ b/docker/Dockerfile.krolik @@ -0,0 +1,65 @@ +# MemOS with Krolik Security Extensions +# +# This Dockerfile builds MemOS with authentication, rate limiting, and admin API. +# It uses the overlay pattern to keep customizations separate from base code. + +FROM python:3.11-slim + +# Install system dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + gcc \ + g++ \ + build-essential \ + libffi-dev \ + python3-dev \ + curl \ + libpq-dev \ + && rm -rf /var/lib/apt/lists/* + +# Create non-root user +RUN groupadd -r memos && useradd -r -g memos -u 1000 memos + +WORKDIR /app + +# Use official Hugging Face +ENV HF_ENDPOINT=https://huggingface.co + +# Copy base MemOS source +COPY src/ ./src/ +COPY pyproject.toml ./ + +# Install base dependencies +RUN pip install --upgrade pip && \ + pip install --no-cache-dir poetry && \ + poetry config virtualenvs.create false && \ + poetry install --no-dev --extras "tree-mem mem-scheduler" + +# Install additional dependencies for Krolik +RUN pip install --no-cache-dir \ + sentence-transformers \ + torch \ + transformers \ + psycopg2-binary \ + redis + +# Apply Krolik overlay (AFTER base install to allow easy updates) +COPY overlays/krolik/ ./src/memos/ + +# Create data directory +RUN mkdir -p /data/memos && chown -R memos:memos /data/memos +RUN chown -R memos:memos /app + +# Set Python path +ENV PYTHONPATH=/app/src + +# Switch to non-root user +USER memos + +EXPOSE 8000 + +# Healthcheck +HEALTHCHECK --interval=30s --timeout=10s --retries=3 --start-period=60s \ + CMD curl -f http://localhost:8000/health || exit 1 + +# Use extended entry point with security features +CMD ["gunicorn", "memos.api.server_api_ext:app", "--preload", "-w", "2", "-k", "uvicorn.workers.UvicornWorker", "--bind", "0.0.0.0:8000", "--timeout", "120"] diff --git a/overlays/README.md b/overlays/README.md new file mode 100644 index 000000000..805821018 --- /dev/null +++ b/overlays/README.md @@ -0,0 +1,86 @@ +# MemOS Overlays + +Overlays are deployment-specific customizations that extend the base MemOS without modifying core files. + +## Structure + +``` +overlays/ +└── krolik/ # Deployment name + └── api/ + ├── middleware/ + │ ├── __init__.py + │ ├── auth.py # API Key authentication + │ └── rate_limit.py # Redis rate limiting + ├── routers/ + │ ├── __init__.py + │ └── admin_router.py # API key management + ├── utils/ + │ ├── __init__.py + │ └── api_keys.py # Key generation utilities + └── server_api_ext.py # Extended entry point +``` + +## How It Works + +1. **Base MemOS** provides core functionality (memory operations, embeddings, etc.) +2. **Overlays** add deployment-specific features without modifying base files +3. **Dockerfile** merges overlays on top of base during build + +## Dockerfile Usage + +```dockerfile +# Clone base MemOS +RUN git clone --depth 1 https://github.com/anatolykoptev/MemOS.git /app + +# Install base dependencies +RUN pip install -r /app/requirements.txt + +# Apply overlay (copies files into src/memos/) +RUN cp -r /app/overlays/krolik/* /app/src/memos/ + +# Use extended entry point +CMD ["gunicorn", "memos.api.server_api_ext:app", ...] +``` + +## Syncing with Upstream + +```bash +# 1. Fetch upstream changes +git fetch upstream + +# 2. Merge upstream into main (preserves overlays) +git merge upstream/main + +# 3. Resolve conflicts if any (usually none in overlays/) +git status + +# 4. Push to fork +git push origin main +``` + +## Adding New Overlays + +1. Create directory: `overlays//` +2. Add customizations following the same structure +3. Create `server_api_ext.py` as entry point +4. Update Dockerfile to use the new overlay + +## Security Features (krolik overlay) + +### API Key Authentication +- SHA-256 hashed keys stored in PostgreSQL +- Master key for admin operations +- Scoped permissions (read, write, admin) +- Internal service bypass for container-to-container + +### Rate Limiting +- Redis-based sliding window algorithm +- In-memory fallback for development +- Per-key or per-IP limiting +- Configurable via environment variables + +### Admin API +- `/admin/keys` - Create, list, revoke API keys +- `/admin/health` - Auth system status +- Protected by admin scope or master key diff --git a/overlays/krolik/api/middleware/__init__.py b/overlays/krolik/api/middleware/__init__.py new file mode 100644 index 000000000..64cbc5c60 --- /dev/null +++ b/overlays/krolik/api/middleware/__init__.py @@ -0,0 +1,13 @@ +"""Krolik middleware extensions for MemOS.""" + +from .auth import verify_api_key, require_scope, require_admin, require_read, require_write +from .rate_limit import RateLimitMiddleware + +__all__ = [ + "verify_api_key", + "require_scope", + "require_admin", + "require_read", + "require_write", + "RateLimitMiddleware", +] diff --git a/overlays/krolik/api/middleware/auth.py b/overlays/krolik/api/middleware/auth.py new file mode 100644 index 000000000..30349c9c4 --- /dev/null +++ b/overlays/krolik/api/middleware/auth.py @@ -0,0 +1,268 @@ +""" +API Key Authentication Middleware for MemOS. + +Validates API keys and extracts user context for downstream handlers. +Keys are validated against SHA-256 hashes stored in PostgreSQL. +""" + +import hashlib +import os +import time +from typing import Any + +from fastapi import Depends, HTTPException, Request, Security +from fastapi.security import APIKeyHeader + +import memos.log + +logger = memos.log.get_logger(__name__) + +# API key header configuration +API_KEY_HEADER = APIKeyHeader(name="Authorization", auto_error=False) + +# Environment configuration +AUTH_ENABLED = os.getenv("AUTH_ENABLED", "false").lower() == "true" +MASTER_KEY_HASH = os.getenv("MASTER_KEY_HASH") # SHA-256 hash of master key +INTERNAL_SERVICE_IPS = {"127.0.0.1", "::1", "memos-mcp", "moltbot", "clawdbot"} + +# Connection pool for auth queries (lazy init) +_auth_pool = None + + +def _get_auth_pool(): + """Get or create auth database connection pool.""" + global _auth_pool + if _auth_pool is not None: + return _auth_pool + + try: + import psycopg2.pool + + _auth_pool = psycopg2.pool.ThreadedConnectionPool( + minconn=1, + maxconn=5, + host=os.getenv("POSTGRES_HOST", "postgres"), + port=int(os.getenv("POSTGRES_PORT", "5432")), + user=os.getenv("POSTGRES_USER", "memos"), + password=os.getenv("POSTGRES_PASSWORD", ""), + dbname=os.getenv("POSTGRES_DB", "memos"), + connect_timeout=10, + ) + logger.info("Auth database pool initialized") + return _auth_pool + except Exception as e: + logger.error(f"Failed to initialize auth pool: {e}") + return None + + +def hash_api_key(key: str) -> str: + """Hash an API key using SHA-256.""" + return hashlib.sha256(key.encode()).hexdigest() + + +def validate_key_format(key: str) -> bool: + """Validate API key format: krlk_<64-hex>.""" + if not key or not key.startswith("krlk_"): + return False + hex_part = key[5:] # Remove 'krlk_' prefix + if len(hex_part) != 64: + return False + try: + int(hex_part, 16) + return True + except ValueError: + return False + + +def get_key_prefix(key: str) -> str: + """Extract prefix for key identification (first 12 chars).""" + return key[:12] if len(key) >= 12 else key + + +async def lookup_api_key(key_hash: str) -> dict[str, Any] | None: + """ + Look up API key in database. + + Returns dict with user_name, scopes, etc. or None if not found. + """ + pool = _get_auth_pool() + if not pool: + logger.warning("Auth pool not available, cannot validate key") + return None + + conn = None + try: + conn = pool.getconn() + with conn.cursor() as cur: + cur.execute( + """ + SELECT id, user_name, scopes, expires_at, is_active + FROM api_keys + WHERE key_hash = %s + """, + (key_hash,), + ) + row = cur.fetchone() + + if not row: + return None + + key_id, user_name, scopes, expires_at, is_active = row + + # Check if key is active + if not is_active: + logger.warning(f"Inactive API key used: {key_hash[:16]}...") + return None + + # Check expiration + if expires_at and expires_at < time.time(): + logger.warning(f"Expired API key used: {key_hash[:16]}...") + return None + + # Update last_used_at + cur.execute( + "UPDATE api_keys SET last_used_at = NOW() WHERE id = %s", + (key_id,), + ) + conn.commit() + + return { + "id": str(key_id), + "user_name": user_name, + "scopes": scopes or ["read"], + } + except Exception as e: + logger.error(f"Database error during key lookup: {e}") + return None + finally: + if conn and pool: + pool.putconn(conn) + + +def is_internal_request(request: Request) -> bool: + """Check if request is from internal service.""" + client_host = request.client.host if request.client else None + + # Check internal IPs + if client_host in INTERNAL_SERVICE_IPS: + return True + + # Check internal header (for container-to-container) + internal_header = request.headers.get("X-Internal-Service") + if internal_header == os.getenv("INTERNAL_SERVICE_SECRET"): + return True + + return False + + +async def verify_api_key( + request: Request, + api_key: str | None = Security(API_KEY_HEADER), +) -> dict[str, Any]: + """ + Verify API key and return user context. + + This is the main dependency for protected endpoints. + + Returns: + dict with user_name, scopes, and is_master_key flag + + Raises: + HTTPException 401 if authentication fails + """ + # Skip auth if disabled + if not AUTH_ENABLED: + return { + "user_name": request.headers.get("X-User-Name", "default"), + "scopes": ["all"], + "is_master_key": False, + "auth_bypassed": True, + } + + # Allow internal services + if is_internal_request(request): + logger.debug(f"Internal request from {request.client.host}") + return { + "user_name": "internal", + "scopes": ["all"], + "is_master_key": False, + "is_internal": True, + } + + # Require API key + if not api_key: + raise HTTPException( + status_code=401, + detail="Missing API key", + headers={"WWW-Authenticate": "ApiKey"}, + ) + + # Handle "Bearer" or "Token" prefix + if api_key.lower().startswith("bearer "): + api_key = api_key[7:] + elif api_key.lower().startswith("token "): + api_key = api_key[6:] + + # Check against master key first (has different format: mk_*) + key_hash = hash_api_key(api_key) + if MASTER_KEY_HASH and key_hash == MASTER_KEY_HASH: + logger.info("Master key authentication") + return { + "user_name": "admin", + "scopes": ["all"], + "is_master_key": True, + } + + # Validate format for regular API keys (krlk_*) + if not validate_key_format(api_key): + raise HTTPException( + status_code=401, + detail="Invalid API key format", + ) + + # Look up in database + key_data = await lookup_api_key(key_hash) + if not key_data: + logger.warning(f"Invalid API key attempt: {get_key_prefix(api_key)}...") + raise HTTPException( + status_code=401, + detail="Invalid or expired API key", + ) + + logger.debug(f"Authenticated user: {key_data['user_name']}") + return { + "user_name": key_data["user_name"], + "scopes": key_data["scopes"], + "is_master_key": False, + "api_key_id": key_data["id"], + } + + +def require_scope(required_scope: str): + """ + Dependency factory to require a specific scope. + + Usage: + @router.post("/admin/keys", dependencies=[Depends(require_scope("admin"))]) + """ + async def scope_checker( + auth: dict[str, Any] = Depends(verify_api_key), + ) -> dict[str, Any]: + scopes = auth.get("scopes", []) + + # "all" scope grants everything + if "all" in scopes or required_scope in scopes: + return auth + + raise HTTPException( + status_code=403, + detail=f"Insufficient permissions. Required scope: {required_scope}", + ) + + return scope_checker + + +# Convenience dependencies +require_read = require_scope("read") +require_write = require_scope("write") +require_admin = require_scope("admin") diff --git a/overlays/krolik/api/middleware/rate_limit.py b/overlays/krolik/api/middleware/rate_limit.py new file mode 100644 index 000000000..12ee84ef4 --- /dev/null +++ b/overlays/krolik/api/middleware/rate_limit.py @@ -0,0 +1,200 @@ +""" +Redis-based Rate Limiting Middleware. + +Implements sliding window rate limiting with Redis. +Falls back to in-memory limiting if Redis is unavailable. +""" + +import os +import time +from collections import defaultdict +from typing import Callable + +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +import memos.log + +logger = memos.log.get_logger(__name__) + +# Configuration from environment +RATE_LIMIT = int(os.getenv("RATE_LIMIT", "100")) # Requests per window +RATE_WINDOW = int(os.getenv("RATE_WINDOW_SEC", "60")) # Window in seconds +REDIS_URL = os.getenv("REDIS_URL", "redis://redis:6379") + +# Redis client (lazy initialization) +_redis_client = None + +# In-memory fallback (per process) +_memory_store: dict[str, list[float]] = defaultdict(list) + + +def _get_redis(): + """Get or create Redis client.""" + global _redis_client + if _redis_client is not None: + return _redis_client + + try: + import redis + + _redis_client = redis.from_url(REDIS_URL, decode_responses=True) + _redis_client.ping() # Test connection + logger.info("Rate limiter connected to Redis") + return _redis_client + except Exception as e: + logger.warning(f"Redis not available for rate limiting: {e}") + return None + + +def _get_client_key(request: Request) -> str: + """ + Generate a unique key for rate limiting. + + Uses API key if available, otherwise falls back to IP. + """ + # Try to get API key from header + auth_header = request.headers.get("Authorization", "") + if auth_header.startswith("krlk_"): + # Use first 20 chars of key as identifier + return f"ratelimit:key:{auth_header[:20]}" + + # Fall back to IP address + client_ip = request.client.host if request.client else "unknown" + + # Check for forwarded IP (behind proxy) + forwarded = request.headers.get("X-Forwarded-For") + if forwarded: + client_ip = forwarded.split(",")[0].strip() + + return f"ratelimit:ip:{client_ip}" + + +def _check_rate_limit_redis(key: str) -> tuple[bool, int, int]: + """ + Check rate limit using Redis sliding window. + + Returns: + (allowed, remaining, reset_time) + """ + redis_client = _get_redis() + if not redis_client: + return _check_rate_limit_memory(key) + + try: + now = time.time() + window_start = now - RATE_WINDOW + + pipe = redis_client.pipeline() + + # Remove old entries + pipe.zremrangebyscore(key, 0, window_start) + + # Count current entries + pipe.zcard(key) + + # Add current request + pipe.zadd(key, {str(now): now}) + + # Set expiry + pipe.expire(key, RATE_WINDOW + 1) + + results = pipe.execute() + current_count = results[1] + + remaining = max(0, RATE_LIMIT - current_count - 1) + reset_time = int(now + RATE_WINDOW) + + if current_count >= RATE_LIMIT: + return False, 0, reset_time + + return True, remaining, reset_time + + except Exception as e: + logger.warning(f"Redis rate limit error: {e}") + return _check_rate_limit_memory(key) + + +def _check_rate_limit_memory(key: str) -> tuple[bool, int, int]: + """ + Fallback in-memory rate limiting. + + Note: This is per-process and not distributed! + """ + now = time.time() + window_start = now - RATE_WINDOW + + # Clean old entries + _memory_store[key] = [t for t in _memory_store[key] if t > window_start] + + current_count = len(_memory_store[key]) + + if current_count >= RATE_LIMIT: + reset_time = int(min(_memory_store[key]) + RATE_WINDOW) if _memory_store[key] else int(now + RATE_WINDOW) + return False, 0, reset_time + + # Add current request + _memory_store[key].append(now) + + remaining = RATE_LIMIT - current_count - 1 + reset_time = int(now + RATE_WINDOW) + + return True, remaining, reset_time + + +class RateLimitMiddleware(BaseHTTPMiddleware): + """ + Rate limiting middleware using sliding window algorithm. + + Adds headers: + - X-RateLimit-Limit: Maximum requests per window + - X-RateLimit-Remaining: Remaining requests + - X-RateLimit-Reset: Unix timestamp when the window resets + + Returns 429 Too Many Requests when limit is exceeded. + """ + + # Paths exempt from rate limiting + EXEMPT_PATHS = {"/health", "/openapi.json", "/docs", "/redoc"} + + async def dispatch(self, request: Request, call_next: Callable) -> Response: + # Skip rate limiting for exempt paths + if request.url.path in self.EXEMPT_PATHS: + return await call_next(request) + + # Skip OPTIONS requests (CORS preflight) + if request.method == "OPTIONS": + return await call_next(request) + + # Get rate limit key + key = _get_client_key(request) + + # Check rate limit + allowed, remaining, reset_time = _check_rate_limit_redis(key) + + if not allowed: + logger.warning(f"Rate limit exceeded for {key}") + return JSONResponse( + status_code=429, + content={ + "detail": "Too many requests. Please slow down.", + "retry_after": reset_time - int(time.time()), + }, + headers={ + "X-RateLimit-Limit": str(RATE_LIMIT), + "X-RateLimit-Remaining": "0", + "X-RateLimit-Reset": str(reset_time), + "Retry-After": str(reset_time - int(time.time())), + }, + ) + + # Process request + response = await call_next(request) + + # Add rate limit headers + response.headers["X-RateLimit-Limit"] = str(RATE_LIMIT) + response.headers["X-RateLimit-Remaining"] = str(remaining) + response.headers["X-RateLimit-Reset"] = str(reset_time) + + return response diff --git a/overlays/krolik/api/routers/__init__.py b/overlays/krolik/api/routers/__init__.py new file mode 100644 index 000000000..656114d7a --- /dev/null +++ b/overlays/krolik/api/routers/__init__.py @@ -0,0 +1,5 @@ +"""Krolik router extensions for MemOS.""" + +from .admin_router import router as admin_router + +__all__ = ["admin_router"] diff --git a/overlays/krolik/api/routers/admin_router.py b/overlays/krolik/api/routers/admin_router.py new file mode 100644 index 000000000..939e5101f --- /dev/null +++ b/overlays/krolik/api/routers/admin_router.py @@ -0,0 +1,225 @@ +""" +Admin Router for API Key Management. + +Protected by master key or admin scope. +""" + +import os +from typing import Any + +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel, Field + +import memos.log +from memos.api.middleware.auth import require_scope, verify_api_key +from memos.api.utils.api_keys import ( + create_api_key_in_db, + generate_master_key, + list_api_keys, + revoke_api_key, +) + +logger = memos.log.get_logger(__name__) + +router = APIRouter(prefix="/admin", tags=["Admin"]) + + +# Request/Response models +class CreateKeyRequest(BaseModel): + user_name: str = Field(..., min_length=1, max_length=255) + scopes: list[str] = Field(default=["read"]) + description: str | None = Field(default=None, max_length=500) + expires_in_days: int | None = Field(default=None, ge=1, le=365) + + +class CreateKeyResponse(BaseModel): + message: str + key: str # Only returned once! + key_prefix: str + user_name: str + scopes: list[str] + + +class KeyListResponse(BaseModel): + message: str + keys: list[dict[str, Any]] + + +class RevokeKeyRequest(BaseModel): + key_id: str + + +class SimpleResponse(BaseModel): + message: str + success: bool = True + + +def _get_db_connection(): + """Get database connection for admin operations.""" + import psycopg2 + + return psycopg2.connect( + host=os.getenv("POSTGRES_HOST", "postgres"), + port=int(os.getenv("POSTGRES_PORT", "5432")), + user=os.getenv("POSTGRES_USER", "memos"), + password=os.getenv("POSTGRES_PASSWORD", ""), + dbname=os.getenv("POSTGRES_DB", "memos"), + ) + + +@router.post( + "/keys", + response_model=CreateKeyResponse, + summary="Create a new API key", + dependencies=[Depends(require_scope("admin"))], +) +def create_key( + request: CreateKeyRequest, + auth: dict = Depends(verify_api_key), +): + """ + Create a new API key for a user. + + Requires admin scope or master key. + + **WARNING**: The API key is only returned once. Store it securely! + """ + try: + conn = _get_db_connection() + try: + api_key = create_api_key_in_db( + conn=conn, + user_name=request.user_name, + scopes=request.scopes, + description=request.description, + expires_in_days=request.expires_in_days, + created_by=auth.get("user_name", "unknown"), + ) + + logger.info( + f"API key created for user '{request.user_name}' by '{auth.get('user_name')}'" + ) + + return CreateKeyResponse( + message="API key created successfully. Store this key securely - it won't be shown again!", + key=api_key.key, + key_prefix=api_key.key_prefix, + user_name=request.user_name, + scopes=request.scopes, + ) + finally: + conn.close() + except Exception as e: + logger.error(f"Failed to create API key: {e}") + raise HTTPException(status_code=500, detail="Failed to create API key") + + +@router.get( + "/keys", + response_model=KeyListResponse, + summary="List API keys", + dependencies=[Depends(require_scope("admin"))], +) +def list_keys( + user_name: str | None = None, + auth: dict = Depends(verify_api_key), +): + """ + List all API keys (admin) or keys for a specific user. + + Note: Actual key values are never returned, only prefixes. + """ + try: + conn = _get_db_connection() + try: + keys = list_api_keys(conn, user_name=user_name) + return KeyListResponse( + message=f"Found {len(keys)} key(s)", + keys=keys, + ) + finally: + conn.close() + except Exception as e: + logger.error(f"Failed to list API keys: {e}") + raise HTTPException(status_code=500, detail="Failed to list API keys") + + +@router.delete( + "/keys/{key_id}", + response_model=SimpleResponse, + summary="Revoke an API key", + dependencies=[Depends(require_scope("admin"))], +) +def revoke_key( + key_id: str, + auth: dict = Depends(verify_api_key), +): + """ + Revoke an API key by ID. + + The key will be deactivated but not deleted (for audit purposes). + """ + try: + conn = _get_db_connection() + try: + success = revoke_api_key(conn, key_id) + if success: + logger.info(f"API key {key_id} revoked by '{auth.get('user_name')}'") + return SimpleResponse(message="API key revoked successfully") + else: + raise HTTPException(status_code=404, detail="API key not found or already revoked") + finally: + conn.close() + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to revoke API key: {e}") + raise HTTPException(status_code=500, detail="Failed to revoke API key") + + +@router.post( + "/generate-master-key", + response_model=dict, + summary="Generate a new master key", + dependencies=[Depends(require_scope("admin"))], +) +def generate_new_master_key( + auth: dict = Depends(verify_api_key), +): + """ + Generate a new master key. + + **WARNING**: Store the key securely! Add MASTER_KEY_HASH to your .env file. + """ + if not auth.get("is_master_key"): + raise HTTPException( + status_code=403, + detail="Only master key can generate new master keys", + ) + + key, key_hash = generate_master_key() + + logger.warning("New master key generated - update MASTER_KEY_HASH in .env") + + return { + "message": "Master key generated. Add MASTER_KEY_HASH to your .env file!", + "key": key, + "key_hash": key_hash, + "env_line": f"MASTER_KEY_HASH={key_hash}", + } + + +@router.get( + "/health", + summary="Admin health check", +) +def admin_health(): + """Health check for admin endpoints.""" + auth_enabled = os.getenv("AUTH_ENABLED", "false").lower() == "true" + master_key_configured = bool(os.getenv("MASTER_KEY_HASH")) + + return { + "status": "ok", + "auth_enabled": auth_enabled, + "master_key_configured": master_key_configured, + } diff --git a/overlays/krolik/api/server_api_ext.py b/overlays/krolik/api/server_api_ext.py new file mode 100644 index 000000000..85b9411af --- /dev/null +++ b/overlays/krolik/api/server_api_ext.py @@ -0,0 +1,120 @@ +""" +Extended Server API for Krolik deployment. + +This module extends the base MemOS server_api with: +- API Key Authentication (PostgreSQL-backed) +- Redis Rate Limiting +- Admin API for key management +- Security Headers + +Usage in Dockerfile: + # Copy overlays after base installation + COPY overlays/krolik/ /app/src/memos/ + + # Use this as entrypoint instead of server_api + CMD ["gunicorn", "memos.api.server_api_ext:app", ...] +""" + +import logging +import os + +from fastapi import FastAPI +from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import Response + +# Import base routers from MemOS +from memos.api.routers.server_router import router as server_router + +# Import Krolik extensions +from memos.api.middleware.rate_limit import RateLimitMiddleware +from memos.api.routers.admin_router import router as admin_router + +# Try to import exception handlers (may vary between MemOS versions) +try: + from memos.api.exceptions import APIExceptionHandler + HAS_EXCEPTION_HANDLER = True +except ImportError: + HAS_EXCEPTION_HANDLER = False + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + + +class SecurityHeadersMiddleware(BaseHTTPMiddleware): + """Add security headers to all responses.""" + + async def dispatch(self, request: Request, call_next) -> Response: + response = await call_next(request) + response.headers["X-Content-Type-Options"] = "nosniff" + response.headers["X-Frame-Options"] = "DENY" + response.headers["X-XSS-Protection"] = "1; mode=block" + response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" + response.headers["Permissions-Policy"] = "geolocation=(), microphone=(), camera=()" + return response + + +# Create FastAPI app +app = FastAPI( + title="MemOS Server REST APIs (Krolik Extended)", + description="MemOS API with authentication, rate limiting, and admin endpoints.", + version="2.0.3-krolik", +) + +# CORS configuration +CORS_ORIGINS = os.getenv("CORS_ORIGINS", "").split(",") +CORS_ORIGINS = [origin.strip() for origin in CORS_ORIGINS if origin.strip()] + +if not CORS_ORIGINS: + CORS_ORIGINS = [ + "https://krolik.hully.one", + "https://memos.hully.one", + "http://localhost:3000", + ] + +app.add_middleware( + CORSMiddleware, + allow_origins=CORS_ORIGINS, + allow_credentials=True, + allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], + allow_headers=["Authorization", "Content-Type", "X-API-Key", "X-User-Name"], +) + +# Security headers +app.add_middleware(SecurityHeadersMiddleware) + +# Rate limiting (before auth to protect against brute force) +RATE_LIMIT_ENABLED = os.getenv("RATE_LIMIT_ENABLED", "true").lower() == "true" +if RATE_LIMIT_ENABLED: + app.add_middleware(RateLimitMiddleware) + logger.info("Rate limiting enabled") + +# Include routers +app.include_router(server_router) +app.include_router(admin_router) + +# Exception handlers +if HAS_EXCEPTION_HANDLER: + from fastapi import HTTPException + app.exception_handler(RequestValidationError)(APIExceptionHandler.validation_error_handler) + app.exception_handler(ValueError)(APIExceptionHandler.value_error_handler) + app.exception_handler(HTTPException)(APIExceptionHandler.http_error_handler) + app.exception_handler(Exception)(APIExceptionHandler.global_exception_handler) + + +@app.get("/health") +async def health_check(): + """Health check endpoint.""" + return { + "status": "healthy", + "version": "2.0.3-krolik", + "auth_enabled": os.getenv("AUTH_ENABLED", "false").lower() == "true", + "rate_limit_enabled": RATE_LIMIT_ENABLED, + } + + +if __name__ == "__main__": + import uvicorn + uvicorn.run("memos.api.server_api_ext:app", host="0.0.0.0", port=8000, workers=1) diff --git a/overlays/krolik/api/utils/__init__.py b/overlays/krolik/api/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/overlays/krolik/api/utils/api_keys.py b/overlays/krolik/api/utils/api_keys.py new file mode 100644 index 000000000..559ddd355 --- /dev/null +++ b/overlays/krolik/api/utils/api_keys.py @@ -0,0 +1,197 @@ +""" +API Key Management Utilities. + +Provides functions for generating, validating, and managing API keys. +""" + +import hashlib +import os +import secrets +from dataclasses import dataclass +from datetime import datetime, timedelta + + +@dataclass +class APIKey: + """Represents a generated API key.""" + + key: str # Full key (only available at creation time) + key_hash: str # SHA-256 hash (stored in database) + key_prefix: str # First 12 chars for identification + + +def generate_api_key() -> APIKey: + """ + Generate a new API key. + + Format: krlk_<64-hex-chars> + + Returns: + APIKey with key, hash, and prefix + """ + # Generate 32 random bytes = 64 hex chars + random_bytes = secrets.token_bytes(32) + hex_part = random_bytes.hex() + + key = f"krlk_{hex_part}" + key_hash = hashlib.sha256(key.encode()).hexdigest() + key_prefix = key[:12] + + return APIKey(key=key, key_hash=key_hash, key_prefix=key_prefix) + + +def hash_key(key: str) -> str: + """Hash an API key using SHA-256.""" + return hashlib.sha256(key.encode()).hexdigest() + + +def validate_key_format(key: str) -> bool: + """ + Validate API key format. + + Valid format: krlk_<64-hex-chars> + """ + if not key or not isinstance(key, str): + return False + + if not key.startswith("krlk_"): + return False + + hex_part = key[5:] + if len(hex_part) != 64: + return False + + try: + int(hex_part, 16) + return True + except ValueError: + return False + + +def generate_master_key() -> tuple[str, str]: + """ + Generate a master key for admin operations. + + Returns: + Tuple of (key, hash) + """ + random_bytes = secrets.token_bytes(32) + key = f"mk_{random_bytes.hex()}" + key_hash = hashlib.sha256(key.encode()).hexdigest() + return key, key_hash + + +def create_api_key_in_db( + conn, + user_name: str, + scopes: list[str] | None = None, + description: str | None = None, + expires_in_days: int | None = None, + created_by: str | None = None, +) -> APIKey: + """ + Create a new API key and store in database. + + Args: + conn: Database connection + user_name: Owner of the key + scopes: List of scopes (default: ["read"]) + description: Human-readable description + expires_in_days: Days until expiration (None = never) + created_by: Who created this key + + Returns: + APIKey with the generated key (only time it's available!) + """ + api_key = generate_api_key() + + expires_at = None + if expires_in_days: + expires_at = datetime.utcnow() + timedelta(days=expires_in_days) + + with conn.cursor() as cur: + cur.execute( + """ + INSERT INTO api_keys (key_hash, key_prefix, user_name, scopes, description, expires_at, created_by) + VALUES (%s, %s, %s, %s, %s, %s, %s) + RETURNING id + """, + ( + api_key.key_hash, + api_key.key_prefix, + user_name, + scopes or ["read"], + description, + expires_at, + created_by, + ), + ) + conn.commit() + + return api_key + + +def revoke_api_key(conn, key_id: str) -> bool: + """ + Revoke an API key by ID. + + Returns: + True if key was revoked, False if not found + """ + with conn.cursor() as cur: + cur.execute( + "UPDATE api_keys SET is_active = false WHERE id = %s AND is_active = true", + (key_id,), + ) + conn.commit() + return cur.rowcount > 0 + + +def list_api_keys(conn, user_name: str | None = None) -> list[dict]: + """ + List API keys (without exposing the actual keys). + + Args: + conn: Database connection + user_name: Filter by user (None = all users) + + Returns: + List of key metadata dicts + """ + with conn.cursor() as cur: + if user_name: + cur.execute( + """ + SELECT id, key_prefix, user_name, scopes, description, + created_at, last_used_at, expires_at, is_active + FROM api_keys + WHERE user_name = %s + ORDER BY created_at DESC + """, + (user_name,), + ) + else: + cur.execute( + """ + SELECT id, key_prefix, user_name, scopes, description, + created_at, last_used_at, expires_at, is_active + FROM api_keys + ORDER BY created_at DESC + """ + ) + + rows = cur.fetchall() + return [ + { + "id": str(row[0]), + "key_prefix": row[1], + "user_name": row[2], + "scopes": row[3], + "description": row[4], + "created_at": row[5].isoformat() if row[5] else None, + "last_used_at": row[6].isoformat() if row[6] else None, + "expires_at": row[7].isoformat() if row[7] else None, + "is_active": row[8], + } + for row in rows + ] diff --git a/src/memos/api/config.py b/src/memos/api/config.py index a3bf25be0..ad017ad78 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -628,6 +628,30 @@ def get_polardb_config(user_id: str | None = None) -> dict[str, Any]: "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 1024)), } + @staticmethod + def get_postgres_config(user_id: str | None = None) -> dict[str, Any]: + """Get PostgreSQL + pgvector configuration for MemOS graph storage. + + Uses standard PostgreSQL with pgvector extension. + Schema: memos.memories, memos.edges + """ + user_name = os.getenv("MEMOS_USER_NAME", "default") + if user_id: + user_name = f"memos_{user_id.replace('-', '')}" + + return { + "host": os.getenv("POSTGRES_HOST", "postgres"), + "port": int(os.getenv("POSTGRES_PORT", "5432")), + "user": os.getenv("POSTGRES_USER", "n8n"), + "password": os.getenv("POSTGRES_PASSWORD", ""), + "db_name": os.getenv("POSTGRES_DB", "n8n"), + "schema_name": os.getenv("MEMOS_SCHEMA", "memos"), + "user_name": user_name, + "use_multi_db": False, + "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", "384")), + "maxconn": int(os.getenv("POSTGRES_MAX_CONN", "20")), + } + @staticmethod def get_mysql_config() -> dict[str, Any]: """Get MySQL configuration.""" @@ -884,13 +908,16 @@ def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "Gene if os.getenv("ENABLE_INTERNET", "false").lower() == "true" else None ) + postgres_config = APIConfig.get_postgres_config(user_id=user_id) graph_db_backend_map = { "neo4j-community": neo4j_community_config, "neo4j": neo4j_config, "nebular": nebular_config, "polardb": polardb_config, + "postgres": postgres_config, } - graph_db_backend = os.getenv("NEO4J_BACKEND", "neo4j-community").lower() + # Support both GRAPH_DB_BACKEND and legacy NEO4J_BACKEND env vars + graph_db_backend = os.getenv("GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "neo4j-community")).lower() if graph_db_backend in graph_db_backend_map: # Create MemCube config @@ -958,18 +985,21 @@ def get_default_cube_config() -> "GeneralMemCubeConfig | None": neo4j_config = APIConfig.get_neo4j_config(user_id="default") nebular_config = APIConfig.get_nebular_config(user_id="default") polardb_config = APIConfig.get_polardb_config(user_id="default") + postgres_config = APIConfig.get_postgres_config(user_id="default") graph_db_backend_map = { "neo4j-community": neo4j_community_config, "neo4j": neo4j_config, "nebular": nebular_config, "polardb": polardb_config, + "postgres": postgres_config, } internet_config = ( APIConfig.get_internet_config() if os.getenv("ENABLE_INTERNET", "false").lower() == "true" else None ) - graph_db_backend = os.getenv("NEO4J_BACKEND", "neo4j-community").lower() + # Support both GRAPH_DB_BACKEND and legacy NEO4J_BACKEND env vars + graph_db_backend = os.getenv("GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "neo4j-community")).lower() if graph_db_backend in graph_db_backend_map: return GeneralMemCubeConfig.model_validate( { diff --git a/src/memos/api/handlers/config_builders.py b/src/memos/api/handlers/config_builders.py index fce789e2a..2d82cb3ca 100644 --- a/src/memos/api/handlers/config_builders.py +++ b/src/memos/api/handlers/config_builders.py @@ -41,9 +41,11 @@ def build_graph_db_config(user_id: str = "default") -> dict[str, Any]: "neo4j": APIConfig.get_neo4j_config(user_id=user_id), "nebular": APIConfig.get_nebular_config(user_id=user_id), "polardb": APIConfig.get_polardb_config(user_id=user_id), + "postgres": APIConfig.get_postgres_config(user_id=user_id), } - graph_db_backend = os.getenv("NEO4J_BACKEND", "nebular").lower() + # Support both GRAPH_DB_BACKEND and legacy NEO4J_BACKEND env vars + graph_db_backend = os.getenv("GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "nebular")).lower() return GraphDBConfigFactory.model_validate( { "backend": graph_db_backend, diff --git a/src/memos/api/mcp_serve.py b/src/memos/api/mcp_serve.py index ce2e41390..8f8e70311 100644 --- a/src/memos/api/mcp_serve.py +++ b/src/memos/api/mcp_serve.py @@ -122,15 +122,6 @@ def load_default_config(user_id="default_user"): return config, cube -class MOSMCPStdioServer: - def __init__(self): - self.mcp = FastMCP("MOS Memory System") - config, cube = load_default_config() - self.mos_core = MOS(config=config) - self.mos_core.register_mem_cube(cube) - self._setup_tools() - - class MOSMCPServer: """MCP Server that accepts an existing MOS instance.""" @@ -584,7 +575,6 @@ def _run_mcp(self, transport: str = "stdio", **kwargs): raise ValueError(f"Unsupported transport: {transport}") -MOSMCPStdioServer.run = _run_mcp MOSMCPServer.run = _run_mcp @@ -610,5 +600,5 @@ def _run_mcp(self, transport: str = "stdio", **kwargs): args = parser.parse_args() # Create and run MCP server - server = MOSMCPStdioServer() + server = MOSMCPServer() server.run(transport=args.transport, host=args.host, port=args.port) diff --git a/src/memos/configs/graph_db.py b/src/memos/configs/graph_db.py index 3b4bace0e..5ce9faad1 100644 --- a/src/memos/configs/graph_db.py +++ b/src/memos/configs/graph_db.py @@ -211,6 +211,58 @@ def validate_config(self): return self +class PostgresGraphDBConfig(BaseConfig): + """ + PostgreSQL + pgvector configuration for MemOS. + + Uses standard PostgreSQL with pgvector extension for vector search. + Does NOT require Apache AGE or other graph extensions. + + Schema: + - memos_memories: Main table for memory nodes (id, memory, properties JSONB, embedding vector) + - memos_edges: Edge table for relationships (source_id, target_id, type) + + Example: + --- + host = "postgres" + port = 5432 + user = "n8n" + password = "secret" + db_name = "n8n" + schema_name = "memos" + user_name = "default" + """ + + host: str = Field(..., description="Database host") + port: int = Field(default=5432, description="Database port") + user: str = Field(..., description="Database user") + password: str = Field(..., description="Database password") + db_name: str = Field(..., description="Database name") + schema_name: str = Field(default="memos", description="Schema name for MemOS tables") + user_name: str | None = Field( + default=None, + description="Logical user/tenant ID for data isolation", + ) + use_multi_db: bool = Field( + default=False, + description="If False: use single database with logical isolation by user_name", + ) + embedding_dimension: int = Field(default=768, description="Dimension of vector embedding (768 for all-mpnet-base-v2)") + maxconn: int = Field( + default=20, + description="Maximum number of connections in the connection pool", + ) + + @model_validator(mode="after") + def validate_config(self): + """Validate config.""" + if not self.db_name: + raise ValueError("`db_name` must be provided") + if not self.use_multi_db and not self.user_name: + raise ValueError("In single-database mode, `user_name` must be provided") + return self + + class GraphDBConfigFactory(BaseModel): backend: str = Field(..., description="Backend for graph database") config: dict[str, Any] = Field(..., description="Configuration for the graph database backend") @@ -220,6 +272,7 @@ class GraphDBConfigFactory(BaseModel): "neo4j-community": Neo4jCommunityGraphDBConfig, "nebular": NebulaGraphDBConfig, "polardb": PolarDBGraphDBConfig, + "postgres": PostgresGraphDBConfig, } @field_validator("backend") diff --git a/src/memos/graph_dbs/factory.py b/src/memos/graph_dbs/factory.py index ec9cbcda0..c207e3190 100644 --- a/src/memos/graph_dbs/factory.py +++ b/src/memos/graph_dbs/factory.py @@ -6,6 +6,7 @@ from memos.graph_dbs.neo4j import Neo4jGraphDB from memos.graph_dbs.neo4j_community import Neo4jCommunityGraphDB from memos.graph_dbs.polardb import PolarDBGraphDB +from memos.graph_dbs.postgres import PostgresGraphDB class GraphStoreFactory(BaseGraphDB): @@ -16,6 +17,7 @@ class GraphStoreFactory(BaseGraphDB): "neo4j-community": Neo4jCommunityGraphDB, "nebular": NebulaGraphDB, "polardb": PolarDBGraphDB, + "postgres": PostgresGraphDB, } @classmethod diff --git a/src/memos/graph_dbs/postgres.py b/src/memos/graph_dbs/postgres.py new file mode 100644 index 000000000..f9065d718 --- /dev/null +++ b/src/memos/graph_dbs/postgres.py @@ -0,0 +1,884 @@ +""" +PostgreSQL + pgvector backend for MemOS. + +Simple implementation using standard PostgreSQL with pgvector extension. +No Apache AGE or other graph extensions required. + +Tables: +- {schema}.memories: Memory nodes with JSONB properties and vector embeddings +- {schema}.edges: Relationships between memory nodes +""" + +import json +import time +from contextlib import suppress +from datetime import datetime +from typing import Any, Literal + +from memos.configs.graph_db import PostgresGraphDBConfig +from memos.dependency import require_python_package +from memos.graph_dbs.base import BaseGraphDB +from memos.log import get_logger + +logger = get_logger(__name__) + + +def _prepare_node_metadata(metadata: dict[str, Any]) -> dict[str, Any]: + """Ensure metadata has proper datetime fields and normalized types.""" + now = datetime.utcnow().isoformat() + metadata.setdefault("created_at", now) + metadata.setdefault("updated_at", now) + + # Normalize embedding type + embedding = metadata.get("embedding") + if embedding and isinstance(embedding, list): + metadata["embedding"] = [float(x) for x in embedding] + + return metadata + + +class PostgresGraphDB(BaseGraphDB): + """PostgreSQL + pgvector implementation of a graph memory store.""" + + @require_python_package( + import_name="psycopg2", + install_command="pip install psycopg2-binary", + install_link="https://pypi.org/project/psycopg2-binary/", + ) + def __init__(self, config: PostgresGraphDBConfig): + """Initialize PostgreSQL connection pool.""" + import psycopg2 + import psycopg2.pool + + self.config = config + self.schema = config.schema_name + self.user_name = config.user_name + self._pool_closed = False + + logger.info(f"Connecting to PostgreSQL: {config.host}:{config.port}/{config.db_name}") + + # Create connection pool + self.pool = psycopg2.pool.ThreadedConnectionPool( + minconn=2, + maxconn=config.maxconn, + host=config.host, + port=config.port, + user=config.user, + password=config.password, + dbname=config.db_name, + connect_timeout=30, + keepalives_idle=30, + keepalives_interval=10, + keepalives_count=5, + ) + + # Initialize schema and tables + self._init_schema() + + def _get_conn(self): + """Get connection from pool with health check.""" + if self._pool_closed: + raise RuntimeError("Connection pool is closed") + + for attempt in range(3): + conn = None + try: + conn = self.pool.getconn() + if conn.closed != 0: + self.pool.putconn(conn, close=True) + continue + conn.autocommit = True + # Health check + with conn.cursor() as cur: + cur.execute("SELECT 1") + return conn + except Exception as e: + if conn: + with suppress(Exception): + self.pool.putconn(conn, close=True) + if attempt == 2: + raise RuntimeError(f"Failed to get connection: {e}") from e + time.sleep(0.1) + raise RuntimeError("Failed to get healthy connection") + + def _put_conn(self, conn): + """Return connection to pool.""" + if conn and not self._pool_closed: + try: + self.pool.putconn(conn) + except Exception: + with suppress(Exception): + conn.close() + + def _init_schema(self): + """Create schema and tables if they don't exist.""" + conn = self._get_conn() + try: + with conn.cursor() as cur: + # Create schema + cur.execute(f"CREATE SCHEMA IF NOT EXISTS {self.schema}") + + # Enable pgvector + cur.execute("CREATE EXTENSION IF NOT EXISTS vector") + + # Create memories table + dim = self.config.embedding_dimension + cur.execute(f""" + CREATE TABLE IF NOT EXISTS {self.schema}.memories ( + id TEXT PRIMARY KEY, + memory TEXT NOT NULL DEFAULT '', + properties JSONB NOT NULL DEFAULT '{{}}', + embedding vector({dim}), + user_name TEXT, + created_at TIMESTAMPTZ DEFAULT NOW(), + updated_at TIMESTAMPTZ DEFAULT NOW() + ) + """) + + # Create edges table + cur.execute(f""" + CREATE TABLE IF NOT EXISTS {self.schema}.edges ( + id SERIAL PRIMARY KEY, + source_id TEXT NOT NULL, + target_id TEXT NOT NULL, + edge_type TEXT NOT NULL, + created_at TIMESTAMPTZ DEFAULT NOW(), + UNIQUE(source_id, target_id, edge_type) + ) + """) + + # Create indexes + cur.execute(f""" + CREATE INDEX IF NOT EXISTS idx_memories_user + ON {self.schema}.memories(user_name) + """) + cur.execute(f""" + CREATE INDEX IF NOT EXISTS idx_memories_props + ON {self.schema}.memories USING GIN(properties) + """) + cur.execute(f""" + CREATE INDEX IF NOT EXISTS idx_memories_embedding + ON {self.schema}.memories USING ivfflat(embedding vector_cosine_ops) + WITH (lists = 100) + """) + cur.execute(f""" + CREATE INDEX IF NOT EXISTS idx_edges_source + ON {self.schema}.edges(source_id) + """) + cur.execute(f""" + CREATE INDEX IF NOT EXISTS idx_edges_target + ON {self.schema}.edges(target_id) + """) + + logger.info(f"Schema {self.schema} initialized successfully") + except Exception as e: + logger.error(f"Failed to init schema: {e}") + raise + finally: + self._put_conn(conn) + + # ========================================================================= + # Node Management + # ========================================================================= + + def remove_oldest_memory( + self, memory_type: str, keep_latest: int, user_name: str | None = None + ) -> None: + """ + Remove all memories of a given type except the latest `keep_latest` entries. + + Args: + memory_type: Memory type (e.g., 'WorkingMemory', 'LongTermMemory'). + keep_latest: Number of latest entries to keep. + user_name: User to filter by. + """ + user_name = user_name or self.user_name + keep_latest = int(keep_latest) + + conn = self._get_conn() + try: + with conn.cursor() as cur: + # Find IDs to delete (older than the keep_latest entries) + cur.execute(f""" + WITH ranked AS ( + SELECT id, ROW_NUMBER() OVER (ORDER BY updated_at DESC) as rn + FROM {self.schema}.memories + WHERE user_name = %s + AND properties->>'memory_type' = %s + ) + SELECT id FROM ranked WHERE rn > %s + """, (user_name, memory_type, keep_latest)) + + ids_to_delete = [row[0] for row in cur.fetchall()] + + if ids_to_delete: + # Delete edges first + cur.execute(f""" + DELETE FROM {self.schema}.edges + WHERE source_id = ANY(%s) OR target_id = ANY(%s) + """, (ids_to_delete, ids_to_delete)) + + # Delete nodes + cur.execute(f""" + DELETE FROM {self.schema}.memories + WHERE id = ANY(%s) + """, (ids_to_delete,)) + + logger.info(f"Removed {len(ids_to_delete)} oldest {memory_type} memories for user {user_name}") + finally: + self._put_conn(conn) + + def add_node( + self, id: str, memory: str, metadata: dict[str, Any], user_name: str | None = None + ) -> None: + """Add a memory node.""" + user_name = user_name or self.user_name + metadata = _prepare_node_metadata(metadata.copy()) + + # Extract embedding + embedding = metadata.pop("embedding", None) + created_at = metadata.pop("created_at", datetime.utcnow().isoformat()) + updated_at = metadata.pop("updated_at", datetime.utcnow().isoformat()) + + # Serialize sources if present + if metadata.get("sources"): + metadata["sources"] = [ + json.dumps(s) if not isinstance(s, str) else s + for s in metadata["sources"] + ] + + conn = self._get_conn() + try: + with conn.cursor() as cur: + if embedding: + cur.execute(f""" + INSERT INTO {self.schema}.memories + (id, memory, properties, embedding, user_name, created_at, updated_at) + VALUES (%s, %s, %s, %s::vector, %s, %s, %s) + ON CONFLICT (id) DO UPDATE SET + memory = EXCLUDED.memory, + properties = EXCLUDED.properties, + embedding = EXCLUDED.embedding, + updated_at = EXCLUDED.updated_at + """, (id, memory, json.dumps(metadata), embedding, user_name, created_at, updated_at)) + else: + cur.execute(f""" + INSERT INTO {self.schema}.memories + (id, memory, properties, user_name, created_at, updated_at) + VALUES (%s, %s, %s, %s, %s, %s) + ON CONFLICT (id) DO UPDATE SET + memory = EXCLUDED.memory, + properties = EXCLUDED.properties, + updated_at = EXCLUDED.updated_at + """, (id, memory, json.dumps(metadata), user_name, created_at, updated_at)) + finally: + self._put_conn(conn) + + def add_nodes_batch( + self, nodes: list[dict[str, Any]], user_name: str | None = None + ) -> None: + """Batch add memory nodes.""" + for node in nodes: + self.add_node( + id=node["id"], + memory=node["memory"], + metadata=node.get("metadata", {}), + user_name=user_name, + ) + + def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = None) -> None: + """Update node fields.""" + user_name = user_name or self.user_name + if not fields: + return + + # Get current node + current = self.get_node(id, user_name=user_name) + if not current: + return + + # Merge properties + props = current.get("metadata", {}).copy() + embedding = fields.pop("embedding", None) + memory = fields.pop("memory", current.get("memory", "")) + props.update(fields) + props["updated_at"] = datetime.utcnow().isoformat() + + conn = self._get_conn() + try: + with conn.cursor() as cur: + if embedding: + cur.execute(f""" + UPDATE {self.schema}.memories + SET memory = %s, properties = %s, embedding = %s::vector, updated_at = NOW() + WHERE id = %s AND user_name = %s + """, (memory, json.dumps(props), embedding, id, user_name)) + else: + cur.execute(f""" + UPDATE {self.schema}.memories + SET memory = %s, properties = %s, updated_at = NOW() + WHERE id = %s AND user_name = %s + """, (memory, json.dumps(props), id, user_name)) + finally: + self._put_conn(conn) + + def delete_node(self, id: str, user_name: str | None = None) -> None: + """Delete a node and its edges.""" + user_name = user_name or self.user_name + conn = self._get_conn() + try: + with conn.cursor() as cur: + # Delete edges + cur.execute(f""" + DELETE FROM {self.schema}.edges + WHERE source_id = %s OR target_id = %s + """, (id, id)) + # Delete node + cur.execute(f""" + DELETE FROM {self.schema}.memories + WHERE id = %s AND user_name = %s + """, (id, user_name)) + finally: + self._put_conn(conn) + + def get_node(self, id: str, include_embedding: bool = False, **kwargs) -> dict[str, Any] | None: + """Get a single node by ID.""" + user_name = kwargs.get("user_name") or self.user_name + conn = self._get_conn() + try: + with conn.cursor() as cur: + cols = "id, memory, properties, created_at, updated_at" + if include_embedding: + cols += ", embedding" + cur.execute(f""" + SELECT {cols} FROM {self.schema}.memories + WHERE id = %s AND user_name = %s + """, (id, user_name)) + row = cur.fetchone() + if not row: + return None + return self._parse_row(row, include_embedding) + finally: + self._put_conn(conn) + + def get_nodes( + self, ids: list, include_embedding: bool = False, **kwargs + ) -> list[dict[str, Any]]: + """Get multiple nodes by IDs.""" + if not ids: + return [] + user_name = kwargs.get("user_name") or self.user_name + conn = self._get_conn() + try: + with conn.cursor() as cur: + cols = "id, memory, properties, created_at, updated_at" + if include_embedding: + cols += ", embedding" + cur.execute(f""" + SELECT {cols} FROM {self.schema}.memories + WHERE id = ANY(%s) AND user_name = %s + """, (ids, user_name)) + return [self._parse_row(row, include_embedding) for row in cur.fetchall()] + finally: + self._put_conn(conn) + + def _parse_row(self, row, include_embedding: bool = False) -> dict[str, Any]: + """Parse database row to node dict.""" + props = row[2] if isinstance(row[2], dict) else json.loads(row[2] or "{}") + props["created_at"] = row[3].isoformat() if row[3] else None + props["updated_at"] = row[4].isoformat() if row[4] else None + result = { + "id": row[0], + "memory": row[1] or "", + "metadata": props, + } + if include_embedding and len(row) > 5: + result["metadata"]["embedding"] = row[5] + return result + + # ========================================================================= + # Edge Management + # ========================================================================= + + def add_edge( + self, source_id: str, target_id: str, type: str, user_name: str | None = None + ) -> None: + """Create an edge between nodes.""" + conn = self._get_conn() + try: + with conn.cursor() as cur: + cur.execute(f""" + INSERT INTO {self.schema}.edges (source_id, target_id, edge_type) + VALUES (%s, %s, %s) + ON CONFLICT (source_id, target_id, edge_type) DO NOTHING + """, (source_id, target_id, type)) + finally: + self._put_conn(conn) + + def delete_edge( + self, source_id: str, target_id: str, type: str, user_name: str | None = None + ) -> None: + """Delete an edge.""" + conn = self._get_conn() + try: + with conn.cursor() as cur: + cur.execute(f""" + DELETE FROM {self.schema}.edges + WHERE source_id = %s AND target_id = %s AND edge_type = %s + """, (source_id, target_id, type)) + finally: + self._put_conn(conn) + + def edge_exists(self, source_id: str, target_id: str, type: str) -> bool: + """Check if edge exists.""" + conn = self._get_conn() + try: + with conn.cursor() as cur: + cur.execute(f""" + SELECT 1 FROM {self.schema}.edges + WHERE source_id = %s AND target_id = %s AND edge_type = %s + LIMIT 1 + """, (source_id, target_id, type)) + return cur.fetchone() is not None + finally: + self._put_conn(conn) + + # ========================================================================= + # Graph Queries + # ========================================================================= + + def get_neighbors( + self, id: str, type: str, direction: Literal["in", "out", "both"] = "out" + ) -> list[str]: + """Get neighboring node IDs.""" + conn = self._get_conn() + try: + with conn.cursor() as cur: + if direction == "out": + cur.execute(f""" + SELECT target_id FROM {self.schema}.edges + WHERE source_id = %s AND edge_type = %s + """, (id, type)) + elif direction == "in": + cur.execute(f""" + SELECT source_id FROM {self.schema}.edges + WHERE target_id = %s AND edge_type = %s + """, (id, type)) + else: # both + cur.execute(f""" + SELECT target_id FROM {self.schema}.edges WHERE source_id = %s AND edge_type = %s + UNION + SELECT source_id FROM {self.schema}.edges WHERE target_id = %s AND edge_type = %s + """, (id, type, id, type)) + return [row[0] for row in cur.fetchall()] + finally: + self._put_conn(conn) + + def get_path(self, source_id: str, target_id: str, max_depth: int = 3) -> list[str]: + """Get path between nodes using recursive CTE.""" + conn = self._get_conn() + try: + with conn.cursor() as cur: + cur.execute(f""" + WITH RECURSIVE path AS ( + SELECT source_id, target_id, ARRAY[source_id] as nodes, 1 as depth + FROM {self.schema}.edges + WHERE source_id = %s + UNION ALL + SELECT e.source_id, e.target_id, p.nodes || e.source_id, p.depth + 1 + FROM {self.schema}.edges e + JOIN path p ON e.source_id = p.target_id + WHERE p.depth < %s AND NOT e.source_id = ANY(p.nodes) + ) + SELECT nodes || target_id as full_path + FROM path + WHERE target_id = %s + ORDER BY depth + LIMIT 1 + """, (source_id, max_depth, target_id)) + row = cur.fetchone() + return row[0] if row else [] + finally: + self._put_conn(conn) + + def get_subgraph(self, center_id: str, depth: int = 2) -> list[str]: + """Get subgraph around center node.""" + conn = self._get_conn() + try: + with conn.cursor() as cur: + cur.execute(f""" + WITH RECURSIVE subgraph AS ( + SELECT %s::text as node_id, 0 as level + UNION + SELECT CASE WHEN e.source_id = s.node_id THEN e.target_id ELSE e.source_id END, + s.level + 1 + FROM {self.schema}.edges e + JOIN subgraph s ON (e.source_id = s.node_id OR e.target_id = s.node_id) + WHERE s.level < %s + ) + SELECT DISTINCT node_id FROM subgraph + """, (center_id, depth)) + return [row[0] for row in cur.fetchall()] + finally: + self._put_conn(conn) + + def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]: + """Get ordered chain following relationship type.""" + return self.get_neighbors(id, type, "out") + + # ========================================================================= + # Search Operations + # ========================================================================= + + def search_by_embedding( + self, + vector: list[float], + top_k: int = 5, + scope: str | None = None, + status: str | None = None, + threshold: float | None = None, + search_filter: dict | None = None, + user_name: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, + **kwargs, + ) -> list[dict]: + """Search nodes by vector similarity using pgvector.""" + user_name = user_name or self.user_name + + # Build WHERE clause + conditions = ["embedding IS NOT NULL"] + params = [] + + if user_name: + conditions.append("user_name = %s") + params.append(user_name) + + if scope: + conditions.append("properties->>'memory_type' = %s") + params.append(scope) + + if status: + conditions.append("properties->>'status' = %s") + params.append(status) + else: + conditions.append("(properties->>'status' = 'activated' OR properties->>'status' IS NULL)") + + if search_filter: + for k, v in search_filter.items(): + conditions.append(f"properties->>'{k}' = %s") + params.append(str(v)) + + where_clause = " AND ".join(conditions) + + # pgvector cosine distance: 1 - (a <=> b) gives similarity score + conn = self._get_conn() + try: + with conn.cursor() as cur: + cur.execute(f""" + SELECT id, 1 - (embedding <=> %s::vector) as score + FROM {self.schema}.memories + WHERE {where_clause} + ORDER BY embedding <=> %s::vector + LIMIT %s + """, (vector, *params, vector, top_k)) + + results = [] + for row in cur.fetchall(): + score = float(row[1]) + if threshold is None or score >= threshold: + results.append({"id": row[0], "score": score}) + return results + finally: + self._put_conn(conn) + + def get_by_metadata( + self, + filters: list[dict[str, Any]], + status: str | None = None, + user_name: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, + user_name_flag: bool = True, + ) -> list[str]: + """Get node IDs matching metadata filters.""" + user_name = user_name or self.user_name + + conditions = [] + params = [] + + if user_name_flag and user_name: + conditions.append("user_name = %s") + params.append(user_name) + + if status: + conditions.append("properties->>'status' = %s") + params.append(status) + + for f in filters: + field = f["field"] + op = f.get("op", "=") + value = f["value"] + + if op == "=": + conditions.append(f"properties->>'{field}' = %s") + params.append(str(value)) + elif op == "in": + placeholders = ",".join(["%s"] * len(value)) + conditions.append(f"properties->>'{field}' IN ({placeholders})") + params.extend([str(v) for v in value]) + elif op in (">", ">=", "<", "<="): + conditions.append(f"(properties->>'{field}')::numeric {op} %s") + params.append(value) + elif op == "contains": + conditions.append(f"properties->'{field}' @> %s::jsonb") + params.append(json.dumps([value])) + + where_clause = " AND ".join(conditions) if conditions else "TRUE" + + conn = self._get_conn() + try: + with conn.cursor() as cur: + cur.execute(f""" + SELECT id FROM {self.schema}.memories + WHERE {where_clause} + """, params) + return [row[0] for row in cur.fetchall()] + finally: + self._put_conn(conn) + + def get_all_memory_items( + self, + scope: str, + include_embedding: bool = False, + status: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, + **kwargs, + ) -> list[dict]: + """Get all memory items of a specific type.""" + user_name = kwargs.get("user_name") or self.user_name + + conditions = ["properties->>'memory_type' = %s", "user_name = %s"] + params = [scope, user_name] + + if status: + conditions.append("properties->>'status' = %s") + params.append(status) + + where_clause = " AND ".join(conditions) + + conn = self._get_conn() + try: + with conn.cursor() as cur: + cols = "id, memory, properties, created_at, updated_at" + if include_embedding: + cols += ", embedding" + cur.execute(f""" + SELECT {cols} FROM {self.schema}.memories + WHERE {where_clause} + """, params) + return [self._parse_row(row, include_embedding) for row in cur.fetchall()] + finally: + self._put_conn(conn) + + def get_structure_optimization_candidates( + self, scope: str, include_embedding: bool = False + ) -> list[dict]: + """Find isolated nodes (no edges).""" + user_name = self.user_name + conn = self._get_conn() + try: + with conn.cursor() as cur: + cols = "m.id, m.memory, m.properties, m.created_at, m.updated_at" + cur.execute(f""" + SELECT {cols} + FROM {self.schema}.memories m + LEFT JOIN {self.schema}.edges e1 ON m.id = e1.source_id + LEFT JOIN {self.schema}.edges e2 ON m.id = e2.target_id + WHERE m.properties->>'memory_type' = %s + AND m.user_name = %s + AND m.properties->>'status' = 'activated' + AND e1.id IS NULL + AND e2.id IS NULL + """, (scope, user_name)) + return [self._parse_row(row, False) for row in cur.fetchall()] + finally: + self._put_conn(conn) + + # ========================================================================= + # Maintenance + # ========================================================================= + + def deduplicate_nodes(self) -> None: + """Not implemented - handled at application level.""" + pass + + def get_grouped_counts( + self, + group_fields: list[str], + where_clause: str = "", + params: dict[str, Any] | None = None, + user_name: str | None = None, + ) -> list[dict[str, Any]]: + """ + Count nodes grouped by specified fields. + + Args: + group_fields: Fields to group by, e.g., ["memory_type", "status"] + where_clause: Extra WHERE condition + params: Parameters for WHERE clause + user_name: User to filter by + + Returns: + list[dict]: e.g., [{'memory_type': 'WorkingMemory', 'count': 10}, ...] + """ + user_name = user_name or self.user_name + if not group_fields: + raise ValueError("group_fields cannot be empty") + + # Build SELECT and GROUP BY clauses + # Fields come from JSONB properties column + select_fields = ", ".join([ + f"properties->>'{field}' AS {field}" for field in group_fields + ]) + group_by = ", ".join([f"properties->>'{field}'" for field in group_fields]) + + # Build WHERE clause + conditions = [f"user_name = %s"] + query_params = [user_name] + + if where_clause: + # Parse simple where clause format + where_clause = where_clause.strip() + if where_clause.upper().startswith("WHERE"): + where_clause = where_clause[5:].strip() + if where_clause: + conditions.append(where_clause) + if params: + query_params.extend(params.values()) + + where_sql = " AND ".join(conditions) + + query = f""" + SELECT {select_fields}, COUNT(*) AS count + FROM {self.schema}.memories + WHERE {where_sql} + GROUP BY {group_by} + """ + + conn = self._get_conn() + try: + with conn.cursor() as cur: + cur.execute(query, query_params) + results = [] + for row in cur.fetchall(): + result = {} + for i, field in enumerate(group_fields): + result[field] = row[i] + result["count"] = row[len(group_fields)] + results.append(result) + return results + finally: + self._put_conn(conn) + + def detect_conflicts(self) -> list[tuple[str, str]]: + """Not implemented.""" + return [] + + def merge_nodes(self, id1: str, id2: str) -> str: + """Not implemented.""" + raise NotImplementedError + + def clear(self, user_name: str | None = None) -> None: + """Clear all data for user.""" + user_name = user_name or self.user_name + conn = self._get_conn() + try: + with conn.cursor() as cur: + # Get all node IDs for user + cur.execute(f""" + SELECT id FROM {self.schema}.memories WHERE user_name = %s + """, (user_name,)) + ids = [row[0] for row in cur.fetchall()] + + if ids: + # Delete edges + cur.execute(f""" + DELETE FROM {self.schema}.edges + WHERE source_id = ANY(%s) OR target_id = ANY(%s) + """, (ids, ids)) + + # Delete nodes + cur.execute(f""" + DELETE FROM {self.schema}.memories WHERE user_name = %s + """, (user_name,)) + logger.info(f"Cleared all data for user {user_name}") + finally: + self._put_conn(conn) + + def export_graph(self, include_embedding: bool = False, **kwargs) -> dict[str, Any]: + """Export all data.""" + user_name = kwargs.get("user_name") or self.user_name + conn = self._get_conn() + try: + with conn.cursor() as cur: + # Get nodes + cols = "id, memory, properties, created_at, updated_at" + if include_embedding: + cols += ", embedding" + cur.execute(f""" + SELECT {cols} FROM {self.schema}.memories + WHERE user_name = %s + ORDER BY created_at DESC + """, (user_name,)) + nodes = [self._parse_row(row, include_embedding) for row in cur.fetchall()] + + # Get edges + node_ids = [n["id"] for n in nodes] + if node_ids: + cur.execute(f""" + SELECT source_id, target_id, edge_type + FROM {self.schema}.edges + WHERE source_id = ANY(%s) OR target_id = ANY(%s) + """, (node_ids, node_ids)) + edges = [ + {"source": row[0], "target": row[1], "type": row[2]} + for row in cur.fetchall() + ] + else: + edges = [] + + return { + "nodes": nodes, + "edges": edges, + "total_nodes": len(nodes), + "total_edges": len(edges), + } + finally: + self._put_conn(conn) + + def import_graph(self, data: dict[str, Any], user_name: str | None = None) -> None: + """Import graph data.""" + user_name = user_name or self.user_name + + for node in data.get("nodes", []): + self.add_node( + id=node["id"], + memory=node.get("memory", ""), + metadata=node.get("metadata", {}), + user_name=user_name, + ) + + for edge in data.get("edges", []): + self.add_edge( + source_id=edge["source"], + target_id=edge["target"], + type=edge["type"], + ) + + def close(self): + """Close connection pool.""" + if not self._pool_closed: + self._pool_closed = True + self.pool.closeall() diff --git a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py index 903088a4c..b103acf3a 100644 --- a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py +++ b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py @@ -61,9 +61,11 @@ def build_graph_db_config(user_id: str = "default") -> dict[str, Any]: "neo4j": APIConfig.get_neo4j_config(user_id=user_id), "nebular": APIConfig.get_nebular_config(user_id=user_id), "polardb": APIConfig.get_polardb_config(user_id=user_id), + "postgres": APIConfig.get_postgres_config(user_id=user_id), } - graph_db_backend = os.getenv("NEO4J_BACKEND", "nebular").lower() + # Support both GRAPH_DB_BACKEND and legacy NEO4J_BACKEND env vars + graph_db_backend = os.getenv("GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "nebular")).lower() return GraphDBConfigFactory.model_validate( { "backend": graph_db_backend, diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py index 4541b118b..be1841232 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py @@ -390,18 +390,41 @@ def search_path_b(): if not all_hits: return [] - # merge and deduplicate - unique_ids = {r["id"] for r in all_hits if r.get("id")} + # merge and deduplicate, keeping highest score per ID + id_to_score = {} + for r in all_hits: + rid = r.get("id") + if rid: + score = r.get("score", 0.0) + if rid not in id_to_score or score > id_to_score[rid]: + id_to_score[rid] = score + + # Sort IDs by score (descending) to preserve ranking + sorted_ids = sorted(id_to_score.keys(), key=lambda x: id_to_score[x], reverse=True) + node_dicts = ( self.graph_store.get_nodes( - list(unique_ids), + sorted_ids, include_embedding=self.include_embedding, cube_name=cube_name, user_name=user_name, ) or [] ) - return [TextualMemoryItem.from_dict(n) for n in node_dicts] + + # Restore score-based order and inject scores into metadata + id_to_node = {n.get("id"): n for n in node_dicts} + ordered_nodes = [] + for rid in sorted_ids: + if rid in id_to_node: + node = id_to_node[rid] + # Inject similarity score as relativity + if "metadata" not in node: + node["metadata"] = {} + node["metadata"]["relativity"] = id_to_score.get(rid, 0.0) + ordered_nodes.append(node) + + return [TextualMemoryItem.from_dict(n) for n in ordered_nodes] def _bm25_recall( self, @@ -483,15 +506,38 @@ def _fulltext_recall( if not all_hits: return [] - # merge and deduplicate - unique_ids = {r["id"] for r in all_hits if r.get("id")} + # merge and deduplicate, keeping highest score per ID + id_to_score = {} + for r in all_hits: + rid = r.get("id") + if rid: + score = r.get("score", 0.0) + if rid not in id_to_score or score > id_to_score[rid]: + id_to_score[rid] = score + + # Sort IDs by score (descending) to preserve ranking + sorted_ids = sorted(id_to_score.keys(), key=lambda x: id_to_score[x], reverse=True) + node_dicts = ( self.graph_store.get_nodes( - list(unique_ids), + sorted_ids, include_embedding=self.include_embedding, cube_name=cube_name, user_name=user_name, ) or [] ) - return [TextualMemoryItem.from_dict(n) for n in node_dicts] + + # Restore score-based order and inject scores into metadata + id_to_node = {n.get("id"): n for n in node_dicts} + ordered_nodes = [] + for rid in sorted_ids: + if rid in id_to_node: + node = id_to_node[rid] + # Inject similarity score as relativity + if "metadata" not in node: + node["metadata"] = {} + node["metadata"]["relativity"] = id_to_score.get(rid, 0.0) + ordered_nodes.append(node) + + return [TextualMemoryItem.from_dict(n) for n in ordered_nodes] diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/reranker.py b/src/memos/memories/textual/tree_text_memory/retrieve/reranker.py index 861343e20..b8ab813dc 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/reranker.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/reranker.py @@ -78,7 +78,11 @@ def rerank( embeddings = [item.metadata.embedding for item in items_with_embeddings] if not embeddings: - return [(item, 0.5) for item in graph_results[:top_k]] + # Use relativity from recall stage if available, otherwise default to 0.5 + return [ + (item, getattr(item.metadata, "relativity", None) or 0.5) + for item in graph_results[:top_k] + ] # Step 2: Compute cosine similarities similarity_scores = batch_cosine_similarity(query_embedding, embeddings)