diff --git a/README.md b/README.md index 8569fc7..2e62cfb 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,24 @@ Inspired by papers like [MemGPT](https://memgpt.ai/) and distilled from our own extracts memories from chat interactions and persists them to a database. This information can later be read or queried semantically to provide personalized context when your bot is responding to a particular user. +The memory graph handles thread process deduplication and supports continuous updates to a single "memory schema" as well as "event-based" memories that can be queried semantically. + +![Memory Diagram](./img/memory_graph.png) + +#### Project Structure + +```bash +├── langgraph.json # LangGraph Cloud Configuration +├── memory_service +│   ├── __init__.py +│   └── graph.py # Define the memory service +├── poetry.lock +├── pyproject.toml # Project dependencies +└── tests # Add testing + evaluation logic + └── evals + └── test_memories.py +``` + ## Quickstart This quick start will get your memory service deployed on [LangGraph Cloud](https://langchain-ai.github.io/langgraph/cloud/). Once created, you can interact with it from any API. @@ -52,12 +70,21 @@ Assuming you've followed the steps above, in just a couple of minutes, you shoul Now let's try it out. -#### How to connect to the memory service +## How to connect to the memory service Check out the [example notebook](./example.ipynb) to show how to connect your chat bot (in this case a second graph) to your new memory service. This chat bot reads from the same memory DB as your memory service to easily query from "recall memory". +Connecting to this type of memory service typically follows an interaction pattern similar to the one outlined below: + +![Interaction Pattern](./img/memory_interactions.png) + +A typical user-facing application you'd build to connect with this service would have 3 or more nodes. The first node queries the DB for useful memories. The second node, which contains the LLM, generates the response. The third node posts the new messages to the service. + +The service waits for a pre-determined interval before it considers the thread "complete". If the user queries a second time within that interval, the memory run is [rolled-back](https://langchain-ai.github.io/langgraph/cloud/how-tos/cloud_examples/rollback_concurrent/?h=roll) to avoid duplicate processing of a thread. + + ## How to evaluate Memory management can be challenging to get right. To make sure your schemas suit your applications' needs, we recommend starting from an evaluation set, diff --git a/img/memory_graph.png b/img/memory_graph.png new file mode 100644 index 0000000..2fbd8aa Binary files /dev/null and b/img/memory_graph.png differ diff --git a/img/memory_interactions.png b/img/memory_interactions.png new file mode 100644 index 0000000..f46c643 Binary files /dev/null and b/img/memory_interactions.png differ diff --git a/memory_service/__init__.py b/memory_service/__init__.py index dc1fbc3..77659f7 100644 --- a/memory_service/__init__.py +++ b/memory_service/__init__.py @@ -1,5 +1,5 @@ """Simple example memory extraction service.""" -from memory_service.graph import extraction_graph +from memory_service.graph import memgraph -__all__ = ["extraction_graph"] +__all__ = ["memgraph"] diff --git a/memory_service/_schemas.py b/memory_service/_schemas.py index f046109..972098b 100644 --- a/memory_service/_schemas.py +++ b/memory_service/_schemas.py @@ -1,7 +1,8 @@ from __future__ import annotations -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional +from langchain_core.messages import AnyMessage from langchain_core.pydantic_v1 import BaseModel from langgraph.graph import add_messages from typing_extensions import Annotated, Literal, TypedDict @@ -55,7 +56,7 @@ class GraphConfig(TypedDict, total=False): class State(TypedDict): - messages: Annotated[list, add_messages] + messages: Annotated[List[AnyMessage], add_messages] """The messages in the conversation.""" eager: bool diff --git a/memory_service/graph.py b/memory_service/graph.py index 7bda132..9a55ab9 100644 --- a/memory_service/graph.py +++ b/memory_service/graph.py @@ -20,6 +20,9 @@ from memory_service import _settings as settings from memory_service import _utils as utils logger = logging.getLogger("memory") +# Handle patch memory, where we update a single document in the database. +# If the document doesn't exist, the LLM will generate a new one. +# Otherwise, it will generate JSON patches to update the existing document. async def fetch_patched_state( @@ -99,7 +102,29 @@ async def upsert_patched_state( return {"user_state": {}} -async def extract_insertion_memories( +patch_builder = StateGraph(schemas.SingleExtractorState, schemas.GraphConfig) +patch_builder.add_node(fetch_patched_state) +patch_builder.add_node(extract_patch_memories) +patch_builder.add_node(upsert_patched_state) +patch_builder.add_edge(START, "fetch_patched_state") +patch_builder.add_edge("fetch_patched_state", "extract_patch_memories") + + +def should_commit_patch( + state: schemas.SingleExtractorState, config: RunnableConfig +) -> Literal["upsert_patched_state", "__end__"]: + """Whether there are things extracted to commit to the DB.""" + return "upsert_patched_state" if state["responses"] else END + + +patch_builder.add_conditional_edges("extract_patch_memories", should_commit_patch) +patch_graph = patch_builder.compile() + +# Handle semantic memory, where we insert each memory event +# as a new document in the database. + + +async def extract_semantic_memories( state: schemas.SingleExtractorState, config: RunnableConfig ) -> dict: """Extract embeddable "events".""" @@ -155,50 +180,24 @@ async def insert_memories( return {"user_state": {}} -def route_inbound( +semantic_builder = StateGraph(schemas.SingleExtractorState, schemas.GraphConfig) +# Lots of quality improvements can be made here, such as: +# - Fetch similar memories and prompt model to combine or extrapolate +# - Adding advanced indexing by the memory schema (like importance, relevance, etc.) +semantic_builder.add_node(extract_semantic_memories) +semantic_builder.add_node(insert_memories) +semantic_builder.add_edge(START, "extract_semantic_memories") + + +def should_insert( state: schemas.SingleExtractorState, config: RunnableConfig -) -> Literal["fetch_patched_state", "extract_insertion_memories"]: - configurable = utils.ensure_configurable(config["configurable"]) - update_mode = configurable["schemas"][state["function_name"]]["update_mode"] - match update_mode: - case "patch": - return "fetch_patched_state" - case "insert": - return "extract_insertion_memories" - case _: - raise ValueError(f"Unknown update mode: {update_mode}") - - -extraction_builder = StateGraph(schemas.SingleExtractorState, schemas.GraphConfig) -extraction_builder.add_node(fetch_patched_state) -extraction_builder.add_node(extract_patch_memories) -extraction_builder.add_node(upsert_patched_state) -extraction_builder.add_node(extract_insertion_memories) -extraction_builder.add_node(insert_memories) - -extraction_builder.add_conditional_edges(START, route_inbound) -extraction_builder.add_edge("fetch_patched_state", "extract_patch_memories") - - -def should_commit( - state: schemas.SingleExtractorState, config: RunnableConfig -) -> Literal["insert_memories", "upsert_patched_state", "__end__"]: +) -> Literal["insert_memories", "__end__"]: """Whether there are things extracted to commit to the DB.""" - if not state["responses"]: - return END - configurable = utils.ensure_configurable(config["configurable"]) - function = configurable["schemas"][state["function_name"]] - commit_node_name = ( - "insert_memories" - if function["update_mode"] == "insert" - else "upsert_patched_state" - ) - return commit_node_name + return "insert_memories" if state["responses"] else END -extraction_builder.add_conditional_edges("extract_patch_memories", should_commit) -extraction_builder.add_conditional_edges("extract_insertion_memories", should_commit) -extraction_graph = extraction_builder.compile() +semantic_builder.add_conditional_edges("extract_semantic_memories", should_insert) +semantic_graph = semantic_builder.compile() # This graph is public facing. It receives conversations and distibutes them to the @@ -229,7 +228,8 @@ async def schedule(state: schemas.State, config: RunnableConfig) -> dict: # Create the graph + all nodes builder = StateGraph(schemas.State, schemas.GraphConfig) builder.add_node(schedule) -builder.add_node("extract", extraction_graph) +builder.add_node("handle_patch_memory", patch_graph) +builder.add_node("handle_semantic_memory", semantic_graph) # Add edges builder.add_edge(START, "schedule") @@ -241,9 +241,18 @@ def scatter_schemas(state: schemas.State, config: RunnableConfig) -> list[Send]: These will be executed in parallel. """ configuration = utils.ensure_configurable(config["configurable"]) - return [ - Send("extract", {**state, "function_name": k}) for k in configuration["schemas"] - ] + sends = [] + for k, v in configuration["schemas"].items(): + update_mode = v["update_mode"] + match update_mode: + case "patch": + target = "handle_patch_memory" + case "insert": + target = "handle_semantic_memory" + case _: + raise ValueError(f"Unknown update mode: {update_mode}") + sends.append(Send(target, {**state, "function_name": k})) + return sends builder.add_conditional_edges("schedule", scatter_schemas) diff --git a/poetry.lock b/poetry.lock index b4411e3..ee4d2d5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -507,13 +507,13 @@ files = [ [[package]] name = "fsspec" -version = "2024.6.0" +version = "2024.6.1" description = "File-system specification" optional = false python-versions = ">=3.8" files = [ - {file = "fsspec-2024.6.0-py3-none-any.whl", hash = "sha256:58d7122eb8a1a46f7f13453187bfea4972d66bf01618d37366521b1998034cee"}, - {file = "fsspec-2024.6.0.tar.gz", hash = "sha256:f579960a56e6d8038a9efc8f9c77279ec12e6299aa86b0769a7e9c46b94527c2"}, + {file = "fsspec-2024.6.1-py3-none-any.whl", hash = "sha256:3cb443f8bcd2efb31295a5b9fdb02aee81d8452c80d28f97a6d0959e6cee101e"}, + {file = "fsspec-2024.6.1.tar.gz", hash = "sha256:fad7d7e209dd4c1208e3bbfda706620e0da5142bebbd9c384afb95b07e798e49"}, ] [package.extras] @@ -862,13 +862,13 @@ tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<9.0.0" [[package]] name = "langchain-anthropic" -version = "0.1.16" +version = "0.1.17" description = "An integration package connecting AnthropicMessages and LangChain" optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langchain_anthropic-0.1.16-py3-none-any.whl", hash = "sha256:0d3f66b7ffb2d4ef739ef87c4b096f0dbbf0ce12f988205d1fcaa1da2c9d09fa"}, - {file = "langchain_anthropic-0.1.16.tar.gz", hash = "sha256:28187dfb19389772e0abba98eeb0210ed3d99ab0a73f2e6707aa46c9bbbbe407"}, + {file = "langchain_anthropic-0.1.17-py3-none-any.whl", hash = "sha256:e139602594deebd5deb1531434823889390a58c3b99df911f9ced3ec7cc0746e"}, + {file = "langchain_anthropic-0.1.17.tar.gz", hash = "sha256:bd8b16d07c6b78228eaf76ca6aae28e29a186e0048aab79effd35ccf819d1851"}, ] [package.dependencies] @@ -900,36 +900,36 @@ tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<9.0.0" [[package]] name = "langchain-fireworks" -version = "0.1.3" +version = "0.1.4" description = "An integration package connecting Fireworks and LangChain" optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langchain_fireworks-0.1.3-py3-none-any.whl", hash = "sha256:44e83491a3a9d0a9bd6b761e72b5553d85321cc878e36b68d99d432fba39db33"}, - {file = "langchain_fireworks-0.1.3.tar.gz", hash = "sha256:667daea544512a0c2598d8f0f77b4176476bdc51bdf03e91f865922d51d39a85"}, + {file = "langchain_fireworks-0.1.4-py3-none-any.whl", hash = "sha256:aab2edddea490a5931c556bfc0f2b8a07a18f57a25bf8c77b961b5e37001e75a"}, + {file = "langchain_fireworks-0.1.4.tar.gz", hash = "sha256:550bec787b2660014d421e16ce6c4f51851067db14eb299f5757231517eb9d77"}, ] [package.dependencies] aiohttp = ">=3.9.1,<4.0.0" fireworks-ai = ">=0.13.0" -langchain-core = ">=0.1.52,<0.3" +langchain-core = ">=0.2.2,<0.3" openai = ">=1.10.0,<2.0.0" requests = ">=2,<3" [[package]] name = "langchain-openai" -version = "0.1.10" +version = "0.1.13" description = "An integration package connecting OpenAI and LangChain" optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langchain_openai-0.1.10-py3-none-any.whl", hash = "sha256:62eb000980eb45e4f16c88acdbaeccf3d59266554b0dd3ce6bebea1bbe8143dd"}, - {file = "langchain_openai-0.1.10.tar.gz", hash = "sha256:30f881f8ccaec28c054759837c41fd2a2264fcc5564728ce12e1715891a9ce3c"}, + {file = "langchain_openai-0.1.13-py3-none-any.whl", hash = "sha256:4344b6c5c67088a28eed80ba763157fdd1d690cee679966a021b42f305dbf7b5"}, + {file = "langchain_openai-0.1.13.tar.gz", hash = "sha256:03318669bcb3238f7d1bb043329f91d150ca09246f1faf569ef299f535405c71"}, ] [package.dependencies] langchain-core = ">=0.2.2,<0.3" -openai = ">=1.26.0,<2.0.0" +openai = ">=1.32.0,<2.0.0" tiktoken = ">=0.7,<1" [[package]] @@ -964,13 +964,13 @@ langchain-core = ">=0.2.10,<0.3.0" [[package]] name = "langgraph" -version = "0.1.1" +version = "0.1.4" description = "Building stateful, multi-actor applications with LLMs" optional = false python-versions = "<4.0,>=3.9.0" files = [ - {file = "langgraph-0.1.1-py3-none-any.whl", hash = "sha256:6d798072625fd23ff155d40ee823b4e5eb00731ad8a64e2551fd6ae0cb53aec4"}, - {file = "langgraph-0.1.1.tar.gz", hash = "sha256:cea9831c334f36bae9caa66637422e37f76f069b47233ceb253be6cfb3ecb35e"}, + {file = "langgraph-0.1.4-py3-none-any.whl", hash = "sha256:7473fd87fc4e9f796a300be810cf6fbdcf91ca5355080c78fc502153d1b2752f"}, + {file = "langgraph-0.1.4.tar.gz", hash = "sha256:cd5c301b70d49fe4de88c4f11c9636070ff943af887615e7662df6678c86eb06"}, ] [package.dependencies] @@ -978,13 +978,13 @@ langchain-core = ">=0.2,<0.3" [[package]] name = "langgraph-cli" -version = "0.1.44" +version = "0.1.47" description = "CLI for interacting with LangGraph API" optional = false python-versions = "<4.0.0,>=3.9.0" files = [ - {file = "langgraph_cli-0.1.44-py3-none-any.whl", hash = "sha256:721968ab9d9d74ba824e93f4bb08de094795888b7ccea3792962f051a9cdb419"}, - {file = "langgraph_cli-0.1.44.tar.gz", hash = "sha256:a31f4a71abd4a3c39f886811c69ac433a20bbd7eeca6bd47d8f8ec49afbfcebb"}, + {file = "langgraph_cli-0.1.47-py3-none-any.whl", hash = "sha256:3767941dbac4128b1d87fa2dc24afde4015904890ab9682e6b956aae51c0baea"}, + {file = "langgraph_cli-0.1.47.tar.gz", hash = "sha256:292eb99334efd53c29f40cef3e0c55d198da6e99ee14d0987210997be6e56d32"}, ] [package.dependencies] @@ -1229,13 +1229,13 @@ files = [ [[package]] name = "openai" -version = "1.35.5" +version = "1.35.7" description = "The official Python library for the openai API" optional = false python-versions = ">=3.7.1" files = [ - {file = "openai-1.35.5-py3-none-any.whl", hash = "sha256:28d92503c6e4b6a32a89277b36693023ef41f60922a4b5c8c621e8c5697ae3a6"}, - {file = "openai-1.35.5.tar.gz", hash = "sha256:67ef289ae22d350cbf9381d83ae82c4e3596d71b7ad1cc886143554ee12fe0c9"}, + {file = "openai-1.35.7-py3-none-any.whl", hash = "sha256:3d1e0b0aac9b0db69a972d36dc7efa7563f8e8d65550b27a48f2a0c2ec207e80"}, + {file = "openai-1.35.7.tar.gz", hash = "sha256:009bfa1504c9c7ef64d87be55936d142325656bbc6d98c68b669d6472e4beb09"}, ] [package.dependencies] diff --git a/tests/evals/test_memories.py b/tests/evals/test_memories.py index dd7d50b..27cca64 100644 --- a/tests/evals/test_memories.py +++ b/tests/evals/test_memories.py @@ -1,9 +1,10 @@ +import json from typing import List from unittest.mock import MagicMock, patch import pytest from langchain_core.pydantic_v1 import BaseModel, Field -from langsmith import expect, test +from langsmith import expect, get_current_run_tree, test from memory_service._constants import PATCH_PATH from memory_service._schemas import GraphConfig, MemoryConfig @@ -31,7 +32,7 @@ def core_memory_func() -> MemoryConfig: } -@test(output_keys="num_mems_expected") +@test(output_keys=["num_mems_expected"]) @pytest.mark.parametrize( "messages, existing, num_mems_expected", [ @@ -88,9 +89,12 @@ async def test_patch_memory( index.upsert.assert_called_once() # Get named call args vectors = index.upsert.call_args.kwargs["vectors"] + rt = get_current_run_tree() + rt.outputs = {"upserted": [v["metadata"]["content"] for v in vectors]} assert len(vectors) == 1 # Check if the memory was added - memories = vectors[0]["metadata"]["content"]["memories"] + mem = vectors[0]["metadata"]["content"] + memories = json.loads(mem)["memories"] expect(len(memories)).to_equal(num_mems_expected) @@ -119,7 +123,7 @@ def memorable_event_func() -> MemoryConfig: } -@test(output_keys="num_events_expected") +@test(output_keys=["num_events_expected"]) @pytest.mark.parametrize( "messages, num_events_expected", [