Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/memos/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,18 @@ def get_embedder_config() -> dict[str, Any]:
"model_name_or_path": os.getenv("MOS_EMBEDDER_MODEL", "text-embedding-3-large"),
"headers_extra": json.loads(os.getenv("MOS_EMBEDDER_HEADERS_EXTRA", "{}")),
"base_url": os.getenv("MOS_EMBEDDER_API_BASE", "http://openai.com"),
"backup_client": os.getenv("MOS_EMBEDDER_BACKUP_CLIENT", "false").lower()
== "true",
"backup_base_url": os.getenv(
"MOS_EMBEDDER_BACKUP_API_BASE", "http://openai.com"
),
"backup_api_key": os.getenv("MOS_EMBEDDER_BACKUP_API_KEY", "sk-xxxx"),
"backup_headers_extra": json.loads(
os.getenv("MOS_EMBEDDER_BACKUP_HEADERS_EXTRA", "{}")
),
"backup_model_name_or_path": os.getenv(
"MOS_EMBEDDER_BACKUP_MODEL", "text-embedding-3-large"
),
},
}
else: # ollama
Expand Down
17 changes: 17 additions & 0 deletions src/memos/configs/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,23 @@ class UniversalAPIEmbedderConfig(BaseEmbedderConfig):
base_url: str | None = Field(
default=None, description="Optional base URL for custom or proxied endpoint"
)
backup_client: bool = Field(
default=False,
description="Whether to use backup client",
)
backup_base_url: str | None = Field(
default=None, description="Optional backup base URL for custom or proxied endpoint"
)
backup_api_key: str | None = Field(
default=None, description="Optional backup API key for the embedding provider"
)
backup_headers_extra: dict[str, Any] | None = Field(
default=None,
description="Extra headers for the backup embedding model",
)
backup_model_name_or_path: str | None = Field(
default=None, description="Optional backup model name or path"
)


class EmbedderConfigFactory(BaseConfig):
Expand Down
73 changes: 67 additions & 6 deletions src/memos/embedders/universal_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import asyncio
import os
import time

from openai import AzureOpenAI as AzureClient
from openai import OpenAI as OpenAIClient

Expand Down Expand Up @@ -29,23 +33,80 @@ def __init__(self, config: UniversalAPIEmbedderConfig):
)
else:
raise ValueError(f"Embeddings unsupported provider: {self.provider}")
self.use_backup_client = config.backup_client
if self.use_backup_client:
self.backup_client = OpenAIClient(
api_key=config.backup_api_key,
base_url=config.backup_base_url,
default_headers=config.backup_headers_extra
if config.backup_headers_extra
else None,
)

@timed_with_status(
log_prefix="model_timed_embedding",
log_extra_args={"model_name_or_path": "text-embedding-3-large"},
log_extra_args=lambda self, texts: {
"model_name_or_path": "text-embedding-3-large",
"text_len": len(texts),
"text_content": texts,
},
)
def embed(self, texts: list[str]) -> list[list[float]]:
if isinstance(texts, str):
texts = [texts]
# Truncate texts if max_tokens is configured
texts = self._truncate_texts(texts)

