Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions src/memos/api/handlers/memory_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,44 @@ def handle_get_memory(memory_id: str, naive_mem_cube: NaiveMemCube) -> GetMemory
)


def handle_get_memory_by_ids(
memory_ids: list[str], naive_mem_cube: NaiveMemCube
) -> GetMemoryResponse:
"""
Handler for getting multiple memories by their IDs.

Retrieves multiple memories and formats them as a list of dictionaries.
"""
try:
memories = naive_mem_cube.text_mem.get_by_ids(memory_ids=memory_ids)
except Exception:
memories = []

# Ensure memories is not None
if memories is None:
memories = []

if naive_mem_cube.pref_mem is not None:
collection_names = ["explicit_preference", "implicit_preference"]
for collection_name in collection_names:
try:
result = naive_mem_cube.pref_mem.get_by_ids_with_collection_name(
collection_name, memory_ids
)
if result is not None:
memories.extend(result)
except Exception:
continue

memories = [
format_memory_item(item, save_sources=False) for item in memories if item is not None
]

return GetMemoryResponse(
message="Memories retrieved successfully", code=200, data={"memories": memories}
)


def handle_get_memories(
get_mem_req: GetMemoryRequest, naive_mem_cube: NaiveMemCube
) -> GetMemoryResponse:
Expand Down
8 changes: 8 additions & 0 deletions src/memos/api/routers/server_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,14 @@ def get_memory_by_id(memory_id: str):
)


@router.get("/get_memory_by_ids", summary="Get memory by ids", response_model=GetMemoryResponse)
def get_memory_by_ids(memory_ids: list[str]):
return handlers.memory_handler.handle_get_memory_by_ids(
memory_ids=memory_ids,
naive_mem_cube=naive_mem_cube,
)


@router.post(
"/delete_memory", summary="Delete memories for user", response_model=DeleteMemoryResponse
)
Expand Down
54 changes: 47 additions & 7 deletions src/memos/mem_reader/read_skill_memory/process_skill_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,27 @@
logger = get_logger(__name__)


def add_id_to_mysql(memory_id: str, mem_cube_id: str):
"""Add id to mysql, will deprecate this function in the future"""
# TODO: tmp function, deprecate soon
import requests

skill_mysql_url = os.getenv("SKILLS_MYSQL_URL", "")
skill_mysql_bearer = os.getenv("SKILLS_MYSQL_BEARER", "")

if not skill_mysql_url or not skill_mysql_bearer:
logger.warning("SKILLS_MYSQL_URL or SKILLS_MYSQL_BEARER is not set")
return None
headers = {"Authorization": skill_mysql_bearer, "Content-Type": "application/json"}
data = {"memCubeId": mem_cube_id, "skillId": memory_id}
try:
response = requests.post(skill_mysql_url, headers=headers, json=data)
return response.json()
except Exception as e:
logger.warning(f"Error adding id to mysql: {e}")
return None


