diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 4c2d6ca16..1d21d77c2 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,28 +1,28 @@ ## Description - +Related Issue (Required): Fixes @issue_number -Summary: (summary) +## Type of change -Fix: #(issue) +Please delete options that are not relevant. -Docs Issue/PR: (docs-issue-or-pr-link) +- [ ] Bug fix (non-breaking change which fixes an issue) +- [ ] New feature (non-breaking change which adds functionality) +- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) +- [ ] Refactor (does not change functionality, e.g. code style improvements, linting) +- [ ] Documentation update -Reviewer: @(reviewer) +## How Has This Been Tested? -## Checklist: +Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce. Please also list any relevant details for your test configuration + +- [ ] Unit Test +- [ ] Test Script Or Test Steps (please provide) +- [ ] Pipeline Automated API Test (please provide) + +## Checklist - [ ] I have performed a self-review of my own code | 我已自行检查了自己的代码 - [ ] I have commented my code in hard-to-understand areas | 我已在难以理解的地方对代码进行了注释 @@ -30,3 +30,8 @@ Reviewer: @(reviewer) - [ ] I have created related documentation issue/PR in [MemOS-Docs](https://github.com/MemTensor/MemOS-Docs) (if applicable) | 我已在 [MemOS-Docs](https://github.com/MemTensor/MemOS-Docs) 中创建了相关的文档 issue/PR(如果适用) - [ ] I have linked the issue to this PR (if applicable) | 我已将 issue 链接到此 PR(如果适用) - [ ] I have mentioned the person who will review this PR | 我已提及将审查此 PR 的人 + +## Reviewer Checklist +- [ ] closes #xxxx (Replace xxxx with the GitHub issue number) +- [ ] Made sure Checks passed +- [ ] Tests have been provided diff --git a/docker/requirements-full.txt b/docker/requirements-full.txt index be9ed2068..a14257a76 100644 --- a/docker/requirements-full.txt +++ b/docker/requirements-full.txt @@ -89,7 +89,7 @@ nvidia-cusparselt-cu12==0.6.3 nvidia-nccl-cu12==2.26.2 nvidia-nvjitlink-cu12==12.6.85 nvidia-nvtx-cu12==12.6.77 -ollama==0.4.9 +ollama==0.5.0 onnxruntime==1.22.1 openai==1.97.0 openapi-pydantic==0.5.1 @@ -184,3 +184,4 @@ py-key-value-aio==0.2.8 py-key-value-shared==0.2.8 PyJWT==2.10.1 pytest==9.0.2 +alibabacloud-oss-v2==1.2.2 diff --git a/docker/requirements.txt b/docker/requirements.txt index f89617c10..340f4e140 100644 --- a/docker/requirements.txt +++ b/docker/requirements.txt @@ -54,7 +54,7 @@ mdurl==0.1.2 more-itertools==10.8.0 neo4j==5.28.1 numpy==2.3.4 -ollama==0.4.9 +ollama==0.5.0 openai==1.109.1 openapi-pydantic==0.5.1 orjson==3.11.4 @@ -123,3 +123,4 @@ uvicorn==0.38.0 uvloop==0.22.1; sys_platform != 'win32' watchfiles==1.1.1 websockets==15.0.1 +alibabacloud-oss-v2==1.2.2 diff --git a/examples/extras/nli_e2e_example.py b/examples/extras/nli_e2e_example.py new file mode 100644 index 000000000..087cceec7 --- /dev/null +++ b/examples/extras/nli_e2e_example.py @@ -0,0 +1,104 @@ +import sys +import threading +import time + +import requests +import uvicorn + +from memos.extras.nli_model.client import NLIClient +from memos.extras.nli_model.server.serve import app + + +# Config +PORT = 32534 + + +def run_server(): + print(f"Starting server on port {PORT}...") + # Using a separate thread for the server + uvicorn.run(app, host="127.0.0.1", port=PORT, log_level="info") + + +def main(): + print("Initializing E2E Test...") + + # Start server thread + server_thread = threading.Thread(target=run_server, daemon=True) + server_thread.start() + + # Wait for server to be up + print("Waiting for server to initialize (this may take time if downloading model)...") + client = NLIClient(base_url=f"http://127.0.0.1:{PORT}") + + # Poll until server is ready + start_time = time.time() + ready = False + + # Wait up to 5 minutes for model download and initialization + timeout = 300 + + while time.time() - start_time < timeout: + try: + # Check if docs endpoint is accessible + resp = requests.get(f"http://127.0.0.1:{PORT}/docs", timeout=1) + if resp.status_code == 200: + ready = True + break + except requests.ConnectionError: + pass + except Exception: + # Ignore other errors during startup + pass + + time.sleep(2) + print(".", end="", flush=True) + + print("\n") + if not ready: + print("Server failed to start in time.") + sys.exit(1) + + print("Server is up! Sending request...") + + # Test Data + source = "I like apples" + targets = ["I like apples", "I hate apples", "Paris is a city"] + + try: + results = client.compare_one_to_many(source, targets) + print("-" * 30) + print(f"Source: {source}") + print("Targets & Results:") + for t, r in zip(targets, results, strict=False): + print(f" - '{t}': {r.value}") + print("-" * 30) + + # Basic Validation + passed = True + if results[0].value != "Duplicate": + print(f"FAILURE: Expected Duplicate for '{targets[0]}', got {results[0].value}") + passed = False + + if results[1].value != "Contradiction": + print(f"FAILURE: Expected Contradiction for '{targets[1]}', got {results[1].value}") + passed = False + + if results[2].value != "Unrelated": + print(f"FAILURE: Expected Unrelated for '{targets[2]}', got {results[2].value}") + passed = False + + if passed: + print("\nSUCCESS: Logic verification passed!") + else: + print("\nFAILURE: Unexpected results!") + + except Exception as e: + print(f"Error during request: {e}") + sys.exit(1) + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("\nTest interrupted.") diff --git a/poetry.lock b/poetry.lock index fb818e665..ba31d1a31 100644 --- a/poetry.lock +++ b/poetry.lock @@ -12,6 +12,23 @@ files = [ {file = "absl_py-2.3.1.tar.gz", hash = "sha256:a97820526f7fbfd2ec1bce83f3f25e3a14840dac0d8e02a0b71cd75db3f77fc9"}, ] +[[package]] +name = "alibabacloud-oss-v2" +version = "1.2.2" +description = "Alibaba Cloud OSS (Object Storage Service) SDK V2 for Python" +optional = true +python-versions = ">=3.8" +groups = ["main"] +markers = "extra == \"skill-mem\" or extra == \"all\"" +files = [ + {file = "alibabacloud_oss_v2-1.2.2-py3-none-any.whl", hash = "sha256:d138d1bdb38da6cc20d96b96faaeb099062a710a7f3d50f4b4b39a8cfcbdc120"}, +] + +[package.dependencies] +crcmod-plus = ">=2.1.0" +pycryptodome = ">=3.4.7" +requests = ">=2.18.4" + [[package]] name = "annotated-types" version = "0.7.0" @@ -582,6 +599,65 @@ mypy = ["bokeh", "contourpy[bokeh,docs]", "docutils-stubs", "mypy (==1.15.0)", " test = ["Pillow", "contourpy[test-no-images]", "matplotlib"] test-no-images = ["pytest", "pytest-cov", "pytest-rerunfailures", "pytest-xdist", "wurlitzer"] +[[package]] +name = "crcmod-plus" +version = "2.3.1" +description = "CRC generator - modernized" +optional = true +python-versions = ">=3.9" +groups = ["main"] +markers = "extra == \"skill-mem\" or extra == \"all\"" +files = [ + {file = "crcmod_plus-2.3.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:466d5fb9a05549a401164a2ba46a560779f7240f43f0b864e9fd277c5c12133a"}, + {file = "crcmod_plus-2.3.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b31f039c440d59b808d1d90afbfd90ad901dc6e4a81d32a0fefa8d2c118064b9"}, + {file = "crcmod_plus-2.3.1-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:24088832717435fc94d948e3140518c5a19fea99d1f6180b3396320398aca4c1"}, + {file = "crcmod_plus-2.3.1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e5632576426e78c51ad4ed0569650e397f282cec2751862f3fd8a88dd9d5019a"}, + {file = "crcmod_plus-2.3.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:0313488db8e9048deee987f04859b9ad46c8e6fa26385fb1d3e481c771530961"}, + {file = "crcmod_plus-2.3.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:c1d8ae3ed019e9c164f1effee61cbc509ca39695738f7556fc0685e4c9218c86"}, + {file = "crcmod_plus-2.3.1-cp310-cp310-win32.whl", hash = "sha256:bb54ac5623938726f4e92c18af0ccd9d119011e1821e949440bbfd24552ca539"}, + {file = "crcmod_plus-2.3.1-cp310-cp310-win_amd64.whl", hash = "sha256:79c58a3118e0c95cedffb48745fa1071982f8ba84309267b6020c2fffdbfaea7"}, + {file = "crcmod_plus-2.3.1-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:b7e35e0f7d93d7571c2c9c3d6760e456999ea4c1eae5ead6acac247b5a79e469"}, + {file = "crcmod_plus-2.3.1-cp311-abi3-macosx_10_9_x86_64.whl", hash = "sha256:6853243120db84677b94b625112116f0ef69cd581741d20de58dce4c34242654"}, + {file = "crcmod_plus-2.3.1-cp311-abi3-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:17735bc4e944d552ea18c8609fc6d08a5e64ee9b29cc216ba4d623754029cc3a"}, + {file = "crcmod_plus-2.3.1-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8ac755040a2a35f43ab331978c48a9acb4ff64b425f282a296be467a410f00c3"}, + {file = "crcmod_plus-2.3.1-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:7bdcfb838ca093ca673a3bbb37f62d1e5ec7182e00cc5ee2d00759f9f9f8ab11"}, + {file = "crcmod_plus-2.3.1-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:9166bc3c9b5e7b07b4e6854cac392b4a451b31d58d3950e48c140ab7b5d05394"}, + {file = "crcmod_plus-2.3.1-cp311-abi3-win32.whl", hash = "sha256:cb99b694cce5c862560cf332a8b5e793620e28f0de3726995608bbd6f9b6e09a"}, + {file = "crcmod_plus-2.3.1-cp311-abi3-win_amd64.whl", hash = "sha256:82b0f7e968c430c5a80fe0fc59e75cb54f2e84df2ed0cee5a3ff9cadfbf8a220"}, + {file = "crcmod_plus-2.3.1-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:fcb7a64648d70cac0a90c23bc6c58de6c13b28a0841c742039ba8528e23f51d1"}, + {file = "crcmod_plus-2.3.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:abcf3ac30e41a58dd8d2659930e357d2fd47ab4fabb52382698ed1003c9a2598"}, + {file = "crcmod_plus-2.3.1-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:693d2791af64aaf4467efe1473e02acd0ef1da229100262f29198f3ad59d42f8"}, + {file = "crcmod_plus-2.3.1-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ab075292b41b33be4d2f349e1139ea897023c3ebffc28c0d4c2ed7f2b31f1bce"}, + {file = "crcmod_plus-2.3.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:ccdc48e0af53c68304d60bbccfd5f51aed9979b5721016c3e097d51e0692b35e"}, + {file = "crcmod_plus-2.3.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:283d23e4f13629413e6c963ffcc49c6166c9829b1e4ec6488e0d3703bd218dce"}, + {file = "crcmod_plus-2.3.1-cp313-cp313t-win32.whl", hash = "sha256:53319d2e9697a8d68260709aa61987fb89c49dd02b7f585b82c578659c1922b6"}, + {file = "crcmod_plus-2.3.1-cp313-cp313t-win_amd64.whl", hash = "sha256:c9ebd256f792ef01a1d0335419f679e7501d4fdf132a5206168c5269fcea65d0"}, + {file = "crcmod_plus-2.3.1-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:52abc724f5232eddbe565c258878123337339bf9cfe9ac9c154e38557b8affc5"}, + {file = "crcmod_plus-2.3.1-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:8b0e644395d68bbfb576ee28becb69d962b173fa648ce269aec260f538841fa9"}, + {file = "crcmod_plus-2.3.1-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:07962695c53eedf3c9f0bacb2d7d6c00064394d4c88c0eb7d5b082808812fe82"}, + {file = "crcmod_plus-2.3.1-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:43acb79630192f91e60ec5b979a0e1fc2a4734182ce8b37d657f11fcd27c1f86"}, + {file = "crcmod_plus-2.3.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:52aacdfc0f04510c9c0e6ecf7c09528543cb00f4d4edd0871be8c9b8e03f2c08"}, + {file = "crcmod_plus-2.3.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:ac4ce5a423f3ccf143a42ce6af4661e2f806f09a6124c24996689b3457f1afcb"}, + {file = "crcmod_plus-2.3.1-cp314-cp314t-win32.whl", hash = "sha256:cf2df1058d6bf674c8b7b6f56c7ecdc0479707c81860f032abf69526f0111f70"}, + {file = "crcmod_plus-2.3.1-cp314-cp314t-win_amd64.whl", hash = "sha256:ba925ca53a1e00233a1b93380a46c0e821f6b797a19fc401aec85219cd85fd6f"}, + {file = "crcmod_plus-2.3.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:22600072de422632531e92d7675faf223a5b2548d45c5cd6f77ec4575339900f"}, + {file = "crcmod_plus-2.3.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f940704e359607b47b4a8e98c4d0f453f15bea039eb183cd0ffb14a8268fea78"}, + {file = "crcmod_plus-2.3.1-cp39-cp39-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f939fc1f7d143962a8fbed2305ce5931627fea1ea3a7f1865c04dbba9d41bf67"}, + {file = "crcmod_plus-2.3.1-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c3c6e8c7cf7ef49bcae7d3293996f82edde98e5fa202752ae58bf37a0289d35d"}, + {file = "crcmod_plus-2.3.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:728f68d0e3049ba23978aaf277f3eb405dd21e78be6ba96382739ba09bba473c"}, + {file = "crcmod_plus-2.3.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:b3829ed0dba48765f9b4139cb70b9bdf6553d2154302d9e3de6377556357892f"}, + {file = "crcmod_plus-2.3.1-cp39-cp39-win32.whl", hash = "sha256:855fcbd07c3eb9162c701c1c7ed1a8b5a5f7b1e8c2dd3fd8ed2273e2f141ecc9"}, + {file = "crcmod_plus-2.3.1-cp39-cp39-win_amd64.whl", hash = "sha256:5422081be6403b6fba736c544e79c68410307f7a1a8ac1925b421a5c6f4591d3"}, + {file = "crcmod_plus-2.3.1-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:9397324da1be2729f894744d9031a21ed97584c17fb0289e69e0c3c60916fc5f"}, + {file = "crcmod_plus-2.3.1-pp311-pypy311_pp73-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:073c7a3b832652e66c41c8b8705eaecda704d1cbe850b9fa05fdee36cd50745a"}, + {file = "crcmod_plus-2.3.1-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2e5f4c62553f772ea7ae12d9484801b752622c9c288e49ee7ea34a20b94e4920"}, + {file = "crcmod_plus-2.3.1-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:5e80a9860f66f339956f540d86a768f4fe8c8bfcb139811f14be864425c48d64"}, + {file = "crcmod_plus-2.3.1.tar.gz", hash = "sha256:732ffe3c3ce3ef9b272e1827d8fb894590c4d6ff553f2a2b41ae30f4f94b0f5d"}, +] + +[package.extras] +dev = ["pytest"] + [[package]] name = "cryptography" version = "45.0.5" @@ -2853,14 +2929,14 @@ markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64 [[package]] name = "ollama" -version = "0.4.9" +version = "0.5.0" description = "The official Python client for Ollama." optional = false python-versions = ">=3.8" groups = ["main"] files = [ - {file = "ollama-0.4.9-py3-none-any.whl", hash = "sha256:18c8c85358c54d7f73d6a66cda495b0e3ba99fdb88f824ae470d740fbb211a50"}, - {file = "ollama-0.4.9.tar.gz", hash = "sha256:5266d4d29b5089a01489872b8e8f980f018bccbdd1082b3903448af1d5615ce7"}, + {file = "ollama-0.5.0-py3-none-any.whl", hash = "sha256:625371de663ccb48f14faa49bd85ae409da5e40d84cab42366371234b4dbaf68"}, + {file = "ollama-0.5.0.tar.gz", hash = "sha256:ed6a343b64de22f69309ac930d8ac12b46775aebe21cbb91b859b99f59c53fa7"}, ] [package.dependencies] @@ -3507,6 +3583,58 @@ files = [ ] markers = {main = "extra == \"mem-reader\" or extra == \"all\" or platform_python_implementation != \"PyPy\"", eval = "platform_python_implementation == \"PyPy\""} +[[package]] +name = "pycryptodome" +version = "3.23.0" +description = "Cryptographic library for Python" +optional = true +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +groups = ["main"] +markers = "extra == \"skill-mem\" or extra == \"all\"" +files = [ + {file = "pycryptodome-3.23.0-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:a176b79c49af27d7f6c12e4b178b0824626f40a7b9fed08f712291b6d54bf566"}, + {file = "pycryptodome-3.23.0-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:573a0b3017e06f2cffd27d92ef22e46aa3be87a2d317a5abf7cc0e84e321bd75"}, + {file = "pycryptodome-3.23.0-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:63dad881b99ca653302b2c7191998dd677226222a3f2ea79999aa51ce695f720"}, + {file = "pycryptodome-3.23.0-cp27-cp27m-win32.whl", hash = "sha256:b34e8e11d97889df57166eda1e1ddd7676da5fcd4d71a0062a760e75060514b4"}, + {file = "pycryptodome-3.23.0-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:7ac1080a8da569bde76c0a104589c4f414b8ba296c0b3738cf39a466a9fb1818"}, + {file = "pycryptodome-3.23.0-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:6fe8258e2039eceb74dfec66b3672552b6b7d2c235b2dfecc05d16b8921649a8"}, + {file = "pycryptodome-3.23.0-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:0011f7f00cdb74879142011f95133274741778abba114ceca229adbf8e62c3e4"}, + {file = "pycryptodome-3.23.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:90460fc9e088ce095f9ee8356722d4f10f86e5be06e2354230a9880b9c549aae"}, + {file = "pycryptodome-3.23.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4764e64b269fc83b00f682c47443c2e6e85b18273712b98aa43bcb77f8570477"}, + {file = "pycryptodome-3.23.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eb8f24adb74984aa0e5d07a2368ad95276cf38051fe2dc6605cbcf482e04f2a7"}, + {file = "pycryptodome-3.23.0-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d97618c9c6684a97ef7637ba43bdf6663a2e2e77efe0f863cce97a76af396446"}, + {file = "pycryptodome-3.23.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:9a53a4fe5cb075075d515797d6ce2f56772ea7e6a1e5e4b96cf78a14bac3d265"}, + {file = "pycryptodome-3.23.0-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:763d1d74f56f031788e5d307029caef067febf890cd1f8bf61183ae142f1a77b"}, + {file = "pycryptodome-3.23.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:954af0e2bd7cea83ce72243b14e4fb518b18f0c1649b576d114973e2073b273d"}, + {file = "pycryptodome-3.23.0-cp313-cp313t-win32.whl", hash = "sha256:257bb3572c63ad8ba40b89f6fc9d63a2a628e9f9708d31ee26560925ebe0210a"}, + {file = "pycryptodome-3.23.0-cp313-cp313t-win_amd64.whl", hash = "sha256:6501790c5b62a29fcb227bd6b62012181d886a767ce9ed03b303d1f22eb5c625"}, + {file = "pycryptodome-3.23.0-cp313-cp313t-win_arm64.whl", hash = "sha256:9a77627a330ab23ca43b48b130e202582e91cc69619947840ea4d2d1be21eb39"}, + {file = "pycryptodome-3.23.0-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:187058ab80b3281b1de11c2e6842a357a1f71b42cb1e15bce373f3d238135c27"}, + {file = "pycryptodome-3.23.0-cp37-abi3-macosx_10_9_x86_64.whl", hash = "sha256:cfb5cd445280c5b0a4e6187a7ce8de5a07b5f3f897f235caa11f1f435f182843"}, + {file = "pycryptodome-3.23.0-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:67bd81fcbe34f43ad9422ee8fd4843c8e7198dd88dd3d40e6de42ee65fbe1490"}, + {file = "pycryptodome-3.23.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c8987bd3307a39bc03df5c8e0e3d8be0c4c3518b7f044b0f4c15d1aa78f52575"}, + {file = "pycryptodome-3.23.0-cp37-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:aa0698f65e5b570426fc31b8162ed4603b0c2841cbb9088e2b01641e3065915b"}, + {file = "pycryptodome-3.23.0-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:53ecbafc2b55353edcebd64bf5da94a2a2cdf5090a6915bcca6eca6cc452585a"}, + {file = "pycryptodome-3.23.0-cp37-abi3-musllinux_1_2_i686.whl", hash = "sha256:156df9667ad9f2ad26255926524e1c136d6664b741547deb0a86a9acf5ea631f"}, + {file = "pycryptodome-3.23.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:dea827b4d55ee390dc89b2afe5927d4308a8b538ae91d9c6f7a5090f397af1aa"}, + {file = "pycryptodome-3.23.0-cp37-abi3-win32.whl", hash = "sha256:507dbead45474b62b2bbe318eb1c4c8ee641077532067fec9c1aa82c31f84886"}, + {file = "pycryptodome-3.23.0-cp37-abi3-win_amd64.whl", hash = "sha256:c75b52aacc6c0c260f204cbdd834f76edc9fb0d8e0da9fbf8352ef58202564e2"}, + {file = "pycryptodome-3.23.0-cp37-abi3-win_arm64.whl", hash = "sha256:11eeeb6917903876f134b56ba11abe95c0b0fd5e3330def218083c7d98bbcb3c"}, + {file = "pycryptodome-3.23.0-pp27-pypy_73-manylinux2010_x86_64.whl", hash = "sha256:350ebc1eba1da729b35ab7627a833a1a355ee4e852d8ba0447fafe7b14504d56"}, + {file = "pycryptodome-3.23.0-pp27-pypy_73-win32.whl", hash = "sha256:93837e379a3e5fd2bb00302a47aee9fdf7940d83595be3915752c74033d17ca7"}, + {file = "pycryptodome-3.23.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:ddb95b49df036ddd264a0ad246d1be5b672000f12d6961ea2c267083a5e19379"}, + {file = "pycryptodome-3.23.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8e95564beb8782abfd9e431c974e14563a794a4944c29d6d3b7b5ea042110b4"}, + {file = "pycryptodome-3.23.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:14e15c081e912c4b0d75632acd8382dfce45b258667aa3c67caf7a4d4c13f630"}, + {file = "pycryptodome-3.23.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a7fc76bf273353dc7e5207d172b83f569540fc9a28d63171061c42e361d22353"}, + {file = "pycryptodome-3.23.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:45c69ad715ca1a94f778215a11e66b7ff989d792a4d63b68dc586a1da1392ff5"}, + {file = "pycryptodome-3.23.0-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:865d83c906b0fc6a59b510deceee656b6bc1c4fa0d82176e2b77e97a420a996a"}, + {file = "pycryptodome-3.23.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:89d4d56153efc4d81defe8b65fd0821ef8b2d5ddf8ed19df31ba2f00872b8002"}, + {file = "pycryptodome-3.23.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3f2d0aaf8080bda0587d58fc9fe4766e012441e2eed4269a77de6aea981c8be"}, + {file = "pycryptodome-3.23.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:64093fc334c1eccfd3933c134c4457c34eaca235eeae49d69449dc4728079339"}, + {file = "pycryptodome-3.23.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:ce64e84a962b63a47a592690bdc16a7eaf709d2c2697ababf24a0def566899a6"}, + {file = "pycryptodome-3.23.0.tar.gz", hash = "sha256:447700a657182d60338bab09fdb27518f8856aecd80ae4c6bdddb67ff5da44ef"}, +] + [[package]] name = "pydantic" version = "2.11.7" @@ -6234,14 +6362,15 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\ cffi = ["cffi (>=1.11)"] [extras] -all = ["cachetools", "chonkie", "datasketch", "jieba", "langchain-text-splitters", "markitdown", "neo4j", "nltk", "pika", "pymilvus", "pymysql", "qdrant-client", "rake-nltk", "rank-bm25", "redis", "schedule", "sentence-transformers", "torch", "volcengine-python-sdk"] +all = ["alibabacloud-oss-v2", "cachetools", "chonkie", "datasketch", "jieba", "langchain-text-splitters", "markitdown", "neo4j", "nltk", "pika", "pymilvus", "pymysql", "qdrant-client", "rake-nltk", "rank-bm25", "redis", "schedule", "sentence-transformers", "torch", "volcengine-python-sdk"] mem-reader = ["chonkie", "langchain-text-splitters", "markitdown"] mem-scheduler = ["pika", "redis"] mem-user = ["pymysql"] pref-mem = ["datasketch", "pymilvus"] +skill-mem = ["alibabacloud-oss-v2"] tree-mem = ["neo4j", "schedule"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<4.0" -content-hash = "22bfcac5ed0be1e3aea294e3da96ff1a4bd9d7b62865ad827e1508f5ade6b708" +content-hash = "faff240c05a74263a404e8d9324ffd2f342cb4f0a4c1f5455b87349f6ccc61a5" diff --git a/pyproject.toml b/pyproject.toml index 3fbe4ced4..88e577fc8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ ############################################################################## name = "MemoryOS" -version = "2.0.3" +version = "2.0.4" description = "Intelligence Begins with Memory" license = {text = "Apache-2.0"} readme = "README.md" @@ -37,7 +37,7 @@ classifiers = [ ] dependencies = [ "openai (>=1.77.0,<2.0.0)", - "ollama (>=0.4.8,<0.5.0)", + "ollama (>=0.5.0,<0.5.1)", "transformers (>=4.51.3,<5.0.0)", "tenacity (>=9.1.2,<10.0.0)", # Error handling and retrying library "fastapi[all] (>=0.115.12,<0.116.0)", # Web framework for building APIs @@ -97,6 +97,11 @@ pref-mem = [ "datasketch (>=1.6.5,<2.0.0)", # MinHash library ] +# SkillMemory +skill-mem = [ + "alibabacloud-oss-v2 (>=1.2.2,<1.2.3)", +] + # All optional dependencies # Allow users to install with `pip install MemoryOS[all]` all = [ @@ -123,6 +128,7 @@ all = [ "volcengine-python-sdk (>=4.0.4,<5.0.0)", "nltk (>=3.9.1,<4.0.0)", "rake-nltk (>=1.0.6,<1.1.0)", + "alibabacloud-oss-v2 (>=1.2.2,<1.2.3)", # Uncategorized dependencies ] diff --git a/src/memos/__init__.py b/src/memos/__init__.py index 3c764db79..2d946cfbb 100644 --- a/src/memos/__init__.py +++ b/src/memos/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.0.3" +__version__ = "2.0.4" from memos.configs.mem_cube import GeneralMemCubeConfig from memos.configs.mem_os import MOSConfig diff --git a/src/memos/api/config.py b/src/memos/api/config.py index a3bf25be0..d27c391ab 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -440,6 +440,18 @@ def get_embedder_config() -> dict[str, Any]: "model_name_or_path": os.getenv("MOS_EMBEDDER_MODEL", "text-embedding-3-large"), "headers_extra": json.loads(os.getenv("MOS_EMBEDDER_HEADERS_EXTRA", "{}")), "base_url": os.getenv("MOS_EMBEDDER_API_BASE", "http://openai.com"), + "backup_client": os.getenv("MOS_EMBEDDER_BACKUP_CLIENT", "false").lower() + == "true", + "backup_base_url": os.getenv( + "MOS_EMBEDDER_BACKUP_API_BASE", "http://openai.com" + ), + "backup_api_key": os.getenv("MOS_EMBEDDER_BACKUP_API_KEY", "sk-xxxx"), + "backup_headers_extra": json.loads( + os.getenv("MOS_EMBEDDER_BACKUP_HEADERS_EXTRA", "{}") + ), + "backup_model_name_or_path": os.getenv( + "MOS_EMBEDDER_BACKUP_MODEL", "text-embedding-3-large" + ), }, } else: # ollama @@ -467,6 +479,35 @@ def get_reader_config() -> dict[str, Any]: } @staticmethod + def get_oss_config() -> dict[str, Any] | None: + """Get OSS configuration and validate connection.""" + + config = { + "endpoint": os.getenv("OSS_ENDPOINT", "http://oss-cn-shanghai.aliyuncs.com"), + "access_key_id": os.getenv("OSS_ACCESS_KEY_ID", ""), + "access_key_secret": os.getenv("OSS_ACCESS_KEY_SECRET", ""), + "region": os.getenv("OSS_REGION", ""), + "bucket_name": os.getenv("OSS_BUCKET_NAME", ""), + } + + # Validate that all required fields have values + required_fields = [ + "endpoint", + "access_key_id", + "access_key_secret", + "region", + "bucket_name", + ] + missing_fields = [field for field in required_fields if not config.get(field)] + + if missing_fields: + logger.warning( + f"OSS configuration incomplete. Missing fields: {', '.join(missing_fields)}" + ) + return None + + return config + def get_internet_config() -> dict[str, Any]: """Get embedder configuration.""" reader_config = APIConfig.get_reader_config() @@ -509,6 +550,13 @@ def get_internet_config() -> dict[str, Any]: }, } + @staticmethod + def get_nli_config() -> dict[str, Any]: + """Get NLI model configuration.""" + return { + "base_url": os.getenv("NLI_MODEL_BASE_URL", "http://localhost:32532"), + } + @staticmethod def get_neo4j_community_config(user_id: str | None = None) -> dict[str, Any]: """Get Neo4j community configuration.""" @@ -746,6 +794,11 @@ def get_product_default_config() -> dict[str, Any]: ).split(",") if h.strip() ], + "oss_config": APIConfig.get_oss_config(), + "skills_dir_config": { + "skills_oss_dir": os.getenv("SKILLS_OSS_DIR", "skill_memory/"), + "skills_local_dir": os.getenv("SKILLS_LOCAL_DIR", "/tmp/skill_memory/"), + }, }, }, "enable_textual_memory": True, diff --git a/src/memos/api/handlers/chat_handler.py b/src/memos/api/handlers/chat_handler.py index 812cf2793..8292e027b 100644 --- a/src/memos/api/handlers/chat_handler.py +++ b/src/memos/api/handlers/chat_handler.py @@ -110,6 +110,7 @@ def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, An Raises: HTTPException: If chat fails """ + self.logger.info(f"[ChatHandler] Chat Req is: {chat_req}") try: # Resolve readable cube IDs (for search) readable_cube_ids = chat_req.readable_cube_ids or [chat_req.user_id] @@ -241,6 +242,7 @@ def handle_chat_stream(self, chat_req: ChatRequest) -> StreamingResponse: Raises: HTTPException: If stream initialization fails """ + self.logger.info(f"[ChatHandler] Chat Req is: {chat_req}") try: def generate_chat_response() -> Generator[str, None, None]: @@ -422,6 +424,7 @@ def handle_chat_stream_playground(self, chat_req: ChatPlaygroundRequest) -> Stre Raises: HTTPException: If stream initialization fails """ + self.logger.info(f"[ChatHandler] Chat Req is: {chat_req}") try: def generate_chat_response() -> Generator[str, None, None]: @@ -585,6 +588,8 @@ def generate_chat_response() -> Generator[str, None, None]: # get internet reference internet_reference = self._get_internet_reference( search_response.data.get("text_mem")[0]["memories"] + if search_response.data.get("text_mem") + else [] ) yield f"data: {json.dumps({'type': 'reference', 'data': reference}, ensure_ascii=False)}\n\n" diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index bfbd6271d..13dd92189 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -18,6 +18,7 @@ build_internet_retriever_config, build_llm_config, build_mem_reader_config, + build_nli_client_config, build_pref_adder_config, build_pref_extractor_config, build_pref_retriever_config, @@ -48,6 +49,7 @@ if TYPE_CHECKING: from memos.memories.textual.tree import TreeTextMemory +from memos.extras.nli_model.client import NLIClient from memos.mem_agent.deepsearch_agent import DeepSearchMemAgent from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( InternetRetrieverFactory, @@ -161,6 +163,7 @@ def init_server() -> dict[str, Any]: llm_config = build_llm_config() chat_llm_config = build_chat_llm_config() embedder_config = build_embedder_config() + nli_client_config = build_nli_client_config() mem_reader_config = build_mem_reader_config() reranker_config = build_reranker_config() feedback_reranker_config = build_feedback_reranker_config() @@ -186,6 +189,7 @@ def init_server() -> dict[str, Any]: else None ) embedder = EmbedderFactory.from_config(embedder_config) + nli_client = NLIClient(base_url=nli_client_config["base_url"]) # Pass graph_db to mem_reader for recall operations (deduplication, conflict detection) mem_reader = MemReaderFactory.from_config(mem_reader_config, graph_db=graph_db) reranker = RerankerFactory.from_config(reranker_config) @@ -307,6 +311,9 @@ def init_server() -> dict[str, Any]: ) logger.debug("Searcher created") + # Set searcher to mem_reader + mem_reader.set_searcher(searcher) + # Initialize feedback server feedback_server = SimpleMemFeedback( llm=llm, @@ -385,4 +392,5 @@ def init_server() -> dict[str, Any]: "feedback_server": feedback_server, "redis_client": redis_client, "deepsearch_agent": deepsearch_agent, + "nli_client": nli_client, } diff --git a/src/memos/api/handlers/config_builders.py b/src/memos/api/handlers/config_builders.py index fce789e2a..ed673977a 100644 --- a/src/memos/api/handlers/config_builders.py +++ b/src/memos/api/handlers/config_builders.py @@ -188,3 +188,13 @@ def build_pref_retriever_config() -> dict[str, Any]: Validated retriever configuration dictionary """ return RetrieverConfigFactory.model_validate({"backend": "naive", "config": {}}) + + +def build_nli_client_config() -> dict[str, Any]: + """ + Build NLI client configuration. + + Returns: + NLI client configuration dictionary + """ + return APIConfig.get_nli_config() diff --git a/src/memos/api/handlers/formatters_handler.py b/src/memos/api/handlers/formatters_handler.py index 6e1d9d1b6..cecc42c6c 100644 --- a/src/memos/api/handlers/formatters_handler.py +++ b/src/memos/api/handlers/formatters_handler.py @@ -112,13 +112,17 @@ def post_process_textual_mem( fact_mem = [ mem for mem in text_formatted_mem - if mem["metadata"]["memory_type"] not in ["ToolSchemaMemory", "ToolTrajectoryMemory"] + if mem["metadata"]["memory_type"] + in ["WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"] ] tool_mem = [ mem for mem in text_formatted_mem if mem["metadata"]["memory_type"] in ["ToolSchemaMemory", "ToolTrajectoryMemory"] ] + skill_mem = [ + mem for mem in text_formatted_mem if mem["metadata"]["memory_type"] == "SkillMemory" + ] memories_result["text_mem"].append( { @@ -134,6 +138,13 @@ def post_process_textual_mem( "total_nodes": len(tool_mem), } ) + memories_result["skill_mem"].append( + { + "cube_id": mem_cube_id, + "memories": skill_mem, + "total_nodes": len(skill_mem), + } + ) return memories_result diff --git a/src/memos/api/handlers/memory_handler.py b/src/memos/api/handlers/memory_handler.py index d2aa2b204..e8bc5b640 100644 --- a/src/memos/api/handlers/memory_handler.py +++ b/src/memos/api/handlers/memory_handler.py @@ -210,10 +210,45 @@ def handle_get_memory(memory_id: str, naive_mem_cube: NaiveMemCube) -> GetMemory ) +def handle_get_memory_by_ids( + memory_ids: list[str], naive_mem_cube: NaiveMemCube +) -> GetMemoryResponse: + """ + Handler for getting multiple memories by their IDs. + + Retrieves multiple memories and formats them as a list of dictionaries. + """ + try: + memories = naive_mem_cube.text_mem.get_by_ids(memory_ids=memory_ids) + except Exception: + memories = [] + + # Ensure memories is not None + if memories is None: + memories = [] + + if naive_mem_cube.pref_mem is not None: + collection_names = ["explicit_preference", "implicit_preference"] + for collection_name in collection_names: + try: + result = naive_mem_cube.pref_mem.get_by_ids_with_collection_name( + collection_name, memory_ids + ) + if result is not None: + result = [format_memory_item(item, save_sources=False) for item in result] + memories.extend(result) + except Exception: + continue + + return GetMemoryResponse( + message="Memories retrieved successfully", code=200, data={"memories": memories} + ) + + def handle_get_memories( get_mem_req: GetMemoryRequest, naive_mem_cube: NaiveMemCube ) -> GetMemoryResponse: - results: dict[str, Any] = {"text_mem": [], "pref_mem": [], "tool_mem": []} + results: dict[str, Any] = {"text_mem": [], "pref_mem": [], "tool_mem": [], "skill_mem": []} memories = naive_mem_cube.text_mem.get_all( user_name=get_mem_req.mem_cube_id, user_id=get_mem_req.user_id, @@ -226,6 +261,8 @@ def handle_get_memories( if not get_mem_req.include_tool_memory: results["tool_mem"] = [] + if not get_mem_req.include_skill_memory: + results["skill_mem"] = [] preferences: list[TextualMemoryItem] = [] @@ -270,6 +307,7 @@ def handle_get_memories( "text_mem": results.get("text_mem", []), "pref_mem": results.get("pref_mem", []), "tool_mem": results.get("tool_mem", []), + "skill_mem": results.get("skill_mem", []), } return GetMemoryResponse(message="Memories retrieved successfully", data=filtered_results) diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index e5af52f87..93eff185b 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -5,12 +5,12 @@ using dependency injection for better modularity and testability. """ -import time +import copy +import math from typing import Any from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies -from memos.api.handlers.formatters_handler import rerank_knowledge_mem from memos.api.product_models import APISearchRequest, SearchResponse from memos.log import get_logger from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import ( @@ -58,32 +58,41 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse """ self.logger.info(f"[SearchHandler] Search Req is: {search_req}") - # Increase recall pool if deduplication is enabled to ensure diversity - original_top_k = search_req.top_k - if search_req.dedup == "sim": - search_req.top_k = original_top_k * 5 + # Use deepcopy to avoid modifying the original request object + search_req_local = copy.deepcopy(search_req) + original_top_k = search_req_local.top_k + + # Expand top_k for deduplication (5x to ensure enough candidates) + if search_req_local.dedup in ("sim", "mmr"): + search_req_local.top_k = original_top_k * 5 + + # Create new searcher with include_embedding for MMR deduplication + searcher_to_use = self.searcher + if search_req_local.dedup == "mmr": + text_mem = getattr(self.naive_mem_cube, "text_mem", None) + if text_mem is not None: + # Create new searcher instance with include_embedding=True + searcher_to_use = text_mem.get_searcher( + manual_close_internet=not getattr(self.searcher, "internet_retriever", None), + moscube=False, + process_llm=getattr(self.mem_reader, "llm", None), + ) + # Override include_embedding for this searcher + if hasattr(searcher_to_use, "graph_retriever"): + searcher_to_use.graph_retriever.include_embedding = True - cube_view = self._build_cube_view(search_req) + # Search and deduplicate + cube_view = self._build_cube_view(search_req_local, searcher_to_use) + results = cube_view.search_memories(search_req_local) - results = cube_view.search_memories(search_req) - if search_req.dedup == "sim": + if search_req_local.dedup == "sim": results = self._dedup_text_memories(results, original_top_k) self._strip_embeddings(results) - # Restore original top_k for downstream logic or response metadata - search_req.top_k = original_top_k - - start_time = time.time() - text_mem = results["text_mem"] - results["text_mem"] = rerank_knowledge_mem( - self.reranker, - query=search_req.query, - text_mem=text_mem, - top_k=original_top_k, - file_mem_proportion=0.5, - ) - rerank_time = time.time() - start_time + elif search_req_local.dedup == "mmr": + pref_top_k = getattr(search_req_local, "pref_top_k", 6) + results = self._mmr_dedup_text_memories(results, original_top_k, pref_top_k) + self._strip_embeddings(results) - self.logger.info(f"[Knowledge_replace_memory_time] Rerank time: {rerank_time} seconds") self.logger.info( f"[SearchHandler] Final search results: count={len(results)} results={results}" ) @@ -140,6 +149,205 @@ def _dedup_text_memories(self, results: dict[str, Any], target_top_k: int) -> di bucket["memories"] = [flat[i][1] for i in selected_indices] return results + def _mmr_dedup_text_memories( + self, results: dict[str, Any], text_top_k: int, pref_top_k: int = 6 + ) -> dict[str, Any]: + """ + MMR-based deduplication with progressive penalty for high similarity. + + Performs deduplication on both text_mem and preference memories together. + Other memory types (tool_mem, etc.) are not modified. + + Args: + results: Search results containing text_mem and preference buckets + text_top_k: Target number of text memories to return per bucket + pref_top_k: Target number of preference memories to return per bucket + + Algorithm: + 1. Prefill top 5 by relevance + 2. MMR selection: balance relevance vs diversity + 3. Re-sort by original relevance for better generation quality + """ + text_buckets = results.get("text_mem", []) + pref_buckets = results.get("preference", []) + + # Early return if no memories to deduplicate + if not text_buckets and not pref_buckets: + return results + + # Flatten all memories with their type and scores + # flat structure: (memory_type, bucket_idx, mem, score) + flat: list[tuple[str, int, dict[str, Any], float]] = [] + + # Flatten text memories + for bucket_idx, bucket in enumerate(text_buckets): + for mem in bucket.get("memories", []): + score = mem.get("metadata", {}).get("relativity", 0.0) + flat.append(("text", bucket_idx, mem, float(score) if score is not None else 0.0)) + + # Flatten preference memories + for bucket_idx, bucket in enumerate(pref_buckets): + for mem in bucket.get("memories", []): + score = mem.get("metadata", {}).get("relativity", 0.0) + flat.append( + ("preference", bucket_idx, mem, float(score) if score is not None else 0.0) + ) + + if len(flat) <= 1: + return results + + # Get or compute embeddings + embeddings = self._extract_embeddings([mem for _, _, mem, _ in flat]) + if embeddings is None: + self.logger.warning("[SearchHandler] Embedding is missing; recomputing embeddings") + documents = [mem.get("memory", "") for _, _, mem, _ in flat] + embeddings = self.searcher.embedder.embed(documents) + + # Compute similarity matrix using NumPy-optimized method + # Returns numpy array but compatible with list[i][j] indexing + similarity_matrix = cosine_similarity_matrix(embeddings) + + # Initialize selection tracking for both text and preference + text_indices_by_bucket: dict[int, list[int]] = {i: [] for i in range(len(text_buckets))} + pref_indices_by_bucket: dict[int, list[int]] = {i: [] for i in range(len(pref_buckets))} + + for flat_index, (mem_type, bucket_idx, _, _) in enumerate(flat): + if mem_type == "text": + text_indices_by_bucket[bucket_idx].append(flat_index) + elif mem_type == "preference": + pref_indices_by_bucket[bucket_idx].append(flat_index) + + selected_global: list[int] = [] + text_selected_by_bucket: dict[int, list[int]] = {i: [] for i in range(len(text_buckets))} + pref_selected_by_bucket: dict[int, list[int]] = {i: [] for i in range(len(pref_buckets))} + selected_texts: set[str] = set() # Track exact text content to avoid duplicates + + # Phase 1: Prefill top N by relevance + # Use the smaller of text_top_k and pref_top_k for prefill count + prefill_top_n = min(2, text_top_k, pref_top_k) if pref_buckets else min(2, text_top_k) + ordered_by_relevance = sorted(range(len(flat)), key=lambda idx: flat[idx][3], reverse=True) + for idx in ordered_by_relevance[: len(flat)]: + if len(selected_global) >= prefill_top_n: + break + mem_type, bucket_idx, mem, _ = flat[idx] + + # Skip if exact text already exists in selected set + mem_text = mem.get("memory", "").strip() + if mem_text in selected_texts: + continue + + # Skip if highly similar (Dice + TF-IDF + 2-gram combined, with embedding filter) + if SearchHandler._is_text_highly_similar_optimized( + idx, mem_text, selected_global, similarity_matrix, flat, threshold=0.9 + ): + continue + + # Check bucket capacity with correct top_k for each type + if mem_type == "text" and len(text_selected_by_bucket[bucket_idx]) < text_top_k: + selected_global.append(idx) + text_selected_by_bucket[bucket_idx].append(idx) + selected_texts.add(mem_text) + elif mem_type == "preference" and len(pref_selected_by_bucket[bucket_idx]) < pref_top_k: + selected_global.append(idx) + pref_selected_by_bucket[bucket_idx].append(idx) + selected_texts.add(mem_text) + + # Phase 2: MMR selection for remaining slots + lambda_relevance = 0.8 + similarity_threshold = 0.9 # Start exponential penalty from 0.9 (lowered from 0.9) + alpha_exponential = 10.0 # Exponential penalty coefficient + remaining = set(range(len(flat))) - set(selected_global) + + while remaining: + best_idx: int | None = None + best_mmr: float | None = None + + for idx in remaining: + mem_type, bucket_idx, mem, _ = flat[idx] + + # Check bucket capacity with correct top_k for each type + if ( + mem_type == "text" and len(text_selected_by_bucket[bucket_idx]) >= text_top_k + ) or ( + mem_type == "preference" + and len(pref_selected_by_bucket[bucket_idx]) >= pref_top_k + ): + continue + + # Check if exact text already exists - if so, skip this candidate entirely + mem_text = mem.get("memory", "").strip() + if mem_text in selected_texts: + continue # Skip duplicate text, don't participate in MMR competition + + # Skip if highly similar (Dice + TF-IDF + 2-gram combined, with embedding filter) + if SearchHandler._is_text_highly_similar_optimized( + idx, mem_text, selected_global, similarity_matrix, flat, threshold=0.9 + ): + continue # Skip highly similar text, don't participate in MMR competition + + relevance = flat[idx][3] + max_sim = ( + 0.0 + if not selected_global + else max(similarity_matrix[idx][j] for j in selected_global) + ) + + # Exponential penalty for similarity > 0.80 + if max_sim > similarity_threshold: + penalty_multiplier = math.exp( + alpha_exponential * (max_sim - similarity_threshold) + ) + diversity = max_sim * penalty_multiplier + else: + diversity = max_sim + + mmr_score = lambda_relevance * relevance - (1.0 - lambda_relevance) * diversity + + if best_mmr is None or mmr_score > best_mmr: + best_mmr = mmr_score + best_idx = idx + + if best_idx is None: + break + + mem_type, bucket_idx, mem, _ = flat[best_idx] + + # Add to selected set and track text + mem_text = mem.get("memory", "").strip() + selected_global.append(best_idx) + selected_texts.add(mem_text) + + if mem_type == "text": + text_selected_by_bucket[bucket_idx].append(best_idx) + elif mem_type == "preference": + pref_selected_by_bucket[bucket_idx].append(best_idx) + remaining.remove(best_idx) + + # Early termination: all buckets are full + text_all_full = all( + len(text_selected_by_bucket[b_idx]) >= min(text_top_k, len(bucket_indices)) + for b_idx, bucket_indices in text_indices_by_bucket.items() + ) + pref_all_full = all( + len(pref_selected_by_bucket[b_idx]) >= min(pref_top_k, len(bucket_indices)) + for b_idx, bucket_indices in pref_indices_by_bucket.items() + ) + if text_all_full and pref_all_full: + break + + # Phase 3: Re-sort by original relevance and fill back to buckets + for bucket_idx, bucket in enumerate(text_buckets): + selected_indices = text_selected_by_bucket.get(bucket_idx, []) + selected_indices = sorted(selected_indices, key=lambda i: flat[i][3], reverse=True) + bucket["memories"] = [flat[i][2] for i in selected_indices] + + for bucket_idx, bucket in enumerate(pref_buckets): + selected_indices = pref_selected_by_bucket.get(bucket_idx, []) + selected_indices = sorted(selected_indices, key=lambda i: flat[i][3], reverse=True) + bucket["memories"] = [flat[i][2] for i in selected_indices] + + return results + @staticmethod def _is_unrelated( index: int, @@ -180,6 +388,168 @@ def _strip_embeddings(results: dict[str, Any]) -> None: if "embedding" in metadata: metadata["embedding"] = [] + @staticmethod + def _dice_similarity(text1: str, text2: str) -> float: + """ + Calculate Dice coefficient (character-level, fastest). + + Dice = 2 * |A ∩ B| / (|A| + |B|) + Speed: O(n + m), ~0.05-0.1ms per comparison + + Args: + text1: First text string + text2: Second text string + + Returns: + Dice similarity score between 0.0 and 1.0 + """ + if not text1 or not text2: + return 0.0 + + chars1 = set(text1) + chars2 = set(text2) + + intersection = len(chars1 & chars2) + return 2 * intersection / (len(chars1) + len(chars2)) + + @staticmethod + def _bigram_similarity(text1: str, text2: str) -> float: + """ + Calculate character-level 2-gram Jaccard similarity. + + Speed: O(n + m), ~0.1-0.2ms per comparison + Considers local order (more strict than Dice). + + Args: + text1: First text string + text2: Second text string + + Returns: + Jaccard similarity score between 0.0 and 1.0 + """ + if not text1 or not text2: + return 0.0 + + # Generate 2-grams + bigrams1 = {text1[i : i + 2] for i in range(len(text1) - 1)} if len(text1) >= 2 else {text1} + bigrams2 = {text2[i : i + 2] for i in range(len(text2) - 1)} if len(text2) >= 2 else {text2} + + intersection = len(bigrams1 & bigrams2) + union = len(bigrams1 | bigrams2) + + return intersection / union if union > 0 else 0.0 + + @staticmethod + def _tfidf_similarity(text1: str, text2: str) -> float: + """ + Calculate TF-IDF cosine similarity (character-level, no sklearn). + + Speed: O(n + m), ~0.3-0.5ms per comparison + Considers character frequency weighting. + + Args: + text1: First text string + text2: Second text string + + Returns: + Cosine similarity score between 0.0 and 1.0 + """ + if not text1 or not text2: + return 0.0 + + from collections import Counter + + # Character frequency (TF) + tf1 = Counter(text1) + tf2 = Counter(text2) + + # All unique characters (vocabulary) + vocab = set(tf1.keys()) | set(tf2.keys()) + + # Simple IDF: log(2 / df) where df is document frequency + # For two documents, IDF is log(2/1)=0.693 if char appears in one doc, + # or log(2/2)=0 if appears in both (we use log(2/1) for simplicity) + idf = {char: (1.0 if char in tf1 and char in tf2 else 1.5) for char in vocab} + + # TF-IDF vectors + vec1 = {char: tf1.get(char, 0) * idf[char] for char in vocab} + vec2 = {char: tf2.get(char, 0) * idf[char] for char in vocab} + + # Cosine similarity + dot_product = sum(vec1[char] * vec2[char] for char in vocab) + norm1 = math.sqrt(sum(v * v for v in vec1.values())) + norm2 = math.sqrt(sum(v * v for v in vec2.values())) + + if norm1 == 0 or norm2 == 0: + return 0.0 + + return dot_product / (norm1 * norm2) + + @staticmethod + def _is_text_highly_similar_optimized( + candidate_idx: int, + candidate_text: str, + selected_global: list[int], + similarity_matrix, + flat: list, + threshold: float = 0.9, + ) -> bool: + """ + Multi-algorithm text similarity check with embedding pre-filtering. + + Strategy: + 1. Only compare with the single highest embedding similarity item (not all 25) + 2. Only perform text comparison if embedding similarity > 0.60 + 3. Use weighted combination of three algorithms: + - Dice (40%): Fastest, character-level set similarity + - TF-IDF (35%): Considers character frequency weighting + - 2-gram (25%): Considers local character order + + Combined formula: + combined_score = 0.40 * dice + 0.35 * tfidf + 0.25 * bigram + + This reduces comparisons from O(N) to O(1) per candidate, with embedding pre-filtering. + Expected speedup: 100-200x compared to LCS approach. + + Args: + candidate_idx: Index of candidate memory in flat list + candidate_text: Text content of candidate memory + selected_global: List of already selected memory indices + similarity_matrix: Precomputed embedding similarity matrix + flat: Flat list of all memories + threshold: Combined similarity threshold (default 0.75) + + Returns: + True if candidate is highly similar to any selected memory + """ + if not selected_global: + return False + + # Find the already-selected memory with highest embedding similarity + max_sim_idx = max(selected_global, key=lambda j: similarity_matrix[candidate_idx][j]) + max_sim = similarity_matrix[candidate_idx][max_sim_idx] + + # If highest embedding similarity < 0.60, skip text comparison entirely + if max_sim <= 0.9: + return False + + # Get text of most similar memory + most_similar_mem = flat[max_sim_idx][2] + most_similar_text = most_similar_mem.get("memory", "").strip() + + # Calculate three similarity scores + dice_sim = SearchHandler._dice_similarity(candidate_text, most_similar_text) + tfidf_sim = SearchHandler._tfidf_similarity(candidate_text, most_similar_text) + bigram_sim = SearchHandler._bigram_similarity(candidate_text, most_similar_text) + + # Weighted combination: Dice (40%) + TF-IDF (35%) + 2-gram (25%) + # Dice has highest weight (fastest and most reliable) + # TF-IDF considers frequency (handles repeated characters well) + # 2-gram considers order (catches local pattern similarity) + combined_score = 0.40 * dice_sim + 0.35 * tfidf_sim + 0.25 * bigram_sim + + return combined_score >= threshold + def _resolve_cube_ids(self, search_req: APISearchRequest) -> list[str]: """ Normalize target cube ids from search_req. @@ -192,8 +562,9 @@ def _resolve_cube_ids(self, search_req: APISearchRequest) -> list[str]: return [search_req.user_id] - def _build_cube_view(self, search_req: APISearchRequest) -> MemCubeView: + def _build_cube_view(self, search_req: APISearchRequest, searcher=None) -> MemCubeView: cube_ids = self._resolve_cube_ids(search_req) + searcher_to_use = searcher if searcher is not None else self.searcher if len(cube_ids) == 1: cube_id = cube_ids[0] @@ -203,7 +574,7 @@ def _build_cube_view(self, search_req: APISearchRequest) -> MemCubeView: mem_reader=self.mem_reader, mem_scheduler=self.mem_scheduler, logger=self.logger, - searcher=self.searcher, + searcher=searcher_to_use, deepsearch_agent=self.deepsearch_agent, ) else: @@ -214,7 +585,7 @@ def _build_cube_view(self, search_req: APISearchRequest) -> MemCubeView: mem_reader=self.mem_reader, mem_scheduler=self.mem_scheduler, logger=self.logger, - searcher=self.searcher, + searcher=searcher_to_use, deepsearch_agent=self.deepsearch_agent, ) for cube_id in cube_ids diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index b2f8a9fa3..d8fa784a3 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -319,11 +319,11 @@ class APISearchRequest(BaseRequest): description="Number of textual memories to retrieve (top-K). Default: 10.", ) - dedup: Literal["no", "sim"] | None = Field( - None, + dedup: Literal["no", "sim", "mmr"] | None = Field( + "mmr", description=( "Optional dedup option for textual memories. " - "Use 'no' for no dedup, 'sim' for similarity dedup. " + "Use 'no' for no dedup, 'sim' for similarity dedup, 'mmr' for MMR-based dedup. " "If None, default exact-text dedup is applied." ), ) @@ -358,6 +358,18 @@ class APISearchRequest(BaseRequest): description="Number of tool memories to retrieve (top-K). Default: 6.", ) + include_skill_memory: bool = Field( + True, + description="Whether to retrieve skill memories along with general memories. " + "If enabled, the system will automatically recall skill memories " + "relevant to the query. Default: True.", + ) + skill_mem_top_k: int = Field( + 3, + ge=0, + description="Number of skill memories to retrieve (top-K). Default: 3.", + ) + # ==== Filter conditions ==== # TODO: maybe add detailed description later filter: dict[str, Any] | None = Field( @@ -393,7 +405,7 @@ class APISearchRequest(BaseRequest): # Internal field for search memory type search_memory_type: str = Field( "All", - description="Type of memory to search: All, WorkingMemory, LongTermMemory, UserMemory, OuterMemory, ToolSchemaMemory, ToolTrajectoryMemory", + description="Type of memory to search: All, WorkingMemory, LongTermMemory, UserMemory, OuterMemory, ToolSchemaMemory, ToolTrajectoryMemory, SkillMemory", ) # ==== Context ==== @@ -772,7 +784,8 @@ class GetMemoryRequest(BaseRequest): mem_cube_id: str = Field(..., description="Cube ID") user_id: str | None = Field(None, description="User ID") include_preference: bool = Field(True, description="Whether to return preference memory") - include_tool_memory: bool = Field(False, description="Whether to return tool memory") + include_tool_memory: bool = Field(True, description="Whether to return tool memory") + include_skill_memory: bool = Field(True, description="Whether to return skill memory") filter: dict[str, Any] | None = Field(None, description="Filter for the memory") page: int | None = Field( None, diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index 86b75d73e..736c328ac 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -320,6 +320,14 @@ def get_memory_by_id(memory_id: str): ) +@router.post("/get_memory_by_ids", summary="Get memory by ids", response_model=GetMemoryResponse) +def get_memory_by_ids(memory_ids: list[str]): + return handlers.memory_handler.handle_get_memory_by_ids( + memory_ids=memory_ids, + naive_mem_cube=naive_mem_cube, + ) + + @router.post( "/delete_memory", summary="Delete memories for user", response_model=DeleteMemoryResponse ) diff --git a/src/memos/configs/embedder.py b/src/memos/configs/embedder.py index c2e648247..050043ab0 100644 --- a/src/memos/configs/embedder.py +++ b/src/memos/configs/embedder.py @@ -58,6 +58,23 @@ class UniversalAPIEmbedderConfig(BaseEmbedderConfig): base_url: str | None = Field( default=None, description="Optional base URL for custom or proxied endpoint" ) + backup_client: bool = Field( + default=False, + description="Whether to use backup client", + ) + backup_base_url: str | None = Field( + default=None, description="Optional backup base URL for custom or proxied endpoint" + ) + backup_api_key: str | None = Field( + default=None, description="Optional backup API key for the embedding provider" + ) + backup_headers_extra: dict[str, Any] | None = Field( + default=None, + description="Extra headers for the backup embedding model", + ) + backup_model_name_or_path: str | None = Field( + default=None, description="Optional backup model name or path" + ) class EmbedderConfigFactory(BaseConfig): diff --git a/src/memos/configs/mem_reader.py b/src/memos/configs/mem_reader.py index eaaa71461..4bd7953c0 100644 --- a/src/memos/configs/mem_reader.py +++ b/src/memos/configs/mem_reader.py @@ -57,6 +57,15 @@ class MultiModalStructMemReaderConfig(BaseMemReaderConfig): "If None, reads from FILE_PARSER_DIRECT_MARKDOWN_HOSTNAMES environment variable.", ) + oss_config: dict[str, Any] | None = Field( + default=None, + description="OSS configuration for the MemReader", + ) + skills_dir_config: dict[str, Any] | None = Field( + default=None, + description="Skills directory for the MemReader", + ) + class StrategyStructMemReaderConfig(BaseMemReaderConfig): """StrategyStruct MemReader configuration class.""" diff --git a/src/memos/embedders/universal_api.py b/src/memos/embedders/universal_api.py index 60bae15a5..538d913ea 100644 --- a/src/memos/embedders/universal_api.py +++ b/src/memos/embedders/universal_api.py @@ -1,3 +1,7 @@ +import asyncio +import os +import time + from openai import AzureOpenAI as AzureClient from openai import OpenAI as OpenAIClient @@ -29,23 +33,80 @@ def __init__(self, config: UniversalAPIEmbedderConfig): ) else: raise ValueError(f"Embeddings unsupported provider: {self.provider}") + self.use_backup_client = config.backup_client + if self.use_backup_client: + self.backup_client = OpenAIClient( + api_key=config.backup_api_key, + base_url=config.backup_base_url, + default_headers=config.backup_headers_extra + if config.backup_headers_extra + else None, + ) @timed_with_status( log_prefix="model_timed_embedding", - log_extra_args={"model_name_or_path": "text-embedding-3-large"}, + log_extra_args=lambda self, texts: { + "model_name_or_path": "text-embedding-3-large", + "text_len": len(texts), + "text_content": texts, + }, ) def embed(self, texts: list[str]) -> list[list[float]]: + if isinstance(texts, str): + texts = [texts] # Truncate texts if max_tokens is configured texts = self._truncate_texts(texts) - + logger.info(f"Embeddings request with input: {texts}") if self.provider == "openai" or self.provider == "azure": try: - response = self.client.embeddings.create( - model=getattr(self.config, "model_name_or_path", "text-embedding-3-large"), - input=texts, + + async def _create_embeddings(): + return self.client.embeddings.create( + model=getattr(self.config, "model_name_or_path", "text-embedding-3-large"), + input=texts, + ) + + init_time = time.time() + response = asyncio.run( + asyncio.wait_for( + _create_embeddings(), timeout=int(os.getenv("MOS_EMBEDDER_TIMEOUT", 5)) + ) ) + logger.info(f"Embeddings request succeeded with {time.time() - init_time} seconds") + logger.info(f"Embeddings request response: {response}") return [r.embedding for r in response.data] except Exception as e: - raise Exception(f"Embeddings request ended with error: {e}") from e + if self.use_backup_client: + logger.warning( + f"Embeddings request ended with {type(e).__name__} error: {e}, try backup client" + ) + try: + + async def _create_embeddings_backup(): + return self.backup_client.embeddings.create( + model=getattr( + self.config, + "backup_model_name_or_path", + "text-embedding-3-large", + ), + input=texts, + ) + + init_time = time.time() + response = asyncio.run( + asyncio.wait_for( + _create_embeddings_backup(), + timeout=int(os.getenv("MOS_EMBEDDER_TIMEOUT", 5)), + ) + ) + logger.info( + f"Backup embeddings request succeeded with {time.time() - init_time} seconds" + ) + logger.info(f"Backup embeddings request response: {response}") + return [r.embedding for r in response.data] + except Exception as e: + raise ValueError(f"Backup embeddings request ended with error: {e}") from e + else: + raise ValueError(f"Embeddings request ended with error: {e}") from e else: raise ValueError(f"Embeddings unsupported provider: {self.provider}") diff --git a/src/memos/extras/__init__.py b/src/memos/extras/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/memos/extras/nli_model/__init__.py b/src/memos/extras/nli_model/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/memos/extras/nli_model/client.py b/src/memos/extras/nli_model/client.py new file mode 100644 index 000000000..a02dae9f6 --- /dev/null +++ b/src/memos/extras/nli_model/client.py @@ -0,0 +1,61 @@ +import logging + +import requests + +from memos.extras.nli_model.types import NLIResult + + +logger = logging.getLogger(__name__) + + +class NLIClient: + """ + Client for interacting with the deployed NLI model service. + """ + + def __init__(self, base_url: str = "http://localhost:32532"): + self.base_url = base_url.rstrip("/") + self.session = requests.Session() + + def compare_one_to_many(self, source: str, targets: list[str]) -> list[NLIResult]: + """ + Compare one source text against multiple target memories using the NLI service. + + Args: + source: The new memory content. + targets: List of existing memory contents to compare against. + + Returns: + List of NLIResult corresponding to each target. + """ + if not targets: + return [] + + url = f"{self.base_url}/compare_one_to_many" + # Match schemas.CompareRequest + payload = {"source": source, "targets": targets} + + try: + response = self.session.post(url, json=payload, timeout=30) + response.raise_for_status() + data = response.json() + + # Match schemas.CompareResponse + results_str = data.get("results", []) + + results = [] + for res_str in results_str: + try: + results.append(NLIResult(res_str)) + except ValueError: + logger.warning( + f"[NLIClient] Unknown result: {res_str}, defaulting to UNRELATED" + ) + results.append(NLIResult.UNRELATED) + + return results + + except requests.RequestException as e: + logger.error(f"[NLIClient] Request failed: {e}") + # Fallback: if NLI fails, assume all are Unrelated to avoid blocking the flow. + return [NLIResult.UNRELATED] * len(targets) diff --git a/src/memos/extras/nli_model/server/README.md b/src/memos/extras/nli_model/server/README.md new file mode 100644 index 000000000..f6886e0e4 --- /dev/null +++ b/src/memos/extras/nli_model/server/README.md @@ -0,0 +1,69 @@ +# NLI Model Server + +This directory contains the standalone server for the Natural Language Inference (NLI) model used by MemOS. + +## Prerequisites + +- Python 3.10+ +- CUDA-capable GPU (Recommended for performance) +- `torch` and `transformers` libraries (required for the server) + +## Running the Server + +You can run the server using the module syntax from the project root to ensure imports work correctly. + +### 1. Basic Start +```bash +python -m memos.extras.nli_model.server.serve +``` + +### 2. Configuration +You can configure the server by editing config.py: + +- `HOST`: The host to bind to (default: `0.0.0.0`) +- `PORT`: The port to bind to (default: `32532`) +- `NLI_DEVICE`: The device to run the model on. + - `cuda` (Default, uses cuda:0 if available, else fallback to mps/cpu) + - `cuda:0` (Specific GPU) + - `mps` (Apple Silicon) + - `cpu` (CPU) + +## API Usage + +### Compare One to Many +**POST** `/compare_one_to_many` + +**Request Body:** +```json +{ + "source": "I just ate an apple.", + "targets": [ + "I ate a fruit.", + "I hate apples.", + "The sky is blue." + ] +} +``` + +## Testing + +An end-to-end example script is provided to verify the server's functionality. This script starts the server locally and runs a client request to verify the NLI logic. + +### End-to-End Test + +Run the example script from the project root: + +```bash +python examples/extras/nli_e2e_example.py +``` + +**Response:** +```json +{ + "results": [ + "Duplicate", // Entailment + "Contradiction", // Contradiction + "Unrelated" // Neutral + ] +} +``` diff --git a/src/memos/extras/nli_model/server/__init__.py b/src/memos/extras/nli_model/server/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/memos/extras/nli_model/server/config.py b/src/memos/extras/nli_model/server/config.py new file mode 100644 index 000000000..d2e12175d --- /dev/null +++ b/src/memos/extras/nli_model/server/config.py @@ -0,0 +1,23 @@ +import logging + + +NLI_MODEL_NAME = "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli" + +# Configuration +# You can set the device directly here. +# Examples: +# - "cuda" : Use default GPU (cuda:0) if available, else auto-fallback +# - "cuda:0" : Use specific GPU +# - "mps" : Use Apple Silicon GPU (if available) +# - "cpu" : Use CPU +NLI_DEVICE = "cuda" +NLI_MODEL_HOST = "0.0.0.0" +NLI_MODEL_PORT = 32532 + +# Configure logging for NLI Server +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s | %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler(), logging.FileHandler("nli_server.log")], +) +logger = logging.getLogger("nli_server") diff --git a/src/memos/extras/nli_model/server/handler.py b/src/memos/extras/nli_model/server/handler.py new file mode 100644 index 000000000..3e98ddeb0 --- /dev/null +++ b/src/memos/extras/nli_model/server/handler.py @@ -0,0 +1,186 @@ +import re + +from memos.extras.nli_model.server.config import NLI_MODEL_NAME, logger +from memos.extras.nli_model.types import NLIResult + + +# Placeholder for lazy imports +torch = None +AutoModelForSequenceClassification = None +AutoTokenizer = None + + +def _map_label_to_result(raw: str) -> NLIResult: + t = raw.lower() + if "entail" in t: + return NLIResult.DUPLICATE + if "contrad" in t or "refut" in t: + return NLIResult.CONTRADICTION + # Neutral or unknown + return NLIResult.UNRELATED + + +def _clean_temporal_markers(s: str) -> str: + # Remove temporal/aspect markers that might cause contradiction + # Chinese markers (simple replace is usually okay as they are characters) + zh_markers = ["刚刚", "曾经", "正在", "目前", "现在"] + for m in zh_markers: + s = s.replace(m, "") + + # English markers (need word boundaries to avoid "snow" -> "s") + en_markers = ["just", "once", "currently", "now"] + pattern = r"\b(" + "|".join(en_markers) + r")\b" + s = re.sub(pattern, "", s, flags=re.IGNORECASE) + + # Cleanup extra spaces + s = re.sub(r"\s+", " ", s).strip() + return s + + +class NLIHandler: + """ + NLI Model Handler for inference. + Requires `torch` and `transformers` to be installed. + """ + + def __init__(self, device: str = "cpu", use_fp16: bool = True, use_compile: bool = True): + global torch, AutoModelForSequenceClassification, AutoTokenizer + try: + import torch + + from transformers import AutoModelForSequenceClassification, AutoTokenizer + except ImportError as e: + raise ImportError( + "NLIHandler requires 'torch' and 'transformers'. " + "Please install them via 'pip install torch transformers' or use the requirements.txt." + ) from e + + self.device = self._resolve_device(device) + logger.info(f"Final resolved device: {self.device}") + + # Set defaults based on device if not explicitly provided + is_cuda = "cuda" in self.device + if not is_cuda: + use_fp16 = False + use_compile = False + + self.tokenizer = AutoTokenizer.from_pretrained(NLI_MODEL_NAME) + + model_kwargs = {} + if use_fp16 and is_cuda: + model_kwargs["torch_dtype"] = torch.float16 + + self.model = AutoModelForSequenceClassification.from_pretrained( + NLI_MODEL_NAME, **model_kwargs + ).to(self.device) + self.model.eval() + + self.id2label = {int(k): v for k, v in self.model.config.id2label.items()} + self.softmax = torch.nn.Softmax(dim=-1).to(self.device) + + if use_compile and hasattr(torch, "compile"): + logger.info("Compiling model with torch.compile...") + self.model = torch.compile(self.model) + + def _resolve_device(self, device: str) -> str: + d = device.strip().lower() + + has_cuda = torch.cuda.is_available() + has_mps = torch.backends.mps.is_available() if hasattr(torch.backends, "mps") else False + + if d == "cpu": + return "cpu" + + if d.startswith("cuda"): + if has_cuda: + if d == "cuda": + return "cuda:0" + return d + + # Fallback if CUDA not available + if has_mps: + logger.warning( + f"Device '{device}' requested but CUDA not available. Fallback to MPS." + ) + return "mps" + + logger.warning( + f"Device '{device}' requested but CUDA/MPS not available. Fallback to CPU." + ) + return "cpu" + + if d == "mps": + if has_mps: + return "mps" + + logger.warning(f"Device '{device}' requested but MPS not available. Fallback to CPU.") + return "cpu" + + # Fallback / Auto-detect for other cases (e.g. "gpu" or unknown) + if has_cuda: + return "cuda:0" + if has_mps: + return "mps" + + return "cpu" + + def predict_batch(self, premises: list[str], hypotheses: list[str]) -> list[NLIResult]: + # Clean inputs + premises = [_clean_temporal_markers(p) for p in premises] + hypotheses = [_clean_temporal_markers(h) for h in hypotheses] + + # Batch tokenize with padding + inputs = self.tokenizer( + premises, hypotheses, return_tensors="pt", truncation=True, max_length=512, padding=True + ).to(self.device) + with torch.no_grad(): + out = self.model(**inputs) + probs = self.softmax(out.logits) + + results = [] + for p in probs: + idx = int(torch.argmax(p).item()) + res = self.id2label.get(idx, str(idx)) + results.append(_map_label_to_result(res)) + return results + + def compare_one_to_many(self, source: str, targets: list[str]) -> list[NLIResult]: + """ + Compare one source text against multiple target memories efficiently using batch processing. + Performs bidirectional checks (Source <-> Target) for each pair. + """ + if not targets: + return [] + + n = len(targets) + # Construct batch: + # First n pairs: Source -> Target_i + # Next n pairs: Target_i -> Source + premises = [source] * n + targets + hypotheses = targets + [source] * n + + # Run single large batch inference + raw_results = self.predict_batch(premises, hypotheses) + + # Split results back + results_ab = raw_results[:n] + results_ba = raw_results[n:] + + final_results = [] + for i in range(n): + res_ab = results_ab[i] + res_ba = results_ba[i] + + # 1. Any Contradiction -> Contradiction (Sensitive detection, filtered by LLM later) + if res_ab == NLIResult.CONTRADICTION or res_ba == NLIResult.CONTRADICTION: + final_results.append(NLIResult.CONTRADICTION) + + # 2. Any Entailment -> Duplicate (as per user requirement) + elif res_ab == NLIResult.DUPLICATE or res_ba == NLIResult.DUPLICATE: + final_results.append(NLIResult.DUPLICATE) + + # 3. Otherwise (Both Neutral) -> Unrelated + else: + final_results.append(NLIResult.UNRELATED) + + return final_results diff --git a/src/memos/extras/nli_model/server/serve.py b/src/memos/extras/nli_model/server/serve.py new file mode 100644 index 000000000..0ed9eae65 --- /dev/null +++ b/src/memos/extras/nli_model/server/serve.py @@ -0,0 +1,44 @@ +from contextlib import asynccontextmanager + +import uvicorn + +from fastapi import FastAPI, HTTPException + +from memos.extras.nli_model.server.config import NLI_DEVICE, NLI_MODEL_HOST, NLI_MODEL_PORT +from memos.extras.nli_model.server.handler import NLIHandler +from memos.extras.nli_model.types import CompareRequest, CompareResponse + + +# Global handler instance +nli_handler: NLIHandler | None = None + + +@asynccontextmanager +async def lifespan(app: FastAPI): + global nli_handler + nli_handler = NLIHandler(device=NLI_DEVICE) + yield + # Clean up if needed + nli_handler = None + + +app = FastAPI(lifespan=lifespan) + + +@app.post("/compare_one_to_many", response_model=CompareResponse) +async def compare_one_to_many(request: CompareRequest): + if nli_handler is None: + raise HTTPException(status_code=503, detail="Model not loaded") + try: + results = nli_handler.compare_one_to_many(request.source, request.targets) + return CompareResponse(results=results) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +def start_server(host: str = "0.0.0.0", port: int = 32532): + uvicorn.run(app, host=host, port=port) + + +if __name__ == "__main__": + start_server(host=NLI_MODEL_HOST, port=NLI_MODEL_PORT) diff --git a/src/memos/extras/nli_model/types.py b/src/memos/extras/nli_model/types.py new file mode 100644 index 000000000..619f8f508 --- /dev/null +++ b/src/memos/extras/nli_model/types.py @@ -0,0 +1,18 @@ +from enum import Enum + +from pydantic import BaseModel + + +class NLIResult(Enum): + DUPLICATE = "Duplicate" + CONTRADICTION = "Contradiction" + UNRELATED = "Unrelated" + + +class CompareRequest(BaseModel): + source: str + targets: list[str] + + +class CompareResponse(BaseModel): + results: list[NLIResult] diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 40c0c9684..b9c8ca84b 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -1130,30 +1130,27 @@ def get_nodes( - Assumes all provided IDs are valid and exist. - Returns empty list if input is empty. """ + logger.info(f"get_nodes ids:{ids},user_name:{user_name}") if not ids: return [] - # Build WHERE clause using agtype_access_operator like get_node method - where_conditions = [] - params = [] - - for id_val in ids: - where_conditions.append( - "ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = %s::agtype" - ) - params.append(self.format_param_value(id_val)) - - where_clause = " OR ".join(where_conditions) + # Build WHERE clause using IN operator with agtype array + # Use ANY operator with array for better performance + placeholders = ",".join(["%s"] * len(ids)) + params = [self.format_param_value(id_val) for id_val in ids] query = f""" SELECT id, properties, embedding FROM "{self.db_name}_graph"."Memory" - WHERE ({where_clause}) + WHERE ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = ANY(ARRAY[{placeholders}]::agtype[]) """ - user_name = user_name if user_name else self.config.user_name - query += " AND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" - params.append(self.format_param_value(user_name)) + # Only add user_name filter if provided + if user_name is not None: + query += " AND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" + params.append(self.format_param_value(user_name)) + + logger.info(f"get_nodes query:{query},params:{params}") conn = None try: @@ -4313,7 +4310,7 @@ def _build_user_name_and_kb_ids_conditions_sql( user_name_conditions = [] effective_user_name = user_name if user_name else default_user_name - if effective_user_name and default_user_name != "xxx": + if effective_user_name: user_name_conditions.append( f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{effective_user_name}\"'::agtype" ) diff --git a/src/memos/mem_reader/base.py b/src/memos/mem_reader/base.py index 87bf43b0f..b034c9367 100644 --- a/src/memos/mem_reader/base.py +++ b/src/memos/mem_reader/base.py @@ -7,6 +7,7 @@ if TYPE_CHECKING: from memos.graph_dbs.base import BaseGraphDB + from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher class BaseMemReader(ABC): @@ -33,6 +34,12 @@ def set_graph_db(self, graph_db: "BaseGraphDB | None") -> None: graph_db: The graph database instance, or None to disable recall operations. """ + @abstractmethod + def set_searcher(self, searcher: "Searcher | None") -> None: + """ + Set the searcher instance for recall operations. + """ + @abstractmethod def get_memory( self, scene_data: list, type: str, info: dict[str, Any], mode: str = "fast" diff --git a/src/memos/mem_reader/factory.py b/src/memos/mem_reader/factory.py index 2749327bf..7bd551fb8 100644 --- a/src/memos/mem_reader/factory.py +++ b/src/memos/mem_reader/factory.py @@ -10,6 +10,7 @@ if TYPE_CHECKING: from memos.graph_dbs.base import BaseGraphDB + from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher class MemReaderFactory(BaseMemReader): @@ -27,6 +28,7 @@ def from_config( cls, config_factory: MemReaderConfigFactory, graph_db: Optional["BaseGraphDB | None"] = None, + searcher: Optional["Searcher | None"] = None, ) -> BaseMemReader: """ Create a MemReader instance from configuration. @@ -50,4 +52,7 @@ def from_config( if graph_db is not None: reader.set_graph_db(graph_db) + if searcher is not None: + reader.set_searcher(searcher) + return reader diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index 9edcd0a55..236a8f180 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -10,6 +10,7 @@ from memos.context.context import ContextThreadPoolExecutor from memos.mem_reader.read_multi_modal import MultiModalParser, detect_lang from memos.mem_reader.read_multi_modal.base import _derive_key +from memos.mem_reader.read_skill_memory.process_skill_memory import process_skill_memory_fine from memos.mem_reader.simple_struct import PROMPT_DICT, SimpleStructMemReader from memos.mem_reader.utils import parse_json_result from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata @@ -38,6 +39,12 @@ def __init__(self, config: MultiModalStructMemReaderConfig): # Extract direct_markdown_hostnames before converting to SimpleStructMemReaderConfig direct_markdown_hostnames = getattr(config, "direct_markdown_hostnames", None) + # oss + self.oss_config = getattr(config, "oss_config", None) + + # skills_dir + self.skills_dir_config = getattr(config, "skills_dir_config", None) + # Create config_dict excluding direct_markdown_hostnames for SimpleStructMemReaderConfig config_dict = config.model_dump(exclude_none=True) config_dict.pop("direct_markdown_hostnames", None) @@ -79,12 +86,11 @@ def _split_large_memory_item( chunks = self.chunker.chunk(item_text) split_items = [] - for chunk in chunks: + def _create_chunk_item(chunk): # Chunk objects have a 'text' attribute chunk_text = chunk.text if not chunk_text or not chunk_text.strip(): - continue - + return None # Create a new memory item for each chunk, preserving original metadata split_item = self._make_memory_item( value=chunk_text, @@ -98,8 +104,17 @@ def _split_large_memory_item( key=item.metadata.key, sources=item.metadata.sources or [], background=item.metadata.background or "", + need_embed=False, ) - split_items.append(split_item) + return split_item + + # Use thread pool to parallel process chunks + with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: + futures = [executor.submit(_create_chunk_item, chunk) for chunk in chunks] + for future in concurrent.futures.as_completed(futures): + split_item = future.result() + if split_item is not None: + split_items.append(split_item) return split_items if split_items else [item] except Exception as e: @@ -127,15 +142,41 @@ def _concat_multi_modal_memories( # Split large memory items before processing processed_items = [] - for item in all_memory_items: - item_text = item.memory or "" - item_tokens = self._count_tokens(item_text) - if item_tokens > max_tokens: - # Split the large item into multiple chunks - split_items = self._split_large_memory_item(item, max_tokens) - processed_items.extend(split_items) - else: - processed_items.append(item) + # control whether to parallel chunk large memory items + parallel_chunking = True + + if parallel_chunking: + # parallel chunk large memory items + with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: + future_to_item = { + executor.submit(self._split_large_memory_item, item, max_tokens): item + for item in all_memory_items + if (item.memory or "") and self._count_tokens(item.memory) > max_tokens + } + processed_items.extend( + [ + item + for item in all_memory_items + if not ( + (item.memory or "") and self._count_tokens(item.memory) > max_tokens + ) + ] + ) + # collect split items from futures + for future in concurrent.futures.as_completed(future_to_item): + split_items = future.result() + processed_items.extend(split_items) + else: + # serial chunk large memory items + for item in all_memory_items: + item_text = item.memory or "" + item_tokens = self._count_tokens(item_text) + if item_tokens > max_tokens: + # Split the large item into multiple chunks + split_items = self._split_large_memory_item(item, max_tokens) + processed_items.extend(split_items) + else: + processed_items.append(item) # If only one item after processing, return as-is if len(processed_items) == 1: @@ -464,13 +505,6 @@ def _get_maybe_merged_memory( status="activated", threshold=merge_threshold, user_name=user_name, - filter={ - "or": [ - {"memory_type": "LongTermMemory"}, - {"memory_type": "UserMemory"}, - {"memory_type": "WorkingMemory"}, - ] - }, ) if not search_results: @@ -797,13 +831,29 @@ def _process_multi_modal_data( if isinstance(scene_data_info, list): # Parse each message in the list all_memory_items = [] - for msg in scene_data_info: - items = self.multi_modal_parser.parse(msg, info, mode="fast", **kwargs) - all_memory_items.extend(items) + # Use thread pool to parse each message in parallel + with ContextThreadPoolExecutor(max_workers=30) as executor: + futures = [ + executor.submit( + self.multi_modal_parser.parse, + msg, + info, + mode="fast", + need_emb=False, + **kwargs, + ) + for msg in scene_data_info + ] + for future in concurrent.futures.as_completed(futures): + try: + items = future.result() + all_memory_items.extend(items) + except Exception as e: + logger.error(f"[MultiModalFine] Error in parallel parsing: {e}") else: # Parse as single message all_memory_items = self.multi_modal_parser.parse( - scene_data_info, info, mode="fast", **kwargs + scene_data_info, info, mode="fast", need_emb=False, **kwargs ) fast_memory_items = self._concat_multi_modal_memories(all_memory_items) if mode == "fast": @@ -819,13 +869,27 @@ def _process_multi_modal_data( future_tool = executor.submit( self._process_tool_trajectory_fine, fast_memory_items, info, **kwargs ) + future_skill = executor.submit( + process_skill_memory_fine, + fast_memory_items=fast_memory_items, + info=info, + searcher=self.searcher, + graph_db=self.graph_db, + llm=self.llm, + embedder=self.embedder, + oss_config=self.oss_config, + skills_dir_config=self.skills_dir_config, + **kwargs, + ) # Collect results fine_memory_items_string_parser = future_string.result() fine_memory_items_tool_trajectory_parser = future_tool.result() + fine_memory_items_skill_memory_parser = future_skill.result() fine_memory_items.extend(fine_memory_items_string_parser) fine_memory_items.extend(fine_memory_items_tool_trajectory_parser) + fine_memory_items.extend(fine_memory_items_skill_memory_parser) # Part B: get fine multimodal items for fast_item in fast_memory_items: @@ -844,7 +908,7 @@ def _process_multi_modal_data( @timed def _process_transfer_multi_modal_data( - self, raw_node: TextualMemoryItem, custom_tags: list[str] | None = None, **kwargs + self, raw_nodes: list[TextualMemoryItem], custom_tags: list[str] | None = None, **kwargs ) -> list[TextualMemoryItem]: """ Process transfer for multimodal data. @@ -852,42 +916,56 @@ def _process_transfer_multi_modal_data( Each source is processed independently by its corresponding parser, which knows how to rebuild the original message and parse it in fine mode. """ - sources = raw_node.metadata.sources or [] - if not sources: - logger.warning("[MultiModalStruct] No sources found in raw_node") + if not raw_nodes: + logger.warning("[MultiModalStruct] No raw nodes found.") return [] - # Extract info from raw_node (same as simple_struct.py) + # Extract info from raw_nodes (same as simple_struct.py) info = { - "user_id": raw_node.metadata.user_id, - "session_id": raw_node.metadata.session_id, - **(raw_node.metadata.info or {}), + "user_id": raw_nodes[0].metadata.user_id, + "session_id": raw_nodes[0].metadata.session_id, + **(raw_nodes[0].metadata.info or {}), } fine_memory_items = [] # Part A: call llm in parallel using thread pool with ContextThreadPoolExecutor(max_workers=2) as executor: future_string = executor.submit( - self._process_string_fine, [raw_node], info, custom_tags, **kwargs + self._process_string_fine, raw_nodes, info, custom_tags, **kwargs ) future_tool = executor.submit( - self._process_tool_trajectory_fine, [raw_node], info, **kwargs + self._process_tool_trajectory_fine, raw_nodes, info, **kwargs + ) + future_skill = executor.submit( + process_skill_memory_fine, + raw_nodes, + info, + searcher=self.searcher, + llm=self.llm, + embedder=self.embedder, + graph_db=self.graph_db, + oss_config=self.oss_config, + skills_dir_config=self.skills_dir_config, + **kwargs, ) # Collect results fine_memory_items_string_parser = future_string.result() fine_memory_items_tool_trajectory_parser = future_tool.result() - + fine_memory_items_skill_memory_parser = future_skill.result() fine_memory_items.extend(fine_memory_items_string_parser) fine_memory_items.extend(fine_memory_items_tool_trajectory_parser) + fine_memory_items.extend(fine_memory_items_skill_memory_parser) # Part B: get fine multimodal items - for source in sources: - lang = getattr(source, "lang", "en") - items = self.multi_modal_parser.process_transfer( - source, context_items=[raw_node], info=info, custom_tags=custom_tags, lang=lang - ) - fine_memory_items.extend(items) + for raw_node in raw_nodes: + sources = raw_node.metadata.sources + for source in sources: + lang = getattr(source, "lang", "en") + items = self.multi_modal_parser.process_transfer( + source, context_items=[raw_node], info=info, custom_tags=custom_tags, lang=lang + ) + fine_memory_items.extend(items) return fine_memory_items def get_scene_data_info(self, scene_data: list, type: str) -> list[list[Any]]: @@ -944,22 +1022,7 @@ def fine_transfer_simple_mem( if not input_memories: return [] - memory_list = [] - # Process Q&A pairs concurrently with context propagation - with ContextThreadPoolExecutor() as executor: - futures = [ - executor.submit( - self._process_transfer_multi_modal_data, scene_data_info, custom_tags, **kwargs - ) - for scene_data_info in input_memories - ] - for future in concurrent.futures.as_completed(futures): - try: - res_memory = future.result() - if res_memory is not None: - memory_list.append(res_memory) - except Exception as e: - logger.error(f"Task failed with exception: {e}") - logger.error(traceback.format_exc()) - return memory_list + memory_list = self._process_transfer_multi_modal_data(input_memories, custom_tags, **kwargs) + + return [memory_list] diff --git a/src/memos/mem_reader/read_multi_modal/assistant_parser.py b/src/memos/mem_reader/read_multi_modal/assistant_parser.py index 3519216d2..89d4fec7f 100644 --- a/src/memos/mem_reader/read_multi_modal/assistant_parser.py +++ b/src/memos/mem_reader/read_multi_modal/assistant_parser.py @@ -210,6 +210,7 @@ def parse_fast( info: dict[str, Any], **kwargs, ) -> list[TextualMemoryItem]: + need_emb = kwargs.get("need_emb", True) if not isinstance(message, dict): logger.warning(f"[AssistantParser] Expected dict, got {type(message)}") return [] @@ -290,7 +291,7 @@ def parse_fast( status="activated", tags=["mode:fast"], key=_derive_key(line), - embedding=self.embedder.embed([line])[0], + embedding=self.embedder.embed([line])[0] if need_emb else None, usage=[], sources=sources, background="", diff --git a/src/memos/mem_reader/read_multi_modal/base.py b/src/memos/mem_reader/read_multi_modal/base.py index 1a756c5d0..95d427864 100644 --- a/src/memos/mem_reader/read_multi_modal/base.py +++ b/src/memos/mem_reader/read_multi_modal/base.py @@ -15,6 +15,7 @@ TextualMemoryItem, TreeNodeTextualMemoryMetadata, ) +from memos.utils import timed from .utils import detect_lang, get_text_splitter @@ -245,6 +246,7 @@ def parse( else: raise ValueError(f"Unknown mode: {mode}. Must be 'fast' or 'fine'") + @timed def _split_text(self, text: str, is_markdown: bool = False) -> list[str]: """ Split text into chunks using text splitter from utils. diff --git a/src/memos/mem_reader/read_multi_modal/multi_modal_parser.py b/src/memos/mem_reader/read_multi_modal/multi_modal_parser.py index 2c8140419..808410e65 100644 --- a/src/memos/mem_reader/read_multi_modal/multi_modal_parser.py +++ b/src/memos/mem_reader/read_multi_modal/multi_modal_parser.py @@ -149,6 +149,7 @@ def parse( logger.warning(f"[MultiModalParser] No parser found for message: {message}") return [] + logger.info(f"[{parser.__class__.__name__}] Parsing message in {mode} mode: {message}") # Parse using the appropriate parser try: return parser.parse(message, info, mode=mode, **kwargs) diff --git a/src/memos/mem_reader/read_multi_modal/user_parser.py b/src/memos/mem_reader/read_multi_modal/user_parser.py index 1c9afab65..1ab48c82e 100644 --- a/src/memos/mem_reader/read_multi_modal/user_parser.py +++ b/src/memos/mem_reader/read_multi_modal/user_parser.py @@ -151,6 +151,7 @@ def parse_fast( info: dict[str, Any], **kwargs, ) -> list[TextualMemoryItem]: + need_emb = kwargs.get("need_emb", True) if not isinstance(message, dict): logger.warning(f"[UserParser] Expected dict, got {type(message)}") return [] @@ -192,7 +193,7 @@ def parse_fast( status="activated", tags=["mode:fast"], key=_derive_key(line), - embedding=self.embedder.embed([line])[0], + embedding=self.embedder.embed([line])[0] if need_emb else None, usage=[], sources=sources, background="", diff --git a/src/memos/mem_reader/read_skill_memory/process_skill_memory.py b/src/memos/mem_reader/read_skill_memory/process_skill_memory.py new file mode 100644 index 000000000..bb809e69d --- /dev/null +++ b/src/memos/mem_reader/read_skill_memory/process_skill_memory.py @@ -0,0 +1,721 @@ +import json +import os +import shutil +import uuid +import zipfile + +from concurrent.futures import as_completed +from datetime import datetime +from pathlib import Path +from typing import Any + +from memos.context.context import ContextThreadPoolExecutor +from memos.dependency import require_python_package +from memos.embedders.base import BaseEmbedder +from memos.graph_dbs.base import BaseGraphDB +from memos.llms.base import BaseLLM +from memos.log import get_logger +from memos.mem_reader.read_multi_modal import detect_lang +from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata +from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher +from memos.templates.skill_mem_prompt import ( + SKILL_MEMORY_EXTRACTION_PROMPT, + SKILL_MEMORY_EXTRACTION_PROMPT_ZH, + TASK_CHUNKING_PROMPT, + TASK_CHUNKING_PROMPT_ZH, + TASK_QUERY_REWRITE_PROMPT, + TASK_QUERY_REWRITE_PROMPT_ZH, +) +from memos.types import MessageList + + +logger = get_logger(__name__) + + +def add_id_to_mysql(memory_id: str, mem_cube_id: str): + """Add id to mysql, will deprecate this function in the future""" + # TODO: tmp function, deprecate soon + import requests + + skill_mysql_url = os.getenv("SKILLS_MYSQL_URL", "") + skill_mysql_bearer = os.getenv("SKILLS_MYSQL_BEARER", "") + + if not skill_mysql_url or not skill_mysql_bearer: + logger.warning("[PROCESS_SKILLS] SKILLS_MYSQL_URL or SKILLS_MYSQL_BEARER is not set") + return None + headers = {"Authorization": skill_mysql_bearer, "Content-Type": "application/json"} + data = {"memCubeId": mem_cube_id, "skillId": memory_id} + try: + response = requests.post(skill_mysql_url, headers=headers, json=data) + + logger.info(f"[PROCESS_SKILLS] response: \n\n{response.json()}") + logger.info(f"[PROCESS_SKILLS] memory_id: \n\n{memory_id}") + logger.info(f"[PROCESS_SKILLS] mem_cube_id: \n\n{mem_cube_id}") + logger.info(f"[PROCESS_SKILLS] skill_mysql_url: \n\n{skill_mysql_url}") + logger.info(f"[PROCESS_SKILLS] skill_mysql_bearer: \n\n{skill_mysql_bearer}") + logger.info(f"[PROCESS_SKILLS] headers: \n\n{headers}") + logger.info(f"[PROCESS_SKILLS] data: \n\n{data}") + + return response.json() + except Exception as e: + logger.warning(f"[PROCESS_SKILLS] Error adding id to mysql: {e}") + return None + + +@require_python_package( + import_name="alibabacloud_oss_v2", + install_command="pip install alibabacloud-oss-v2", +) +def create_oss_client(oss_config: dict[str, Any] | None = None) -> Any: + import alibabacloud_oss_v2 as oss + + credentials_provider = oss.credentials.EnvironmentVariableCredentialsProvider() + + # load SDK's default configuration, and set credential provider + cfg = oss.config.load_default() + cfg.credentials_provider = credentials_provider + cfg.region = oss_config.get("region", os.getenv("OSS_REGION")) + cfg.endpoint = oss_config.get("endpoint", os.getenv("OSS_ENDPOINT")) + client = oss.Client(cfg) + + return client + + +def _reconstruct_messages_from_memory_items(memory_items: list[TextualMemoryItem]) -> MessageList: + reconstructed_messages = [] + seen = set() # Track (role, content) tuples to detect duplicates + + for memory_item in memory_items: + for source_message in memory_item.metadata.sources: + try: + role = source_message.role + content = source_message.content + + # Create a tuple for deduplication + message_key = (role, content) + + # Only add if not seen before (keep first occurrence) + if message_key not in seen: + reconstructed_messages.append({"role": role, "content": content}) + seen.add(message_key) + except Exception as e: + logger.warning(f"[PROCESS_SKILLS] Error reconstructing message: {e}") + continue + + return reconstructed_messages + + +def _add_index_to_message(messages: MessageList) -> MessageList: + for i, message in enumerate(messages): + message["idx"] = i + return messages + + +def _split_task_chunk_by_llm(llm: BaseLLM, messages: MessageList) -> dict[str, MessageList]: + """Split messages into task chunks by LLM.""" + messages_context = "\n".join( + [ + f"{message.get('idx', i)}: {message['role']}: {message['content']}" + for i, message in enumerate(messages) + ] + ) + lang = detect_lang(messages_context) + template = TASK_CHUNKING_PROMPT_ZH if lang == "zh" else TASK_CHUNKING_PROMPT + prompt = [{"role": "user", "content": template.replace("{{messages}}", messages_context)}] + for attempt in range(3): + try: + skills_llm = os.getenv("SKILLS_LLM", None) + llm_kwargs = {"model_name_or_path": skills_llm} if skills_llm else {} + response_text = llm.generate(prompt, **llm_kwargs) + response_json = json.loads(response_text.replace("```json", "").replace("```", "")) + break + except Exception as e: + logger.warning(f"[PROCESS_SKILLS] LLM generate failed (attempt {attempt + 1}): {e}") + if attempt == 2: + logger.warning( + "[PROCESS_SKILLS] LLM generate failed after 3 retries, returning empty dict" + ) + response_json = [] + break + + task_chunks = {} + for item in response_json: + task_name = item["task_name"] + message_indices = item["message_indices"] + for indices in message_indices: + # Validate that indices is a list/tuple with exactly 2 elements + if not isinstance(indices, list | tuple) or len(indices) != 2: + logger.warning( + f"[PROCESS_SKILLS] Invalid message indices format for task '{task_name}': {indices}, skipping" + ) + continue + start, end = indices + task_chunks.setdefault(task_name, []).extend(messages[start : end + 1]) + return task_chunks + + +def _extract_skill_memory_by_llm( + messages: MessageList, old_memories: list[TextualMemoryItem], llm: BaseLLM +) -> dict[str, Any]: + old_memories_dict = [skill_memory.model_dump() for skill_memory in old_memories] + old_mem_references = [ + { + "id": mem["id"], + "name": mem["metadata"]["name"], + "description": mem["metadata"]["description"], + "procedure": mem["metadata"]["procedure"], + "experience": mem["metadata"]["experience"], + "preference": mem["metadata"]["preference"], + "examples": mem["metadata"]["examples"], + "tags": mem["metadata"]["tags"], + "scripts": mem["metadata"].get("scripts"), + "others": mem["metadata"]["others"], + } + for mem in old_memories_dict + ] + + # Prepare conversation context + messages_context = "\n".join( + [f"{message['role']}: {message['content']}" for message in messages] + ) + + # Prepare old memories context + old_memories_context = json.dumps(old_mem_references, ensure_ascii=False, indent=2) + + # Prepare prompt + lang = detect_lang(messages_context) + template = SKILL_MEMORY_EXTRACTION_PROMPT_ZH if lang == "zh" else SKILL_MEMORY_EXTRACTION_PROMPT + prompt_content = template.replace("{old_memories}", old_memories_context).replace( + "{messages}", messages_context + ) + + prompt = [{"role": "user", "content": prompt_content}] + + # Call LLM to extract skill memory with retry logic + for attempt in range(3): + try: + # Only pass model_name_or_path if SKILLS_LLM is set + skills_llm = os.getenv("SKILLS_LLM", None) + llm_kwargs = {"model_name_or_path": skills_llm} if skills_llm else {} + response_text = llm.generate(prompt, **llm_kwargs) + # Clean up response (remove markdown code blocks if present) + response_text = response_text.strip() + response_text = response_text.replace("```json", "").replace("```", "").strip() + + # Parse JSON response + skill_memory = json.loads(response_text) + + # If LLM returns null (parsed as None), log and return None + if skill_memory is None: + logger.info( + "[PROCESS_SKILLS] No skill memory extracted from conversation (LLM returned null)" + ) + return None + + return skill_memory + + except json.JSONDecodeError as e: + logger.warning(f"[PROCESS_SKILLS] JSON decode failed (attempt {attempt + 1}): {e}") + logger.debug(f"[PROCESS_SKILLS] Response text: {response_text}") + if attempt == 2: + logger.warning("[PROCESS_SKILLS] Failed to parse skill memory after 3 retries") + return None + except Exception as e: + logger.warning( + f"[PROCESS_SKILLS] LLM skill memory extraction failed (attempt {attempt + 1}): {e}" + ) + if attempt == 2: + logger.warning( + "[PROCESS_SKILLS] LLM skill memory extraction failed after 3 retries" + ) + return None + + return None + + +def _recall_related_skill_memories( + task_type: str, + messages: MessageList, + searcher: Searcher, + llm: BaseLLM, + rewrite_query: bool, + info: dict[str, Any], + mem_cube_id: str, +) -> list[TextualMemoryItem]: + query = _rewrite_query(task_type, messages, llm, rewrite_query) + related_skill_memories = searcher.search( + query, + top_k=5, + memory_type="SkillMemory", + info=info, + include_skill_memory=True, + user_name=mem_cube_id, + ) + + return related_skill_memories + + +def _rewrite_query(task_type: str, messages: MessageList, llm: BaseLLM, rewrite_query: bool) -> str: + if not rewrite_query: + # Return the first user message content if rewrite is disabled + return messages[0]["content"] if messages else "" + + # Construct messages context for LLM + messages_context = "\n".join( + [f"{message['role']}: {message['content']}" for message in messages] + ) + + # Prepare prompt with task type and messages + lang = detect_lang(messages_context) + template = TASK_QUERY_REWRITE_PROMPT_ZH if lang == "zh" else TASK_QUERY_REWRITE_PROMPT + prompt_content = template.replace("{task_type}", task_type).replace( + "{messages}", messages_context + ) + prompt = [{"role": "user", "content": prompt_content}] + + # Call LLM to rewrite the query with retry logic + for attempt in range(3): + try: + response_text = llm.generate(prompt) + # Clean up response (remove any markdown formatting if present) + response_text = response_text.strip() + logger.info(f"[PROCESS_SKILLS] Rewritten query for task '{task_type}': {response_text}") + return response_text + except Exception as e: + logger.warning( + f"[PROCESS_SKILLS] LLM query rewrite failed (attempt {attempt + 1}): {e}" + ) + if attempt == 2: + logger.warning( + "[PROCESS_SKILLS] LLM query rewrite failed after 3 retries, returning first message content" + ) + return messages[0]["content"] if messages else "" + + # Fallback (should not reach here due to return in exception handling) + return messages[0]["content"] if messages else "" + + +@require_python_package( + import_name="alibabacloud_oss_v2", + install_command="pip install alibabacloud-oss-v2", +) +def _upload_skills_to_oss(local_file_path: str, oss_file_path: str, client: Any) -> str: + import alibabacloud_oss_v2 as oss + + result = client.put_object_from_file( + request=oss.PutObjectRequest( + bucket=os.getenv("OSS_BUCKET_NAME"), + key=oss_file_path, + ), + filepath=local_file_path, + ) + + if result.status_code != 200: + logger.warning("[PROCESS_SKILLS] Failed to upload skill to OSS") + return "" + + # Construct and return the URL + bucket_name = os.getenv("OSS_BUCKET_NAME") + endpoint = os.getenv("OSS_ENDPOINT").replace("https://", "").replace("http://", "") + url = f"https://{bucket_name}.{endpoint}/{oss_file_path}" + return url + + +@require_python_package( + import_name="alibabacloud_oss_v2", + install_command="pip install alibabacloud-oss-v2", +) +def _delete_skills_from_oss(oss_file_path: str, client: Any) -> Any: + import alibabacloud_oss_v2 as oss + + result = client.delete_object( + oss.DeleteObjectRequest( + bucket=os.getenv("OSS_BUCKET_NAME"), + key=oss_file_path, + ) + ) + return result + + +def _write_skills_to_file( + skill_memory: dict[str, Any], info: dict[str, Any], skills_dir_config: dict[str, Any] +) -> str: + user_id = info.get("user_id", "unknown") + skill_name = skill_memory.get("name", "unnamed_skill").replace(" ", "_").lower() + + # Create tmp directory for user if it doesn't exist + tmp_dir = Path(skills_dir_config["skills_local_dir"]) / user_id + tmp_dir.mkdir(parents=True, exist_ok=True) + + # Create skill directory directly in tmp_dir + skill_dir = tmp_dir / skill_name + skill_dir.mkdir(parents=True, exist_ok=True) + + # Generate SKILL.md content with frontmatter + skill_md_content = f"""--- +name: {skill_name} +description: {skill_memory.get("description", "")} +--- +""" + + # Add Procedure section only if present + procedure = skill_memory.get("procedure", "") + if procedure and procedure.strip(): + skill_md_content += f"\n## Procedure\n{procedure}\n" + + # Add Experience section only if there are items + experiences = skill_memory.get("experience", []) + if experiences: + skill_md_content += "\n## Experience\n" + for idx, exp in enumerate(experiences, 1): + skill_md_content += f"{idx}. {exp}\n" + + # Add User Preferences section only if there are items + preferences = skill_memory.get("preference", []) + if preferences: + skill_md_content += "\n## User Preferences\n" + for pref in preferences: + skill_md_content += f"- {pref}\n" + + # Add Examples section only if there are items + examples = skill_memory.get("examples", []) + if examples: + skill_md_content += "\n## Examples\n" + for idx, example in enumerate(examples, 1): + skill_md_content += f"\n### Example {idx}\n```markdown\n{example}\n```\n" + + # Add scripts reference if present + scripts = skill_memory.get("scripts") + if scripts and isinstance(scripts, dict): + skill_md_content += "\n## Scripts\n" + skill_md_content += "This skill includes the following executable scripts:\n\n" + for script_name in scripts: + skill_md_content += f"- `./scripts/{script_name}`\n" + + # Add others - handle both inline content and separate markdown files + others = skill_memory.get("others") + if others and isinstance(others, dict): + # Separate markdown files from inline content + md_files = {} + inline_content = {} + + for key, value in others.items(): + if key.endswith(".md"): + md_files[key] = value + else: + inline_content[key] = value + + # Add inline content to SKILL.md + if inline_content: + skill_md_content += "\n## Additional Information\n" + for key, value in inline_content.items(): + skill_md_content += f"\n### {key}\n{value}\n" + + # Add references to separate markdown files + if md_files: + if not inline_content: + skill_md_content += "\n## Additional Information\n" + skill_md_content += "\nSee also:\n" + for md_filename in md_files: + skill_md_content += f"- [{md_filename}](./{md_filename})\n" + + # Write SKILL.md file + skill_md_path = skill_dir / "SKILL.md" + with open(skill_md_path, "w", encoding="utf-8") as f: + f.write(skill_md_content) + + # Write separate markdown files from others + if others and isinstance(others, dict): + for key, value in others.items(): + if key.endswith(".md"): + md_file_path = skill_dir / key + with open(md_file_path, "w", encoding="utf-8") as f: + f.write(value) + + # If there are scripts, create a scripts directory with individual script files + if scripts and isinstance(scripts, dict): + scripts_dir = skill_dir / "scripts" + scripts_dir.mkdir(parents=True, exist_ok=True) + + # Write each script to its own file + for script_filename, script_content in scripts.items(): + # Ensure filename ends with .py + if not script_filename.endswith(".py"): + script_filename = f"{script_filename}.py" + + script_path = scripts_dir / script_filename + with open(script_path, "w", encoding="utf-8") as f: + f.write(script_content) + + # Create zip file in tmp_dir + zip_filename = f"{skill_name}.zip" + zip_path = tmp_dir / zip_filename + + with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf: + # Walk through the skill directory and add all files + for file_path in skill_dir.rglob("*"): + if file_path.is_file(): + # Use relative path from skill_dir for archive + arcname = Path(skill_dir.name) / file_path.relative_to(skill_dir) + zipf.write(str(file_path), str(arcname)) + + logger.info(f"[PROCESS_SKILLS] Created skill zip file: {zip_path}") + return str(zip_path) + + +def create_skill_memory_item( + skill_memory: dict[str, Any], info: dict[str, Any], embedder: BaseEmbedder | None = None +) -> TextualMemoryItem: + info_ = info.copy() + user_id = info_.pop("user_id", "") + session_id = info_.pop("session_id", "") + + # Use description as the memory content + memory_content = skill_memory.get("description", "") + + # Create metadata with all skill-specific fields directly + metadata = TreeNodeTextualMemoryMetadata( + user_id=user_id, + session_id=session_id, + memory_type="SkillMemory", + status="activated", + tags=skill_memory.get("tags", []), + key=skill_memory.get("name", ""), + sources=[], + usage=[], + background="", + confidence=0.99, + created_at=datetime.now().isoformat(), + updated_at=datetime.now().isoformat(), + type="skills", + info=info_, + embedding=embedder.embed([memory_content])[0] if embedder else None, + # Skill-specific fields + name=skill_memory.get("name", ""), + description=skill_memory.get("description", ""), + procedure=skill_memory.get("procedure", ""), + experience=skill_memory.get("experience", []), + preference=skill_memory.get("preference", []), + examples=skill_memory.get("examples", []), + scripts=skill_memory.get("scripts"), + others=skill_memory.get("others"), + url=skill_memory.get("url", ""), + ) + + # If this is an update, use the old memory ID + item_id = ( + skill_memory.get("old_memory_id", "") + if skill_memory.get("update", False) + else str(uuid.uuid4()) + ) + if not item_id: + item_id = str(uuid.uuid4()) + + return TextualMemoryItem(id=item_id, memory=memory_content, metadata=metadata) + + +def process_skill_memory_fine( + fast_memory_items: list[TextualMemoryItem], + info: dict[str, Any], + searcher: Searcher | None = None, + graph_db: BaseGraphDB | None = None, + llm: BaseLLM | None = None, + embedder: BaseEmbedder | None = None, + rewrite_query: bool = True, + oss_config: dict[str, Any] | None = None, + skills_dir_config: dict[str, Any] | None = None, + **kwargs, +) -> list[TextualMemoryItem]: + # Validate required configurations + if not oss_config: + logger.warning("[PROCESS_SKILLS] OSS configuration is required for skill memory processing") + return [] + + if not skills_dir_config: + logger.warning( + "[PROCESS_SKILLS] Skills directory configuration is required for skill memory processing" + ) + return [] + + # Validate skills_dir has required keys + required_keys = ["skills_local_dir", "skills_oss_dir"] + missing_keys = [key for key in required_keys if key not in skills_dir_config] + if missing_keys: + logger.warning( + f"[PROCESS_SKILLS] Skills directory configuration missing required keys: {', '.join(missing_keys)}" + ) + return [] + + oss_client = create_oss_client(oss_config) + if not oss_client: + logger.warning("[PROCESS_SKILLS] Failed to create OSS client") + return [] + + messages = _reconstruct_messages_from_memory_items(fast_memory_items) + messages = _add_index_to_message(messages) + + task_chunks = _split_task_chunk_by_llm(llm, messages) + if not task_chunks: + logger.warning("[PROCESS_SKILLS] No task chunks found") + return [] + + # recall - get related skill memories for each task separately (parallel) + related_skill_memories_by_task = {} + with ContextThreadPoolExecutor(max_workers=min(len(task_chunks), 5)) as executor: + recall_futures = { + executor.submit( + _recall_related_skill_memories, + task_type=task, + messages=msg, + searcher=searcher, + llm=llm, + rewrite_query=rewrite_query, + info=info, + mem_cube_id=kwargs.get("user_name", info.get("user_id", "")), + ): task + for task, msg in task_chunks.items() + } + for future in as_completed(recall_futures): + task_name = recall_futures[future] + try: + related_memories = future.result() + related_skill_memories_by_task[task_name] = related_memories + except Exception as e: + logger.warning( + f"[PROCESS_SKILLS] Error recalling skill memories for task '{task_name}': {e}" + ) + related_skill_memories_by_task[task_name] = [] + + skill_memories = [] + with ContextThreadPoolExecutor(max_workers=min(len(task_chunks), 5)) as executor: + futures = { + executor.submit( + _extract_skill_memory_by_llm, + messages, + related_skill_memories_by_task.get(task_type, []), + llm, + ): task_type + for task_type, messages in task_chunks.items() + } + for future in as_completed(futures): + try: + skill_memory = future.result() + if skill_memory: # Only add non-None results + skill_memories.append(skill_memory) + except Exception as e: + logger.warning(f"[PROCESS_SKILLS] Error extracting skill memory: {e}") + continue + + # write skills to file and get zip paths + skill_memory_with_paths = [] + with ContextThreadPoolExecutor(max_workers=min(len(skill_memories), 5)) as executor: + futures = { + executor.submit( + _write_skills_to_file, skill_memory, info, skills_dir_config + ): skill_memory + for skill_memory in skill_memories + } + for future in as_completed(futures): + try: + zip_path = future.result() + skill_memory = futures[future] + skill_memory_with_paths.append((skill_memory, zip_path)) + except Exception as e: + logger.warning(f"[PROCESS_SKILLS] Error writing skills to file: {e}") + continue + + # Create a mapping from old_memory_id to old memory for easy lookup + # Collect all related memories from all tasks + all_related_memories = [] + for memories in related_skill_memories_by_task.values(): + all_related_memories.extend(memories) + old_memories_map = {mem.id: mem for mem in all_related_memories} + + # upload skills to oss and set urls directly to skill_memory + user_id = info.get("user_id", "unknown") + + for skill_memory, zip_path in skill_memory_with_paths: + try: + # Delete old skill from OSS if this is an update + if skill_memory.get("update", False) and skill_memory.get("old_memory_id"): + old_memory_id = skill_memory["old_memory_id"] + old_memory = old_memories_map.get(old_memory_id) + + if old_memory: + # Get old OSS path from the old memory's metadata + old_oss_path = getattr(old_memory.metadata, "url", None) + + if old_oss_path: + try: + # delete old skill from OSS + zip_filename = Path(old_oss_path).name + old_oss_path = ( + Path(skills_dir_config["skills_oss_dir"]) / user_id / zip_filename + ).as_posix() + _delete_skills_from_oss(old_oss_path, oss_client) + logger.info( + f"[PROCESS_SKILLS] Deleted old skill from OSS: {old_oss_path}" + ) + except Exception as e: + logger.warning( + f"[PROCESS_SKILLS] Failed to delete old skill from OSS: {e}" + ) + + # delete old skill from graph db + if graph_db: + graph_db.delete_node_by_prams(memory_ids=[old_memory_id]) + logger.info( + f"[PROCESS_SKILLS] Deleted old skill from graph db: {old_memory_id}" + ) + + # Upload new skill to OSS + # Use the same filename as the local zip file + zip_filename = Path(zip_path).name + oss_path = ( + Path(skills_dir_config["skills_oss_dir"]) / user_id / zip_filename + ).as_posix() + + # _upload_skills_to_oss returns the URL + url = _upload_skills_to_oss( + local_file_path=str(zip_path), oss_file_path=oss_path, client=oss_client + ) + + # Set URL directly to skill_memory + skill_memory["url"] = url + + logger.info(f"[PROCESS_SKILLS] Uploaded skill to OSS: {url}") + except Exception as e: + logger.warning(f"[PROCESS_SKILLS] Error uploading skill to OSS: {e}") + skill_memory["url"] = "" # Set to empty string if upload fails + finally: + # Clean up local files after upload + try: + zip_file = Path(zip_path) + skill_dir = zip_file.parent / zip_file.stem + # Delete zip file + if zip_file.exists(): + zip_file.unlink() + # Delete skill directory + if skill_dir.exists(): + shutil.rmtree(skill_dir) + logger.info(f"[PROCESS_SKILLS] Cleaned up local files: {zip_path} and {skill_dir}") + except Exception as cleanup_error: + logger.warning(f"[PROCESS_SKILLS] Error cleaning up local files: {cleanup_error}") + + # Create TextualMemoryItem objects + skill_memory_items = [] + for skill_memory in skill_memories: + try: + memory_item = create_skill_memory_item(skill_memory, info, embedder) + skill_memory_items.append(memory_item) + except Exception as e: + logger.warning(f"[PROCESS_SKILLS] Error creating skill memory item: {e}") + continue + + # TODO: deprecate this funtion and call + for skill_memory in skill_memory_items: + add_id_to_mysql( + memory_id=skill_memory.id, mem_cube_id=kwargs.get("user_name", info.get("user_id", "")) + ) + + return skill_memory_items diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index 783da763e..f3ae98ccb 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -20,6 +20,7 @@ if TYPE_CHECKING: from memos.graph_dbs.base import BaseGraphDB + from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher from memos.mem_reader.read_multi_modal import coerce_scene_data, detect_lang from memos.mem_reader.utils import ( count_tokens_text, @@ -187,6 +188,9 @@ def __init__(self, config: SimpleStructMemReaderConfig): def set_graph_db(self, graph_db: "BaseGraphDB | None") -> None: self.graph_db = graph_db + def set_searcher(self, searcher: "Searcher | None") -> None: + self.searcher = searcher + def _make_memory_item( self, value: str, @@ -198,6 +202,7 @@ def _make_memory_item( background: str = "", type_: str = "fact", confidence: float = 0.99, + need_embed: bool = True, **kwargs, ) -> TextualMemoryItem: """construct memory item""" @@ -213,7 +218,7 @@ def _make_memory_item( status="activated", tags=tags or [], key=key if key is not None else derive_key(value), - embedding=self.embedder.embed([value])[0], + embedding=self.embedder.embed([value])[0] if need_embed else None, usage=[], sources=sources or [], background=background, diff --git a/src/memos/memories/textual/base.py b/src/memos/memories/textual/base.py index 6b0b7e8a6..cbf1a97b3 100644 --- a/src/memos/memories/textual/base.py +++ b/src/memos/memories/textual/base.py @@ -59,7 +59,9 @@ def get(self, memory_id: str, user_name: str | None = None) -> TextualMemoryItem """ @abstractmethod - def get_by_ids(self, memory_ids: list[str]) -> list[TextualMemoryItem]: + def get_by_ids( + self, memory_ids: list[str], user_name: str | None = None + ) -> list[TextualMemoryItem]: """Get memories by their IDs. Args: memory_ids (list[str]): List of memory IDs to retrieve. diff --git a/src/memos/memories/textual/item.py b/src/memos/memories/textual/item.py index a1c85033b..46770758d 100644 --- a/src/memos/memories/textual/item.py +++ b/src/memos/memories/textual/item.py @@ -112,6 +112,7 @@ class TreeNodeTextualMemoryMetadata(TextualMemoryMetadata): "OuterMemory", "ToolSchemaMemory", "ToolTrajectoryMemory", + "SkillMemory", ] = Field(default="WorkingMemory", description="Memory lifecycle type.") sources: list[SourceMessage] | None = Field( default=None, description="Multiple origins of the memory (e.g., URLs, notes)." diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index b963cfa9b..b556db5d7 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -161,6 +161,8 @@ def search( user_name: str | None = None, search_tool_memory: bool = False, tool_mem_top_k: int = 6, + include_skill_memory: bool = False, + skill_mem_top_k: int = 3, dedup: str | None = None, **kwargs, ) -> list[TextualMemoryItem]: @@ -208,6 +210,8 @@ def search( user_name=user_name, search_tool_memory=search_tool_memory, tool_mem_top_k=tool_mem_top_k, + include_skill_memory=include_skill_memory, + skill_mem_top_k=skill_mem_top_k, dedup=dedup, **kwargs, ) @@ -319,7 +323,8 @@ def get(self, memory_id: str, user_name: str | None = None) -> TextualMemoryItem def get_by_ids( self, memory_ids: list[str], user_name: str | None = None ) -> list[TextualMemoryItem]: - raise NotImplementedError + graph_output = self.graph_store.get_nodes(ids=memory_ids, user_name=user_name) + return graph_output def get_all( self, diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index c96d5a12a..5e9c74f61 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -159,7 +159,12 @@ def _add_memories_batch( for memory in memories: working_id = str(uuid.uuid4()) - if memory.metadata.memory_type not in ("ToolSchemaMemory", "ToolTrajectoryMemory"): + if memory.metadata.memory_type in ( + "WorkingMemory", + "LongTermMemory", + "UserMemory", + "OuterMemory", + ): working_metadata = memory.metadata.model_copy( update={"memory_type": "WorkingMemory"} ).model_dump(exclude_none=True) @@ -176,8 +181,11 @@ def _add_memories_batch( "UserMemory", "ToolSchemaMemory", "ToolTrajectoryMemory", + "SkillMemory", ): - graph_node_id = str(uuid.uuid4()) + if not memory.id: + logger.error("Memory ID is not set, generating a new one") + graph_node_id = memory.id or str(uuid.uuid4()) metadata_dict = memory.metadata.model_dump(exclude_none=True) metadata_dict["updated_at"] = datetime.now().isoformat() @@ -310,7 +318,12 @@ def _process_memory(self, memory: TextualMemoryItem, user_name: str | None = Non working_id = str(uuid.uuid4()) with ContextThreadPoolExecutor(max_workers=2, thread_name_prefix="mem") as ex: - if memory.metadata.memory_type not in ("ToolSchemaMemory", "ToolTrajectoryMemory"): + if memory.metadata.memory_type in ( + "WorkingMemory", + "LongTermMemory", + "UserMemory", + "OuterMemory", + ): f_working = ex.submit( self._add_memory_to_db, memory, "WorkingMemory", user_name, working_id ) @@ -321,6 +334,7 @@ def _process_memory(self, memory: TextualMemoryItem, user_name: str | None = Non "UserMemory", "ToolSchemaMemory", "ToolTrajectoryMemory", + "SkillMemory", ): f_graph = ex.submit( self._add_to_graph_memory, @@ -372,7 +386,9 @@ def _add_to_graph_memory( """ Generalized method to add memory to a graph-based memory type (e.g., LongTermMemory, UserMemory). """ - node_id = str(uuid.uuid4()) + if not memory.id: + logger.error("Memory ID is not set, generating a new one") + node_id = memory.id or str(uuid.uuid4()) # Step 2: Add new node to graph metadata_dict = memory.metadata.model_dump(exclude_none=True) tags = metadata_dict.get("tags") or [] diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py b/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py new file mode 100644 index 000000000..a5fc7e049 --- /dev/null +++ b/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py @@ -0,0 +1,264 @@ +import concurrent.futures +import re + +from typing import Any + +from memos.context.context import ContextThreadPoolExecutor +from memos.log import get_logger +from memos.mem_reader.read_multi_modal.utils import detect_lang +from memos.memories.textual.item import TextualMemoryItem +from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer + + +logger = get_logger(__name__) + + +class PreUpdateRetriever: + def __init__(self, graph_db, embedder): + """ + The PreUpdateRetriever is designed for the /add phase . + It serves to recall potentially duplicate/conflict memories against the new content that's being added. + + Args: + graph_db: The graph database instance (Neo4j, PolarDB, etc.) + embedder: The embedder instance for vector search + """ + self.graph_db = graph_db + self.embedder = embedder + # Use existing tokenizer for keyword extraction + self.tokenizer = FastTokenizer(use_jieba=True, use_stopwords=True) + + def _adjust_perspective(self, text: str, role: str, lang: str) -> str: + """ + For better search result, we adjust the perspective + from 1st person to 3rd person based on role and language. + "I" -> "User" (if role is user) + "I" -> "Assistant" (if role is assistant) + """ + if not role: + return text + + role = role.lower() + replacements = [] + + # Determine replacements based on language and role + if lang == "zh": + if role == "user": + replacements = [("我", "用户")] + elif role == "assistant": + replacements = [("我", "助手")] + else: # default to en + if role == "user": + replacements = [ + (r"\bI\b", "User"), + (r"\bme\b", "User"), + (r"\bmy\b", "User's"), + (r"\bmine\b", "User's"), + (r"\bmyself\b", "User himself"), + ] + elif role == "assistant": + replacements = [ + (r"\bI\b", "Assistant"), + (r"\bme\b", "Assistant"), + (r"\bmy\b", "Assistant's"), + (r"\bmine\b", "Assistant's"), + (r"\bmyself\b", "Assistant himself"), + ] + + adjusted_text = text + for pattern, repl in replacements: + if lang == "zh": + adjusted_text = adjusted_text.replace(pattern, repl) + else: + adjusted_text = re.sub(pattern, repl, adjusted_text, flags=re.IGNORECASE) + + return adjusted_text + + def _preprocess_query(self, item: TextualMemoryItem) -> str: + """ + Preprocess the query item: + 1. Extract language and role from metadata/sources + 2. Adjust perspective (I -> User/Assistant) based on role/lang + """ + raw_text = item.memory or "" + if not raw_text.strip(): + return "" + + # Extract lang/role + lang = None + role = None + sources = item.metadata.sources + + if sources: + source_list = sources if isinstance(sources, list) else [sources] + for source in source_list: + if hasattr(source, "lang") and source.lang: + lang = source.lang + elif isinstance(source, dict) and source.get("lang"): + lang = source.get("lang") + + if hasattr(source, "role") and source.role: + role = source.role + elif isinstance(source, dict) and source.get("role"): + role = source.get("role") + + if lang and role: + break + + if lang is None: + lang = detect_lang(raw_text) + + # Adjust perspective + return self._adjust_perspective(raw_text, role, lang) + + def _get_full_memories( + self, candidate_ids: list[str], user_name: str + ) -> list[TextualMemoryItem]: + """ + Retrieve full memories for given candidate ids. + """ + full_recalled_memories = self.graph_db.get_nodes(candidate_ids, user_name=user_name) + return [TextualMemoryItem.from_dict(item) for item in full_recalled_memories] + + def vector_search( + self, + query_text: str, + query_embedding: list[float] | None, + user_name: str, + top_k: int, + search_filter: dict[str, Any] | None = None, + threshold: float = 0.5, + ) -> list[dict]: + try: + # Use pre-computed embedding if available (matches raw/clean query) + # Otherwise embed the switched query for better semantic match + q_embed = query_embedding if query_embedding else self.embedder.embed([query_text])[0] + + # Assuming graph_db.search_by_embedding returns list of dicts or items + results = self.graph_db.search_by_embedding( + vector=q_embed, + top_k=top_k, + status=None, + threshold=threshold, + user_name=user_name, + filter=search_filter, + ) + return results + except Exception as e: + logger.error(f"[PreUpdateRetriever] Vector search failed: {e}") + return [] + + def keyword_search( + self, + query_text: str, + user_name: str, + top_k: int, + search_filter: dict[str, Any] | None = None, + ) -> list[dict]: + try: + # 1. Tokenize using existing tokenizer + keywords = self.tokenizer.tokenize_mixed(query_text) + if not keywords: + return [] + + results = [] + + # 2. Try seach_by_keywords_tfidf (PolarDB specific) + if hasattr(self.graph_db, "seach_by_keywords_tfidf"): + try: + results = self.graph_db.seach_by_keywords_tfidf( + query_words=keywords, user_name=user_name, filter=search_filter + ) + except Exception as e: + logger.warning(f"[PreUpdateRetriever] seach_by_keywords_tfidf failed: {e}") + + # 3. Fallback to search_by_fulltext + if not results and hasattr(self.graph_db, "search_by_fulltext"): + try: + results = self.graph_db.search_by_fulltext( + query_words=keywords, top_k=top_k, user_name=user_name, filter=search_filter + ) + except Exception as e: + logger.warning(f"[PreUpdateRetriever] search_by_fulltext failed: {e}") + + return results[:top_k] + + except Exception as e: + logger.error(f"[PreUpdateRetriever] Keyword search failed: {e}") + return [] + + def retrieve( + self, item: TextualMemoryItem, user_name: str, top_k: int = 10, sim_threshold: float = 0.5 + ) -> list[TextualMemoryItem]: + """ + Recall related memories for a TextualMemoryItem using hybrid search (Vector + Keyword). + Might actually return top_k ~ 2top_k items. + Designed for low latency. + + Args: + item: The memory item to find related memories for + user_name: User identifier for scoping search + top_k: Max number of results to return + sim_threshold: minimal similarity threshold for vector search + + Returns: + List of TextualMemoryItem + """ + # 1. Preprocess + switched_query = self._preprocess_query(item) + + # 2. Recall + futures = [] + common_filter = { + "status": {"in": ["activated", "resolving"]}, + "memory_type": {"in": ["LongTermMemory", "UserMemory", "WorkingMemory"]}, + } + + with ContextThreadPoolExecutor(max_workers=3, thread_name_prefix="fast_recall") as executor: + # Task A: Vector Search (Semantic) + query_embedding = ( + item.metadata.embedding if hasattr(item.metadata, "embedding") else None + ) + futures.append( + executor.submit( + self.vector_search, + switched_query, + query_embedding, + user_name, + top_k, + common_filter, + sim_threshold, + ) + ) + + # Task B: Keyword Search + futures.append( + executor.submit( + self.keyword_search, switched_query, user_name, top_k, common_filter + ) + ) + + # 3. Collect Results + retrieved_ids = set() # for deduplicating ids + for future in concurrent.futures.as_completed(futures): + try: + res = future.result() + if not res: + continue + + for r in res: + retrieved_ids.add(r["id"]) + + except Exception as e: + logger.error(f"[PreUpdateRetriever] Search future task failed: {e}") + + retrieved_ids = list(retrieved_ids) + + if not retrieved_ids: + return [] + + # 4. Retrieve full memories to from just ids + # TODO: We should modify the db functions to support returning arbitrary fields, instead of search twice. + final_memories = self._get_full_memories(retrieved_ids, user_name) + + return final_memories diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py index 4541b118b..c9f2ec156 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py @@ -67,6 +67,7 @@ def retrieve( "UserMemory", "ToolSchemaMemory", "ToolTrajectoryMemory", + "SkillMemory", ]: raise ValueError(f"Unsupported memory scope: {memory_scope}") diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py b/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py index 5a82883c8..1c887355c 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py @@ -466,7 +466,12 @@ def find_best_unrelated_subgroup(sentences: list, similarity_matrix: list, bar: def cosine_similarity_matrix(embeddings: list[list[float]]) -> list[list[float]]: - norms = np.linalg.norm(embeddings, axis=1, keepdims=True) - x_normalized = embeddings / norms + embeddings_array = np.asarray(embeddings) + norms = np.linalg.norm(embeddings_array, axis=1, keepdims=True) + # Handle zero vectors to avoid division by zero + norms[norms == 0] = 1.0 + x_normalized = embeddings_array / norms similarity_matrix = np.dot(x_normalized, x_normalized.T) + # Handle any NaN or Inf values + similarity_matrix = np.nan_to_num(similarity_matrix, nan=0.0, posinf=0.0, neginf=0.0) return similarity_matrix diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 8c30d74f3..dcd4e1fba 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -81,6 +81,8 @@ def retrieve( user_name: str | None = None, search_tool_memory: bool = False, tool_mem_top_k: int = 6, + include_skill_memory: bool = False, + skill_mem_top_k: int = 3, **kwargs, ) -> list[tuple[TextualMemoryItem, float]]: logger.info( @@ -108,6 +110,8 @@ def retrieve( user_name, search_tool_memory, tool_mem_top_k, + include_skill_memory, + skill_mem_top_k, ) return results @@ -119,6 +123,8 @@ def post_retrieve( info=None, search_tool_memory: bool = False, tool_mem_top_k: int = 6, + include_skill_memory: bool = False, + skill_mem_top_k: int = 3, dedup: str | None = None, plugin=False, ): @@ -127,7 +133,13 @@ def post_retrieve( else: deduped = self._deduplicate_results(retrieved_results) final_results = self._sort_and_trim( - deduped, top_k, plugin, search_tool_memory, tool_mem_top_k + deduped, + top_k, + plugin, + search_tool_memory, + tool_mem_top_k, + include_skill_memory, + skill_mem_top_k, ) self._update_usage_history(final_results, info, user_name) return final_results @@ -145,6 +157,8 @@ def search( user_name: str | None = None, search_tool_memory: bool = False, tool_mem_top_k: int = 6, + include_skill_memory: bool = False, + skill_mem_top_k: int = 3, dedup: str | None = None, **kwargs, ) -> list[TextualMemoryItem]: @@ -192,6 +206,8 @@ def search( user_name=user_name, search_tool_memory=search_tool_memory, tool_mem_top_k=tool_mem_top_k, + include_skill_memory=include_skill_memory, + skill_mem_top_k=skill_mem_top_k, **kwargs, ) @@ -207,6 +223,8 @@ def search( plugin=kwargs.get("plugin", False), search_tool_memory=search_tool_memory, tool_mem_top_k=tool_mem_top_k, + include_skill_memory=include_skill_memory, + skill_mem_top_k=skill_mem_top_k, dedup=dedup, ) @@ -305,8 +323,10 @@ def _retrieve_paths( user_name: str | None = None, search_tool_memory: bool = False, tool_mem_top_k: int = 6, + include_skill_memory: bool = False, + skill_mem_top_k: int = 3, ): - """Run A/B/C retrieval paths in parallel""" + """Run A/B/C/D/E retrieval paths in parallel""" tasks = [] id_filter = { "user_id": info.get("user_id", None), @@ -314,7 +334,7 @@ def _retrieve_paths( } id_filter = {k: v for k, v in id_filter.items() if v is not None} - with ContextThreadPoolExecutor(max_workers=3) as executor: + with ContextThreadPoolExecutor(max_workers=5) as executor: tasks.append( executor.submit( self._retrieve_from_working_memory, @@ -373,6 +393,22 @@ def _retrieve_paths( mode=mode, ) ) + if include_skill_memory: + tasks.append( + executor.submit( + self._retrieve_from_skill_memory, + query, + parsed_goal, + query_embedding, + skill_mem_top_k, + memory_type, + search_filter, + search_priority, + user_name, + id_filter, + mode=mode, + ) + ) results = [] for t in tasks: results.extend(t.result()) @@ -642,6 +678,58 @@ def _retrieve_from_tool_memory( ) return schema_reranked + trajectory_reranked + # --- Path E + @timed + def _retrieve_from_skill_memory( + self, + query, + parsed_goal, + query_embedding, + top_k, + memory_type, + search_filter: dict | None = None, + search_priority: dict | None = None, + user_name: str | None = None, + id_filter: dict | None = None, + mode: str = "fast", + ): + """Retrieve and rerank from SkillMemory""" + if memory_type not in ["All", "SkillMemory"]: + logger.info(f"[PATH-E] '{query}' Skipped (memory_type does not match)") + return [] + + # chain of thinking + cot_embeddings = [] + if self.vec_cot: + queries = self._cot_query(query, mode=mode, context=parsed_goal.context) + if len(queries) > 1: + cot_embeddings = self.embedder.embed(queries) + cot_embeddings.extend(query_embedding) + else: + cot_embeddings = query_embedding + + items = self.graph_retriever.retrieve( + query=query, + parsed_goal=parsed_goal, + query_embedding=cot_embeddings, + top_k=top_k * 2, + memory_scope="SkillMemory", + search_filter=search_filter, + search_priority=search_priority, + user_name=user_name, + id_filter=id_filter, + use_fast_graph=self.use_fast_graph, + ) + + return self.reranker.rerank( + query=query, + query_embedding=query_embedding[0], + graph_results=items, + top_k=top_k, + parsed_goal=parsed_goal, + search_filter=search_filter, + ) + @timed def _retrieve_simple( self, @@ -704,7 +792,14 @@ def _deduplicate_results(self, results): @timed def _sort_and_trim( - self, results, top_k, plugin=False, search_tool_memory=False, tool_mem_top_k=6 + self, + results, + top_k, + plugin=False, + search_tool_memory=False, + tool_mem_top_k=6, + include_skill_memory=False, + skill_mem_top_k=3, ): """Sort results by score and trim to top_k""" final_items = [] @@ -749,11 +844,35 @@ def _sort_and_trim( metadata=SearchedTreeNodeTextualMemoryMetadata(**meta_data), ) ) + + if include_skill_memory: + skill_results = [ + (item, score) + for item, score in results + if item.metadata.memory_type == "SkillMemory" + ] + sorted_skill_results = sorted(skill_results, key=lambda pair: pair[1], reverse=True)[ + :skill_mem_top_k + ] + for item, score in sorted_skill_results: + if plugin and round(score, 2) == 0.00: + continue + meta_data = item.metadata.model_dump() + meta_data["relativity"] = score + final_items.append( + TextualMemoryItem( + id=item.id, + memory=item.memory, + metadata=SearchedTreeNodeTextualMemoryMetadata(**meta_data), + ) + ) + # separate textual results results = [ (item, score) for item, score in results - if item.metadata.memory_type not in ["ToolSchemaMemory", "ToolTrajectoryMemory"] + if item.metadata.memory_type + in ["WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"] ] sorted_results = sorted(results, key=lambda pair: pair[1], reverse=True)[:top_k] diff --git a/src/memos/multi_mem_cube/composite_cube.py b/src/memos/multi_mem_cube/composite_cube.py index c1017bfae..0d2d460e9 100644 --- a/src/memos/multi_mem_cube/composite_cube.py +++ b/src/memos/multi_mem_cube/composite_cube.py @@ -46,6 +46,7 @@ def search_memories(self, search_req: APISearchRequest) -> dict[str, Any]: "pref_mem": [], "pref_note": "", "tool_mem": [], + "skill_mem": [], } def _search_single_cube(view: SingleCubeView) -> dict[str, Any]: @@ -65,7 +66,7 @@ def _search_single_cube(view: SingleCubeView) -> dict[str, Any]: merged_results["para_mem"].extend(cube_result.get("para_mem", [])) merged_results["pref_mem"].extend(cube_result.get("pref_mem", [])) merged_results["tool_mem"].extend(cube_result.get("tool_mem", [])) - + merged_results["skill_mem"].extend(cube_result.get("skill_mem", [])) note = cube_result.get("pref_note") if note: if merged_results["pref_note"]: diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 426cf32be..2f7883548 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -2,6 +2,7 @@ import json import os +import time import traceback from dataclasses import dataclass @@ -121,6 +122,7 @@ def search_memories(self, search_req: APISearchRequest) -> dict[str, Any]: "pref_mem": [], "pref_note": "", "tool_mem": [], + "skill_mem": [], } # Determine search mode @@ -265,7 +267,7 @@ def _deep_search( info=info, ) formatted_memories = [ - format_memory_item(data, include_embedding=search_req.dedup == "sim") + format_memory_item(data, include_embedding=search_req.dedup in ("sim", "mmr")) for data in enhanced_memories ] return formatted_memories @@ -277,7 +279,7 @@ def _agentic_search( search_req.query, user_id=user_context.mem_cube_id ) formatted_memories = [ - format_memory_item(data, include_embedding=search_req.dedup == "sim") + format_memory_item(data, include_embedding=search_req.dedup in ("sim", "mmr")) for data in deepsearch_results ] return formatted_memories @@ -389,7 +391,7 @@ def _dedup_by_content(memories: list) -> list: enhanced_memories if search_req.dedup == "no" else _dedup_by_content(enhanced_memories) ) formatted_memories = [ - format_memory_item(data, include_embedding=search_req.dedup == "sim") + format_memory_item(data, include_embedding=search_req.dedup in ("sim", "mmr")) for data in deduped_memories ] @@ -475,11 +477,13 @@ def _fast_search( plugin=plugin, search_tool_memory=search_req.search_tool_memory, tool_mem_top_k=search_req.tool_mem_top_k, + include_skill_memory=search_req.include_skill_memory, + skill_mem_top_k=search_req.skill_mem_top_k, dedup=search_req.dedup, ) formatted_memories = [ - format_memory_item(data, include_embedding=search_req.dedup == "sim") + format_memory_item(data, include_embedding=search_req.dedup in ("sim", "mmr")) for data in search_results ] @@ -790,7 +794,7 @@ def _process_text_mem( extract_mode, add_req.mode, ) - + init_time = time.time() # Extract memories memories_local = self.mem_reader.get_memory( [add_req.messages], @@ -804,6 +808,9 @@ def _process_text_mem( mode=extract_mode, user_name=user_context.mem_cube_id, ) + self.logger.info( + f"Time for get_memory in extract mode {extract_mode}: {time.time() - init_time}" + ) flattened_local = [mm for m in memories_local for mm in m] # Explicitly set source_doc_id to metadata if present in info diff --git a/src/memos/templates/skill_mem_prompt.py b/src/memos/templates/skill_mem_prompt.py new file mode 100644 index 000000000..0bc0c1809 --- /dev/null +++ b/src/memos/templates/skill_mem_prompt.py @@ -0,0 +1,251 @@ +TASK_CHUNKING_PROMPT = """ +# Context (Conversation Records) +{{messages}} + +# Role +You are an expert in natural language processing (NLP) and dialogue logic analysis. You excel at organizing logical threads from complex long conversations and accurately extracting users' core intentions. + +# Task +Please analyze the provided conversation records, identify all independent "tasks" that the user has asked the AI to perform, and assign the corresponding dialogue message numbers to each task. + +**Note**: Tasks should be high-level and general, typically divided by theme or topic. For example: "Travel Planning", "PDF Operations", "Code Review", "Data Analysis", etc. Avoid being too specific or granular. + +# Rules & Constraints +1. **Task Independence**: If multiple unrelated topics are discussed in the conversation, identify them as different tasks. +2. **Non-continuous Processing**: Pay attention to identifying "jumping" conversations. For example, if the user made travel plans in messages 8-11, switched to consulting about weather in messages 12-22, and then returned to making travel plans in messages 23-24, be sure to assign both 8-11 and 23-24 to the task "Making travel plans". However, if messages are continuous and belong to the same task, do not split them apart. +3. **Filter Chit-chat**: Only extract tasks with clear goals, instructions, or knowledge-based discussions. Ignore meaningless greetings (such as "Hello", "Are you there?") or closing remarks unless they are part of the task context. +4. **Main Task and Subtasks**: Carefully identify whether subtasks serve a main task. If a subtask supports the main task (e.g., "checking weather" serves "travel planning"), do NOT separate it as an independent task. Instead, include all related conversations in the main task. Only split tasks when they are truly independent and unrelated. +5. **Output Format**: Please strictly follow the JSON format for output to facilitate my subsequent processing. +6. **Language Consistency**: The language used in the task_name field must match the language used in the conversation records. +7. **Generic Task Names**: Use generic, reusable task names, not specific descriptions. For example, use "Travel Planning" instead of "Planning a 5-day trip to Chengdu". + +```json +[ + { + "task_id": 1, + "task_name": "Generic task name (e.g., Travel Planning, Code Review, Data Analysis)", + "message_indices": [[0, 5],[16, 17]], # 0-5 and 16-17 are the message indices for this task + "reasoning": "Briefly explain why these messages are grouped together" + }, + ... +] +``` +""" + + +TASK_CHUNKING_PROMPT_ZH = """ +# 上下文(对话记录) +{{messages}} + +# 角色 +你是自然语言处理(NLP)和对话逻辑分析的专家。你擅长从复杂的长对话中整理逻辑线索,准确提取用户的核心意图。 + +# 任务 +请分析提供的对话记录,识别所有用户要求 AI 执行的独立"任务",并为每个任务分配相应的对话消息编号。 + +**注意**:任务应该是高层次和通用的,通常按主题或话题划分。例如:"旅行计划"、"PDF操作"、"代码审查"、"数据分析"等。避免过于具体或细化。 + +# 规则与约束 +1. **任务独立性**:如果对话中讨论了多个不相关的话题,请将它们识别为不同的任务。 +2. **非连续处理**:注意识别"跳跃式"对话。例如,如果用户在消息 8-11 中制定旅行计划,在消息 12-22 中切换到咨询天气,然后在消息 23-24 中返回到制定旅行计划,请务必将 8-11 和 23-24 都分配给"制定旅行计划"任务。但是,如果消息是连续的且属于同一任务,不能将其分开。 +3. **过滤闲聊**:仅提取具有明确目标、指令或基于知识的讨论的任务。忽略无意义的问候(例如"你好"、"在吗?")或结束语,除非它们是任务上下文的一部分。 +4. **主任务与子任务识别**:仔细识别子任务是否服务于主任务。如果子任务是为主任务服务的(例如"查天气"服务于"旅行规划"),不要将其作为独立任务分离出来,而是将所有相关对话都划分到主任务中。只有真正独立且无关联的任务才需要分开。 +5. **输出格式**:请严格遵循 JSON 格式输出,以便我后续处理。 +6. **语言一致性**:task_name 字段使用的语言必须与对话记录中使用的语言相匹配。 +7. **通用任务名称**:使用通用的、可复用的任务名称,而不是具体的描述。例如,使用"旅行规划"而不是"规划成都5日游"。 + +```json +[ + { + "task_id": 1, + "task_name": "通用任务名称(例如:旅行规划、代码审查、数据分析)", + "message_indices": [[0, 5],[16, 17]], # 0-5 和 16-17 是此任务的消息索引 + "reasoning": "简要解释为什么这些消息被分组在一起" + }, + ... +] +``` +""" + + +SKILL_MEMORY_EXTRACTION_PROMPT = """ +# Role +You are an expert in skill abstraction and knowledge extraction. You excel at distilling general, reusable methodologies from specific conversations. + +# Task +Extract a universal skill template from the conversation that can be applied to similar scenarios. Compare with existing skills to determine if this is new or an update. + +# Existing Skill Memories +{old_memories} + +# Conversation Messages +{messages} + +# Core Principles +1. **Generalization**: Extract abstract methodologies applicable across scenarios. Avoid specific details (e.g., "Travel Planning" not "Beijing Travel Planning"). +2. **Universality**: All fields except "example" must remain general and scenario-independent. +3. **Similarity Check**: If similar skill exists, set "update": true with "old_memory_id". Otherwise, set "update": false and leave "old_memory_id" empty. +4. **Language Consistency**: Match the conversation language. + +# Output Format +```json +{ + "name": "General skill name (e.g., 'Travel Itinerary Planning', 'Code Review Workflow')", + "description": "Universal description of what this skill accomplishes", + "procedure": "Generic step-by-step process: 1. Step one 2. Step two...", + "experience": ["General principle or lesson learned", "Best practice applicable to similar cases..."], + "preference": ["User's general preference pattern", "Preferred approach or constraint..."], + "examples": ["Complete formatted output example in markdown format showing the final deliverable structure, content can be abbreviated with '...' but should demonstrate the format and structure", "Another complete output template..."], + "tags": ["keyword1", "keyword2"], + "scripts": {"script_name.py": "# Python code here\nprint('Hello')", "another_script.py": "# More code\nimport os"}, + "others": {"Section Title": "Content here", "reference.md": "# Reference content for this skill"}, + "update": false, + "old_memory_id": "" +} +``` + +# Field Specifications +- **name**: Generic skill identifier without specific instances +- **description**: Universal purpose and applicability +- **procedure**: Abstract, reusable process steps without specific details. Should be generalizable to similar tasks +- **experience**: General lessons, principles, or insights +- **preference**: User's overarching preference patterns +- **tags**: Generic keywords for categorization +- **scripts**: Dictionary of scripts where key is the .py filename and value is the executable code snippet. Only applicable for code-related tasks (e.g., data processing, automation). Use null for non-coding tasks +- **others**: Supplementary information beyond standard fields or lengthy content unsuitable for other fields. Can be either: + - Simple key-value pairs where key is a title and value is content + - Separate markdown files where key is .md filename and value is the markdown content + - Use null if not applicable +- **examples**: Complete output templates showing the final deliverable format and structure. Should demonstrate how the task result looks when this skill is applied, including format, sections, and content organization. Content can be abbreviated but must show the complete structure. Use markdown format for better readability +- **update**: true if updating existing skill, false if new +- **old_memory_id**: ID of skill being updated, or empty string if new + +# Critical Guidelines +- Keep all fields general except "examples" +- "examples" should demonstrate complete final output format and structure with all necessary sections +- "others" contains supplementary context or extended information +- Return null if no extractable skill exists + +# Output Format +Output the JSON object only. +""" + + +SKILL_MEMORY_EXTRACTION_PROMPT_ZH = """ +# 角色 +你是技能抽象和知识提取的专家。你擅长从具体对话中提炼通用的、可复用的方法论。 + +# 任务 +从对话中提取可应用于类似场景的通用技能模板。对比现有技能判断是新建还是更新。 + +# 现有技能记忆 +{old_memories} + +# 对话消息 +{messages} + +# 核心原则 +1. **通用化**:提取可跨场景应用的抽象方法论。避免具体细节(如"旅行规划"而非"北京旅行规划")。 +2. **普适性**:除"examples"外,所有字段必须保持通用,与具体场景无关。 +3. **相似性检查**:如存在相似技能,设置"update": true 及"old_memory_id"。否则设置"update": false 并将"old_memory_id"留空。 +4. **语言一致性**:与对话语言保持一致。 + +# 输出格式 +```json +{ + "name": "通用技能名称(如:'旅行行程规划'、'代码审查流程')", + "description": "技能作用的通用描述", + "procedure": "通用的分步流程:1. 步骤一 2. 步骤二...", + "experience": ["通用原则或经验教训", "可应用于类似场景的最佳实践..."], + "preference": ["用户的通用偏好模式", "偏好的方法或约束..."], + "examples": ["展示最终交付成果的完整格式范本(使用 markdown 格式), 内容可用'...'省略,但需展示完整格式和结构", "另一个完整输出模板..."], + "tags": ["关键词1", "关键词2"], + "scripts": {"script_name.py": "# Python 代码\nprint('Hello')", "another_script.py": "# 更多代码\nimport os"}, + "others": {"章节标题": "这里的内容", "reference.md": "# 此技能的参考内容"}, + "update": false, + "old_memory_id": "" +} +``` + +# 字段规范 +- **name**:通用技能标识符,不含具体实例 +- **description**:通用用途和适用范围 +- **procedure**:抽象的、可复用的流程步骤,不含具体细节。应当能够推广到类似任务 +- **experience**:通用经验、原则或见解 +- **preference**:用户的整体偏好模式 +- **tags**:通用分类关键词 +- **scripts**:脚本字典,其中 key 是 .py 文件名,value 是可执行代码片段。仅适用于代码相关任务(如数据处理、自动化脚本等)。非编程任务直接使用 null +- **others**:标准字段之外的补充信息或不适合放在其他字段的较长内容。可以是: + - 简单的键值对,其中 key 是标题,value 是内容 + - 独立的 markdown 文件,其中 key 是 .md 文件名,value 是 markdown 内容 + - 如果不适用则使用 null +- **examples**:展示最终任务成果的输出模板,包括格式、章节和内容组织结构。应展示应用此技能后任务结果的样子,包含所有必要的部分。内容可以省略但必须展示完整结构。使用 markdown 格式以提高可读性 +- **update**:更新现有技能为true,新建为false +- **old_memory_id**:被更新技能的ID,新建则为空字符串 + +# 关键指导 +- 除"examples"外保持所有字段通用 +- "examples"应展示完整的最终输出格式和结构,包含所有必要章节 +- "others"包含补充说明或扩展信息 +- 无法提取技能时返回null + +# 输出格式 +仅输出JSON对象。 +""" + + +TASK_QUERY_REWRITE_PROMPT = """ +# Role +You are an expert in understanding user intentions and task requirements. You excel at analyzing conversations and extracting the core task description. + +# Task +Based on the provided task type and conversation messages, analyze and determine what specific task the user wants to complete, then rewrite it into a clear, concise task query string. + +# Task Type +{task_type} + +# Conversation Messages +{messages} + +# Requirements +1. Analyze the conversation content to understand the user's core intention +2. Consider the task type as context +3. Extract and summarize the key task objective +4. Output a clear, concise task description string (one sentence) +5. Use the same language as the conversation +6. Focus on WHAT needs to be done, not HOW to do it +7. Do not include any explanations, just output the rewritten task string directly + +# Output +Please output only the rewritten task query string, without any additional formatting or explanation. +""" + + +TASK_QUERY_REWRITE_PROMPT_ZH = """ +# 角色 +你是理解用户意图和任务需求的专家。你擅长分析对话并提取核心任务描述。 + +# 任务 +基于提供的任务类型和对话消息,分析并确定用户想要完成的具体任务,然后将其重写为清晰、简洁的任务查询字符串。 + +# 任务类型 +{task_type} + +# 对话消息 +{messages} + +# 要求 +1. 分析对话内容以理解用户的核心意图 +2. 将任务类型作为上下文考虑 +3. 提取并总结关键任务目标 +4. 输出清晰、简洁的任务描述字符串(一句话) +5. 使用与对话相同的语言 +6. 关注需要做什么(WHAT),而不是如何做(HOW) +7. 不要包含任何解释,直接输出重写后的任务字符串 + +# 输出 +请仅输出重写后的任务查询字符串,不要添加任何额外的格式或解释。 +""" + +SKILLS_AUTHORING_PROMPT = """ +""" diff --git a/tests/extras/__init__.py b/tests/extras/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/extras/nli_model/__init__.py b/tests/extras/nli_model/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/extras/nli_model/test_client_integration.py b/tests/extras/nli_model/test_client_integration.py new file mode 100644 index 000000000..5beff14a0 --- /dev/null +++ b/tests/extras/nli_model/test_client_integration.py @@ -0,0 +1,129 @@ +import threading +import time +import unittest + +from unittest.mock import MagicMock, patch + +import requests +import uvicorn + +from memos.extras.nli_model.client import NLIClient +from memos.extras.nli_model.server.serve import app +from memos.extras.nli_model.types import NLIResult + + +# We need to mock the NLIHandler to avoid loading the heavy model +# but we want to run the real FastAPI server. +class TestNLIClientIntegration(unittest.TestCase): + server_thread = None + stop_server = False + port = 32533 # Use a different port for testing + + @classmethod + def setUpClass(cls): + # Patch the lifespan to inject a mock handler instead of real NLIHandler + cls.mock_handler = MagicMock() + cls.mock_handler.compare_one_to_many.return_value = [ + NLIResult.DUPLICATE, + NLIResult.CONTRADICTION, + ] + + # We need to patch the module where lifespan is defined/used or modify the global variable + # Since 'app' is already imported, we can patch the global nli_handler in serve.py + # But lifespan sets it on startup. + + # Let's patch NLIHandler class in serve.py so when lifespan instantiates it, it gets our mock + cls.handler_patcher = patch("memos.extras.nli_model.server.serve.NLIHandler") + cls.MockHandlerClass = cls.handler_patcher.start() + cls.MockHandlerClass.return_value = cls.mock_handler + + # Start server in a thread + def run_server(): + # Disable logs for uvicorn to keep test output clean + config = uvicorn.Config(app, host="127.0.0.1", port=cls.port, log_level="error") + cls.server = uvicorn.Server(config) + cls.server.run() + + cls.server_thread = threading.Thread(target=run_server, daemon=True) + cls.server_thread.start() + + # Wait for server to be ready + cls._wait_for_server() + + @classmethod + def tearDownClass(cls): + # Stop the server + if hasattr(cls, "server"): + cls.server.should_exit = True + if cls.server_thread: + cls.server_thread.join(timeout=5) + + cls.handler_patcher.stop() + + @classmethod + def _wait_for_server(cls): + url = f"http://127.0.0.1:{cls.port}/docs" + retries = 20 + for _ in range(retries): + try: + response = requests.get(url) + if response.status_code == 200: + return + except requests.ConnectionError: + pass + time.sleep(0.1) + raise RuntimeError("Server failed to start") + + def setUp(self): + self.client = NLIClient(base_url=f"http://127.0.0.1:{self.port}") + # Reset mock calls before each test + self.mock_handler.reset_mock() + # Ensure default behavior + self.mock_handler.compare_one_to_many.return_value = [ + NLIResult.DUPLICATE, + NLIResult.CONTRADICTION, + ] + + def test_real_server_compare_one_to_many(self): + source = "I like apples." + targets = ["I love fruit.", "I hate apples."] + + results = self.client.compare_one_to_many(source, targets) + + # Verify result + self.assertEqual(len(results), 2) + self.assertEqual(results[0], NLIResult.DUPLICATE) + self.assertEqual(results[1], NLIResult.CONTRADICTION) + + # Verify server received the request + self.mock_handler.compare_one_to_many.assert_called_once() + args, _ = self.mock_handler.compare_one_to_many.call_args + self.assertEqual(args[0], source) + self.assertEqual(args[1], targets) + + def test_real_server_empty_targets(self): + source = "I like apples." + targets = [] + + results = self.client.compare_one_to_many(source, targets) + + self.assertEqual(results, []) + # Should not call handler because client handles empty list + self.mock_handler.compare_one_to_many.assert_not_called() + + def test_real_server_handler_error(self): + # Simulate handler error + self.mock_handler.compare_one_to_many.side_effect = ValueError("Something went wrong") + + source = "I like apples." + targets = ["I love fruit."] + + # Client should catch 500 and return UNRELATED + results = self.client.compare_one_to_many(source, targets) + + self.assertEqual(len(results), 1) + self.assertEqual(results[0], NLIResult.UNRELATED) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/memories/textual/test_pre_update_retriever.py b/tests/memories/textual/test_pre_update_retriever.py new file mode 100644 index 000000000..6bed90abb --- /dev/null +++ b/tests/memories/textual/test_pre_update_retriever.py @@ -0,0 +1,150 @@ +import unittest +import uuid + +from dotenv import load_dotenv + +from memos.api.handlers.config_builders import build_embedder_config, build_graph_db_config +from memos.embedders.factory import EmbedderFactory +from memos.graph_dbs.factory import GraphStoreFactory +from memos.memories.textual.item import ( + SourceMessage, + TextualMemoryItem, + TreeNodeTextualMemoryMetadata, +) +from memos.memories.textual.tree_text_memory.retrieve.pre_update import PreUpdateRetriever + + +# Load environment variables +load_dotenv() + + +class TestPreUpdateRecaller(unittest.TestCase): + @classmethod + def setUpClass(cls): + # Initialize graph_db and embedder using factories + # We assume environment variables are set for these to work + try: + cls.graph_db_config = build_graph_db_config() + cls.graph_db = GraphStoreFactory.from_config(cls.graph_db_config) + + cls.embedder_config = build_embedder_config() + cls.embedder = EmbedderFactory.from_config(cls.embedder_config) + except Exception as e: + raise unittest.SkipTest( + f"Skipping test because initialization failed (likely missing env vars): {e}" + ) from e + + cls.recaller = PreUpdateRetriever(cls.graph_db, cls.embedder) + + # Use a unique user name to isolate tests + cls.user_name = "test_pre_update_recaller_user_" + str(uuid.uuid4())[:8] + + def setUp(self): + # Add some data to the db + self.added_ids = [] + + # Create a memory item to add + self.memory_text = "The user likes to eat apples." + self.embedding = self.embedder.embed([self.memory_text])[0] + + # We use dictionary for metadata to simulate what might be passed or stored + # But wait, add_node expects metadata as a dict usually. + metadata = { + "memory_type": "LongTermMemory", + "status": "activated", + "embedding": self.embedding, + "created_at": "2023-01-01T00:00:00", + "updated_at": "2023-01-01T00:00:00", + "tags": ["food", "fruit"], + "key": "user_preference", + "sources": [], + } + + node_id = str(uuid.uuid4()) + self.graph_db.add_node(node_id, self.memory_text, metadata, user_name=self.user_name) + self.added_ids.append(node_id) + + # Add another one + self.memory_text_2 = "The user has a dog named Rex." + self.embedding_2 = self.embedder.embed([self.memory_text_2])[0] + metadata_2 = { + "memory_type": "LongTermMemory", + "status": "activated", + "embedding": self.embedding_2, + "created_at": "2023-01-01T00:00:00", + "updated_at": "2023-01-01T00:00:00", + "tags": ["pet", "dog"], + "key": "user_pet", + "sources": [], + } + node_id_2 = str(uuid.uuid4()) + self.graph_db.add_node(node_id_2, self.memory_text_2, metadata_2, user_name=self.user_name) + self.added_ids.append(node_id_2) + + def tearDown(self): + """Clean up test data.""" + for node_id in self.added_ids: + try: + self.graph_db.delete_node(node_id, user_name=self.user_name) + except Exception as e: + print(f"Error deleting node {node_id}: {e}") + + def test_recall_vector_search(self): + """Test recalling using vector search (implicit in recall method).""" + # "I like apples" -> perspective adjustment should match "The user likes to eat apples" + query_text = "I like apples" + + # Create metadata with source to trigger perspective adjustment + # role="user" means "I" -> "User" + source = SourceMessage(role="user", lang="en") + metadata = TreeNodeTextualMemoryMetadata(sources=[source], memory_type="WorkingMemory") + + item = TextualMemoryItem(memory=query_text, metadata=metadata) + + # The recall method does both vector and keyword search + results = self.recaller.retrieve(item, self.user_name, top_k=5) + + # Verify we got results + self.assertTrue(len(results) > 0, "Should return at least one result") + found_texts = [r.memory for r in results] + + # Check if the relevant memory is found + # "The user likes to eat apples." should be found. + # We check for "apples" to be safe + self.assertTrue( + any("apples" in t for t in found_texts), + f"Expected 'apples' in results, got: {found_texts}", + ) + + def test_recall_keyword_search(self): + """Test recalling where keyword search might be more relevant.""" + # "Rex" is a specific name + query_text = "What is the name of my dog?" + source = SourceMessage(role="user", lang="en") + metadata = TreeNodeTextualMemoryMetadata(sources=[source], memory_type="WorkingMemory") + + item = TextualMemoryItem(memory=query_text, metadata=metadata) + + results = self.recaller.retrieve(item, self.user_name, top_k=5) + + found_texts = [r.memory for r in results] + self.assertTrue( + any("Rex" in t for t in found_texts), f"Expected 'Rex' in results, got: {found_texts}" + ) + + def test_perspective_adjustment(self): + """Unit test for the _adjust_perspective method specifically.""" + text = "I went to the store myself." + adjusted = self.recaller._adjust_perspective(text, "user", "en") + # I -> User, myself -> User himself + self.assertIn("User", adjusted) + self.assertIn("User himself", adjusted) + + text_zh = "我喜欢吃苹果" + adjusted_zh = self.recaller._adjust_perspective(text_zh, "user", "zh") + # 我 -> 用户 + self.assertIn("用户", adjusted_zh) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/memories/textual/test_pre_update_retriever_latency.py b/tests/memories/textual/test_pre_update_retriever_latency.py new file mode 100644 index 000000000..f4a359de9 --- /dev/null +++ b/tests/memories/textual/test_pre_update_retriever_latency.py @@ -0,0 +1,183 @@ +import time +import unittest +import uuid + +import numpy as np + +from dotenv import load_dotenv + +from memos.api.handlers.config_builders import build_embedder_config, build_graph_db_config +from memos.embedders.factory import EmbedderFactory +from memos.graph_dbs.factory import GraphStoreFactory +from memos.memories.textual.item import ( + SourceMessage, + TextualMemoryItem, + TreeNodeTextualMemoryMetadata, +) +from memos.memories.textual.tree_text_memory.retrieve.pre_update import PreUpdateRetriever + + +# Load environment variables +load_dotenv() + + +class TestPreUpdateRecallerLatency(unittest.TestCase): + """ + Performance and latency tests for PreUpdateRetriever. + These tests are designed to measure latency and might take longer to run. + """ + + @classmethod + def setUpClass(cls): + # Initialize graph_db and embedder using factories + try: + cls.graph_db_config = build_graph_db_config() + cls.graph_db = GraphStoreFactory.from_config(cls.graph_db_config) + + cls.embedder_config = build_embedder_config() + cls.embedder = EmbedderFactory.from_config(cls.embedder_config) + except Exception as e: + raise unittest.SkipTest( + f"Skipping test because initialization failed (likely missing env vars): {e}" + ) from e + + cls.recaller = PreUpdateRetriever(cls.graph_db, cls.embedder) + + # Use a unique user name to isolate tests + cls.user_name = "test_pre_update_recaller_latency_user_" + str(uuid.uuid4())[:8] + + def setUp(self): + # Add a substantial amount of data for latency testing + self.added_ids = [] + self.num_items = 20 + + print(f"\nPopulating database with {self.num_items} items for latency test...") + for i in range(self.num_items): + text = f"This is memory item number {i}. The user might enjoy topic {i % 5}." + embedding = self.embedder.embed([text])[0] + metadata = { + "memory_type": "LongTermMemory", + "status": "activated", + "embedding": embedding, + "created_at": "2023-01-01T00:00:00", + "updated_at": "2023-01-01T00:00:00", + "tags": [f"tag_{i}"], + "key": f"key_{i}", + "sources": [], + } + node_id = str(uuid.uuid4()) + self.graph_db.add_node(node_id, text, metadata, user_name=self.user_name) + self.added_ids.append(node_id) + + def tearDown(self): + """Clean up test data.""" + print("Cleaning up test data...") + for node_id in self.added_ids: + try: + self.graph_db.delete_node(node_id, user_name=self.user_name) + except Exception as e: + print(f"Error deleting node {node_id}: {e}") + + def measure_network_rtt(self, trials=10): + """Measure average network round-trip time.""" + print(f"Measuring Network RTT (using {trials} probes)...") + latencies = [] + + # Try to use raw driver for minimal overhead if available (Neo4j specific) + if hasattr(self.graph_db, "driver") and hasattr(self.graph_db, "db_name"): + print("Using Neo4j driver for direct ping...") + try: + with self.graph_db.driver.session(database=self.graph_db.db_name) as session: + # Warmup + session.run("RETURN 1").single() + + for _ in range(trials): + start = time.time() + session.run("RETURN 1").single() + latencies.append((time.time() - start) * 1000) + except Exception as e: + print(f"Direct driver ping failed: {e}. Falling back to get_node.") + latencies = [] + + if not latencies: + # Fallback to get_node with non-existent ID + print("Using get_node for ping...") + for _ in range(trials): + probe_id = str(uuid.uuid4()) + start = time.time() + self.graph_db.get_node(probe_id, user_name=self.user_name) + latencies.append((time.time() - start) * 1000) + + avg_rtt = np.mean(latencies) + print(f"Average Network RTT: {avg_rtt:.2f} ms") + return avg_rtt + + def test_recall_latency(self): + """Test and report recall latency statistics.""" + avg_rtt = self.measure_network_rtt() + + queries = [ + "I enjoy topic 1", + "What about topic 3?", + "Do I have any preferences?", + "Tell me about memory item 5", + ] + + latencies = [] + + # Warmup + print("Warming up...") + warmup_item = TextualMemoryItem( + memory="warmup query", + metadata=TreeNodeTextualMemoryMetadata( + sources=[SourceMessage(role="user", lang="en")], memory_type="WorkingMemory" + ), + ) + self.recaller.retrieve(warmup_item, self.user_name, top_k=5) + + print(f"Running {len(queries)} queries...") + for q in queries: + # Pre-calculate embedding to exclude from latency measurement + q_embedding = self.embedder.embed([q])[0] + + item = TextualMemoryItem( + memory=q, + metadata=TreeNodeTextualMemoryMetadata( + sources=[SourceMessage(role="user", lang="en")], + memory_type="WorkingMemory", + embedding=q_embedding, + ), + ) + + start_time = time.time() + results = self.recaller.retrieve(item, self.user_name, top_k=5) + end_time = time.time() + + duration_ms = (end_time - start_time) * 1000 + latencies.append(duration_ms) + print(f"Query: '{q}' -> Found {len(results)} results in {duration_ms:.2f} ms") + + # Assert that we actually found results (sanity check) + if "preferences" not in q: # The preferences query might return 0 + self.assertTrue(len(results) > 0, f"Expected results for query: {q}") + + # Report Results + avg_latency = np.mean(latencies) + p95_latency = np.percentile(latencies, 95) + min_latency = np.min(latencies) + max_latency = np.max(latencies) + internal_processing = avg_latency - avg_rtt + + print("\n--- Latency Results ---") + print(f"Average Network RTT: {avg_rtt:.2f} ms") + print(f"Average Total Latency: {avg_latency:.2f} ms") + print(f"Estimated Internal Processing: {internal_processing:.2f} ms") + print(f"95th Percentile: {p95_latency:.2f} ms") + print(f"Min Latency: {min_latency:.2f} ms") + print(f"Max Latency: {max_latency:.2f} ms") + + self.assertLess(internal_processing, 200, "Internal processing should be under 200ms") + + +if __name__ == "__main__": + unittest.main()