Update indexing for evals

This commit is contained in:
William Fu-Hinthorn
2024-06-30 17:28:16 -07:00
parent 74f1edda7a
commit eac65c79a8
8 changed files with 118 additions and 77 deletions
+28 -1
View File
@@ -6,6 +6,24 @@ Inspired by papers like [MemGPT](https://memgpt.ai/) and distilled from our own
extracts memories from chat interactions and persists them to a database. This information can later be read or queried semantically
to provide personalized context when your bot is responding to a particular user.
The memory graph handles thread process deduplication and supports continuous updates to a single "memory schema" as well as "event-based" memories that can be queried semantically.
![Memory Diagram](./img/memory_graph.png)
#### Project Structure
```bash
├── langgraph.json # LangGraph Cloud Configuration
├── memory_service
│   ├── __init__.py
│   └── graph.py # Define the memory service
├── poetry.lock
├── pyproject.toml # Project dependencies
└── tests # Add testing + evaluation logic
└── evals
└── test_memories.py
```
## Quickstart
This quick start will get your memory service deployed on [LangGraph Cloud](https://langchain-ai.github.io/langgraph/cloud/). Once created, you can interact with it from any API.
@@ -52,12 +70,21 @@ Assuming you've followed the steps above, in just a couple of minutes, you shoul
Now let's try it out.
#### How to connect to the memory service
## How to connect to the memory service
Check out the [example notebook](./example.ipynb) to show how to connect your chat bot (in this case a second graph) to your new memory service.
This chat bot reads from the same memory DB as your memory service to easily query from "recall memory".
Connecting to this type of memory service typically follows an interaction pattern similar to the one outlined below:
![Interaction Pattern](./img/memory_interactions.png)
A typical user-facing application you'd build to connect with this service would have 3 or more nodes. The first node queries the DB for useful memories. The second node, which contains the LLM, generates the response. The third node posts the new messages to the service.
The service waits for a pre-determined interval before it considers the thread "complete". If the user queries a second time within that interval, the memory run is [rolled-back](https://langchain-ai.github.io/langgraph/cloud/how-tos/cloud_examples/rollback_concurrent/?h=roll) to avoid duplicate processing of a thread.
## How to evaluate
Memory management can be challenging to get right. To make sure your schemas suit your applications' needs, we recommend starting from an evaluation set,
Binary file not shown.

After

Width:  |  Height:  |  Size: 245 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 757 KiB

+2 -2
View File
@@ -1,5 +1,5 @@
"""Simple example memory extraction service."""
from memory_service.graph import extraction_graph
from memory_service.graph import memgraph
__all__ = ["extraction_graph"]
__all__ = ["memgraph"]
+3 -2
View File
@@ -1,7 +1,8 @@
from __future__ import annotations
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional
from langchain_core.messages import AnyMessage
from langchain_core.pydantic_v1 import BaseModel
from langgraph.graph import add_messages
from typing_extensions import Annotated, Literal, TypedDict
@@ -55,7 +56,7 @@ class GraphConfig(TypedDict, total=False):
class State(TypedDict):
messages: Annotated[list, add_messages]
messages: Annotated[List[AnyMessage], add_messages]
"""The messages in the conversation."""
eager: bool
+54 -45
View File
@@ -20,6 +20,9 @@ from memory_service import _settings as settings
from memory_service import _utils as utils
logger = logging.getLogger("memory")
# Handle patch memory, where we update a single document in the database.
# If the document doesn't exist, the LLM will generate a new one.
# Otherwise, it will generate JSON patches to update the existing document.
async def fetch_patched_state(
@@ -99,7 +102,29 @@ async def upsert_patched_state(
return {"user_state": {}}
async def extract_insertion_memories(
patch_builder = StateGraph(schemas.SingleExtractorState, schemas.GraphConfig)
patch_builder.add_node(fetch_patched_state)
patch_builder.add_node(extract_patch_memories)
patch_builder.add_node(upsert_patched_state)
patch_builder.add_edge(START, "fetch_patched_state")
patch_builder.add_edge("fetch_patched_state", "extract_patch_memories")
def should_commit_patch(
state: schemas.SingleExtractorState, config: RunnableConfig
) -> Literal["upsert_patched_state", "__end__"]:
"""Whether there are things extracted to commit to the DB."""
return "upsert_patched_state" if state["responses"] else END
patch_builder.add_conditional_edges("extract_patch_memories", should_commit_patch)
patch_graph = patch_builder.compile()
# Handle semantic memory, where we insert each memory event
# as a new document in the database.
async def extract_semantic_memories(
state: schemas.SingleExtractorState, config: RunnableConfig
) -> dict:
"""Extract embeddable "events"."""
@@ -155,50 +180,24 @@ async def insert_memories(
return {"user_state": {}}
def route_inbound(
semantic_builder = StateGraph(schemas.SingleExtractorState, schemas.GraphConfig)
# Lots of quality improvements can be made here, such as:
# - Fetch similar memories and prompt model to combine or extrapolate
# - Adding advanced indexing by the memory schema (like importance, relevance, etc.)
semantic_builder.add_node(extract_semantic_memories)
semantic_builder.add_node(insert_memories)
semantic_builder.add_edge(START, "extract_semantic_memories")
def should_insert(
state: schemas.SingleExtractorState, config: RunnableConfig
) -> Literal["fetch_patched_state", "extract_insertion_memories"]:
configurable = utils.ensure_configurable(config["configurable"])
update_mode = configurable["schemas"][state["function_name"]]["update_mode"]
match update_mode:
case "patch":
return "fetch_patched_state"
case "insert":
return "extract_insertion_memories"
case _:
raise ValueError(f"Unknown update mode: {update_mode}")
extraction_builder = StateGraph(schemas.SingleExtractorState, schemas.GraphConfig)
extraction_builder.add_node(fetch_patched_state)
extraction_builder.add_node(extract_patch_memories)
extraction_builder.add_node(upsert_patched_state)
extraction_builder.add_node(extract_insertion_memories)
extraction_builder.add_node(insert_memories)
extraction_builder.add_conditional_edges(START, route_inbound)
extraction_builder.add_edge("fetch_patched_state", "extract_patch_memories")
def should_commit(
state: schemas.SingleExtractorState, config: RunnableConfig
) -> Literal["insert_memories", "upsert_patched_state", "__end__"]:
) -> Literal["insert_memories", "__end__"]:
"""Whether there are things extracted to commit to the DB."""
if not state["responses"]:
return END
configurable = utils.ensure_configurable(config["configurable"])
function = configurable["schemas"][state["function_name"]]
commit_node_name = (
"insert_memories"
if function["update_mode"] == "insert"
else "upsert_patched_state"
)
return commit_node_name
return "insert_memories" if state["responses"] else END
extraction_builder.add_conditional_edges("extract_patch_memories", should_commit)
extraction_builder.add_conditional_edges("extract_insertion_memories", should_commit)
extraction_graph = extraction_builder.compile()
semantic_builder.add_conditional_edges("extract_semantic_memories", should_insert)
semantic_graph = semantic_builder.compile()
# This graph is public facing. It receives conversations and distibutes them to the
@@ -229,7 +228,8 @@ async def schedule(state: schemas.State, config: RunnableConfig) -> dict:
# Create the graph + all nodes
builder = StateGraph(schemas.State, schemas.GraphConfig)
builder.add_node(schedule)
builder.add_node("extract", extraction_graph)
builder.add_node("handle_patch_memory", patch_graph)
builder.add_node("handle_semantic_memory", semantic_graph)
# Add edges
builder.add_edge(START, "schedule")
@@ -241,9 +241,18 @@ def scatter_schemas(state: schemas.State, config: RunnableConfig) -> list[Send]:
These will be executed in parallel.
"""
configuration = utils.ensure_configurable(config["configurable"])
return [
Send("extract", {**state, "function_name": k}) for k in configuration["schemas"]
]
sends = []
for k, v in configuration["schemas"].items():
update_mode = v["update_mode"]
match update_mode:
case "patch":
target = "handle_patch_memory"
case "insert":
target = "handle_semantic_memory"
case _:
raise ValueError(f"Unknown update mode: {update_mode}")
sends.append(Send(target, {**state, "function_name": k}))
return sends
builder.add_conditional_edges("schedule", scatter_schemas)
Generated
+23 -23
View File
@@ -507,13 +507,13 @@ files = [
[[package]]
name = "fsspec"
version = "2024.6.0"
version = "2024.6.1"
description = "File-system specification"
optional = false
python-versions = ">=3.8"
files = [
{file = "fsspec-2024.6.0-py3-none-any.whl", hash = "sha256:58d7122eb8a1a46f7f13453187bfea4972d66bf01618d37366521b1998034cee"},
{file = "fsspec-2024.6.0.tar.gz", hash = "sha256:f579960a56e6d8038a9efc8f9c77279ec12e6299aa86b0769a7e9c46b94527c2"},
{file = "fsspec-2024.6.1-py3-none-any.whl", hash = "sha256:3cb443f8bcd2efb31295a5b9fdb02aee81d8452c80d28f97a6d0959e6cee101e"},
{file = "fsspec-2024.6.1.tar.gz", hash = "sha256:fad7d7e209dd4c1208e3bbfda706620e0da5142bebbd9c384afb95b07e798e49"},
]
[package.extras]
@@ -862,13 +862,13 @@ tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<9.0.0"
[[package]]
name = "langchain-anthropic"
version = "0.1.16"
version = "0.1.17"
description = "An integration package connecting AnthropicMessages and LangChain"
optional = false
python-versions = "<4.0,>=3.8.1"
files = [
{file = "langchain_anthropic-0.1.16-py3-none-any.whl", hash = "sha256:0d3f66b7ffb2d4ef739ef87c4b096f0dbbf0ce12f988205d1fcaa1da2c9d09fa"},
{file = "langchain_anthropic-0.1.16.tar.gz", hash = "sha256:28187dfb19389772e0abba98eeb0210ed3d99ab0a73f2e6707aa46c9bbbbe407"},
{file = "langchain_anthropic-0.1.17-py3-none-any.whl", hash = "sha256:e139602594deebd5deb1531434823889390a58c3b99df911f9ced3ec7cc0746e"},
{file = "langchain_anthropic-0.1.17.tar.gz", hash = "sha256:bd8b16d07c6b78228eaf76ca6aae28e29a186e0048aab79effd35ccf819d1851"},
]
[package.dependencies]
@@ -900,36 +900,36 @@ tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<9.0.0"
[[package]]
name = "langchain-fireworks"
version = "0.1.3"
version = "0.1.4"
description = "An integration package connecting Fireworks and LangChain"
optional = false
python-versions = "<4.0,>=3.8.1"
files = [
{file = "langchain_fireworks-0.1.3-py3-none-any.whl", hash = "sha256:44e83491a3a9d0a9bd6b761e72b5553d85321cc878e36b68d99d432fba39db33"},
{file = "langchain_fireworks-0.1.3.tar.gz", hash = "sha256:667daea544512a0c2598d8f0f77b4176476bdc51bdf03e91f865922d51d39a85"},
{file = "langchain_fireworks-0.1.4-py3-none-any.whl", hash = "sha256:aab2edddea490a5931c556bfc0f2b8a07a18f57a25bf8c77b961b5e37001e75a"},
{file = "langchain_fireworks-0.1.4.tar.gz", hash = "sha256:550bec787b2660014d421e16ce6c4f51851067db14eb299f5757231517eb9d77"},
]
[package.dependencies]
aiohttp = ">=3.9.1,<4.0.0"
fireworks-ai = ">=0.13.0"
langchain-core = ">=0.1.52,<0.3"
langchain-core = ">=0.2.2,<0.3"
openai = ">=1.10.0,<2.0.0"
requests = ">=2,<3"
[[package]]
name = "langchain-openai"
version = "0.1.10"
version = "0.1.13"
description = "An integration package connecting OpenAI and LangChain"
optional = false
python-versions = "<4.0,>=3.8.1"
files = [
{file = "langchain_openai-0.1.10-py3-none-any.whl", hash = "sha256:62eb000980eb45e4f16c88acdbaeccf3d59266554b0dd3ce6bebea1bbe8143dd"},
{file = "langchain_openai-0.1.10.tar.gz", hash = "sha256:30f881f8ccaec28c054759837c41fd2a2264fcc5564728ce12e1715891a9ce3c"},
{file = "langchain_openai-0.1.13-py3-none-any.whl", hash = "sha256:4344b6c5c67088a28eed80ba763157fdd1d690cee679966a021b42f305dbf7b5"},
{file = "langchain_openai-0.1.13.tar.gz", hash = "sha256:03318669bcb3238f7d1bb043329f91d150ca09246f1faf569ef299f535405c71"},
]
[package.dependencies]
langchain-core = ">=0.2.2,<0.3"
openai = ">=1.26.0,<2.0.0"
openai = ">=1.32.0,<2.0.0"
tiktoken = ">=0.7,<1"
[[package]]
@@ -964,13 +964,13 @@ langchain-core = ">=0.2.10,<0.3.0"
[[package]]
name = "langgraph"
version = "0.1.1"
version = "0.1.4"
description = "Building stateful, multi-actor applications with LLMs"
optional = false
python-versions = "<4.0,>=3.9.0"
files = [
{file = "langgraph-0.1.1-py3-none-any.whl", hash = "sha256:6d798072625fd23ff155d40ee823b4e5eb00731ad8a64e2551fd6ae0cb53aec4"},
{file = "langgraph-0.1.1.tar.gz", hash = "sha256:cea9831c334f36bae9caa66637422e37f76f069b47233ceb253be6cfb3ecb35e"},
{file = "langgraph-0.1.4-py3-none-any.whl", hash = "sha256:7473fd87fc4e9f796a300be810cf6fbdcf91ca5355080c78fc502153d1b2752f"},
{file = "langgraph-0.1.4.tar.gz", hash = "sha256:cd5c301b70d49fe4de88c4f11c9636070ff943af887615e7662df6678c86eb06"},
]
[package.dependencies]
@@ -978,13 +978,13 @@ langchain-core = ">=0.2,<0.3"
[[package]]
name = "langgraph-cli"
version = "0.1.44"
version = "0.1.47"
description = "CLI for interacting with LangGraph API"
optional = false
python-versions = "<4.0.0,>=3.9.0"
files = [
{file = "langgraph_cli-0.1.44-py3-none-any.whl", hash = "sha256:721968ab9d9d74ba824e93f4bb08de094795888b7ccea3792962f051a9cdb419"},
{file = "langgraph_cli-0.1.44.tar.gz", hash = "sha256:a31f4a71abd4a3c39f886811c69ac433a20bbd7eeca6bd47d8f8ec49afbfcebb"},
{file = "langgraph_cli-0.1.47-py3-none-any.whl", hash = "sha256:3767941dbac4128b1d87fa2dc24afde4015904890ab9682e6b956aae51c0baea"},
{file = "langgraph_cli-0.1.47.tar.gz", hash = "sha256:292eb99334efd53c29f40cef3e0c55d198da6e99ee14d0987210997be6e56d32"},
]
[package.dependencies]
@@ -1229,13 +1229,13 @@ files = [
[[package]]
name = "openai"
version = "1.35.5"
version = "1.35.7"
description = "The official Python library for the openai API"
optional = false
python-versions = ">=3.7.1"
files = [
{file = "openai-1.35.5-py3-none-any.whl", hash = "sha256:28d92503c6e4b6a32a89277b36693023ef41f60922a4b5c8c621e8c5697ae3a6"},
{file = "openai-1.35.5.tar.gz", hash = "sha256:67ef289ae22d350cbf9381d83ae82c4e3596d71b7ad1cc886143554ee12fe0c9"},
{file = "openai-1.35.7-py3-none-any.whl", hash = "sha256:3d1e0b0aac9b0db69a972d36dc7efa7563f8e8d65550b27a48f2a0c2ec207e80"},
{file = "openai-1.35.7.tar.gz", hash = "sha256:009bfa1504c9c7ef64d87be55936d142325656bbc6d98c68b669d6472e4beb09"},
]
[package.dependencies]
+8 -4
View File
@@ -1,9 +1,10 @@
import json
from typing import List
from unittest.mock import MagicMock, patch
import pytest
from langchain_core.pydantic_v1 import BaseModel, Field
from langsmith import expect, test
from langsmith import expect, get_current_run_tree, test
from memory_service._constants import PATCH_PATH
from memory_service._schemas import GraphConfig, MemoryConfig
@@ -31,7 +32,7 @@ def core_memory_func() -> MemoryConfig:
}
@test(output_keys="num_mems_expected")
@test(output_keys=["num_mems_expected"])
@pytest.mark.parametrize(
"messages, existing, num_mems_expected",
[
@@ -88,9 +89,12 @@ async def test_patch_memory(
index.upsert.assert_called_once()
# Get named call args
vectors = index.upsert.call_args.kwargs["vectors"]
rt = get_current_run_tree()
rt.outputs = {"upserted": [v["metadata"]["content"] for v in vectors]}
assert len(vectors) == 1
# Check if the memory was added
memories = vectors[0]["metadata"]["content"]["memories"]
mem = vectors[0]["metadata"]["content"]
memories = json.loads(mem)["memories"]
expect(len(memories)).to_equal(num_mems_expected)
@@ -119,7 +123,7 @@ def memorable_event_func() -> MemoryConfig:
}
@test(output_keys="num_events_expected")
@test(output_keys=["num_events_expected"])
@pytest.mark.parametrize(
"messages, num_events_expected",
[