@require_python_package(
import_name="alibabacloud_oss_v2",
install_command="pip install alibabacloud-oss-v2",
Expand Down Expand Up @@ -108,7 +129,14 @@ def _split_task_chunk_by_llm(llm: BaseLLM, messages: MessageList) -> dict[str, M
for item in response_json:
task_name = item["task_name"]
message_indices = item["message_indices"]
for start, end in message_indices:
for indices in message_indices:
# Validate that indices is a list/tuple with exactly 2 elements
if not isinstance(indices, list | tuple) or len(indices) != 2:
logger.warning(
f"Invalid message indices format for task '{task_name}': {indices}, skipping"
)
continue
start, end = indices
task_chunks.setdefault(task_name, []).extend(messages[start : end + 1])
return task_chunks

Expand All @@ -125,7 +153,7 @@ def _extract_skill_memory_by_llm(
"procedure": mem["metadata"]["procedure"],
"experience": mem["metadata"]["experience"],
"preference": mem["metadata"]["preference"],
"example": mem["metadata"]["example"],
"examples": mem["metadata"]["examples"],
"tags": mem["metadata"]["tags"],
"scripts": mem["metadata"].get("scripts"),
"others": mem["metadata"]["others"],
Expand Down Expand Up @@ -153,7 +181,10 @@ def _extract_skill_memory_by_llm(
# Call LLM to extract skill memory with retry logic
for attempt in range(3):
try:
response_text = llm.generate(prompt)
# Only pass model_name_or_path if SKILLS_LLM is set
skills_llm = os.getenv("SKILLS_LLM", None)
llm_kwargs = {"model_name_or_path": skills_llm} if skills_llm else {}
response_text = llm.generate(prompt, **llm_kwargs)
# Clean up response (remove markdown code blocks if present)
response_text = response_text.strip()
response_text = response_text.replace("```json", "").replace("```", "").strip()
Expand Down Expand Up @@ -195,7 +226,7 @@ def _recall_related_skill_memories(
query = _rewrite_query(task_type, messages, llm, rewrite_query)
related_skill_memories = searcher.search(
query,
top_k=10,
top_k=5,
memory_type="SkillMemory",
info=info,
include_skill_memory=True,
Expand Down Expand Up @@ -326,11 +357,11 @@ def _write_skills_to_file(
skill_md_content += f"- {pref}\n"

# Add Examples section only if there are items
examples = skill_memory.get("example", [])
examples = skill_memory.get("examples", [])
if examples:
skill_md_content += "\n## Examples\n"
for idx, example in enumerate(examples, 1):
skill_md_content += f"\n### Example {idx}\n{example}\n"
skill_md_content += f"\n### Example {idx}\n```markdown\n{example}\n```\n"

# Add scripts reference if present
scripts = skill_memory.get("scripts")
Expand Down Expand Up @@ -444,7 +475,7 @@ def create_skill_memory_item(
procedure=skill_memory.get("procedure", ""),
experience=skill_memory.get("experience", []),
preference=skill_memory.get("preference", []),
example=skill_memory.get("example", []),
examples=skill_memory.get("examples", []),
scripts=skill_memory.get("scripts"),
others=skill_memory.get("others"),
url=skill_memory.get("url", ""),
Expand Down Expand Up @@ -501,6 +532,9 @@ def process_skill_memory_fine(
messages = _add_index_to_message(messages)

task_chunks = _split_task_chunk_by_llm(llm, messages)
if not task_chunks:
logger.warning("No task chunks found")
return []

# recall - get related skill memories for each task separately (parallel)
related_skill_memories_by_task = {}
Expand Down Expand Up @@ -647,4 +681,10 @@ def process_skill_memory_fine(
logger.warning(f"Error creating skill memory item: {e}")
continue

# TODO: deprecate this funtion and call
for skill_memory in skill_memory_items:
add_id_to_mysql(
memory_id=skill_memory.id, mem_cube_id=kwargs.get("user_name", info.get("user_id", ""))
)

return skill_memory_items
4 changes: 3 additions & 1 deletion src/memos/memories/textual/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def get(self, memory_id: str, user_name: str | None = None) -> TextualMemoryItem
"""

@abstractmethod
def get_by_ids(self, memory_ids: list[str]) -> list[TextualMemoryItem]:
def get_by_ids(
self, memory_ids: list[str], user_name: str | None = None
) -> list[TextualMemoryItem]:
"""Get memories by their IDs.
Args:
memory_ids (list[str]): List of memory IDs to retrieve.
Expand Down
3 changes: 2 additions & 1 deletion src/memos/memories/textual/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,8 @@ def get(self, memory_id: str, user_name: str | None = None) -> TextualMemoryItem
def get_by_ids(
self, memory_ids: list[str], user_name: str | None = None
) -> list[TextualMemoryItem]:
raise NotImplementedError
graph_output = self.graph_store.get_nodes(ids=memory_ids, user_name=user_name)
return graph_output

def get_all(
self,
Expand Down
Loading