diff --git a/src/memos/api/config.py b/src/memos/api/config.py index a3bf25be0..024de4af5 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -440,6 +440,18 @@ def get_embedder_config() -> dict[str, Any]: "model_name_or_path": os.getenv("MOS_EMBEDDER_MODEL", "text-embedding-3-large"), "headers_extra": json.loads(os.getenv("MOS_EMBEDDER_HEADERS_EXTRA", "{}")), "base_url": os.getenv("MOS_EMBEDDER_API_BASE", "http://openai.com"), + "backup_client": os.getenv("MOS_EMBEDDER_BACKUP_CLIENT", "false").lower() + == "true", + "backup_base_url": os.getenv( + "MOS_EMBEDDER_BACKUP_API_BASE", "http://openai.com" + ), + "backup_api_key": os.getenv("MOS_EMBEDDER_BACKUP_API_KEY", "sk-xxxx"), + "backup_headers_extra": json.loads( + os.getenv("MOS_EMBEDDER_BACKUP_HEADERS_EXTRA", "{}") + ), + "backup_model_name_or_path": os.getenv( + "MOS_EMBEDDER_BACKUP_MODEL", "text-embedding-3-large" + ), }, } else: # ollama diff --git a/src/memos/configs/embedder.py b/src/memos/configs/embedder.py index c2e648247..050043ab0 100644 --- a/src/memos/configs/embedder.py +++ b/src/memos/configs/embedder.py @@ -58,6 +58,23 @@ class UniversalAPIEmbedderConfig(BaseEmbedderConfig): base_url: str | None = Field( default=None, description="Optional base URL for custom or proxied endpoint" ) + backup_client: bool = Field( + default=False, + description="Whether to use backup client", + ) + backup_base_url: str | None = Field( + default=None, description="Optional backup base URL for custom or proxied endpoint" + ) + backup_api_key: str | None = Field( + default=None, description="Optional backup API key for the embedding provider" + ) + backup_headers_extra: dict[str, Any] | None = Field( + default=None, + description="Extra headers for the backup embedding model", + ) + backup_model_name_or_path: str | None = Field( + default=None, description="Optional backup model name or path" + ) class EmbedderConfigFactory(BaseConfig): diff --git a/src/memos/embedders/universal_api.py b/src/memos/embedders/universal_api.py index 60bae15a5..d2bdf9318 100644 --- a/src/memos/embedders/universal_api.py +++ b/src/memos/embedders/universal_api.py @@ -1,3 +1,7 @@ +import asyncio +import os +import time + from openai import AzureOpenAI as AzureClient from openai import OpenAI as OpenAIClient @@ -29,23 +33,80 @@ def __init__(self, config: UniversalAPIEmbedderConfig): ) else: raise ValueError(f"Embeddings unsupported provider: {self.provider}") + self.use_backup_client = config.backup_client + if self.use_backup_client: + self.backup_client = OpenAIClient( + api_key=config.backup_api_key, + base_url=config.backup_base_url, + default_headers=config.backup_headers_extra + if config.backup_headers_extra + else None, + ) @timed_with_status( log_prefix="model_timed_embedding", - log_extra_args={"model_name_or_path": "text-embedding-3-large"}, + log_extra_args=lambda self, texts: { + "model_name_or_path": "text-embedding-3-large", + "text_len": len(texts), + "text_content": texts, + }, ) def embed(self, texts: list[str]) -> list[list[float]]: + if isinstance(texts, str): + texts = [texts] # Truncate texts if max_tokens is configured texts = self._truncate_texts(texts) - + logger.info(f"Embeddings request with input: {texts}") if self.provider == "openai" or self.provider == "azure": try: - response = self.client.embeddings.create( - model=getattr(self.config, "model_name_or_path", "text-embedding-3-large"), - input=texts, + + async def _create_embeddings(): + return self.client.embeddings.create( + model=getattr(self.config, "model_name_or_path", "text-embedding-3-large"), + input=texts, + ) + + init_time = time.time() + response = asyncio.run( + asyncio.wait_for( + _create_embeddings(), timeout=int(os.getenv("MOS_EMBEDDER_TIMEOUT", 5)) + ) ) + logger.info(f"Embeddings request succeeded with {time.time() - init_time} seconds") + logger.info(f"Embeddings request response: {response}") return [r.embedding for r in response.data] except Exception as e: - raise Exception(f"Embeddings request ended with error: {e}") from e + logger.warning( + f"Embeddings request ended with {type(e).__name__} error: {e}, try backup client" + ) + if self.use_backup_client: + try: + + async def _create_embeddings_backup(): + return self.backup_client.embeddings.create( + model=getattr( + self.config, + "backup_model_name_or_path", + "text-embedding-3-large", + ), + input=texts, + ) + + init_time = time.time() + response = asyncio.run( + asyncio.wait_for( + _create_embeddings_backup(), + timeout=int(os.getenv("MOS_EMBEDDER_TIMEOUT", 5)), + ) + ) + logger.info( + f"Backup embeddings request succeeded with {time.time() - init_time} seconds" + ) + logger.info(f"Backup embeddings request response: {response}") + return [r.embedding for r in response.data] + except Exception as e: + raise ValueError(f"Backup embeddings request ended with error: {e}") from e + else: + raise ValueError(f"Embeddings request ended with error: {e}") from e else: raise ValueError(f"Embeddings unsupported provider: {self.provider}") diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index 9edcd0a55..5579e3a9e 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -79,12 +79,11 @@ def _split_large_memory_item( chunks = self.chunker.chunk(item_text) split_items = [] - for chunk in chunks: + def _create_chunk_item(chunk): # Chunk objects have a 'text' attribute chunk_text = chunk.text if not chunk_text or not chunk_text.strip(): - continue - + return None # Create a new memory item for each chunk, preserving original metadata split_item = self._make_memory_item( value=chunk_text, @@ -98,8 +97,17 @@ def _split_large_memory_item( key=item.metadata.key, sources=item.metadata.sources or [], background=item.metadata.background or "", + need_embed=False, ) - split_items.append(split_item) + return split_item + + # Use thread pool to parallel process chunks + with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: + futures = [executor.submit(_create_chunk_item, chunk) for chunk in chunks] + for future in concurrent.futures.as_completed(futures): + split_item = future.result() + if split_item is not None: + split_items.append(split_item) return split_items if split_items else [item] except Exception as e: @@ -127,15 +135,41 @@ def _concat_multi_modal_memories( # Split large memory items before processing processed_items = [] - for item in all_memory_items: - item_text = item.memory or "" - item_tokens = self._count_tokens(item_text) - if item_tokens > max_tokens: - # Split the large item into multiple chunks - split_items = self._split_large_memory_item(item, max_tokens) - processed_items.extend(split_items) - else: - processed_items.append(item) + # control whether to parallel chunk large memory items + parallel_chunking = True + + if parallel_chunking: + # parallel chunk large memory items + with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: + future_to_item = { + executor.submit(self._split_large_memory_item, item, max_tokens): item + for item in all_memory_items + if (item.memory or "") and self._count_tokens(item.memory) > max_tokens + } + processed_items.extend( + [ + item + for item in all_memory_items + if not ( + (item.memory or "") and self._count_tokens(item.memory) > max_tokens + ) + ] + ) + # collect split items from futures + for future in concurrent.futures.as_completed(future_to_item): + split_items = future.result() + processed_items.extend(split_items) + else: + # serial chunk large memory items + for item in all_memory_items: + item_text = item.memory or "" + item_tokens = self._count_tokens(item_text) + if item_tokens > max_tokens: + # Split the large item into multiple chunks + split_items = self._split_large_memory_item(item, max_tokens) + processed_items.extend(split_items) + else: + processed_items.append(item) # If only one item after processing, return as-is if len(processed_items) == 1: @@ -797,13 +831,29 @@ def _process_multi_modal_data( if isinstance(scene_data_info, list): # Parse each message in the list all_memory_items = [] - for msg in scene_data_info: - items = self.multi_modal_parser.parse(msg, info, mode="fast", **kwargs) - all_memory_items.extend(items) + # Use thread pool to parse each message in parallel + with ContextThreadPoolExecutor(max_workers=30) as executor: + futures = [ + executor.submit( + self.multi_modal_parser.parse, + msg, + info, + mode="fast", + need_emb=False, + **kwargs, + ) + for msg in scene_data_info + ] + for future in concurrent.futures.as_completed(futures): + try: + items = future.result() + all_memory_items.extend(items) + except Exception as e: + logger.error(f"[MultiModalFine] Error in parallel parsing: {e}") else: # Parse as single message all_memory_items = self.multi_modal_parser.parse( - scene_data_info, info, mode="fast", **kwargs + scene_data_info, info, mode="fast", need_emb=False, **kwargs ) fast_memory_items = self._concat_multi_modal_memories(all_memory_items) if mode == "fast": diff --git a/src/memos/mem_reader/read_multi_modal/assistant_parser.py b/src/memos/mem_reader/read_multi_modal/assistant_parser.py index 3519216d2..89d4fec7f 100644 --- a/src/memos/mem_reader/read_multi_modal/assistant_parser.py +++ b/src/memos/mem_reader/read_multi_modal/assistant_parser.py @@ -210,6 +210,7 @@ def parse_fast( info: dict[str, Any], **kwargs, ) -> list[TextualMemoryItem]: + need_emb = kwargs.get("need_emb", True) if not isinstance(message, dict): logger.warning(f"[AssistantParser] Expected dict, got {type(message)}") return [] @@ -290,7 +291,7 @@ def parse_fast( status="activated", tags=["mode:fast"], key=_derive_key(line), - embedding=self.embedder.embed([line])[0], + embedding=self.embedder.embed([line])[0] if need_emb else None, usage=[], sources=sources, background="", diff --git a/src/memos/mem_reader/read_multi_modal/multi_modal_parser.py b/src/memos/mem_reader/read_multi_modal/multi_modal_parser.py index 2c8140419..808410e65 100644 --- a/src/memos/mem_reader/read_multi_modal/multi_modal_parser.py +++ b/src/memos/mem_reader/read_multi_modal/multi_modal_parser.py @@ -149,6 +149,7 @@ def parse( logger.warning(f"[MultiModalParser] No parser found for message: {message}") return [] + logger.info(f"[{parser.__class__.__name__}] Parsing message in {mode} mode: {message}") # Parse using the appropriate parser try: return parser.parse(message, info, mode=mode, **kwargs) diff --git a/src/memos/mem_reader/read_multi_modal/user_parser.py b/src/memos/mem_reader/read_multi_modal/user_parser.py index 1c9afab65..1ab48c82e 100644 --- a/src/memos/mem_reader/read_multi_modal/user_parser.py +++ b/src/memos/mem_reader/read_multi_modal/user_parser.py @@ -151,6 +151,7 @@ def parse_fast( info: dict[str, Any], **kwargs, ) -> list[TextualMemoryItem]: + need_emb = kwargs.get("need_emb", True) if not isinstance(message, dict): logger.warning(f"[UserParser] Expected dict, got {type(message)}") return [] @@ -192,7 +193,7 @@ def parse_fast( status="activated", tags=["mode:fast"], key=_derive_key(line), - embedding=self.embedder.embed([line])[0], + embedding=self.embedder.embed([line])[0] if need_emb else None, usage=[], sources=sources, background="", diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index 783da763e..6f8bff4ad 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -198,6 +198,7 @@ def _make_memory_item( background: str = "", type_: str = "fact", confidence: float = 0.99, + need_embed: bool = True, **kwargs, ) -> TextualMemoryItem: """construct memory item""" @@ -213,7 +214,7 @@ def _make_memory_item( status="activated", tags=tags or [], key=key if key is not None else derive_key(value), - embedding=self.embedder.embed([value])[0], + embedding=self.embedder.embed([value])[0] if need_embed else None, usage=[], sources=sources or [], background=background, diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 426cf32be..f12ff1a1d 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -2,6 +2,7 @@ import json import os +import time import traceback from dataclasses import dataclass @@ -790,7 +791,7 @@ def _process_text_mem( extract_mode, add_req.mode, ) - + init_time = time.time() # Extract memories memories_local = self.mem_reader.get_memory( [add_req.messages], @@ -804,6 +805,9 @@ def _process_text_mem( mode=extract_mode, user_name=user_context.mem_cube_id, ) + self.logger.info( + f"Time for get_memory in extract mode {extract_mode}: {time.time() - init_time}" + ) flattened_local = [mm for m in memories_local for mm in m] # Explicitly set source_doc_id to metadata if present in info