From 05ee090b319d792d39cca1f823adee369857107d Mon Sep 17 00:00:00 2001 From: Anatoly Koptev Date: Sat, 24 Jan 2026 06:11:56 -0800 Subject: [PATCH 1/7] fix: remove duplicate MOSMCPStdioServer class, use MOSMCPServer The MOSMCPStdioServer class was calling _setup_tools() which was not defined. Consolidated into MOSMCPServer which has the proper implementation. --- src/memos/api/mcp_serve.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/src/memos/api/mcp_serve.py b/src/memos/api/mcp_serve.py index ce2e41390..8f8e70311 100644 --- a/src/memos/api/mcp_serve.py +++ b/src/memos/api/mcp_serve.py @@ -122,15 +122,6 @@ def load_default_config(user_id="default_user"): return config, cube -class MOSMCPStdioServer: - def __init__(self): - self.mcp = FastMCP("MOS Memory System") - config, cube = load_default_config() - self.mos_core = MOS(config=config) - self.mos_core.register_mem_cube(cube) - self._setup_tools() - - class MOSMCPServer: """MCP Server that accepts an existing MOS instance.""" @@ -584,7 +575,6 @@ def _run_mcp(self, transport: str = "stdio", **kwargs): raise ValueError(f"Unsupported transport: {transport}") -MOSMCPStdioServer.run = _run_mcp MOSMCPServer.run = _run_mcp @@ -610,5 +600,5 @@ def _run_mcp(self, transport: str = "stdio", **kwargs): args = parser.parse_args() # Create and run MCP server - server = MOSMCPStdioServer() + server = MOSMCPServer() server.run(transport=args.transport, host=args.host, port=args.port) From 56d59277c96571d839bfd1dbcf5a713627017470 Mon Sep 17 00:00:00 2001 From: Anatoly Koptev Date: Sat, 24 Jan 2026 17:04:21 -0800 Subject: [PATCH 2/7] feat: add PostgreSQL + pgvector backend for graph database - Create PostgresGraphDB class with full BaseGraphDB implementation - Add PostgresGraphDBConfig with connection pooling support - Register postgres backend in GraphStoreFactory - Update APIConfig with get_postgres_config method - Support GRAPH_DB_BACKEND env var with neo4j fallback Replaces Neo4j dependency with native PostgreSQL using: - JSONB for flexible node properties - pgvector for embedding similarity search - Standard SQL for graph traversal --- src/memos/api/config.py | 34 +- src/memos/api/handlers/config_builders.py | 4 +- src/memos/configs/graph_db.py | 53 ++ src/memos/graph_dbs/factory.py | 2 + src/memos/graph_dbs/postgres.py | 769 ++++++++++++++++++ .../init_components_for_scheduler.py | 4 +- 6 files changed, 862 insertions(+), 4 deletions(-) create mode 100644 src/memos/graph_dbs/postgres.py diff --git a/src/memos/api/config.py b/src/memos/api/config.py index a3bf25be0..ad017ad78 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -628,6 +628,30 @@ def get_polardb_config(user_id: str | None = None) -> dict[str, Any]: "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 1024)), } + @staticmethod + def get_postgres_config(user_id: str | None = None) -> dict[str, Any]: + """Get PostgreSQL + pgvector configuration for MemOS graph storage. + + Uses standard PostgreSQL with pgvector extension. + Schema: memos.memories, memos.edges + """ + user_name = os.getenv("MEMOS_USER_NAME", "default") + if user_id: + user_name = f"memos_{user_id.replace('-', '')}" + + return { + "host": os.getenv("POSTGRES_HOST", "postgres"), + "port": int(os.getenv("POSTGRES_PORT", "5432")), + "user": os.getenv("POSTGRES_USER", "n8n"), + "password": os.getenv("POSTGRES_PASSWORD", ""), + "db_name": os.getenv("POSTGRES_DB", "n8n"), + "schema_name": os.getenv("MEMOS_SCHEMA", "memos"), + "user_name": user_name, + "use_multi_db": False, + "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", "384")), + "maxconn": int(os.getenv("POSTGRES_MAX_CONN", "20")), + } + @staticmethod def get_mysql_config() -> dict[str, Any]: """Get MySQL configuration.""" @@ -884,13 +908,16 @@ def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "Gene if os.getenv("ENABLE_INTERNET", "false").lower() == "true" else None ) + postgres_config = APIConfig.get_postgres_config(user_id=user_id) graph_db_backend_map = { "neo4j-community": neo4j_community_config, "neo4j": neo4j_config, "nebular": nebular_config, "polardb": polardb_config, + "postgres": postgres_config, } - graph_db_backend = os.getenv("NEO4J_BACKEND", "neo4j-community").lower() + # Support both GRAPH_DB_BACKEND and legacy NEO4J_BACKEND env vars + graph_db_backend = os.getenv("GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "neo4j-community")).lower() if graph_db_backend in graph_db_backend_map: # Create MemCube config @@ -958,18 +985,21 @@ def get_default_cube_config() -> "GeneralMemCubeConfig | None": neo4j_config = APIConfig.get_neo4j_config(user_id="default") nebular_config = APIConfig.get_nebular_config(user_id="default") polardb_config = APIConfig.get_polardb_config(user_id="default") + postgres_config = APIConfig.get_postgres_config(user_id="default") graph_db_backend_map = { "neo4j-community": neo4j_community_config, "neo4j": neo4j_config, "nebular": nebular_config, "polardb": polardb_config, + "postgres": postgres_config, } internet_config = ( APIConfig.get_internet_config() if os.getenv("ENABLE_INTERNET", "false").lower() == "true" else None ) - graph_db_backend = os.getenv("NEO4J_BACKEND", "neo4j-community").lower() + # Support both GRAPH_DB_BACKEND and legacy NEO4J_BACKEND env vars + graph_db_backend = os.getenv("GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "neo4j-community")).lower() if graph_db_backend in graph_db_backend_map: return GeneralMemCubeConfig.model_validate( { diff --git a/src/memos/api/handlers/config_builders.py b/src/memos/api/handlers/config_builders.py index fce789e2a..2d82cb3ca 100644 --- a/src/memos/api/handlers/config_builders.py +++ b/src/memos/api/handlers/config_builders.py @@ -41,9 +41,11 @@ def build_graph_db_config(user_id: str = "default") -> dict[str, Any]: "neo4j": APIConfig.get_neo4j_config(user_id=user_id), "nebular": APIConfig.get_nebular_config(user_id=user_id), "polardb": APIConfig.get_polardb_config(user_id=user_id), + "postgres": APIConfig.get_postgres_config(user_id=user_id), } - graph_db_backend = os.getenv("NEO4J_BACKEND", "nebular").lower() + # Support both GRAPH_DB_BACKEND and legacy NEO4J_BACKEND env vars + graph_db_backend = os.getenv("GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "nebular")).lower() return GraphDBConfigFactory.model_validate( { "backend": graph_db_backend, diff --git a/src/memos/configs/graph_db.py b/src/memos/configs/graph_db.py index 3b4bace0e..7feda1570 100644 --- a/src/memos/configs/graph_db.py +++ b/src/memos/configs/graph_db.py @@ -211,6 +211,58 @@ def validate_config(self): return self +class PostgresGraphDBConfig(BaseConfig): + """ + PostgreSQL + pgvector configuration for MemOS. + + Uses standard PostgreSQL with pgvector extension for vector search. + Does NOT require Apache AGE or other graph extensions. + + Schema: + - memos_memories: Main table for memory nodes (id, memory, properties JSONB, embedding vector) + - memos_edges: Edge table for relationships (source_id, target_id, type) + + Example: + --- + host = "postgres" + port = 5432 + user = "n8n" + password = "secret" + db_name = "n8n" + schema_name = "memos" + user_name = "default" + """ + + host: str = Field(..., description="Database host") + port: int = Field(default=5432, description="Database port") + user: str = Field(..., description="Database user") + password: str = Field(..., description="Database password") + db_name: str = Field(..., description="Database name") + schema_name: str = Field(default="memos", description="Schema name for MemOS tables") + user_name: str | None = Field( + default=None, + description="Logical user/tenant ID for data isolation", + ) + use_multi_db: bool = Field( + default=False, + description="If False: use single database with logical isolation by user_name", + ) + embedding_dimension: int = Field(default=384, description="Dimension of vector embedding") + maxconn: int = Field( + default=20, + description="Maximum number of connections in the connection pool", + ) + + @model_validator(mode="after") + def validate_config(self): + """Validate config.""" + if not self.db_name: + raise ValueError("`db_name` must be provided") + if not self.use_multi_db and not self.user_name: + raise ValueError("In single-database mode, `user_name` must be provided") + return self + + class GraphDBConfigFactory(BaseModel): backend: str = Field(..., description="Backend for graph database") config: dict[str, Any] = Field(..., description="Configuration for the graph database backend") @@ -220,6 +272,7 @@ class GraphDBConfigFactory(BaseModel): "neo4j-community": Neo4jCommunityGraphDBConfig, "nebular": NebulaGraphDBConfig, "polardb": PolarDBGraphDBConfig, + "postgres": PostgresGraphDBConfig, } @field_validator("backend") diff --git a/src/memos/graph_dbs/factory.py b/src/memos/graph_dbs/factory.py index ec9cbcda0..c207e3190 100644 --- a/src/memos/graph_dbs/factory.py +++ b/src/memos/graph_dbs/factory.py @@ -6,6 +6,7 @@ from memos.graph_dbs.neo4j import Neo4jGraphDB from memos.graph_dbs.neo4j_community import Neo4jCommunityGraphDB from memos.graph_dbs.polardb import PolarDBGraphDB +from memos.graph_dbs.postgres import PostgresGraphDB class GraphStoreFactory(BaseGraphDB): @@ -16,6 +17,7 @@ class GraphStoreFactory(BaseGraphDB): "neo4j-community": Neo4jCommunityGraphDB, "nebular": NebulaGraphDB, "polardb": PolarDBGraphDB, + "postgres": PostgresGraphDB, } @classmethod diff --git a/src/memos/graph_dbs/postgres.py b/src/memos/graph_dbs/postgres.py new file mode 100644 index 000000000..d3c621059 --- /dev/null +++ b/src/memos/graph_dbs/postgres.py @@ -0,0 +1,769 @@ +""" +PostgreSQL + pgvector backend for MemOS. + +Simple implementation using standard PostgreSQL with pgvector extension. +No Apache AGE or other graph extensions required. + +Tables: +- {schema}.memories: Memory nodes with JSONB properties and vector embeddings +- {schema}.edges: Relationships between memory nodes +""" + +import json +import time +from contextlib import suppress +from datetime import datetime +from typing import Any, Literal + +from memos.configs.graph_db import PostgresGraphDBConfig +from memos.dependency import require_python_package +from memos.graph_dbs.base import BaseGraphDB +from memos.log import get_logger + +logger = get_logger(__name__) + + +def _prepare_node_metadata(metadata: dict[str, Any]) -> dict[str, Any]: + """Ensure metadata has proper datetime fields and normalized types.""" + now = datetime.utcnow().isoformat() + metadata.setdefault("created_at", now) + metadata.setdefault("updated_at", now) + + # Normalize embedding type + embedding = metadata.get("embedding") + if embedding and isinstance(embedding, list): + metadata["embedding"] = [float(x) for x in embedding] + + return metadata + + +class PostgresGraphDB(BaseGraphDB): + """PostgreSQL + pgvector implementation of a graph memory store.""" + + @require_python_package( + import_name="psycopg2", + install_command="pip install psycopg2-binary", + install_link="https://pypi.org/project/psycopg2-binary/", + ) + def __init__(self, config: PostgresGraphDBConfig): + """Initialize PostgreSQL connection pool.""" + import psycopg2 + import psycopg2.pool + + self.config = config + self.schema = config.schema_name + self.user_name = config.user_name + self._pool_closed = False + + logger.info(f"Connecting to PostgreSQL: {config.host}:{config.port}/{config.db_name}") + + # Create connection pool + self.pool = psycopg2.pool.ThreadedConnectionPool( + minconn=2, + maxconn=config.maxconn, + host=config.host, + port=config.port, + user=config.user, + password=config.password, + dbname=config.db_name, + connect_timeout=30, + keepalives_idle=30, + keepalives_interval=10, + keepalives_count=5, + ) + + # Initialize schema and tables + self._init_schema() + + def _get_conn(self): + """Get connection from pool with health check.""" + if self._pool_closed: + raise RuntimeError("Connection pool is closed") + + for attempt in range(3): + conn = None + try: + conn = self.pool.getconn() + if conn.closed != 0: + self.pool.putconn(conn, close=True) + continue + conn.autocommit = True + # Health check + with conn.cursor() as cur: + cur.execute("SELECT 1") + return conn + except Exception as e: + if conn: + with suppress(Exception): + self.pool.putconn(conn, close=True) + if attempt == 2: + raise RuntimeError(f"Failed to get connection: {e}") from e + time.sleep(0.1) + raise RuntimeError("Failed to get healthy connection") + + def _put_conn(self, conn): + """Return connection to pool.""" + if conn and not self._pool_closed: + try: + self.pool.putconn(conn) + except Exception: + with suppress(Exception): + conn.close() + + def _init_schema(self): + """Create schema and tables if they don't exist.""" + conn = self._get_conn() + try: + with conn.cursor() as cur: + # Create schema + cur.execute(f"CREATE SCHEMA IF NOT EXISTS {self.schema}") + + # Enable pgvector + cur.execute("CREATE EXTENSION IF NOT EXISTS vector") + + # Create memories table + dim = self.config.embedding_dimension + cur.execute(f""" + CREATE TABLE IF NOT EXISTS {self.schema}.memories ( + id TEXT PRIMARY KEY, + memory TEXT NOT NULL DEFAULT '', + properties JSONB NOT NULL DEFAULT '{{}}', + embedding vector({dim}), + user_name TEXT, + created_at TIMESTAMPTZ DEFAULT NOW(), + updated_at TIMESTAMPTZ DEFAULT NOW() + ) + """) + + # Create edges table + cur.execute(f""" + CREATE TABLE IF NOT EXISTS {self.schema}.edges ( + id SERIAL PRIMARY KEY, + source_id TEXT NOT NULL, + target_id TEXT NOT NULL, + edge_type TEXT NOT NULL, + created_at TIMESTAMPTZ DEFAULT NOW(), + UNIQUE(source_id, target_id, edge_type) + ) + """) + + # Create indexes + cur.execute(f""" + CREATE INDEX IF NOT EXISTS idx_memories_user + ON {self.schema}.memories(user_name) + """) + cur.execute(f""" + CREATE INDEX IF NOT EXISTS idx_memories_props + ON {self.schema}.memories USING GIN(properties) + """) + cur.execute(f""" + CREATE INDEX IF NOT EXISTS idx_memories_embedding + ON {self.schema}.memories USING ivfflat(embedding vector_cosine_ops) + WITH (lists = 100) + """) + cur.execute(f""" + CREATE INDEX IF NOT EXISTS idx_edges_source + ON {self.schema}.edges(source_id) + """) + cur.execute(f""" + CREATE INDEX IF NOT EXISTS idx_edges_target + ON {self.schema}.edges(target_id) + """) + + logger.info(f"Schema {self.schema} initialized successfully") + except Exception as e: + logger.error(f"Failed to init schema: {e}") + raise + finally: + self._put_conn(conn) + + # ========================================================================= + # Node Management + # ========================================================================= + + def add_node( + self, id: str, memory: str, metadata: dict[str, Any], user_name: str | None = None + ) -> None: + """Add a memory node.""" + user_name = user_name or self.user_name + metadata = _prepare_node_metadata(metadata.copy()) + + # Extract embedding + embedding = metadata.pop("embedding", None) + created_at = metadata.pop("created_at", datetime.utcnow().isoformat()) + updated_at = metadata.pop("updated_at", datetime.utcnow().isoformat()) + + # Serialize sources if present + if metadata.get("sources"): + metadata["sources"] = [ + json.dumps(s) if not isinstance(s, str) else s + for s in metadata["sources"] + ] + + conn = self._get_conn() + try: + with conn.cursor() as cur: + if embedding: + cur.execute(f""" + INSERT INTO {self.schema}.memories + (id, memory, properties, embedding, user_name, created_at, updated_at) + VALUES (%s, %s, %s, %s::vector, %s, %s, %s) + ON CONFLICT (id) DO UPDATE SET + memory = EXCLUDED.memory, + properties = EXCLUDED.properties, + embedding = EXCLUDED.embedding, + updated_at = EXCLUDED.updated_at + """, (id, memory, json.dumps(metadata), embedding, user_name, created_at, updated_at)) + else: + cur.execute(f""" + INSERT INTO {self.schema}.memories + (id, memory, properties, user_name, created_at, updated_at) + VALUES (%s, %s, %s, %s, %s, %s) + ON CONFLICT (id) DO UPDATE SET + memory = EXCLUDED.memory, + properties = EXCLUDED.properties, + updated_at = EXCLUDED.updated_at + """, (id, memory, json.dumps(metadata), user_name, created_at, updated_at)) + finally: + self._put_conn(conn) + + def add_nodes_batch( + self, nodes: list[dict[str, Any]], user_name: str | None = None + ) -> None: + """Batch add memory nodes.""" + for node in nodes: + self.add_node( + id=node["id"], + memory=node["memory"], + metadata=node.get("metadata", {}), + user_name=user_name, + ) + + def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = None) -> None: + """Update node fields.""" + user_name = user_name or self.user_name + if not fields: + return + + # Get current node + current = self.get_node(id, user_name=user_name) + if not current: + return + + # Merge properties + props = current.get("metadata", {}).copy() + embedding = fields.pop("embedding", None) + memory = fields.pop("memory", current.get("memory", "")) + props.update(fields) + props["updated_at"] = datetime.utcnow().isoformat() + + conn = self._get_conn() + try: + with conn.cursor() as cur: + if embedding: + cur.execute(f""" + UPDATE {self.schema}.memories + SET memory = %s, properties = %s, embedding = %s::vector, updated_at = NOW() + WHERE id = %s AND user_name = %s + """, (memory, json.dumps(props), embedding, id, user_name)) + else: + cur.execute(f""" + UPDATE {self.schema}.memories + SET memory = %s, properties = %s, updated_at = NOW() + WHERE id = %s AND user_name = %s + """, (memory, json.dumps(props), id, user_name)) + finally: + self._put_conn(conn) + + def delete_node(self, id: str, user_name: str | None = None) -> None: + """Delete a node and its edges.""" + user_name = user_name or self.user_name + conn = self._get_conn() + try: + with conn.cursor() as cur: + # Delete edges + cur.execute(f""" + DELETE FROM {self.schema}.edges + WHERE source_id = %s OR target_id = %s + """, (id, id)) + # Delete node + cur.execute(f""" + DELETE FROM {self.schema}.memories + WHERE id = %s AND user_name = %s + """, (id, user_name)) + finally: + self._put_conn(conn) + + def get_node(self, id: str, include_embedding: bool = False, **kwargs) -> dict[str, Any] | None: + """Get a single node by ID.""" + user_name = kwargs.get("user_name") or self.user_name + conn = self._get_conn() + try: + with conn.cursor() as cur: + cols = "id, memory, properties, created_at, updated_at" + if include_embedding: + cols += ", embedding" + cur.execute(f""" + SELECT {cols} FROM {self.schema}.memories + WHERE id = %s AND user_name = %s + """, (id, user_name)) + row = cur.fetchone() + if not row: + return None + return self._parse_row(row, include_embedding) + finally: + self._put_conn(conn) + + def get_nodes( + self, ids: list, include_embedding: bool = False, **kwargs + ) -> list[dict[str, Any]]: + """Get multiple nodes by IDs.""" + if not ids: + return [] + user_name = kwargs.get("user_name") or self.user_name + conn = self._get_conn() + try: + with conn.cursor() as cur: + cols = "id, memory, properties, created_at, updated_at" + if include_embedding: + cols += ", embedding" + cur.execute(f""" + SELECT {cols} FROM {self.schema}.memories + WHERE id = ANY(%s) AND user_name = %s + """, (ids, user_name)) + return [self._parse_row(row, include_embedding) for row in cur.fetchall()] + finally: + self._put_conn(conn) + + def _parse_row(self, row, include_embedding: bool = False) -> dict[str, Any]: + """Parse database row to node dict.""" + props = row[2] if isinstance(row[2], dict) else json.loads(row[2] or "{}") + props["created_at"] = row[3].isoformat() if row[3] else None + props["updated_at"] = row[4].isoformat() if row[4] else None + result = { + "id": row[0], + "memory": row[1] or "", + "metadata": props, + } + if include_embedding and len(row) > 5: + result["metadata"]["embedding"] = row[5] + return result + + # ========================================================================= + # Edge Management + # ========================================================================= + + def add_edge( + self, source_id: str, target_id: str, type: str, user_name: str | None = None + ) -> None: + """Create an edge between nodes.""" + conn = self._get_conn() + try: + with conn.cursor() as cur: + cur.execute(f""" + INSERT INTO {self.schema}.edges (source_id, target_id, edge_type) + VALUES (%s, %s, %s) + ON CONFLICT (source_id, target_id, edge_type) DO NOTHING + """, (source_id, target_id, type)) + finally: + self._put_conn(conn) + + def delete_edge( + self, source_id: str, target_id: str, type: str, user_name: str | None = None + ) -> None: + """Delete an edge.""" + conn = self._get_conn() + try: + with conn.cursor() as cur: + cur.execute(f""" + DELETE FROM {self.schema}.edges + WHERE source_id = %s AND target_id = %s AND edge_type = %s + """, (source_id, target_id, type)) + finally: + self._put_conn(conn) + + def edge_exists(self, source_id: str, target_id: str, type: str) -> bool: + """Check if edge exists.""" + conn = self._get_conn() + try: + with conn.cursor() as cur: + cur.execute(f""" + SELECT 1 FROM {self.schema}.edges + WHERE source_id = %s AND target_id = %s AND edge_type = %s + LIMIT 1 + """, (source_id, target_id, type)) + return cur.fetchone() is not None + finally: + self._put_conn(conn) + + # ========================================================================= + # Graph Queries + # ========================================================================= + + def get_neighbors( + self, id: str, type: str, direction: Literal["in", "out", "both"] = "out" + ) -> list[str]: + """Get neighboring node IDs.""" + conn = self._get_conn() + try: + with conn.cursor() as cur: + if direction == "out": + cur.execute(f""" + SELECT target_id FROM {self.schema}.edges + WHERE source_id = %s AND edge_type = %s + """, (id, type)) + elif direction == "in": + cur.execute(f""" + SELECT source_id FROM {self.schema}.edges + WHERE target_id = %s AND edge_type = %s + """, (id, type)) + else: # both + cur.execute(f""" + SELECT target_id FROM {self.schema}.edges WHERE source_id = %s AND edge_type = %s + UNION + SELECT source_id FROM {self.schema}.edges WHERE target_id = %s AND edge_type = %s + """, (id, type, id, type)) + return [row[0] for row in cur.fetchall()] + finally: + self._put_conn(conn) + + def get_path(self, source_id: str, target_id: str, max_depth: int = 3) -> list[str]: + """Get path between nodes using recursive CTE.""" + conn = self._get_conn() + try: + with conn.cursor() as cur: + cur.execute(f""" + WITH RECURSIVE path AS ( + SELECT source_id, target_id, ARRAY[source_id] as nodes, 1 as depth + FROM {self.schema}.edges + WHERE source_id = %s + UNION ALL + SELECT e.source_id, e.target_id, p.nodes || e.source_id, p.depth + 1 + FROM {self.schema}.edges e + JOIN path p ON e.source_id = p.target_id + WHERE p.depth < %s AND NOT e.source_id = ANY(p.nodes) + ) + SELECT nodes || target_id as full_path + FROM path + WHERE target_id = %s + ORDER BY depth + LIMIT 1 + """, (source_id, max_depth, target_id)) + row = cur.fetchone() + return row[0] if row else [] + finally: + self._put_conn(conn) + + def get_subgraph(self, center_id: str, depth: int = 2) -> list[str]: + """Get subgraph around center node.""" + conn = self._get_conn() + try: + with conn.cursor() as cur: + cur.execute(f""" + WITH RECURSIVE subgraph AS ( + SELECT %s::text as node_id, 0 as level + UNION + SELECT CASE WHEN e.source_id = s.node_id THEN e.target_id ELSE e.source_id END, + s.level + 1 + FROM {self.schema}.edges e + JOIN subgraph s ON (e.source_id = s.node_id OR e.target_id = s.node_id) + WHERE s.level < %s + ) + SELECT DISTINCT node_id FROM subgraph + """, (center_id, depth)) + return [row[0] for row in cur.fetchall()] + finally: + self._put_conn(conn) + + def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]: + """Get ordered chain following relationship type.""" + return self.get_neighbors(id, type, "out") + + # ========================================================================= + # Search Operations + # ========================================================================= + + def search_by_embedding( + self, + vector: list[float], + top_k: int = 5, + scope: str | None = None, + status: str | None = None, + threshold: float | None = None, + search_filter: dict | None = None, + user_name: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, + **kwargs, + ) -> list[dict]: + """Search nodes by vector similarity using pgvector.""" + user_name = user_name or self.user_name + + # Build WHERE clause + conditions = ["embedding IS NOT NULL"] + params = [] + + if user_name: + conditions.append("user_name = %s") + params.append(user_name) + + if scope: + conditions.append("properties->>'memory_type' = %s") + params.append(scope) + + if status: + conditions.append("properties->>'status' = %s") + params.append(status) + else: + conditions.append("(properties->>'status' = 'activated' OR properties->>'status' IS NULL)") + + if search_filter: + for k, v in search_filter.items(): + conditions.append(f"properties->>'{k}' = %s") + params.append(str(v)) + + where_clause = " AND ".join(conditions) + + # pgvector cosine distance: 1 - (a <=> b) gives similarity score + conn = self._get_conn() + try: + with conn.cursor() as cur: + cur.execute(f""" + SELECT id, 1 - (embedding <=> %s::vector) as score + FROM {self.schema}.memories + WHERE {where_clause} + ORDER BY embedding <=> %s::vector + LIMIT %s + """, (vector, *params, vector, top_k)) + + results = [] + for row in cur.fetchall(): + score = float(row[1]) + if threshold is None or score >= threshold: + results.append({"id": row[0], "score": score}) + return results + finally: + self._put_conn(conn) + + def get_by_metadata( + self, + filters: list[dict[str, Any]], + status: str | None = None, + user_name: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, + user_name_flag: bool = True, + ) -> list[str]: + """Get node IDs matching metadata filters.""" + user_name = user_name or self.user_name + + conditions = [] + params = [] + + if user_name_flag and user_name: + conditions.append("user_name = %s") + params.append(user_name) + + if status: + conditions.append("properties->>'status' = %s") + params.append(status) + + for f in filters: + field = f["field"] + op = f.get("op", "=") + value = f["value"] + + if op == "=": + conditions.append(f"properties->>'{field}' = %s") + params.append(str(value)) + elif op == "in": + placeholders = ",".join(["%s"] * len(value)) + conditions.append(f"properties->>'{field}' IN ({placeholders})") + params.extend([str(v) for v in value]) + elif op in (">", ">=", "<", "<="): + conditions.append(f"(properties->>'{field}')::numeric {op} %s") + params.append(value) + elif op == "contains": + conditions.append(f"properties->'{field}' @> %s::jsonb") + params.append(json.dumps([value])) + + where_clause = " AND ".join(conditions) if conditions else "TRUE" + + conn = self._get_conn() + try: + with conn.cursor() as cur: + cur.execute(f""" + SELECT id FROM {self.schema}.memories + WHERE {where_clause} + """, params) + return [row[0] for row in cur.fetchall()] + finally: + self._put_conn(conn) + + def get_all_memory_items( + self, + scope: str, + include_embedding: bool = False, + status: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, + **kwargs, + ) -> list[dict]: + """Get all memory items of a specific type.""" + user_name = kwargs.get("user_name") or self.user_name + + conditions = ["properties->>'memory_type' = %s", "user_name = %s"] + params = [scope, user_name] + + if status: + conditions.append("properties->>'status' = %s") + params.append(status) + + where_clause = " AND ".join(conditions) + + conn = self._get_conn() + try: + with conn.cursor() as cur: + cols = "id, memory, properties, created_at, updated_at" + if include_embedding: + cols += ", embedding" + cur.execute(f""" + SELECT {cols} FROM {self.schema}.memories + WHERE {where_clause} + """, params) + return [self._parse_row(row, include_embedding) for row in cur.fetchall()] + finally: + self._put_conn(conn) + + def get_structure_optimization_candidates( + self, scope: str, include_embedding: bool = False + ) -> list[dict]: + """Find isolated nodes (no edges).""" + user_name = self.user_name + conn = self._get_conn() + try: + with conn.cursor() as cur: + cols = "m.id, m.memory, m.properties, m.created_at, m.updated_at" + cur.execute(f""" + SELECT {cols} + FROM {self.schema}.memories m + LEFT JOIN {self.schema}.edges e1 ON m.id = e1.source_id + LEFT JOIN {self.schema}.edges e2 ON m.id = e2.target_id + WHERE m.properties->>'memory_type' = %s + AND m.user_name = %s + AND m.properties->>'status' = 'activated' + AND e1.id IS NULL + AND e2.id IS NULL + """, (scope, user_name)) + return [self._parse_row(row, False) for row in cur.fetchall()] + finally: + self._put_conn(conn) + + # ========================================================================= + # Maintenance + # ========================================================================= + + def deduplicate_nodes(self) -> None: + """Not implemented - handled at application level.""" + pass + + def detect_conflicts(self) -> list[tuple[str, str]]: + """Not implemented.""" + return [] + + def merge_nodes(self, id1: str, id2: str) -> str: + """Not implemented.""" + raise NotImplementedError + + def clear(self, user_name: str | None = None) -> None: + """Clear all data for user.""" + user_name = user_name or self.user_name + conn = self._get_conn() + try: + with conn.cursor() as cur: + # Get all node IDs for user + cur.execute(f""" + SELECT id FROM {self.schema}.memories WHERE user_name = %s + """, (user_name,)) + ids = [row[0] for row in cur.fetchall()] + + if ids: + # Delete edges + cur.execute(f""" + DELETE FROM {self.schema}.edges + WHERE source_id = ANY(%s) OR target_id = ANY(%s) + """, (ids, ids)) + + # Delete nodes + cur.execute(f""" + DELETE FROM {self.schema}.memories WHERE user_name = %s + """, (user_name,)) + logger.info(f"Cleared all data for user {user_name}") + finally: + self._put_conn(conn) + + def export_graph(self, include_embedding: bool = False, **kwargs) -> dict[str, Any]: + """Export all data.""" + user_name = kwargs.get("user_name") or self.user_name + conn = self._get_conn() + try: + with conn.cursor() as cur: + # Get nodes + cols = "id, memory, properties, created_at, updated_at" + if include_embedding: + cols += ", embedding" + cur.execute(f""" + SELECT {cols} FROM {self.schema}.memories + WHERE user_name = %s + ORDER BY created_at DESC + """, (user_name,)) + nodes = [self._parse_row(row, include_embedding) for row in cur.fetchall()] + + # Get edges + node_ids = [n["id"] for n in nodes] + if node_ids: + cur.execute(f""" + SELECT source_id, target_id, edge_type + FROM {self.schema}.edges + WHERE source_id = ANY(%s) OR target_id = ANY(%s) + """, (node_ids, node_ids)) + edges = [ + {"source": row[0], "target": row[1], "type": row[2]} + for row in cur.fetchall() + ] + else: + edges = [] + + return { + "nodes": nodes, + "edges": edges, + "total_nodes": len(nodes), + "total_edges": len(edges), + } + finally: + self._put_conn(conn) + + def import_graph(self, data: dict[str, Any], user_name: str | None = None) -> None: + """Import graph data.""" + user_name = user_name or self.user_name + + for node in data.get("nodes", []): + self.add_node( + id=node["id"], + memory=node.get("memory", ""), + metadata=node.get("metadata", {}), + user_name=user_name, + ) + + for edge in data.get("edges", []): + self.add_edge( + source_id=edge["source"], + target_id=edge["target"], + type=edge["type"], + ) + + def close(self): + """Close connection pool.""" + if not self._pool_closed: + self._pool_closed = True + self.pool.closeall() diff --git a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py index 903088a4c..b103acf3a 100644 --- a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py +++ b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py @@ -61,9 +61,11 @@ def build_graph_db_config(user_id: str = "default") -> dict[str, Any]: "neo4j": APIConfig.get_neo4j_config(user_id=user_id), "nebular": APIConfig.get_nebular_config(user_id=user_id), "polardb": APIConfig.get_polardb_config(user_id=user_id), + "postgres": APIConfig.get_postgres_config(user_id=user_id), } - graph_db_backend = os.getenv("NEO4J_BACKEND", "nebular").lower() + # Support both GRAPH_DB_BACKEND and legacy NEO4J_BACKEND env vars + graph_db_backend = os.getenv("GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "nebular")).lower() return GraphDBConfigFactory.model_validate( { "backend": graph_db_backend, From a33f297079e9a99cd306c264805b97e407cab3d5 Mon Sep 17 00:00:00 2001 From: Anatoly Koptev Date: Sat, 24 Jan 2026 17:10:11 -0800 Subject: [PATCH 3/7] feat: change embedding dimension to 768 (all-mpnet-base-v2) Match krolik schema embedding dimension for compatibility --- src/memos/configs/graph_db.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/memos/configs/graph_db.py b/src/memos/configs/graph_db.py index 7feda1570..5ce9faad1 100644 --- a/src/memos/configs/graph_db.py +++ b/src/memos/configs/graph_db.py @@ -247,7 +247,7 @@ class PostgresGraphDBConfig(BaseConfig): default=False, description="If False: use single database with logical isolation by user_name", ) - embedding_dimension: int = Field(default=384, description="Dimension of vector embedding") + embedding_dimension: int = Field(default=768, description="Dimension of vector embedding (768 for all-mpnet-base-v2)") maxconn: int = Field( default=20, description="Maximum number of connections in the connection pool", From 1a3514722e67b45b83ebcd7fd7b5453ccac68e57 Mon Sep 17 00:00:00 2001 From: Anatoly Koptev Date: Tue, 27 Jan 2026 00:10:55 -0800 Subject: [PATCH 4/7] fix: add missing methods to PostgresGraphDB Add remove_oldest_memory and get_grouped_counts methods required by MemOS memory management functionality. --- src/memos/graph_dbs/postgres.py | 115 ++++++++++++++++++++++++++++++++ 1 file changed, 115 insertions(+) diff --git a/src/memos/graph_dbs/postgres.py b/src/memos/graph_dbs/postgres.py index d3c621059..f9065d718 100644 --- a/src/memos/graph_dbs/postgres.py +++ b/src/memos/graph_dbs/postgres.py @@ -181,6 +181,53 @@ def _init_schema(self): # Node Management # ========================================================================= + def remove_oldest_memory( + self, memory_type: str, keep_latest: int, user_name: str | None = None + ) -> None: + """ + Remove all memories of a given type except the latest `keep_latest` entries. + + Args: + memory_type: Memory type (e.g., 'WorkingMemory', 'LongTermMemory'). + keep_latest: Number of latest entries to keep. + user_name: User to filter by. + """ + user_name = user_name or self.user_name + keep_latest = int(keep_latest) + + conn = self._get_conn() + try: + with conn.cursor() as cur: + # Find IDs to delete (older than the keep_latest entries) + cur.execute(f""" + WITH ranked AS ( + SELECT id, ROW_NUMBER() OVER (ORDER BY updated_at DESC) as rn + FROM {self.schema}.memories + WHERE user_name = %s + AND properties->>'memory_type' = %s + ) + SELECT id FROM ranked WHERE rn > %s + """, (user_name, memory_type, keep_latest)) + + ids_to_delete = [row[0] for row in cur.fetchall()] + + if ids_to_delete: + # Delete edges first + cur.execute(f""" + DELETE FROM {self.schema}.edges + WHERE source_id = ANY(%s) OR target_id = ANY(%s) + """, (ids_to_delete, ids_to_delete)) + + # Delete nodes + cur.execute(f""" + DELETE FROM {self.schema}.memories + WHERE id = ANY(%s) + """, (ids_to_delete,)) + + logger.info(f"Removed {len(ids_to_delete)} oldest {memory_type} memories for user {user_name}") + finally: + self._put_conn(conn) + def add_node( self, id: str, memory: str, metadata: dict[str, Any], user_name: str | None = None ) -> None: @@ -667,6 +714,74 @@ def deduplicate_nodes(self) -> None: """Not implemented - handled at application level.""" pass + def get_grouped_counts( + self, + group_fields: list[str], + where_clause: str = "", + params: dict[str, Any] | None = None, + user_name: str | None = None, + ) -> list[dict[str, Any]]: + """ + Count nodes grouped by specified fields. + + Args: + group_fields: Fields to group by, e.g., ["memory_type", "status"] + where_clause: Extra WHERE condition + params: Parameters for WHERE clause + user_name: User to filter by + + Returns: + list[dict]: e.g., [{'memory_type': 'WorkingMemory', 'count': 10}, ...] + """ + user_name = user_name or self.user_name + if not group_fields: + raise ValueError("group_fields cannot be empty") + + # Build SELECT and GROUP BY clauses + # Fields come from JSONB properties column + select_fields = ", ".join([ + f"properties->>'{field}' AS {field}" for field in group_fields + ]) + group_by = ", ".join([f"properties->>'{field}'" for field in group_fields]) + + # Build WHERE clause + conditions = [f"user_name = %s"] + query_params = [user_name] + + if where_clause: + # Parse simple where clause format + where_clause = where_clause.strip() + if where_clause.upper().startswith("WHERE"): + where_clause = where_clause[5:].strip() + if where_clause: + conditions.append(where_clause) + if params: + query_params.extend(params.values()) + + where_sql = " AND ".join(conditions) + + query = f""" + SELECT {select_fields}, COUNT(*) AS count + FROM {self.schema}.memories + WHERE {where_sql} + GROUP BY {group_by} + """ + + conn = self._get_conn() + try: + with conn.cursor() as cur: + cur.execute(query, query_params) + results = [] + for row in cur.fetchall(): + result = {} + for i, field in enumerate(group_fields): + result[field] = row[i] + result["count"] = row[len(group_fields)] + results.append(result) + return results + finally: + self._put_conn(conn) + def detect_conflicts(self) -> list[tuple[str, str]]: """Not implemented.""" return [] From e05a01d22d711513fc8762be7adaab75f58beafd Mon Sep 17 00:00:00 2001 From: Anatoly Koptev Date: Wed, 28 Jan 2026 04:47:46 -0800 Subject: [PATCH 5/7] fix(recall): preserve vector similarity ranking in search results The merge/deduplicate logic was converting hit IDs to a set, losing the score-based ordering from vector search. Now keeps highest score per ID and returns results sorted by similarity score (descending). Fixes both _vector_recall and _fulltext_recall methods. --- .../tree_text_memory/retrieve/recall.py | 62 ++++++++++++++++--- 1 file changed, 54 insertions(+), 8 deletions(-) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py index 4541b118b..be1841232 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py @@ -390,18 +390,41 @@ def search_path_b(): if not all_hits: return [] - # merge and deduplicate - unique_ids = {r["id"] for r in all_hits if r.get("id")} + # merge and deduplicate, keeping highest score per ID + id_to_score = {} + for r in all_hits: + rid = r.get("id") + if rid: + score = r.get("score", 0.0) + if rid not in id_to_score or score > id_to_score[rid]: + id_to_score[rid] = score + + # Sort IDs by score (descending) to preserve ranking + sorted_ids = sorted(id_to_score.keys(), key=lambda x: id_to_score[x], reverse=True) + node_dicts = ( self.graph_store.get_nodes( - list(unique_ids), + sorted_ids, include_embedding=self.include_embedding, cube_name=cube_name, user_name=user_name, ) or [] ) - return [TextualMemoryItem.from_dict(n) for n in node_dicts] + + # Restore score-based order and inject scores into metadata + id_to_node = {n.get("id"): n for n in node_dicts} + ordered_nodes = [] + for rid in sorted_ids: + if rid in id_to_node: + node = id_to_node[rid] + # Inject similarity score as relativity + if "metadata" not in node: + node["metadata"] = {} + node["metadata"]["relativity"] = id_to_score.get(rid, 0.0) + ordered_nodes.append(node) + + return [TextualMemoryItem.from_dict(n) for n in ordered_nodes] def _bm25_recall( self, @@ -483,15 +506,38 @@ def _fulltext_recall( if not all_hits: return [] - # merge and deduplicate - unique_ids = {r["id"] for r in all_hits if r.get("id")} + # merge and deduplicate, keeping highest score per ID + id_to_score = {} + for r in all_hits: + rid = r.get("id") + if rid: + score = r.get("score", 0.0) + if rid not in id_to_score or score > id_to_score[rid]: + id_to_score[rid] = score + + # Sort IDs by score (descending) to preserve ranking + sorted_ids = sorted(id_to_score.keys(), key=lambda x: id_to_score[x], reverse=True) + node_dicts = ( self.graph_store.get_nodes( - list(unique_ids), + sorted_ids, include_embedding=self.include_embedding, cube_name=cube_name, user_name=user_name, ) or [] ) - return [TextualMemoryItem.from_dict(n) for n in node_dicts] + + # Restore score-based order and inject scores into metadata + id_to_node = {n.get("id"): n for n in node_dicts} + ordered_nodes = [] + for rid in sorted_ids: + if rid in id_to_node: + node = id_to_node[rid] + # Inject similarity score as relativity + if "metadata" not in node: + node["metadata"] = {} + node["metadata"]["relativity"] = id_to_score.get(rid, 0.0) + ordered_nodes.append(node) + + return [TextualMemoryItem.from_dict(n) for n in ordered_nodes] From 4ad5716c2d20e6398ee71b71750a2e444e2514b4 Mon Sep 17 00:00:00 2001 From: Anatoly Koptev Date: Wed, 28 Jan 2026 11:46:15 -0800 Subject: [PATCH 6/7] fix(reranker): use recall relativity scores when embeddings unavailable When embeddings aren't available, the reranker was defaulting to 0.5 scores, ignoring the relativity scores set during the recall phase. Now uses item.metadata.relativity from the recall stage when available. Co-Authored-By: Claude Opus 4.5 --- .../memories/textual/tree_text_memory/retrieve/reranker.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/reranker.py b/src/memos/memories/textual/tree_text_memory/retrieve/reranker.py index 861343e20..b8ab813dc 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/reranker.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/reranker.py @@ -78,7 +78,11 @@ def rerank( embeddings = [item.metadata.embedding for item in items_with_embeddings] if not embeddings: - return [(item, 0.5) for item in graph_results[:top_k]] + # Use relativity from recall stage if available, otherwise default to 0.5 + return [ + (item, getattr(item.metadata, "relativity", None) or 0.5) + for item in graph_results[:top_k] + ] # Step 2: Compute cosine similarities similarity_scores = batch_cosine_similarity(query_embedding, embeddings) From bf2b107e0cbaad0b8f178dc4ebdcca15b3b076f2 Mon Sep 17 00:00:00 2001 From: Anatoly Koptev Date: Wed, 28 Jan 2026 22:58:21 -0800 Subject: [PATCH 7/7] feat: add overlay pattern for Krolik security extensions - Add overlays/krolik/ with auth, rate-limit, admin API - Add Dockerfile.krolik for production builds - Add SYNC_UPSTREAM.md documentation - Keeps customizations separate from base MemOS for easy upstream sync --- SYNC_UPSTREAM.md | 160 +++++++++++ docker/Dockerfile.krolik | 65 +++++ overlays/README.md | 86 ++++++ overlays/krolik/api/middleware/__init__.py | 13 + overlays/krolik/api/middleware/auth.py | 268 +++++++++++++++++++ overlays/krolik/api/middleware/rate_limit.py | 200 ++++++++++++++ overlays/krolik/api/routers/__init__.py | 5 + overlays/krolik/api/routers/admin_router.py | 225 ++++++++++++++++ overlays/krolik/api/server_api_ext.py | 120 +++++++++ overlays/krolik/api/utils/__init__.py | 0 overlays/krolik/api/utils/api_keys.py | 197 ++++++++++++++ 11 files changed, 1339 insertions(+) create mode 100644 SYNC_UPSTREAM.md create mode 100644 docker/Dockerfile.krolik create mode 100644 overlays/README.md create mode 100644 overlays/krolik/api/middleware/__init__.py create mode 100644 overlays/krolik/api/middleware/auth.py create mode 100644 overlays/krolik/api/middleware/rate_limit.py create mode 100644 overlays/krolik/api/routers/__init__.py create mode 100644 overlays/krolik/api/routers/admin_router.py create mode 100644 overlays/krolik/api/server_api_ext.py create mode 100644 overlays/krolik/api/utils/__init__.py create mode 100644 overlays/krolik/api/utils/api_keys.py diff --git a/SYNC_UPSTREAM.md b/SYNC_UPSTREAM.md new file mode 100644 index 000000000..abe5cd886 --- /dev/null +++ b/SYNC_UPSTREAM.md @@ -0,0 +1,160 @@ +# Синхронизация с Upstream MemOS + +## Архитектура + +``` +┌─────────────────────────────────────────────────────────────┐ +│ MemTensor/MemOS (upstream) │ +│ Оригинал │ +└─────────────────────────┬───────────────────────────────────┘ + │ git fetch upstream + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ anatolykoptev/MemOS (fork) │ +│ ┌────────────────────┐ ┌─────────────────────────────┐ │ +│ │ src/memos/ │ │ overlays/krolik/ │ │ +│ │ (base MemOS) │ │ (auth, rate-limit, admin) │ │ +│ │ │ │ │ │ +│ │ ← syncs with │ │ ← НАШИ кастомизации │ │ +│ │ upstream │ │ (никогда не конфликтуют) │ │ +│ └────────────────────┘ └─────────────────────────────┘ │ +└─────────────────────────┬───────────────────────────────────┘ + │ Dockerfile.krolik + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ krolik-server (production) │ +│ src/memos/ + overlays merged at build │ +└─────────────────────────────────────────────────────────────┘ +``` + +## Регулярная синхронизация (еженедельно) + +```bash +cd ~/CascadeProjects/piternow_project/MemOS + +# 1. Получить изменения upstream +git fetch upstream + +# 2. Посмотреть что нового +git log --oneline upstream/main..main # Наши коммиты +git log --oneline main..upstream/main # Новое в upstream + +# 3. Merge upstream (overlays/ не затрагивается) +git checkout main +git merge upstream/main + +# 4. Если конфликты (редко, только в src/): +# - Разрешить конфликты +# - git add . +# - git commit + +# 5. Push в наш fork +git push origin main +``` + +## Обновление production (krolik-server) + +После синхронизации форка: + +```bash +cd ~/krolik-server + +# Пересобрать с новым MemOS +docker compose build --no-cache memos-api + +# Перезапустить +docker compose up -d memos-api + +# Проверить логи +docker logs -f memos-api +``` + +## Добавление новых фич в overlay + +```bash +# 1. Создать файл в overlays/krolik/ +vim overlays/krolik/api/middleware/new_feature.py + +# 2. Импортировать в server_api_ext.py +vim overlays/krolik/api/server_api_ext.py + +# 3. Commit в наш fork +git add overlays/ +git commit -m "feat(krolik): add new_feature middleware" +git push origin main +``` + +## Важные правила + +### ✅ Делать: +- Все кастомизации в `overlays/krolik/` +- Багфиксы в `src/` которые полезны upstream — создавать PR +- Регулярно синхронизировать с upstream + +### ❌ НЕ делать: +- Модифицировать файлы в `src/memos/` напрямую +- Форкать API в overlay вместо расширения +- Игнорировать обновления upstream > 2 недель + +## Структура overlays + +``` +overlays/ +└── krolik/ + └── api/ + ├── middleware/ + │ ├── __init__.py + │ ├── auth.py # API Key auth (PostgreSQL) + │ └── rate_limit.py # Redis sliding window + ├── routers/ + │ ├── __init__.py + │ └── admin_router.py # /admin/keys CRUD + ├── utils/ + │ ├── __init__.py + │ └── api_keys.py # Key generation + └── server_api_ext.py # Entry point +``` + +## Environment Variables (Krolik) + +```bash +# Authentication +AUTH_ENABLED=true +MASTER_KEY_HASH= +INTERNAL_SERVICE_SECRET= + +# Rate Limiting +RATE_LIMIT_ENABLED=true +RATE_LIMIT=100 +RATE_WINDOW_SEC=60 +REDIS_URL=redis://redis:6379 + +# PostgreSQL (for API keys) +POSTGRES_HOST=postgres +POSTGRES_PORT=5432 +POSTGRES_USER=memos +POSTGRES_PASSWORD= +POSTGRES_DB=memos + +# CORS +CORS_ORIGINS=https://krolik.hully.one,https://memos.hully.one +``` + +## Миграция из текущего krolik-server + +Текущий `krolik-server/services/memos-core/` содержит смешанный код. +После перехода на overlay pattern: + +1. **krolik-server** будет использовать `Dockerfile.krolik` из форка +2. **Локальные изменения** удаляются из krolik-server +3. **Все кастомизации** живут в `MemOS/overlays/krolik/` + +```yaml +# docker-compose.yml (krolik-server) +services: + memos-api: + build: + context: ../MemOS # Используем форк напрямую + dockerfile: docker/Dockerfile.krolik + # ... остальная конфигурация +``` diff --git a/docker/Dockerfile.krolik b/docker/Dockerfile.krolik new file mode 100644 index 000000000..c475a6d30 --- /dev/null +++ b/docker/Dockerfile.krolik @@ -0,0 +1,65 @@ +# MemOS with Krolik Security Extensions +# +# This Dockerfile builds MemOS with authentication, rate limiting, and admin API. +# It uses the overlay pattern to keep customizations separate from base code. + +FROM python:3.11-slim + +# Install system dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + gcc \ + g++ \ + build-essential \ + libffi-dev \ + python3-dev \ + curl \ + libpq-dev \ + && rm -rf /var/lib/apt/lists/* + +# Create non-root user +RUN groupadd -r memos && useradd -r -g memos -u 1000 memos + +WORKDIR /app + +# Use official Hugging Face +ENV HF_ENDPOINT=https://huggingface.co + +# Copy base MemOS source +COPY src/ ./src/ +COPY pyproject.toml ./ + +# Install base dependencies +RUN pip install --upgrade pip && \ + pip install --no-cache-dir poetry && \ + poetry config virtualenvs.create false && \ + poetry install --no-dev --extras "tree-mem mem-scheduler" + +# Install additional dependencies for Krolik +RUN pip install --no-cache-dir \ + sentence-transformers \ + torch \ + transformers \ + psycopg2-binary \ + redis + +# Apply Krolik overlay (AFTER base install to allow easy updates) +COPY overlays/krolik/ ./src/memos/ + +# Create data directory +RUN mkdir -p /data/memos && chown -R memos:memos /data/memos +RUN chown -R memos:memos /app + +# Set Python path +ENV PYTHONPATH=/app/src + +# Switch to non-root user +USER memos + +EXPOSE 8000 + +# Healthcheck +HEALTHCHECK --interval=30s --timeout=10s --retries=3 --start-period=60s \ + CMD curl -f http://localhost:8000/health || exit 1 + +# Use extended entry point with security features +CMD ["gunicorn", "memos.api.server_api_ext:app", "--preload", "-w", "2", "-k", "uvicorn.workers.UvicornWorker", "--bind", "0.0.0.0:8000", "--timeout", "120"] diff --git a/overlays/README.md b/overlays/README.md new file mode 100644 index 000000000..805821018 --- /dev/null +++ b/overlays/README.md @@ -0,0 +1,86 @@ +# MemOS Overlays + +Overlays are deployment-specific customizations that extend the base MemOS without modifying core files. + +## Structure + +``` +overlays/ +└── krolik/ # Deployment name + └── api/ + ├── middleware/ + │ ├── __init__.py + │ ├── auth.py # API Key authentication + │ └── rate_limit.py # Redis rate limiting + ├── routers/ + │ ├── __init__.py + │ └── admin_router.py # API key management + ├── utils/ + │ ├── __init__.py + │ └── api_keys.py # Key generation utilities + └── server_api_ext.py # Extended entry point +``` + +## How It Works + +1. **Base MemOS** provides core functionality (memory operations, embeddings, etc.) +2. **Overlays** add deployment-specific features without modifying base files +3. **Dockerfile** merges overlays on top of base during build + +## Dockerfile Usage + +```dockerfile +# Clone base MemOS +RUN git clone --depth 1 https://github.com/anatolykoptev/MemOS.git /app + +# Install base dependencies +RUN pip install -r /app/requirements.txt + +# Apply overlay (copies files into src/memos/) +RUN cp -r /app/overlays/krolik/* /app/src/memos/ + +# Use extended entry point +CMD ["gunicorn", "memos.api.server_api_ext:app", ...] +``` + +## Syncing with Upstream + +```bash +# 1. Fetch upstream changes +git fetch upstream + +# 2. Merge upstream into main (preserves overlays) +git merge upstream/main + +# 3. Resolve conflicts if any (usually none in overlays/) +git status + +# 4. Push to fork +git push origin main +``` + +## Adding New Overlays + +1. Create directory: `overlays//` +2. Add customizations following the same structure +3. Create `server_api_ext.py` as entry point +4. Update Dockerfile to use the new overlay + +## Security Features (krolik overlay) + +### API Key Authentication +- SHA-256 hashed keys stored in PostgreSQL +- Master key for admin operations +- Scoped permissions (read, write, admin) +- Internal service bypass for container-to-container + +### Rate Limiting +- Redis-based sliding window algorithm +- In-memory fallback for development +- Per-key or per-IP limiting +- Configurable via environment variables + +### Admin API +- `/admin/keys` - Create, list, revoke API keys +- `/admin/health` - Auth system status +- Protected by admin scope or master key diff --git a/overlays/krolik/api/middleware/__init__.py b/overlays/krolik/api/middleware/__init__.py new file mode 100644 index 000000000..64cbc5c60 --- /dev/null +++ b/overlays/krolik/api/middleware/__init__.py @@ -0,0 +1,13 @@ +"""Krolik middleware extensions for MemOS.""" + +from .auth import verify_api_key, require_scope, require_admin, require_read, require_write +from .rate_limit import RateLimitMiddleware + +__all__ = [ + "verify_api_key", + "require_scope", + "require_admin", + "require_read", + "require_write", + "RateLimitMiddleware", +] diff --git a/overlays/krolik/api/middleware/auth.py b/overlays/krolik/api/middleware/auth.py new file mode 100644 index 000000000..30349c9c4 --- /dev/null +++ b/overlays/krolik/api/middleware/auth.py @@ -0,0 +1,268 @@ +""" +API Key Authentication Middleware for MemOS. + +Validates API keys and extracts user context for downstream handlers. +Keys are validated against SHA-256 hashes stored in PostgreSQL. +""" + +import hashlib +import os +import time +from typing import Any + +from fastapi import Depends, HTTPException, Request, Security +from fastapi.security import APIKeyHeader + +import memos.log + +logger = memos.log.get_logger(__name__) + +# API key header configuration +API_KEY_HEADER = APIKeyHeader(name="Authorization", auto_error=False) + +# Environment configuration +AUTH_ENABLED = os.getenv("AUTH_ENABLED", "false").lower() == "true" +MASTER_KEY_HASH = os.getenv("MASTER_KEY_HASH") # SHA-256 hash of master key +INTERNAL_SERVICE_IPS = {"127.0.0.1", "::1", "memos-mcp", "moltbot", "clawdbot"} + +# Connection pool for auth queries (lazy init) +_auth_pool = None + + +def _get_auth_pool(): + """Get or create auth database connection pool.""" + global _auth_pool + if _auth_pool is not None: + return _auth_pool + + try: + import psycopg2.pool + + _auth_pool = psycopg2.pool.ThreadedConnectionPool( + minconn=1, + maxconn=5, + host=os.getenv("POSTGRES_HOST", "postgres"), + port=int(os.getenv("POSTGRES_PORT", "5432")), + user=os.getenv("POSTGRES_USER", "memos"), + password=os.getenv("POSTGRES_PASSWORD", ""), + dbname=os.getenv("POSTGRES_DB", "memos"), + connect_timeout=10, + ) + logger.info("Auth database pool initialized") + return _auth_pool + except Exception as e: + logger.error(f"Failed to initialize auth pool: {e}") + return None + + +def hash_api_key(key: str) -> str: + """Hash an API key using SHA-256.""" + return hashlib.sha256(key.encode()).hexdigest() + + +def validate_key_format(key: str) -> bool: + """Validate API key format: krlk_<64-hex>.""" + if not key or not key.startswith("krlk_"): + return False + hex_part = key[5:] # Remove 'krlk_' prefix + if len(hex_part) != 64: + return False + try: + int(hex_part, 16) + return True + except ValueError: + return False + + +def get_key_prefix(key: str) -> str: + """Extract prefix for key identification (first 12 chars).""" + return key[:12] if len(key) >= 12 else key + + +async def lookup_api_key(key_hash: str) -> dict[str, Any] | None: + """ + Look up API key in database. + + Returns dict with user_name, scopes, etc. or None if not found. + """ + pool = _get_auth_pool() + if not pool: + logger.warning("Auth pool not available, cannot validate key") + return None + + conn = None + try: + conn = pool.getconn() + with conn.cursor() as cur: + cur.execute( + """ + SELECT id, user_name, scopes, expires_at, is_active + FROM api_keys + WHERE key_hash = %s + """, + (key_hash,), + ) + row = cur.fetchone() + + if not row: + return None + + key_id, user_name, scopes, expires_at, is_active = row + + # Check if key is active + if not is_active: + logger.warning(f"Inactive API key used: {key_hash[:16]}...") + return None + + # Check expiration + if expires_at and expires_at < time.time(): + logger.warning(f"Expired API key used: {key_hash[:16]}...") + return None + + # Update last_used_at + cur.execute( + "UPDATE api_keys SET last_used_at = NOW() WHERE id = %s", + (key_id,), + ) + conn.commit() + + return { + "id": str(key_id), + "user_name": user_name, + "scopes": scopes or ["read"], + } + except Exception as e: + logger.error(f"Database error during key lookup: {e}") + return None + finally: + if conn and pool: + pool.putconn(conn) + + +def is_internal_request(request: Request) -> bool: + """Check if request is from internal service.""" + client_host = request.client.host if request.client else None + + # Check internal IPs + if client_host in INTERNAL_SERVICE_IPS: + return True + + # Check internal header (for container-to-container) + internal_header = request.headers.get("X-Internal-Service") + if internal_header == os.getenv("INTERNAL_SERVICE_SECRET"): + return True + + return False + + +async def verify_api_key( + request: Request, + api_key: str | None = Security(API_KEY_HEADER), +) -> dict[str, Any]: + """ + Verify API key and return user context. + + This is the main dependency for protected endpoints. + + Returns: + dict with user_name, scopes, and is_master_key flag + + Raises: + HTTPException 401 if authentication fails + """ + # Skip auth if disabled + if not AUTH_ENABLED: + return { + "user_name": request.headers.get("X-User-Name", "default"), + "scopes": ["all"], + "is_master_key": False, + "auth_bypassed": True, + } + + # Allow internal services + if is_internal_request(request): + logger.debug(f"Internal request from {request.client.host}") + return { + "user_name": "internal", + "scopes": ["all"], + "is_master_key": False, + "is_internal": True, + } + + # Require API key + if not api_key: + raise HTTPException( + status_code=401, + detail="Missing API key", + headers={"WWW-Authenticate": "ApiKey"}, + ) + + # Handle "Bearer" or "Token" prefix + if api_key.lower().startswith("bearer "): + api_key = api_key[7:] + elif api_key.lower().startswith("token "): + api_key = api_key[6:] + + # Check against master key first (has different format: mk_*) + key_hash = hash_api_key(api_key) + if MASTER_KEY_HASH and key_hash == MASTER_KEY_HASH: + logger.info("Master key authentication") + return { + "user_name": "admin", + "scopes": ["all"], + "is_master_key": True, + } + + # Validate format for regular API keys (krlk_*) + if not validate_key_format(api_key): + raise HTTPException( + status_code=401, + detail="Invalid API key format", + ) + + # Look up in database + key_data = await lookup_api_key(key_hash) + if not key_data: + logger.warning(f"Invalid API key attempt: {get_key_prefix(api_key)}...") + raise HTTPException( + status_code=401, + detail="Invalid or expired API key", + ) + + logger.debug(f"Authenticated user: {key_data['user_name']}") + return { + "user_name": key_data["user_name"], + "scopes": key_data["scopes"], + "is_master_key": False, + "api_key_id": key_data["id"], + } + + +def require_scope(required_scope: str): + """ + Dependency factory to require a specific scope. + + Usage: + @router.post("/admin/keys", dependencies=[Depends(require_scope("admin"))]) + """ + async def scope_checker( + auth: dict[str, Any] = Depends(verify_api_key), + ) -> dict[str, Any]: + scopes = auth.get("scopes", []) + + # "all" scope grants everything + if "all" in scopes or required_scope in scopes: + return auth + + raise HTTPException( + status_code=403, + detail=f"Insufficient permissions. Required scope: {required_scope}", + ) + + return scope_checker + + +# Convenience dependencies +require_read = require_scope("read") +require_write = require_scope("write") +require_admin = require_scope("admin") diff --git a/overlays/krolik/api/middleware/rate_limit.py b/overlays/krolik/api/middleware/rate_limit.py new file mode 100644 index 000000000..12ee84ef4 --- /dev/null +++ b/overlays/krolik/api/middleware/rate_limit.py @@ -0,0 +1,200 @@ +""" +Redis-based Rate Limiting Middleware. + +Implements sliding window rate limiting with Redis. +Falls back to in-memory limiting if Redis is unavailable. +""" + +import os +import time +from collections import defaultdict +from typing import Callable + +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +import memos.log + +logger = memos.log.get_logger(__name__) + +# Configuration from environment +RATE_LIMIT = int(os.getenv("RATE_LIMIT", "100")) # Requests per window +RATE_WINDOW = int(os.getenv("RATE_WINDOW_SEC", "60")) # Window in seconds +REDIS_URL = os.getenv("REDIS_URL", "redis://redis:6379") + +# Redis client (lazy initialization) +_redis_client = None + +# In-memory fallback (per process) +_memory_store: dict[str, list[float]] = defaultdict(list) + + +def _get_redis(): + """Get or create Redis client.""" + global _redis_client + if _redis_client is not None: + return _redis_client + + try: + import redis + + _redis_client = redis.from_url(REDIS_URL, decode_responses=True) + _redis_client.ping() # Test connection + logger.info("Rate limiter connected to Redis") + return _redis_client + except Exception as e: + logger.warning(f"Redis not available for rate limiting: {e}") + return None + + +def _get_client_key(request: Request) -> str: + """ + Generate a unique key for rate limiting. + + Uses API key if available, otherwise falls back to IP. + """ + # Try to get API key from header + auth_header = request.headers.get("Authorization", "") + if auth_header.startswith("krlk_"): + # Use first 20 chars of key as identifier + return f"ratelimit:key:{auth_header[:20]}" + + # Fall back to IP address + client_ip = request.client.host if request.client else "unknown" + + # Check for forwarded IP (behind proxy) + forwarded = request.headers.get("X-Forwarded-For") + if forwarded: + client_ip = forwarded.split(",")[0].strip() + + return f"ratelimit:ip:{client_ip}" + + +def _check_rate_limit_redis(key: str) -> tuple[bool, int, int]: + """ + Check rate limit using Redis sliding window. + + Returns: + (allowed, remaining, reset_time) + """ + redis_client = _get_redis() + if not redis_client: + return _check_rate_limit_memory(key) + + try: + now = time.time() + window_start = now - RATE_WINDOW + + pipe = redis_client.pipeline() + + # Remove old entries + pipe.zremrangebyscore(key, 0, window_start) + + # Count current entries + pipe.zcard(key) + + # Add current request + pipe.zadd(key, {str(now): now}) + + # Set expiry + pipe.expire(key, RATE_WINDOW + 1) + + results = pipe.execute() + current_count = results[1] + + remaining = max(0, RATE_LIMIT - current_count - 1) + reset_time = int(now + RATE_WINDOW) + + if current_count >= RATE_LIMIT: + return False, 0, reset_time + + return True, remaining, reset_time + + except Exception as e: + logger.warning(f"Redis rate limit error: {e}") + return _check_rate_limit_memory(key) + + +def _check_rate_limit_memory(key: str) -> tuple[bool, int, int]: + """ + Fallback in-memory rate limiting. + + Note: This is per-process and not distributed! + """ + now = time.time() + window_start = now - RATE_WINDOW + + # Clean old entries + _memory_store[key] = [t for t in _memory_store[key] if t > window_start] + + current_count = len(_memory_store[key]) + + if current_count >= RATE_LIMIT: + reset_time = int(min(_memory_store[key]) + RATE_WINDOW) if _memory_store[key] else int(now + RATE_WINDOW) + return False, 0, reset_time + + # Add current request + _memory_store[key].append(now) + + remaining = RATE_LIMIT - current_count - 1 + reset_time = int(now + RATE_WINDOW) + + return True, remaining, reset_time + + +class RateLimitMiddleware(BaseHTTPMiddleware): + """ + Rate limiting middleware using sliding window algorithm. + + Adds headers: + - X-RateLimit-Limit: Maximum requests per window + - X-RateLimit-Remaining: Remaining requests + - X-RateLimit-Reset: Unix timestamp when the window resets + + Returns 429 Too Many Requests when limit is exceeded. + """ + + # Paths exempt from rate limiting + EXEMPT_PATHS = {"/health", "/openapi.json", "/docs", "/redoc"} + + async def dispatch(self, request: Request, call_next: Callable) -> Response: + # Skip rate limiting for exempt paths + if request.url.path in self.EXEMPT_PATHS: + return await call_next(request) + + # Skip OPTIONS requests (CORS preflight) + if request.method == "OPTIONS": + return await call_next(request) + + # Get rate limit key + key = _get_client_key(request) + + # Check rate limit + allowed, remaining, reset_time = _check_rate_limit_redis(key) + + if not allowed: + logger.warning(f"Rate limit exceeded for {key}") + return JSONResponse( + status_code=429, + content={ + "detail": "Too many requests. Please slow down.", + "retry_after": reset_time - int(time.time()), + }, + headers={ + "X-RateLimit-Limit": str(RATE_LIMIT), + "X-RateLimit-Remaining": "0", + "X-RateLimit-Reset": str(reset_time), + "Retry-After": str(reset_time - int(time.time())), + }, + ) + + # Process request + response = await call_next(request) + + # Add rate limit headers + response.headers["X-RateLimit-Limit"] = str(RATE_LIMIT) + response.headers["X-RateLimit-Remaining"] = str(remaining) + response.headers["X-RateLimit-Reset"] = str(reset_time) + + return response diff --git a/overlays/krolik/api/routers/__init__.py b/overlays/krolik/api/routers/__init__.py new file mode 100644 index 000000000..656114d7a --- /dev/null +++ b/overlays/krolik/api/routers/__init__.py @@ -0,0 +1,5 @@ +"""Krolik router extensions for MemOS.""" + +from .admin_router import router as admin_router + +__all__ = ["admin_router"] diff --git a/overlays/krolik/api/routers/admin_router.py b/overlays/krolik/api/routers/admin_router.py new file mode 100644 index 000000000..939e5101f --- /dev/null +++ b/overlays/krolik/api/routers/admin_router.py @@ -0,0 +1,225 @@ +""" +Admin Router for API Key Management. + +Protected by master key or admin scope. +""" + +import os +from typing import Any + +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel, Field + +import memos.log +from memos.api.middleware.auth import require_scope, verify_api_key +from memos.api.utils.api_keys import ( + create_api_key_in_db, + generate_master_key, + list_api_keys, + revoke_api_key, +) + +logger = memos.log.get_logger(__name__) + +router = APIRouter(prefix="/admin", tags=["Admin"]) + + +# Request/Response models +class CreateKeyRequest(BaseModel): + user_name: str = Field(..., min_length=1, max_length=255) + scopes: list[str] = Field(default=["read"]) + description: str | None = Field(default=None, max_length=500) + expires_in_days: int | None = Field(default=None, ge=1, le=365) + + +class CreateKeyResponse(BaseModel): + message: str + key: str # Only returned once! + key_prefix: str + user_name: str + scopes: list[str] + + +class KeyListResponse(BaseModel): + message: str + keys: list[dict[str, Any]] + + +class RevokeKeyRequest(BaseModel): + key_id: str + + +class SimpleResponse(BaseModel): + message: str + success: bool = True + + +def _get_db_connection(): + """Get database connection for admin operations.""" + import psycopg2 + + return psycopg2.connect( + host=os.getenv("POSTGRES_HOST", "postgres"), + port=int(os.getenv("POSTGRES_PORT", "5432")), + user=os.getenv("POSTGRES_USER", "memos"), + password=os.getenv("POSTGRES_PASSWORD", ""), + dbname=os.getenv("POSTGRES_DB", "memos"), + ) + + +@router.post( + "/keys", + response_model=CreateKeyResponse, + summary="Create a new API key", + dependencies=[Depends(require_scope("admin"))], +) +def create_key( + request: CreateKeyRequest, + auth: dict = Depends(verify_api_key), +): + """ + Create a new API key for a user. + + Requires admin scope or master key. + + **WARNING**: The API key is only returned once. Store it securely! + """ + try: + conn = _get_db_connection() + try: + api_key = create_api_key_in_db( + conn=conn, + user_name=request.user_name, + scopes=request.scopes, + description=request.description, + expires_in_days=request.expires_in_days, + created_by=auth.get("user_name", "unknown"), + ) + + logger.info( + f"API key created for user '{request.user_name}' by '{auth.get('user_name')}'" + ) + + return CreateKeyResponse( + message="API key created successfully. Store this key securely - it won't be shown again!", + key=api_key.key, + key_prefix=api_key.key_prefix, + user_name=request.user_name, + scopes=request.scopes, + ) + finally: + conn.close() + except Exception as e: + logger.error(f"Failed to create API key: {e}") + raise HTTPException(status_code=500, detail="Failed to create API key") + + +@router.get( + "/keys", + response_model=KeyListResponse, + summary="List API keys", + dependencies=[Depends(require_scope("admin"))], +) +def list_keys( + user_name: str | None = None, + auth: dict = Depends(verify_api_key), +): + """ + List all API keys (admin) or keys for a specific user. + + Note: Actual key values are never returned, only prefixes. + """ + try: + conn = _get_db_connection() + try: + keys = list_api_keys(conn, user_name=user_name) + return KeyListResponse( + message=f"Found {len(keys)} key(s)", + keys=keys, + ) + finally: + conn.close() + except Exception as e: + logger.error(f"Failed to list API keys: {e}") + raise HTTPException(status_code=500, detail="Failed to list API keys") + + +@router.delete( + "/keys/{key_id}", + response_model=SimpleResponse, + summary="Revoke an API key", + dependencies=[Depends(require_scope("admin"))], +) +def revoke_key( + key_id: str, + auth: dict = Depends(verify_api_key), +): + """ + Revoke an API key by ID. + + The key will be deactivated but not deleted (for audit purposes). + """ + try: + conn = _get_db_connection() + try: + success = revoke_api_key(conn, key_id) + if success: + logger.info(f"API key {key_id} revoked by '{auth.get('user_name')}'") + return SimpleResponse(message="API key revoked successfully") + else: + raise HTTPException(status_code=404, detail="API key not found or already revoked") + finally: + conn.close() + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to revoke API key: {e}") + raise HTTPException(status_code=500, detail="Failed to revoke API key") + + +@router.post( + "/generate-master-key", + response_model=dict, + summary="Generate a new master key", + dependencies=[Depends(require_scope("admin"))], +) +def generate_new_master_key( + auth: dict = Depends(verify_api_key), +): + """ + Generate a new master key. + + **WARNING**: Store the key securely! Add MASTER_KEY_HASH to your .env file. + """ + if not auth.get("is_master_key"): + raise HTTPException( + status_code=403, + detail="Only master key can generate new master keys", + ) + + key, key_hash = generate_master_key() + + logger.warning("New master key generated - update MASTER_KEY_HASH in .env") + + return { + "message": "Master key generated. Add MASTER_KEY_HASH to your .env file!", + "key": key, + "key_hash": key_hash, + "env_line": f"MASTER_KEY_HASH={key_hash}", + } + + +@router.get( + "/health", + summary="Admin health check", +) +def admin_health(): + """Health check for admin endpoints.""" + auth_enabled = os.getenv("AUTH_ENABLED", "false").lower() == "true" + master_key_configured = bool(os.getenv("MASTER_KEY_HASH")) + + return { + "status": "ok", + "auth_enabled": auth_enabled, + "master_key_configured": master_key_configured, + } diff --git a/overlays/krolik/api/server_api_ext.py b/overlays/krolik/api/server_api_ext.py new file mode 100644 index 000000000..85b9411af --- /dev/null +++ b/overlays/krolik/api/server_api_ext.py @@ -0,0 +1,120 @@ +""" +Extended Server API for Krolik deployment. + +This module extends the base MemOS server_api with: +- API Key Authentication (PostgreSQL-backed) +- Redis Rate Limiting +- Admin API for key management +- Security Headers + +Usage in Dockerfile: + # Copy overlays after base installation + COPY overlays/krolik/ /app/src/memos/ + + # Use this as entrypoint instead of server_api + CMD ["gunicorn", "memos.api.server_api_ext:app", ...] +""" + +import logging +import os + +from fastapi import FastAPI +from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import Response + +# Import base routers from MemOS +from memos.api.routers.server_router import router as server_router + +# Import Krolik extensions +from memos.api.middleware.rate_limit import RateLimitMiddleware +from memos.api.routers.admin_router import router as admin_router + +# Try to import exception handlers (may vary between MemOS versions) +try: + from memos.api.exceptions import APIExceptionHandler + HAS_EXCEPTION_HANDLER = True +except ImportError: + HAS_EXCEPTION_HANDLER = False + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + + +class SecurityHeadersMiddleware(BaseHTTPMiddleware): + """Add security headers to all responses.""" + + async def dispatch(self, request: Request, call_next) -> Response: + response = await call_next(request) + response.headers["X-Content-Type-Options"] = "nosniff" + response.headers["X-Frame-Options"] = "DENY" + response.headers["X-XSS-Protection"] = "1; mode=block" + response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" + response.headers["Permissions-Policy"] = "geolocation=(), microphone=(), camera=()" + return response + + +# Create FastAPI app +app = FastAPI( + title="MemOS Server REST APIs (Krolik Extended)", + description="MemOS API with authentication, rate limiting, and admin endpoints.", + version="2.0.3-krolik", +) + +# CORS configuration +CORS_ORIGINS = os.getenv("CORS_ORIGINS", "").split(",") +CORS_ORIGINS = [origin.strip() for origin in CORS_ORIGINS if origin.strip()] + +if not CORS_ORIGINS: + CORS_ORIGINS = [ + "https://krolik.hully.one", + "https://memos.hully.one", + "http://localhost:3000", + ] + +app.add_middleware( + CORSMiddleware, + allow_origins=CORS_ORIGINS, + allow_credentials=True, + allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], + allow_headers=["Authorization", "Content-Type", "X-API-Key", "X-User-Name"], +) + +# Security headers +app.add_middleware(SecurityHeadersMiddleware) + +# Rate limiting (before auth to protect against brute force) +RATE_LIMIT_ENABLED = os.getenv("RATE_LIMIT_ENABLED", "true").lower() == "true" +if RATE_LIMIT_ENABLED: + app.add_middleware(RateLimitMiddleware) + logger.info("Rate limiting enabled") + +# Include routers +app.include_router(server_router) +app.include_router(admin_router) + +# Exception handlers +if HAS_EXCEPTION_HANDLER: + from fastapi import HTTPException + app.exception_handler(RequestValidationError)(APIExceptionHandler.validation_error_handler) + app.exception_handler(ValueError)(APIExceptionHandler.value_error_handler) + app.exception_handler(HTTPException)(APIExceptionHandler.http_error_handler) + app.exception_handler(Exception)(APIExceptionHandler.global_exception_handler) + + +@app.get("/health") +async def health_check(): + """Health check endpoint.""" + return { + "status": "healthy", + "version": "2.0.3-krolik", + "auth_enabled": os.getenv("AUTH_ENABLED", "false").lower() == "true", + "rate_limit_enabled": RATE_LIMIT_ENABLED, + } + + +if __name__ == "__main__": + import uvicorn + uvicorn.run("memos.api.server_api_ext:app", host="0.0.0.0", port=8000, workers=1) diff --git a/overlays/krolik/api/utils/__init__.py b/overlays/krolik/api/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/overlays/krolik/api/utils/api_keys.py b/overlays/krolik/api/utils/api_keys.py new file mode 100644 index 000000000..559ddd355 --- /dev/null +++ b/overlays/krolik/api/utils/api_keys.py @@ -0,0 +1,197 @@ +""" +API Key Management Utilities. + +Provides functions for generating, validating, and managing API keys. +""" + +import hashlib +import os +import secrets +from dataclasses import dataclass +from datetime import datetime, timedelta + + +@dataclass +class APIKey: + """Represents a generated API key.""" + + key: str # Full key (only available at creation time) + key_hash: str # SHA-256 hash (stored in database) + key_prefix: str # First 12 chars for identification + + +def generate_api_key() -> APIKey: + """ + Generate a new API key. + + Format: krlk_<64-hex-chars> + + Returns: + APIKey with key, hash, and prefix + """ + # Generate 32 random bytes = 64 hex chars + random_bytes = secrets.token_bytes(32) + hex_part = random_bytes.hex() + + key = f"krlk_{hex_part}" + key_hash = hashlib.sha256(key.encode()).hexdigest() + key_prefix = key[:12] + + return APIKey(key=key, key_hash=key_hash, key_prefix=key_prefix) + + +def hash_key(key: str) -> str: + """Hash an API key using SHA-256.""" + return hashlib.sha256(key.encode()).hexdigest() + + +def validate_key_format(key: str) -> bool: + """ + Validate API key format. + + Valid format: krlk_<64-hex-chars> + """ + if not key or not isinstance(key, str): + return False + + if not key.startswith("krlk_"): + return False + + hex_part = key[5:] + if len(hex_part) != 64: + return False + + try: + int(hex_part, 16) + return True + except ValueError: + return False + + +def generate_master_key() -> tuple[str, str]: + """ + Generate a master key for admin operations. + + Returns: + Tuple of (key, hash) + """ + random_bytes = secrets.token_bytes(32) + key = f"mk_{random_bytes.hex()}" + key_hash = hashlib.sha256(key.encode()).hexdigest() + return key, key_hash + + +def create_api_key_in_db( + conn, + user_name: str, + scopes: list[str] | None = None, + description: str | None = None, + expires_in_days: int | None = None, + created_by: str | None = None, +) -> APIKey: + """ + Create a new API key and store in database. + + Args: + conn: Database connection + user_name: Owner of the key + scopes: List of scopes (default: ["read"]) + description: Human-readable description + expires_in_days: Days until expiration (None = never) + created_by: Who created this key + + Returns: + APIKey with the generated key (only time it's available!) + """ + api_key = generate_api_key() + + expires_at = None + if expires_in_days: + expires_at = datetime.utcnow() + timedelta(days=expires_in_days) + + with conn.cursor() as cur: + cur.execute( + """ + INSERT INTO api_keys (key_hash, key_prefix, user_name, scopes, description, expires_at, created_by) + VALUES (%s, %s, %s, %s, %s, %s, %s) + RETURNING id + """, + ( + api_key.key_hash, + api_key.key_prefix, + user_name, + scopes or ["read"], + description, + expires_at, + created_by, + ), + ) + conn.commit() + + return api_key + + +def revoke_api_key(conn, key_id: str) -> bool: + """ + Revoke an API key by ID. + + Returns: + True if key was revoked, False if not found + """ + with conn.cursor() as cur: + cur.execute( + "UPDATE api_keys SET is_active = false WHERE id = %s AND is_active = true", + (key_id,), + ) + conn.commit() + return cur.rowcount > 0 + + +def list_api_keys(conn, user_name: str | None = None) -> list[dict]: + """ + List API keys (without exposing the actual keys). + + Args: + conn: Database connection + user_name: Filter by user (None = all users) + + Returns: + List of key metadata dicts + """ + with conn.cursor() as cur: + if user_name: + cur.execute( + """ + SELECT id, key_prefix, user_name, scopes, description, + created_at, last_used_at, expires_at, is_active + FROM api_keys + WHERE user_name = %s + ORDER BY created_at DESC + """, + (user_name,), + ) + else: + cur.execute( + """ + SELECT id, key_prefix, user_name, scopes, description, + created_at, last_used_at, expires_at, is_active + FROM api_keys + ORDER BY created_at DESC + """ + ) + + rows = cur.fetchall() + return [ + { + "id": str(row[0]), + "key_prefix": row[1], + "user_name": row[2], + "scopes": row[3], + "description": row[4], + "created_at": row[5].isoformat() if row[5] else None, + "last_used_at": row[6].isoformat() if row[6] else None, + "expires_at": row[7].isoformat() if row[7] else None, + "is_active": row[8], + } + for row in rows + ]