logger.info(f"Embeddings request with input: {texts}")
if self.provider == "openai" or self.provider == "azure":
try:
response = self.client.embeddings.create(
model=getattr(self.config, "model_name_or_path", "text-embedding-3-large"),
input=texts,

async def _create_embeddings():
return self.client.embeddings.create(
model=getattr(self.config, "model_name_or_path", "text-embedding-3-large"),
input=texts,
)

init_time = time.time()
response = asyncio.run(
asyncio.wait_for(
_create_embeddings(), timeout=int(os.getenv("MOS_EMBEDDER_TIMEOUT", 5))
)
)
logger.info(f"Embeddings request succeeded with {time.time() - init_time} seconds")
logger.info(f"Embeddings request response: {response}")
return [r.embedding for r in response.data]
except Exception as e:
raise Exception(f"Embeddings request ended with error: {e}") from e
logger.warning(
f"Embeddings request ended with {type(e).__name__} error: {e}, try backup client"
)
if self.use_backup_client:
try:

async def _create_embeddings_backup():
return self.backup_client.embeddings.create(
model=getattr(
self.config,
"backup_model_name_or_path",
"text-embedding-3-large",
),
input=texts,
)

init_time = time.time()
response = asyncio.run(
asyncio.wait_for(
_create_embeddings_backup(),
timeout=int(os.getenv("MOS_EMBEDDER_TIMEOUT", 5)),
)
)
logger.info(
f"Backup embeddings request succeeded with {time.time() - init_time} seconds"
)
logger.info(f"Backup embeddings request response: {response}")
return [r.embedding for r in response.data]
except Exception as e:
raise ValueError(f"Backup embeddings request ended with error: {e}") from e
else:
raise ValueError(f"Embeddings request ended with error: {e}") from e
else:
raise ValueError(f"Embeddings unsupported provider: {self.provider}")
84 changes: 67 additions & 17 deletions src/memos/mem_reader/multi_modal_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,11 @@ def _split_large_memory_item(
chunks = self.chunker.chunk(item_text)
split_items = []

for chunk in chunks:
def _create_chunk_item(chunk):
# Chunk objects have a 'text' attribute
chunk_text = chunk.text
if not chunk_text or not chunk_text.strip():
continue

return None
# Create a new memory item for each chunk, preserving original metadata
split_item = self._make_memory_item(
value=chunk_text,
Expand All @@ -98,8 +97,17 @@ def _split_large_memory_item(
key=item.metadata.key,
sources=item.metadata.sources or [],
background=item.metadata.background or "",
need_embed=False,
)
split_items.append(split_item)
return split_item

# Use thread pool to parallel process chunks
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
futures = [executor.submit(_create_chunk_item, chunk) for chunk in chunks]
for future in concurrent.futures.as_completed(futures):
split_item = future.result()
if split_item is not None:
split_items.append(split_item)

return split_items if split_items else [item]
except Exception as e:
Expand Down Expand Up @@ -127,15 +135,41 @@ def _concat_multi_modal_memories(

# Split large memory items before processing
processed_items = []
for item in all_memory_items:
item_text = item.memory or ""
item_tokens = self._count_tokens(item_text)
if item_tokens > max_tokens:
# Split the large item into multiple chunks
split_items = self._split_large_memory_item(item, max_tokens)
processed_items.extend(split_items)
else:
processed_items.append(item)
# control whether to parallel chunk large memory items
parallel_chunking = True

if parallel_chunking:
# parallel chunk large memory items
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
future_to_item = {
executor.submit(self._split_large_memory_item, item, max_tokens): item
for item in all_memory_items
if (item.memory or "") and self._count_tokens(item.memory) > max_tokens
}
processed_items.extend(
[
item
for item in all_memory_items
if not (
(item.memory or "") and self._count_tokens(item.memory) > max_tokens
)
]
)
# collect split items from futures
for future in concurrent.futures.as_completed(future_to_item):
split_items = future.result()
processed_items.extend(split_items)
else:
# serial chunk large memory items
for item in all_memory_items:
item_text = item.memory or ""
item_tokens = self._count_tokens(item_text)
if item_tokens > max_tokens:
# Split the large item into multiple chunks
split_items = self._split_large_memory_item(item, max_tokens)
processed_items.extend(split_items)
else:
processed_items.append(item)

# If only one item after processing, return as-is
if len(processed_items) == 1:
Expand Down Expand Up @@ -797,13 +831,29 @@ def _process_multi_modal_data(
if isinstance(scene_data_info, list):
# Parse each message in the list
all_memory_items = []
for msg in scene_data_info:
items = self.multi_modal_parser.parse(msg, info, mode="fast", **kwargs)
all_memory_items.extend(items)
# Use thread pool to parse each message in parallel
with ContextThreadPoolExecutor(max_workers=30) as executor:
futures = [
executor.submit(
self.multi_modal_parser.parse,
msg,
info,
mode="fast",
need_emb=False,
**kwargs,
)
for msg in scene_data_info
]
for future in concurrent.futures.as_completed(futures):
try:
items = future.result()
all_memory_items.extend(items)
except Exception as e:
logger.error(f"[MultiModalFine] Error in parallel parsing: {e}")
else:
# Parse as single message
all_memory_items = self.multi_modal_parser.parse(
scene_data_info, info, mode="fast", **kwargs
scene_data_info, info, mode="fast", need_emb=False, **kwargs
)
fast_memory_items = self._concat_multi_modal_memories(all_memory_items)
if mode == "fast":
Expand Down
3 changes: 2 additions & 1 deletion src/memos/mem_reader/read_multi_modal/assistant_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def parse_fast(
info: dict[str, Any],
**kwargs,
) -> list[TextualMemoryItem]:
need_emb = kwargs.get("need_emb", True)
if not isinstance(message, dict):
logger.warning(f"[AssistantParser] Expected dict, got {type(message)}")
return []
Expand Down Expand Up @@ -290,7 +291,7 @@ def parse_fast(
status="activated",
tags=["mode:fast"],
key=_derive_key(line),
embedding=self.embedder.embed([line])[0],
embedding=self.embedder.embed([line])[0] if need_emb else None,
usage=[],
sources=sources,
background="",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def parse(
logger.warning(f"[MultiModalParser] No parser found for message: {message}")
return []

logger.info(f"[{parser.__class__.__name__}] Parsing message in {mode} mode: {message}")
# Parse using the appropriate parser
try:
return parser.parse(message, info, mode=mode, **kwargs)
Expand Down
3 changes: 2 additions & 1 deletion src/memos/mem_reader/read_multi_modal/user_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def parse_fast(
info: dict[str, Any],
**kwargs,
) -> list[TextualMemoryItem]:
need_emb = kwargs.get("need_emb", True)
if not isinstance(message, dict):
logger.warning(f"[UserParser] Expected dict, got {type(message)}")
return []
Expand Down Expand Up @@ -192,7 +193,7 @@ def parse_fast(
status="activated",
tags=["mode:fast"],
key=_derive_key(line),
embedding=self.embedder.embed([line])[0],
embedding=self.embedder.embed([line])[0] if need_emb else None,
usage=[],
sources=sources,
background="",
Expand Down
3 changes: 2 additions & 1 deletion src/memos/mem_reader/simple_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def _make_memory_item(
background: str = "",
type_: str = "fact",
confidence: float = 0.99,
need_embed: bool = True,
**kwargs,
) -> TextualMemoryItem:
"""construct memory item"""
Expand All @@ -213,7 +214,7 @@ def _make_memory_item(
status="activated",
tags=tags or [],
key=key if key is not None else derive_key(value),
embedding=self.embedder.embed([value])[0],
embedding=self.embedder.embed([value])[0] if need_embed else None,
usage=[],
sources=sources or [],
background=background,
Expand Down
6 changes: 5 additions & 1 deletion src/memos/multi_mem_cube/single_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
import os
import time
import traceback

from dataclasses import dataclass
Expand Down Expand Up @@ -790,7 +791,7 @@ def _process_text_mem(
extract_mode,
add_req.mode,
)

init_time = time.time()
# Extract memories
memories_local = self.mem_reader.get_memory(
[add_req.messages],
Expand All @@ -804,6 +805,9 @@ def _process_text_mem(
mode=extract_mode,
user_name=user_context.mem_cube_id,
)
self.logger.info(
f"Time for get_memory in extract mode {extract_mode}: {time.time() - init_time}"
)
flattened_local = [mm for m in memories_local for mm in m]

# Explicitly set source_doc_id to metadata if present in info
Expand Down