Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 22 additions & 17 deletions .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
@@ -1,32 +1,37 @@
## Description

<!--
Please include a summary of the changes below;
Fill in the issue number that this PR addresses (if applicable);
Fill in the related MemOS-Docs repository issue or PR link (if applicable);
Mention the person who will review this PR (if you know who it is);
Replace (summary), (issue), (docs-issue-or-pr-link), and (reviewer) with the appropriate information.
Please include a summary of the change, the problem it solves, the implementation approach, and relevant context. List any dependencies required for this change.

请在下方填写更改的摘要;
填写此 PR 解决的问题编号(如果适用);
填写相关的 MemOS-Docs 仓库 issue 或 PR 链接(如果适用);
提及将审查此 PR 的人(如果您知道是谁);
替换 (summary)、(issue)、(docs-issue-or-pr-link) 和 (reviewer) 为适当的信息。
-->
Related Issue (Required): Fixes @issue_number

Summary: (summary)
## Type of change

Fix: #(issue)
Please delete options that are not relevant.

Docs Issue/PR: (docs-issue-or-pr-link)
- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
- [ ] Refactor (does not change functionality, e.g. code style improvements, linting)
- [ ] Documentation update

Reviewer: @(reviewer)
## How Has This Been Tested?

## Checklist:
Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce. Please also list any relevant details for your test configuration

- [ ] Unit Test
- [ ] Test Script Or Test Steps (please provide)
- [ ] Pipeline Automated API Test (please provide)

## Checklist

- [ ] I have performed a self-review of my own code | 我已自行检查了自己的代码
- [ ] I have commented my code in hard-to-understand areas | 我已在难以理解的地方对代码进行了注释
- [ ] I have added tests that prove my fix is effective or that my feature works | 我已添加测试以证明我的修复有效或功能正常
- [ ] I have created related documentation issue/PR in [MemOS-Docs](https://github.com/MemTensor/MemOS-Docs) (if applicable) | 我已在 [MemOS-Docs](https://github.com/MemTensor/MemOS-Docs) 中创建了相关的文档 issue/PR(如果适用)
- [ ] I have linked the issue to this PR (if applicable) | 我已将 issue 链接到此 PR(如果适用)
- [ ] I have mentioned the person who will review this PR | 我已提及将审查此 PR 的人

## Reviewer Checklist
- [ ] closes #xxxx (Replace xxxx with the GitHub issue number)
- [ ] Made sure Checks passed
- [ ] Tests have been provided
3 changes: 2 additions & 1 deletion docker/requirements-full.txt
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ nvidia-cusparselt-cu12==0.6.3
nvidia-nccl-cu12==2.26.2
nvidia-nvjitlink-cu12==12.6.85
nvidia-nvtx-cu12==12.6.77
ollama==0.4.9
ollama==0.5.0
onnxruntime==1.22.1
openai==1.97.0
openapi-pydantic==0.5.1
Expand Down Expand Up @@ -184,3 +184,4 @@ py-key-value-aio==0.2.8
py-key-value-shared==0.2.8
PyJWT==2.10.1
pytest==9.0.2
alibabacloud-oss-v2==1.2.2
3 changes: 2 additions & 1 deletion docker/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ mdurl==0.1.2
more-itertools==10.8.0
neo4j==5.28.1
numpy==2.3.4
ollama==0.4.9
ollama==0.5.0
openai==1.109.1
openapi-pydantic==0.5.1
orjson==3.11.4
Expand Down Expand Up @@ -123,3 +123,4 @@ uvicorn==0.38.0
uvloop==0.22.1; sys_platform != 'win32'
watchfiles==1.1.1
websockets==15.0.1
alibabacloud-oss-v2==1.2.2
139 changes: 134 additions & 5 deletions poetry.lock

Large diffs are not rendered by default.

8 changes: 7 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ classifiers = [
]
dependencies = [
"openai (>=1.77.0,<2.0.0)",
"ollama (>=0.4.8,<0.5.0)",
"ollama (>=0.5.0,<0.5.1)",
"transformers (>=4.51.3,<5.0.0)",
"tenacity (>=9.1.2,<10.0.0)", # Error handling and retrying library
"fastapi[all] (>=0.115.12,<0.116.0)", # Web framework for building APIs
Expand Down Expand Up @@ -97,6 +97,11 @@ pref-mem = [
"datasketch (>=1.6.5,<2.0.0)", # MinHash library
]

# SkillMemory
skill-mem = [
"alibabacloud-oss-v2 (>=1.2.2,<1.2.3)",
]

# All optional dependencies
# Allow users to install with `pip install MemoryOS[all]`
all = [
Expand All @@ -123,6 +128,7 @@ all = [
"volcengine-python-sdk (>=4.0.4,<5.0.0)",
"nltk (>=3.9.1,<4.0.0)",
"rake-nltk (>=1.0.6,<1.1.0)",
"alibabacloud-oss-v2 (>=1.2.2,<1.2.3)",

# Uncategorized dependencies
]
Expand Down
34 changes: 34 additions & 0 deletions src/memos/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,35 @@ def get_reader_config() -> dict[str, Any]:
}

@staticmethod
def get_oss_config() -> dict[str, Any] | None:
"""Get OSS configuration and validate connection."""

config = {
"endpoint": os.getenv("OSS_ENDPOINT", "http://oss-cn-shanghai.aliyuncs.com"),
"access_key_id": os.getenv("OSS_ACCESS_KEY_ID", ""),
"access_key_secret": os.getenv("OSS_ACCESS_KEY_SECRET", ""),
"region": os.getenv("OSS_REGION", ""),
"bucket_name": os.getenv("OSS_BUCKET_NAME", ""),
}

# Validate that all required fields have values
required_fields = [
"endpoint",
"access_key_id",
"access_key_secret",
"region",
"bucket_name",
]
missing_fields = [field for field in required_fields if not config.get(field)]

if missing_fields:
logger.warning(
f"OSS configuration incomplete. Missing fields: {', '.join(missing_fields)}"
)
return None

return config

def get_internet_config() -> dict[str, Any]:
"""Get embedder configuration."""
reader_config = APIConfig.get_reader_config()
Expand Down Expand Up @@ -746,6 +775,11 @@ def get_product_default_config() -> dict[str, Any]:
).split(",")
if h.strip()
],
"oss_config": APIConfig.get_oss_config(),
"skills_dir_config": {
"skills_oss_dir": os.getenv("SKILLS_OSS_DIR", "skill_memory/"),
"skills_local_dir": os.getenv("SKILLS_LOCAL_DIR", "/tmp/skill_memory/"),
},
},
},
"enable_textual_memory": True,
Expand Down
3 changes: 3 additions & 0 deletions src/memos/api/handlers/component_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,9 @@ def init_server() -> dict[str, Any]:
)
logger.debug("Searcher created")

