Update to use langmem

This commit is contained in:
William Fu-Hinthorn
2025-05-19 15:24:41 -07:00
parent 62714b625b
commit 5af31c3efc
10 changed files with 151 additions and 300 deletions
+31 -31
View File
@@ -29,7 +29,7 @@ Setup instruction auto-generated by `langgraph template lock`. DO NOT EDIT MANUA
The defaults values for `model` are shown below:
```yaml
model: anthropic/claude-3-5-sonnet-20240620
model: anthropic:claude-3-5-sonnet-20240620
```
Follow the instructions below to get set up, or pick one of the additional options.
@@ -399,7 +399,7 @@ Since you've made a newly named memory schema, the memory service will save it w
You can modify schemas with an insertion update_mode in the same way as schemas with a patch update_mode. Define the structure, name it descriptively, set "update_mode" to "insert", and include a concise description. Parameters should have appropriate data types and descriptions. Consider adding constraints for data quality.
3. Select a different model: We default to anthropic/claude-3-5-sonnet-20240620. You can select a compatible chat model using provider/model-name via configuration. Example: openai/gpt-4.
3. Select a different model: We default to anthropic:claude-3-5-sonnet-20240620. You can select a compatible chat model using provider/model-name via configuration. Example: openai:gpt-4.
4. Customize the prompts: We provide default prompts in the graph definition. You can easily update these via configuration.
We'd also encourage you to extend this template by adding additional memory types! "Patch" and "insert" are incredibly powerful already, but you could also extend the logic to add more reflection over related memories to build stronger associations between the saved content. Make the code your own!
@@ -417,119 +417,119 @@ Configuration auto-generated by `langgraph template lock`. DO NOT EDIT MANUALLY.
"properties": {
"model": {
"type": "string",
"default": "anthropic/claude-3-5-sonnet-20240620",
"default": "anthropic:claude-3-5-sonnet-20240620",
"description": "The name of the language model to use for the agent. Should be in the form: provider/model-name.",
"environment": [
{
"value": "anthropic/claude-1.2",
"value": "anthropic:claude-1.2",
"variables": "ANTHROPIC_API_KEY"
},
{
"value": "anthropic/claude-2.0",
"value": "anthropic:claude-2.0",
"variables": "ANTHROPIC_API_KEY"
},
{
"value": "anthropic/claude-2.1",
"value": "anthropic:claude-2.1",
"variables": "ANTHROPIC_API_KEY"
},
{
"value": "anthropic/claude-3-5-sonnet-20240620",
"value": "anthropic:claude-3-5-sonnet-20240620",
"variables": "ANTHROPIC_API_KEY"
},
{
"value": "anthropic/claude-3-haiku-20240307",
"value": "anthropic:claude-3-haiku-20240307",
"variables": "ANTHROPIC_API_KEY"
},
{
"value": "anthropic/claude-3-opus-20240229",
"value": "anthropic:claude-3-opus-20240229",
"variables": "ANTHROPIC_API_KEY"
},
{
"value": "anthropic/claude-3-sonnet-20240229",
"value": "anthropic:claude-3-sonnet-20240229",
"variables": "ANTHROPIC_API_KEY"
},
{
"value": "anthropic/claude-instant-1.2",
"value": "anthropic:claude-instant-1.2",
"variables": "ANTHROPIC_API_KEY"
},
{
"value": "openai/gpt-3.5-turbo",
"value": "openai:gpt-3.5-turbo",
"variables": "OPENAI_API_KEY"
},
{
"value": "openai/gpt-3.5-turbo-0125",
"value": "openai:gpt-3.5-turbo-0125",
"variables": "OPENAI_API_KEY"
},
{
"value": "openai/gpt-3.5-turbo-0301",
"value": "openai:gpt-3.5-turbo-0301",
"variables": "OPENAI_API_KEY"
},
{
"value": "openai/gpt-3.5-turbo-0613",
"value": "openai:gpt-3.5-turbo-0613",
"variables": "OPENAI_API_KEY"
},
{
"value": "openai/gpt-3.5-turbo-1106",
"value": "openai:gpt-3.5-turbo-1106",
"variables": "OPENAI_API_KEY"
},
{
"value": "openai/gpt-3.5-turbo-16k",
"value": "openai:gpt-3.5-turbo-16k",
"variables": "OPENAI_API_KEY"
},
{
"value": "openai/gpt-3.5-turbo-16k-0613",
"value": "openai:gpt-3.5-turbo-16k-0613",
"variables": "OPENAI_API_KEY"
},
{
"value": "openai/gpt-4",
"value": "openai:gpt-4",
"variables": "OPENAI_API_KEY"
},
{
"value": "openai/gpt-4-0125-preview",
"value": "openai:gpt-4-0125-preview",
"variables": "OPENAI_API_KEY"
},
{
"value": "openai/gpt-4-0314",
"value": "openai:gpt-4-0314",
"variables": "OPENAI_API_KEY"
},
{
"value": "openai/gpt-4-0613",
"value": "openai:gpt-4-0613",
"variables": "OPENAI_API_KEY"
},
{
"value": "openai/gpt-4-1106-preview",
"value": "openai:gpt-4-1106-preview",
"variables": "OPENAI_API_KEY"
},
{
"value": "openai/gpt-4-32k",
"value": "openai:gpt-4-32k",
"variables": "OPENAI_API_KEY"
},
{
"value": "openai/gpt-4-32k-0314",
"value": "openai:gpt-4-32k-0314",
"variables": "OPENAI_API_KEY"
},
{
"value": "openai/gpt-4-32k-0613",
"value": "openai:gpt-4-32k-0613",
"variables": "OPENAI_API_KEY"
},
{
"value": "openai/gpt-4-turbo",
"value": "openai:gpt-4-turbo",
"variables": "OPENAI_API_KEY"
},
{
"value": "openai/gpt-4-turbo-preview",
"value": "openai:gpt-4-turbo-preview",
"variables": "OPENAI_API_KEY"
},
{
"value": "openai/gpt-4-vision-preview",
"value": "openai:gpt-4-vision-preview",
"variables": "OPENAI_API_KEY"
},
{
"value": "openai/gpt-4o",
"value": "openai:gpt-4o",
"variables": "OPENAI_API_KEY"
},
{
"value": "openai/gpt-4o-mini",
"value": "openai:gpt-4o-mini",
"variables": "OPENAI_API_KEY"
}
]
+8 -12
View File
@@ -9,21 +9,17 @@ readme = "README.md"
license = { text = "MIT" }
requires-python = ">=3.11"
dependencies = [
"langgraph>=0.2.53,<0.3.0",
"langgraph>=0.3",
"langgraph-checkpoint>=2.0.8",
# Optional (for selecting different models)
"langchain-openai>=0.2.1",
"langchain-anthropic>=0.2.1",
"langchain-openai>=0.3",
"langchain-anthropic>=0.3",
"langchain>=0.3.8",
"python-dotenv>=1.0.1",
"langgraph-sdk>=0.1.40",
"trustcall>=0.0.21",
"langgraph-api",
"langmem>=0.0.25",
]
[project.optional-dependencies]
dev = ["mypy>=1.11.1", "ruff>=0.6.1", "pytest-asyncio"]
[build-system]
requires = ["setuptools>=73.0.0", "wheel"]
build-backend = "setuptools.build_meta"
@@ -62,11 +58,11 @@ convention = "google"
[tool.mypy]
ignore_errors = true
[tool.uv.sources]
langgraph-api = { path = "../../../langgraph-api/api" }
[dependency-groups]
dev = [
"langgraph-cli[inmem]>=0.1.80",
"langgraph-cli[inmem]>=0.2.10",
"mypy>=1.15.0",
"pytest-asyncio>=0.26.0",
"ruff>=0.11.2",
]
+14 -10
View File
@@ -4,7 +4,7 @@ import os
from dataclasses import dataclass, fields
from typing import Any, Optional
from langchain_core.runnables import RunnableConfig
from langgraph.config import get_config
from chatbot.prompts import SYSTEM_PROMPT
@@ -17,23 +17,27 @@ class ChatConfigurable:
mem_assistant_id: str = (
"memory_graph" # update to the UUID if you configure a custom assistant
)
model: str = "anthropic/claude-3-5-sonnet-20240620"
delay_seconds: int = 10 # For debouncing memory creation
model: str = "anthropic:claude-3-5-sonnet-20240620"
delay_seconds: int = 3 # For debouncing memory creation
system_prompt: str = SYSTEM_PROMPT
memory_types: Optional[list[dict]] = None
"""The memory_types for the memory assistant."""
@classmethod
def from_runnable_config(
cls, config: Optional[RunnableConfig] = None
) -> "ChatConfigurable":
"""Load configuration."""
configurable = (
config["configurable"] if config and "configurable" in config else {}
)
def from_context(cls) -> "ChatConfigurable":
"""Create a ChatConfigurable instance from a RunnableConfig object."""
try:
config = get_config()
configurable = (
config["configurable"] if config and "configurable" in config else {}
)
except RuntimeError:
configurable = {}
values: dict[str, Any] = {
f.name: os.environ.get(f.name.upper(), configurable.get(f.name))
for f in fields(cls)
if f.init
}
_fields = {f.name for f in fields(cls) if f.init}
return cls(**{k: v for k, v in values.items() if v})
+15 -14
View File
@@ -1,17 +1,18 @@
"""Example chatbot that incorporates user memories."""
import datetime
from dataclasses import dataclass
from datetime import datetime, timezone
from langchain.chat_models import init_chat_model
from langchain_core.runnables import RunnableConfig
from langgraph.config import get_store
from langgraph.graph import StateGraph
from langgraph.graph.message import Messages, add_messages
from langgraph.store.base import BaseStore
from langgraph_sdk import get_client
from typing_extensions import Annotated
from chatbot.configuration import ChatConfigurable
from chatbot.utils import format_memories, init_model
from chatbot.utils import format_memories
@dataclass
@@ -21,24 +22,26 @@ class ChatState:
messages: Annotated[list[Messages], add_messages]
async def bot(
state: ChatState, config: RunnableConfig, store: BaseStore
) -> dict[str, list[Messages]]:
llm = init_chat_model()
async def bot(state: ChatState) -> dict[str, list[Messages]]:
"""Prompt the bot to resopnd to the user, incorporating memories (if provided)."""
configurable = ChatConfigurable.from_runnable_config(config)
configurable = ChatConfigurable.from_context()
namespace = (configurable.user_id,)
store = get_store()
# This lists ALL user memories in the provided namespace (up to the `limit`)
# you can also filter by content.
query = "\n".join(str(message.content) for message in state.messages)
items = await store.asearch(namespace, query=query, limit=10)
model = init_model(configurable.model)
prompt = configurable.system_prompt.format(
user_info=format_memories(items),
time=datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
time=datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
)
m = await model.ainvoke(
m = await llm.ainvoke(
[{"role": "system", "content": prompt}, *state.messages],
config={"configurable": {"model": configurable.model}},
)
return {"messages": [m]}
@@ -46,7 +49,7 @@ async def bot(
async def schedule_memories(state: ChatState, config: RunnableConfig) -> None:
"""Prompt the bot to respond to the user, incorporating memories (if provided)."""
configurable = ChatConfigurable.from_runnable_config(config)
configurable = ChatConfigurable.from_context()
memory_client = get_client()
await memory_client.runs.create(
# We enqueue the memory formation process on the same thread.
@@ -66,9 +69,7 @@ async def schedule_memories(state: ChatState, config: RunnableConfig) -> None:
after_seconds=configurable.delay_seconds,
# Specify the graph and/or graph configuration to handle the memory processing
assistant_id=configurable.mem_assistant_id,
# the memory service is running in the same deployment & thread, meaning
# it shares state with this chat bot. No content needs to be sent
input={"messages": []},
input={"messages": state.messages},
config={
"configurable": {
# Ensure the memory service knows where to save the extracted memories
-12
View File
@@ -2,8 +2,6 @@
from typing import Optional
from langchain.chat_models import init_chat_model
from langchain_core.language_models import BaseChatModel
from langgraph.store.base import Item
@@ -24,13 +22,3 @@ You have noted the following memorable events from previous interactions with th
{formatted_memories}
</memories>
"""
def init_model(fully_specified_name: str) -> BaseChatModel:
"""Initialize the configured chat model."""
if "/" in fully_specified_name:
provider, model = fully_specified_name.split("/", maxsplit=1)
else:
provider = None
model = fully_specified_name
return init_chat_model(model, model_provider=provider)
+15 -9
View File
@@ -2,9 +2,9 @@
import os
from dataclasses import dataclass, field, fields
from typing import Any, Literal, Optional
from typing import Any, Literal
from langchain_core.runnables import RunnableConfig
from langgraph.config import get_config
from typing_extensions import Annotated
@@ -44,7 +44,7 @@ class Configuration:
user_id: str = "default"
"""The ID of the user to remember in the conversation."""
model: Annotated[str, {"__template_metadata__": {"kind": "llm"}}] = field(
default="anthropic/claude-3-5-sonnet-20240620",
default="anthropic:claude-3-5-sonnet-latest",
metadata={
"description": "The name of the language model to use for the agent. "
"Should be in the form: provider/model-name."
@@ -55,14 +55,20 @@ class Configuration:
memory_types: list[MemoryConfig] = field(default_factory=list)
"""The memory_types for the memory assistant."""
max_extraction_steps: int = 1
"""The maximum number of steps to take when extracting memories."""
@classmethod
def from_runnable_config(
cls, config: Optional[RunnableConfig] = None
) -> "Configuration":
def from_context(cls) -> "Configuration":
"""Create a Configuration instance from a RunnableConfig."""
configurable = (
config["configurable"] if config and "configurable" in config else {}
)
try:
config = get_config()
configurable = (
config["configurable"] if config and "configurable" in config else {}
)
except RuntimeError:
configurable = {}
values: dict[str, Any] = {
f.name: os.environ.get(f.name.upper(), configurable.get(f.name))
for f in fields(cls)
+64 -166
View File
@@ -3,192 +3,90 @@
from __future__ import annotations
import asyncio
import functools
import logging
import uuid
from dataclasses import asdict
from typing import Any
from langchain_core.runnables import RunnableConfig
from langgraph.graph import StateGraph
from langgraph.store.base import BaseStore
from langgraph.types import Send
from trustcall import create_extractor
from langchain_core.messages import AnyMessage
from langgraph.func import entrypoint, task
from langgraph.graph import add_messages
from langmem import create_memory_store_manager
from typing_extensions import Annotated, TypedDict
from memory_graph import configuration
class State(TypedDict):
"""Main graph state."""
messages: Annotated[list[AnyMessage], add_messages]
"""The messages in the conversation."""
class ProcessorState(State):
"""Extractor state."""
function_name: str
from memory_graph import configuration, utils
from memory_graph.state import ProcessorState, State
logger = logging.getLogger("memory")
async def handle_patch_memory(
state: ProcessorState, config: RunnableConfig, *, store: BaseStore
) -> dict:
@functools.lru_cache(maxsize=100)
def get_store_manager(function_name: str):
configurable = configuration.Configuration.from_context()
memory_config = next(
conf for conf in configurable.memory_types if conf.name == function_name
)
kwargs: dict[str, Any] = {
"enable_inserts": memory_config.update_mode == "insert",
}
if memory_config.system_prompt:
kwargs["instructions"] = memory_config.system_prompt
return create_memory_store_manager(
configurable.model,
namespace=("memories", "{user_id}", function_name),
**kwargs,
)
@task()
async def process_memory_type(state: ProcessorState) -> None:
"""Extract the user's state from the conversation and update the memory."""
# Get the overall configuration
configurable = configuration.Configuration.from_runnable_config(config)
# Namespace for memory events, where function_name is the name of the memory schema
namespace = (configurable.user_id, "user_states")
# Fetch existing memories from the store for this (patch) memory schema
existing_item = await store.aget(namespace, state.function_name)
existing = {state.function_name: existing_item.value} if existing_item else None
# Get the configuration for this memory schema (identified by function_name)
memory_config = next(
conf for conf in configurable.memory_types if conf.name == state.function_name
)
# This is what we use to generate new memories
extractor = create_extractor(
utils.init_model(configurable.model),
# We pass the specified (patch) memory schema as a tool
tools=[
{
# Tool name
"name": memory_config.name,
# Tool description
"description": memory_config.description,
# Schema for patch memory
"parameters": memory_config.parameters,
configurable = configuration.Configuration.from_context()
store_manager = get_store_manager(state["function_name"])
await store_manager.ainvoke(
{"messages": state["messages"], "max_steps": configurable.max_extraction_steps},
config={
"configurable": {
"model": configurable.model,
"user_id": configurable.user_id,
}
],
tool_choice="any",
)
# Prepare the messages
prepared_messages = utils.prepare_messages(
state.messages, memory_config.system_prompt
)
# Pass messages and existing patch to the extractor
inputs = {"messages": prepared_messages, "existing": existing}
# Update the patch memory
result = await extractor.ainvoke(inputs, config)
extracted = result["responses"][0].model_dump(mode="json")
# Save to storage
await store.aput(namespace, state.function_name, extracted)
async def handle_insertion_memory(
state: ProcessorState, config: RunnableConfig, *, store: BaseStore
) -> dict[str, list]:
"""Handle insertion memory events."""
# Get the overall configuration
configurable = configuration.Configuration.from_runnable_config(config)
# Namespace for memory events, where function_name is the name of the memory schema
namespace = (configurable.user_id, "events", state.function_name)
# Fetch existing memories from the store (5 most recent ones) for the this (insert) memory schema
query = "\n".join(str(message.content) for message in state.messages)[-3000:]
existing_items = await store.asearch(namespace, query=query, limit=5)
# Get the configuration for this memory schema (identified by function_name)
memory_config = next(
conf for conf in configurable.memory_types if conf.name == state.function_name
)
# This is what we use to generate new memories
extractor = create_extractor(
utils.init_model(configurable.model),
# We pass the specified (insert) memory schema as a tool
tools=[
{
# Tool name
"name": memory_config.name,
# Tool description
"description": memory_config.description,
# Schema for insert memory
"parameters": memory_config.parameters,
}
],
tool_choice="any",
# This allows the extractor to insert new memories
enable_inserts=True,
)
# Generate new memories or update existing memories
extracted = await extractor.ainvoke(
{
# Prepare the messages
"messages": utils.prepare_messages(
state.messages, memory_config.system_prompt
),
# Prepare the existing memories
"existing": (
[
(existing_item.key, state.function_name, existing_item.value)
for existing_item in existing_items
]
if existing_items
else None
),
},
config,
)
# Add the memories to storage
await asyncio.gather(
*(
store.aput(
namespace,
rmeta.get("json_doc_id", str(uuid.uuid4())),
r.model_dump(mode="json"),
)
for r, rmeta in zip(extracted["responses"], extracted["response_metadata"])
)
)
# Create the graph and all nodes
builder = StateGraph(State, config_schema=configuration.Configuration)
builder.add_node(handle_patch_memory, input=ProcessorState)
builder.add_node(handle_insertion_memory, input=ProcessorState)
def scatter_schemas(state: State, config: RunnableConfig) -> list[Send]:
@entrypoint(config_schema=configuration.Configuration)
async def graph(state: State) -> None:
"""Iterate over all memory types in the configuration.
It will route each memory type from configuration to the corresponding memory update node.
The memory update nodes will be executed in parallel.
"""
# Get the configuration
configurable = configuration.Configuration.from_runnable_config(config)
sends = []
current_state = asdict(state)
# Loop over all memory types specified in the configuration
for v in configurable.memory_types:
update_mode = v.update_mode
# This specifies the type of memory update to perform from the configuration
match update_mode:
case "patch":
# This is the corresponding node in the graph for the patch-based memory update
target = "handle_patch_memory"
case "insert":
# This is the corresponding node in the graph for the insert-based memory update
target = "handle_insertion_memory"
case _:
raise ValueError(f"Unknown update mode: {update_mode}")
# Use Send API to route to the target node and pass the name of the memory schema as function_name
# Send API allows each memory node to be executed in parallel
sends.append(
Send(
target,
ProcessorState(**{**current_state, "function_name": v.name}),
if not state["messages"]:
raise ValueError("No messages provided")
configurable = configuration.Configuration.from_context()
await asyncio.gather(
*[
process_memory_type(
ProcessorState(messages=state["messages"], function_name=v.name),
)
)
return sends
for v in configurable.memory_types
]
)
# Add conditional edges to the graph
builder.add_conditional_edges(
"__start__", scatter_schemas, ["handle_patch_memory", "handle_insertion_memory"]
)
# Compile the graph
graph = builder.compile()
__all__ = ["graph"]
-30
View File
@@ -1,30 +0,0 @@
"""Define the shared values."""
from __future__ import annotations
from dataclasses import dataclass
from langchain_core.messages import AnyMessage
from langgraph.graph import add_messages
from typing_extensions import Annotated
@dataclass(kw_only=True)
class State:
"""Main graph state."""
messages: Annotated[list[AnyMessage], add_messages]
"""The messages in the conversation."""
@dataclass(kw_only=True)
class ProcessorState(State):
"""Extractor state."""
function_name: str
__all__ = [
"State",
"ProcessorState",
]
+3 -15
View File
@@ -2,14 +2,12 @@
from typing import Sequence
from langchain.chat_models import init_chat_model
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import BaseMessage, merge_message_runs
from langchain_core.messages import AnyMessage, merge_message_runs
def prepare_messages(
messages: Sequence[BaseMessage], system_prompt: str
) -> list[BaseMessage]:
messages: Sequence[AnyMessage], system_prompt: str
) -> list[AnyMessage]:
"""Merge message runs and add instructions before and after to stay on task."""
sys = {
"role": "system",
@@ -26,13 +24,3 @@ def prepare_messages(
" What memories ought to be retained or updated?</memory-system>",
}
return list(merge_message_runs(messages=[sys] + list(messages) + [m]))
def init_model(fully_specified_name: str) -> BaseChatModel:
"""Initialize the configured chat model."""
if "/" in fully_specified_name:
provider, model = fully_specified_name.split("/", maxsplit=1)
else:
provider = None
model = fully_specified_name
return init_chat_model(model, model_provider=provider)
+1 -1
View File
@@ -2,4 +2,4 @@ from memory_graph.configuration import Configuration
def test_configuration_from_none() -> None:
Configuration.from_runnable_config()
Configuration.from_context()