# Set searcher to mem_reader
mem_reader.set_searcher(searcher)

# Initialize feedback server
feedback_server = SimpleMemFeedback(
llm=llm,
Expand Down
13 changes: 12 additions & 1 deletion src/memos/api/handlers/formatters_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,17 @@ def post_process_textual_mem(
fact_mem = [
mem
for mem in text_formatted_mem
if mem["metadata"]["memory_type"] not in ["ToolSchemaMemory", "ToolTrajectoryMemory"]
if mem["metadata"]["memory_type"]
in ["WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"]
]
tool_mem = [
mem
for mem in text_formatted_mem
if mem["metadata"]["memory_type"] in ["ToolSchemaMemory", "ToolTrajectoryMemory"]
]
skill_mem = [
mem for mem in text_formatted_mem if mem["metadata"]["memory_type"] == "SkillMemory"
]

memories_result["text_mem"].append(
{
Expand All @@ -134,6 +138,13 @@ def post_process_textual_mem(
"total_nodes": len(tool_mem),
}
)
memories_result["skill_mem"].append(
{
"cube_id": mem_cube_id,
"memories": skill_mem,
"total_nodes": len(skill_mem),
}
)
return memories_result


Expand Down
43 changes: 42 additions & 1 deletion src/memos/api/handlers/memory_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,48 @@ def handle_get_memory(memory_id: str, naive_mem_cube: NaiveMemCube) -> GetMemory
)


def handle_get_memory_by_ids(
memory_ids: list[str], naive_mem_cube: NaiveMemCube
) -> GetMemoryResponse:
"""
Handler for getting multiple memories by their IDs.

Retrieves multiple memories and formats them as a list of dictionaries.
"""
try:
memories = naive_mem_cube.text_mem.get_by_ids(memory_ids=memory_ids)
except Exception:
memories = []

# Ensure memories is not None
if memories is None:
memories = []

if naive_mem_cube.pref_mem is not None:
collection_names = ["explicit_preference", "implicit_preference"]
for collection_name in collection_names:
try:
result = naive_mem_cube.pref_mem.get_by_ids_with_collection_name(
collection_name, memory_ids
)
if result is not None:
memories.extend(result)
except Exception:
continue

memories = [
format_memory_item(item, save_sources=False) for item in memories if item is not None
]

return GetMemoryResponse(
message="Memories retrieved successfully", code=200, data={"memories": memories}
)


def handle_get_memories(
get_mem_req: GetMemoryRequest, naive_mem_cube: NaiveMemCube
) -> GetMemoryResponse:
results: dict[str, Any] = {"text_mem": [], "pref_mem": [], "tool_mem": []}
results: dict[str, Any] = {"text_mem": [], "pref_mem": [], "tool_mem": [], "skill_mem": []}
memories = naive_mem_cube.text_mem.get_all(
user_name=get_mem_req.mem_cube_id,
user_id=get_mem_req.user_id,
Expand All @@ -226,6 +264,8 @@ def handle_get_memories(

if not get_mem_req.include_tool_memory:
results["tool_mem"] = []
if not get_mem_req.include_skill_memory:
results["skill_mem"] = []

preferences: list[TextualMemoryItem] = []

Expand Down Expand Up @@ -270,6 +310,7 @@ def handle_get_memories(
"text_mem": results.get("text_mem", []),
"pref_mem": results.get("pref_mem", []),
"tool_mem": results.get("tool_mem", []),
"skill_mem": results.get("skill_mem", []),
}

return GetMemoryResponse(message="Memories retrieved successfully", data=filtered_results)
Expand Down
17 changes: 15 additions & 2 deletions src/memos/api/product_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,18 @@ class APISearchRequest(BaseRequest):
description="Number of tool memories to retrieve (top-K). Default: 6.",
)

include_skill_memory: bool = Field(
True,
description="Whether to retrieve skill memories along with general memories. "
"If enabled, the system will automatically recall skill memories "
"relevant to the query. Default: True.",
)
skill_mem_top_k: int = Field(
3,
ge=0,
description="Number of skill memories to retrieve (top-K). Default: 3.",
)

# ==== Filter conditions ====
# TODO: maybe add detailed description later
filter: dict[str, Any] | None = Field(
Expand Down Expand Up @@ -393,7 +405,7 @@ class APISearchRequest(BaseRequest):
# Internal field for search memory type
search_memory_type: str = Field(
"All",
description="Type of memory to search: All, WorkingMemory, LongTermMemory, UserMemory, OuterMemory, ToolSchemaMemory, ToolTrajectoryMemory",
description="Type of memory to search: All, WorkingMemory, LongTermMemory, UserMemory, OuterMemory, ToolSchemaMemory, ToolTrajectoryMemory, SkillMemory",
)

# ==== Context ====
Expand Down Expand Up @@ -772,7 +784,8 @@ class GetMemoryRequest(BaseRequest):
mem_cube_id: str = Field(..., description="Cube ID")
user_id: str | None = Field(None, description="User ID")
include_preference: bool = Field(True, description="Whether to return preference memory")
include_tool_memory: bool = Field(False, description="Whether to return tool memory")
include_tool_memory: bool = Field(True, description="Whether to return tool memory")
include_skill_memory: bool = Field(True, description="Whether to return skill memory")
filter: dict[str, Any] | None = Field(None, description="Filter for the memory")
page: int | None = Field(
None,
Expand Down
8 changes: 8 additions & 0 deletions src/memos/api/routers/server_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,14 @@ def get_memory_by_id(memory_id: str):
)


@router.get("/get_memory_by_ids", summary="Get memory by ids", response_model=GetMemoryResponse)
def get_memory_by_ids(memory_ids: list[str]):
return handlers.memory_handler.handle_get_memory_by_ids(
memory_ids=memory_ids,
naive_mem_cube=naive_mem_cube,
)


@router.post(
"/delete_memory", summary="Delete memories for user", response_model=DeleteMemoryResponse
)
Expand Down
9 changes: 9 additions & 0 deletions src/memos/configs/mem_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ class MultiModalStructMemReaderConfig(BaseMemReaderConfig):
"If None, reads from FILE_PARSER_DIRECT_MARKDOWN_HOSTNAMES environment variable.",
)

oss_config: dict[str, Any] | None = Field(
default=None,
description="OSS configuration for the MemReader",
)
skills_dir_config: dict[str, Any] | None = Field(
default=None,
description="Skills directory for the MemReader",
)


class StrategyStructMemReaderConfig(BaseMemReaderConfig):
"""StrategyStruct MemReader configuration class."""
Expand Down
29 changes: 13 additions & 16 deletions src/memos/graph_dbs/polardb.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,30 +1130,27 @@ def get_nodes(
- Assumes all provided IDs are valid and exist.
- Returns empty list if input is empty.
"""
logger.info(f"get_nodes ids:{ids},user_name:{user_name}")
if not ids:
return []

# Build WHERE clause using agtype_access_operator like get_node method
where_conditions = []
params = []

for id_val in ids:
where_conditions.append(
"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = %s::agtype"
)
params.append(self.format_param_value(id_val))

where_clause = " OR ".join(where_conditions)
# Build WHERE clause using IN operator with agtype array
# Use ANY operator with array for better performance
placeholders = ",".join(["%s"] * len(ids))
params = [self.format_param_value(id_val) for id_val in ids]

query = f"""
SELECT id, properties, embedding
FROM "{self.db_name}_graph"."Memory"
WHERE ({where_clause})
WHERE ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = ANY(ARRAY[{placeholders}]::agtype[])
"""

user_name = user_name if user_name else self.config.user_name
query += " AND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype"
params.append(self.format_param_value(user_name))
# Only add user_name filter if provided
if user_name is not None:
query += " AND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype"
params.append(self.format_param_value(user_name))

logger.info(f"get_nodes query:{query},params:{params}")

conn = None
try:
Expand Down Expand Up @@ -4313,7 +4310,7 @@ def _build_user_name_and_kb_ids_conditions_sql(
user_name_conditions = []
effective_user_name = user_name if user_name else default_user_name

if effective_user_name and default_user_name != "xxx":
if effective_user_name:
user_name_conditions.append(
f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{effective_user_name}\"'::agtype"
)
Expand Down
Loading