diff --git a/README.md b/README.md index 838b97f..93cb008 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ Setup instruction auto-generated by `langgraph template lock`. DO NOT EDIT MANUA The defaults values for `model` are shown below: ```yaml -model: anthropic/claude-3-5-sonnet-20240620 +model: anthropic:claude-3-5-sonnet-20240620 ``` Follow the instructions below to get set up, or pick one of the additional options. @@ -399,7 +399,7 @@ Since you've made a newly named memory schema, the memory service will save it w You can modify schemas with an insertion update_mode in the same way as schemas with a patch update_mode. Define the structure, name it descriptively, set "update_mode" to "insert", and include a concise description. Parameters should have appropriate data types and descriptions. Consider adding constraints for data quality. -3. Select a different model: We default to anthropic/claude-3-5-sonnet-20240620. You can select a compatible chat model using provider/model-name via configuration. Example: openai/gpt-4. +3. Select a different model: We default to anthropic:claude-3-5-sonnet-20240620. You can select a compatible chat model using provider/model-name via configuration. Example: openai:gpt-4. 4. Customize the prompts: We provide default prompts in the graph definition. You can easily update these via configuration. We'd also encourage you to extend this template by adding additional memory types! "Patch" and "insert" are incredibly powerful already, but you could also extend the logic to add more reflection over related memories to build stronger associations between the saved content. Make the code your own! @@ -417,119 +417,119 @@ Configuration auto-generated by `langgraph template lock`. DO NOT EDIT MANUALLY. "properties": { "model": { "type": "string", - "default": "anthropic/claude-3-5-sonnet-20240620", + "default": "anthropic:claude-3-5-sonnet-20240620", "description": "The name of the language model to use for the agent. Should be in the form: provider/model-name.", "environment": [ { - "value": "anthropic/claude-1.2", + "value": "anthropic:claude-1.2", "variables": "ANTHROPIC_API_KEY" }, { - "value": "anthropic/claude-2.0", + "value": "anthropic:claude-2.0", "variables": "ANTHROPIC_API_KEY" }, { - "value": "anthropic/claude-2.1", + "value": "anthropic:claude-2.1", "variables": "ANTHROPIC_API_KEY" }, { - "value": "anthropic/claude-3-5-sonnet-20240620", + "value": "anthropic:claude-3-5-sonnet-20240620", "variables": "ANTHROPIC_API_KEY" }, { - "value": "anthropic/claude-3-haiku-20240307", + "value": "anthropic:claude-3-haiku-20240307", "variables": "ANTHROPIC_API_KEY" }, { - "value": "anthropic/claude-3-opus-20240229", + "value": "anthropic:claude-3-opus-20240229", "variables": "ANTHROPIC_API_KEY" }, { - "value": "anthropic/claude-3-sonnet-20240229", + "value": "anthropic:claude-3-sonnet-20240229", "variables": "ANTHROPIC_API_KEY" }, { - "value": "anthropic/claude-instant-1.2", + "value": "anthropic:claude-instant-1.2", "variables": "ANTHROPIC_API_KEY" }, { - "value": "openai/gpt-3.5-turbo", + "value": "openai:gpt-3.5-turbo", "variables": "OPENAI_API_KEY" }, { - "value": "openai/gpt-3.5-turbo-0125", + "value": "openai:gpt-3.5-turbo-0125", "variables": "OPENAI_API_KEY" }, { - "value": "openai/gpt-3.5-turbo-0301", + "value": "openai:gpt-3.5-turbo-0301", "variables": "OPENAI_API_KEY" }, { - "value": "openai/gpt-3.5-turbo-0613", + "value": "openai:gpt-3.5-turbo-0613", "variables": "OPENAI_API_KEY" }, { - "value": "openai/gpt-3.5-turbo-1106", + "value": "openai:gpt-3.5-turbo-1106", "variables": "OPENAI_API_KEY" }, { - "value": "openai/gpt-3.5-turbo-16k", + "value": "openai:gpt-3.5-turbo-16k", "variables": "OPENAI_API_KEY" }, { - "value": "openai/gpt-3.5-turbo-16k-0613", + "value": "openai:gpt-3.5-turbo-16k-0613", "variables": "OPENAI_API_KEY" }, { - "value": "openai/gpt-4", + "value": "openai:gpt-4", "variables": "OPENAI_API_KEY" }, { - "value": "openai/gpt-4-0125-preview", + "value": "openai:gpt-4-0125-preview", "variables": "OPENAI_API_KEY" }, { - "value": "openai/gpt-4-0314", + "value": "openai:gpt-4-0314", "variables": "OPENAI_API_KEY" }, { - "value": "openai/gpt-4-0613", + "value": "openai:gpt-4-0613", "variables": "OPENAI_API_KEY" }, { - "value": "openai/gpt-4-1106-preview", + "value": "openai:gpt-4-1106-preview", "variables": "OPENAI_API_KEY" }, { - "value": "openai/gpt-4-32k", + "value": "openai:gpt-4-32k", "variables": "OPENAI_API_KEY" }, { - "value": "openai/gpt-4-32k-0314", + "value": "openai:gpt-4-32k-0314", "variables": "OPENAI_API_KEY" }, { - "value": "openai/gpt-4-32k-0613", + "value": "openai:gpt-4-32k-0613", "variables": "OPENAI_API_KEY" }, { - "value": "openai/gpt-4-turbo", + "value": "openai:gpt-4-turbo", "variables": "OPENAI_API_KEY" }, { - "value": "openai/gpt-4-turbo-preview", + "value": "openai:gpt-4-turbo-preview", "variables": "OPENAI_API_KEY" }, { - "value": "openai/gpt-4-vision-preview", + "value": "openai:gpt-4-vision-preview", "variables": "OPENAI_API_KEY" }, { - "value": "openai/gpt-4o", + "value": "openai:gpt-4o", "variables": "OPENAI_API_KEY" }, { - "value": "openai/gpt-4o-mini", + "value": "openai:gpt-4o-mini", "variables": "OPENAI_API_KEY" } ] diff --git a/pyproject.toml b/pyproject.toml index 02c3bab..a1aa641 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,21 +9,17 @@ readme = "README.md" license = { text = "MIT" } requires-python = ">=3.11" dependencies = [ - "langgraph>=0.2.53,<0.3.0", + "langgraph>=0.3", "langgraph-checkpoint>=2.0.8", # Optional (for selecting different models) - "langchain-openai>=0.2.1", - "langchain-anthropic>=0.2.1", + "langchain-openai>=0.3", + "langchain-anthropic>=0.3", "langchain>=0.3.8", "python-dotenv>=1.0.1", "langgraph-sdk>=0.1.40", - "trustcall>=0.0.21", - "langgraph-api", + "langmem>=0.0.25", ] -[project.optional-dependencies] -dev = ["mypy>=1.11.1", "ruff>=0.6.1", "pytest-asyncio"] - [build-system] requires = ["setuptools>=73.0.0", "wheel"] build-backend = "setuptools.build_meta" @@ -62,11 +58,11 @@ convention = "google" [tool.mypy] ignore_errors = true -[tool.uv.sources] -langgraph-api = { path = "../../../langgraph-api/api" } - [dependency-groups] dev = [ - "langgraph-cli[inmem]>=0.1.80", + "langgraph-cli[inmem]>=0.2.10", + "mypy>=1.15.0", + "pytest-asyncio>=0.26.0", + "ruff>=0.11.2", ] diff --git a/src/chatbot/configuration.py b/src/chatbot/configuration.py index a18abdf..a8548b4 100644 --- a/src/chatbot/configuration.py +++ b/src/chatbot/configuration.py @@ -4,7 +4,7 @@ import os from dataclasses import dataclass, fields from typing import Any, Optional -from langchain_core.runnables import RunnableConfig +from langgraph.config import get_config from chatbot.prompts import SYSTEM_PROMPT @@ -17,23 +17,27 @@ class ChatConfigurable: mem_assistant_id: str = ( "memory_graph" # update to the UUID if you configure a custom assistant ) - model: str = "anthropic/claude-3-5-sonnet-20240620" - delay_seconds: int = 10 # For debouncing memory creation + model: str = "anthropic:claude-3-5-sonnet-20240620" + delay_seconds: int = 3 # For debouncing memory creation system_prompt: str = SYSTEM_PROMPT memory_types: Optional[list[dict]] = None """The memory_types for the memory assistant.""" @classmethod - def from_runnable_config( - cls, config: Optional[RunnableConfig] = None - ) -> "ChatConfigurable": - """Load configuration.""" - configurable = ( - config["configurable"] if config and "configurable" in config else {} - ) + def from_context(cls) -> "ChatConfigurable": + """Create a ChatConfigurable instance from a RunnableConfig object.""" + try: + config = get_config() + configurable = ( + config["configurable"] if config and "configurable" in config else {} + ) + except RuntimeError: + configurable = {} + values: dict[str, Any] = { f.name: os.environ.get(f.name.upper(), configurable.get(f.name)) for f in fields(cls) if f.init } + _fields = {f.name for f in fields(cls) if f.init} return cls(**{k: v for k, v in values.items() if v}) diff --git a/src/chatbot/graph.py b/src/chatbot/graph.py index 9612faf..b0bc843 100644 --- a/src/chatbot/graph.py +++ b/src/chatbot/graph.py @@ -1,17 +1,18 @@ """Example chatbot that incorporates user memories.""" +import datetime from dataclasses import dataclass -from datetime import datetime, timezone +from langchain.chat_models import init_chat_model from langchain_core.runnables import RunnableConfig +from langgraph.config import get_store from langgraph.graph import StateGraph from langgraph.graph.message import Messages, add_messages -from langgraph.store.base import BaseStore from langgraph_sdk import get_client from typing_extensions import Annotated from chatbot.configuration import ChatConfigurable -from chatbot.utils import format_memories, init_model +from chatbot.utils import format_memories @dataclass @@ -21,24 +22,26 @@ class ChatState: messages: Annotated[list[Messages], add_messages] -async def bot( - state: ChatState, config: RunnableConfig, store: BaseStore -) -> dict[str, list[Messages]]: +llm = init_chat_model() + + +async def bot(state: ChatState) -> dict[str, list[Messages]]: """Prompt the bot to resopnd to the user, incorporating memories (if provided).""" - configurable = ChatConfigurable.from_runnable_config(config) + configurable = ChatConfigurable.from_context() namespace = (configurable.user_id,) + store = get_store() # This lists ALL user memories in the provided namespace (up to the `limit`) # you can also filter by content. query = "\n".join(str(message.content) for message in state.messages) items = await store.asearch(namespace, query=query, limit=10) - model = init_model(configurable.model) prompt = configurable.system_prompt.format( user_info=format_memories(items), - time=datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"), + time=datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"), ) - m = await model.ainvoke( + m = await llm.ainvoke( [{"role": "system", "content": prompt}, *state.messages], + config={"configurable": {"model": configurable.model}}, ) return {"messages": [m]} @@ -46,7 +49,7 @@ async def bot( async def schedule_memories(state: ChatState, config: RunnableConfig) -> None: """Prompt the bot to respond to the user, incorporating memories (if provided).""" - configurable = ChatConfigurable.from_runnable_config(config) + configurable = ChatConfigurable.from_context() memory_client = get_client() await memory_client.runs.create( # We enqueue the memory formation process on the same thread. @@ -66,9 +69,7 @@ async def schedule_memories(state: ChatState, config: RunnableConfig) -> None: after_seconds=configurable.delay_seconds, # Specify the graph and/or graph configuration to handle the memory processing assistant_id=configurable.mem_assistant_id, - # the memory service is running in the same deployment & thread, meaning - # it shares state with this chat bot. No content needs to be sent - input={"messages": []}, + input={"messages": state.messages}, config={ "configurable": { # Ensure the memory service knows where to save the extracted memories diff --git a/src/chatbot/utils.py b/src/chatbot/utils.py index c54308b..98bb063 100644 --- a/src/chatbot/utils.py +++ b/src/chatbot/utils.py @@ -2,8 +2,6 @@ from typing import Optional -from langchain.chat_models import init_chat_model -from langchain_core.language_models import BaseChatModel from langgraph.store.base import Item @@ -24,13 +22,3 @@ You have noted the following memorable events from previous interactions with th {formatted_memories} """ - - -def init_model(fully_specified_name: str) -> BaseChatModel: - """Initialize the configured chat model.""" - if "/" in fully_specified_name: - provider, model = fully_specified_name.split("/", maxsplit=1) - else: - provider = None - model = fully_specified_name - return init_chat_model(model, model_provider=provider) diff --git a/src/memory_graph/configuration.py b/src/memory_graph/configuration.py index b21c0c3..e40a7c4 100644 --- a/src/memory_graph/configuration.py +++ b/src/memory_graph/configuration.py @@ -2,9 +2,9 @@ import os from dataclasses import dataclass, field, fields -from typing import Any, Literal, Optional +from typing import Any, Literal -from langchain_core.runnables import RunnableConfig +from langgraph.config import get_config from typing_extensions import Annotated @@ -44,7 +44,7 @@ class Configuration: user_id: str = "default" """The ID of the user to remember in the conversation.""" model: Annotated[str, {"__template_metadata__": {"kind": "llm"}}] = field( - default="anthropic/claude-3-5-sonnet-20240620", + default="anthropic:claude-3-5-sonnet-latest", metadata={ "description": "The name of the language model to use for the agent. " "Should be in the form: provider/model-name." @@ -55,14 +55,20 @@ class Configuration: memory_types: list[MemoryConfig] = field(default_factory=list) """The memory_types for the memory assistant.""" + max_extraction_steps: int = 1 + """The maximum number of steps to take when extracting memories.""" + @classmethod - def from_runnable_config( - cls, config: Optional[RunnableConfig] = None - ) -> "Configuration": + def from_context(cls) -> "Configuration": """Create a Configuration instance from a RunnableConfig.""" - configurable = ( - config["configurable"] if config and "configurable" in config else {} - ) + try: + config = get_config() + configurable = ( + config["configurable"] if config and "configurable" in config else {} + ) + except RuntimeError: + configurable = {} + values: dict[str, Any] = { f.name: os.environ.get(f.name.upper(), configurable.get(f.name)) for f in fields(cls) diff --git a/src/memory_graph/graph.py b/src/memory_graph/graph.py index 76b6f43..19dabba 100644 --- a/src/memory_graph/graph.py +++ b/src/memory_graph/graph.py @@ -3,192 +3,90 @@ from __future__ import annotations import asyncio +import functools import logging -import uuid -from dataclasses import asdict +from typing import Any -from langchain_core.runnables import RunnableConfig -from langgraph.graph import StateGraph -from langgraph.store.base import BaseStore -from langgraph.types import Send -from trustcall import create_extractor +from langchain_core.messages import AnyMessage +from langgraph.func import entrypoint, task +from langgraph.graph import add_messages +from langmem import create_memory_store_manager +from typing_extensions import Annotated, TypedDict + +from memory_graph import configuration + + +class State(TypedDict): + """Main graph state.""" + + messages: Annotated[list[AnyMessage], add_messages] + """The messages in the conversation.""" + + +class ProcessorState(State): + """Extractor state.""" + + function_name: str -from memory_graph import configuration, utils -from memory_graph.state import ProcessorState, State logger = logging.getLogger("memory") -async def handle_patch_memory( - state: ProcessorState, config: RunnableConfig, *, store: BaseStore -) -> dict: +@functools.lru_cache(maxsize=100) +def get_store_manager(function_name: str): + configurable = configuration.Configuration.from_context() + memory_config = next( + conf for conf in configurable.memory_types if conf.name == function_name + ) + + kwargs: dict[str, Any] = { + "enable_inserts": memory_config.update_mode == "insert", + } + if memory_config.system_prompt: + kwargs["instructions"] = memory_config.system_prompt + + return create_memory_store_manager( + configurable.model, + namespace=("memories", "{user_id}", function_name), + **kwargs, + ) + + +@task() +async def process_memory_type(state: ProcessorState) -> None: """Extract the user's state from the conversation and update the memory.""" - # Get the overall configuration - configurable = configuration.Configuration.from_runnable_config(config) - - # Namespace for memory events, where function_name is the name of the memory schema - namespace = (configurable.user_id, "user_states") - - # Fetch existing memories from the store for this (patch) memory schema - existing_item = await store.aget(namespace, state.function_name) - existing = {state.function_name: existing_item.value} if existing_item else None - - # Get the configuration for this memory schema (identified by function_name) - memory_config = next( - conf for conf in configurable.memory_types if conf.name == state.function_name - ) - - # This is what we use to generate new memories - extractor = create_extractor( - utils.init_model(configurable.model), - # We pass the specified (patch) memory schema as a tool - tools=[ - { - # Tool name - "name": memory_config.name, - # Tool description - "description": memory_config.description, - # Schema for patch memory - "parameters": memory_config.parameters, + configurable = configuration.Configuration.from_context() + store_manager = get_store_manager(state["function_name"]) + await store_manager.ainvoke( + {"messages": state["messages"], "max_steps": configurable.max_extraction_steps}, + config={ + "configurable": { + "model": configurable.model, + "user_id": configurable.user_id, } - ], - tool_choice="any", - ) - - # Prepare the messages - prepared_messages = utils.prepare_messages( - state.messages, memory_config.system_prompt - ) - - # Pass messages and existing patch to the extractor - inputs = {"messages": prepared_messages, "existing": existing} - # Update the patch memory - result = await extractor.ainvoke(inputs, config) - extracted = result["responses"][0].model_dump(mode="json") - # Save to storage - await store.aput(namespace, state.function_name, extracted) - - -async def handle_insertion_memory( - state: ProcessorState, config: RunnableConfig, *, store: BaseStore -) -> dict[str, list]: - """Handle insertion memory events.""" - # Get the overall configuration - configurable = configuration.Configuration.from_runnable_config(config) - - # Namespace for memory events, where function_name is the name of the memory schema - namespace = (configurable.user_id, "events", state.function_name) - - # Fetch existing memories from the store (5 most recent ones) for the this (insert) memory schema - query = "\n".join(str(message.content) for message in state.messages)[-3000:] - existing_items = await store.asearch(namespace, query=query, limit=5) - - # Get the configuration for this memory schema (identified by function_name) - memory_config = next( - conf for conf in configurable.memory_types if conf.name == state.function_name - ) - - # This is what we use to generate new memories - extractor = create_extractor( - utils.init_model(configurable.model), - # We pass the specified (insert) memory schema as a tool - tools=[ - { - # Tool name - "name": memory_config.name, - # Tool description - "description": memory_config.description, - # Schema for insert memory - "parameters": memory_config.parameters, - } - ], - tool_choice="any", - # This allows the extractor to insert new memories - enable_inserts=True, - ) - - # Generate new memories or update existing memories - extracted = await extractor.ainvoke( - { - # Prepare the messages - "messages": utils.prepare_messages( - state.messages, memory_config.system_prompt - ), - # Prepare the existing memories - "existing": ( - [ - (existing_item.key, state.function_name, existing_item.value) - for existing_item in existing_items - ] - if existing_items - else None - ), }, - config, - ) - - # Add the memories to storage - await asyncio.gather( - *( - store.aput( - namespace, - rmeta.get("json_doc_id", str(uuid.uuid4())), - r.model_dump(mode="json"), - ) - for r, rmeta in zip(extracted["responses"], extracted["response_metadata"]) - ) ) -# Create the graph and all nodes -builder = StateGraph(State, config_schema=configuration.Configuration) -builder.add_node(handle_patch_memory, input=ProcessorState) -builder.add_node(handle_insertion_memory, input=ProcessorState) - - -def scatter_schemas(state: State, config: RunnableConfig) -> list[Send]: +@entrypoint(config_schema=configuration.Configuration) +async def graph(state: State) -> None: """Iterate over all memory types in the configuration. It will route each memory type from configuration to the corresponding memory update node. The memory update nodes will be executed in parallel. """ - # Get the configuration - configurable = configuration.Configuration.from_runnable_config(config) - sends = [] - current_state = asdict(state) - - # Loop over all memory types specified in the configuration - for v in configurable.memory_types: - update_mode = v.update_mode - - # This specifies the type of memory update to perform from the configuration - match update_mode: - case "patch": - # This is the corresponding node in the graph for the patch-based memory update - target = "handle_patch_memory" - case "insert": - # This is the corresponding node in the graph for the insert-based memory update - target = "handle_insertion_memory" - case _: - raise ValueError(f"Unknown update mode: {update_mode}") - - # Use Send API to route to the target node and pass the name of the memory schema as function_name - # Send API allows each memory node to be executed in parallel - sends.append( - Send( - target, - ProcessorState(**{**current_state, "function_name": v.name}), + if not state["messages"]: + raise ValueError("No messages provided") + configurable = configuration.Configuration.from_context() + await asyncio.gather( + *[ + process_memory_type( + ProcessorState(messages=state["messages"], function_name=v.name), ) - ) - return sends + for v in configurable.memory_types + ] + ) -# Add conditional edges to the graph -builder.add_conditional_edges( - "__start__", scatter_schemas, ["handle_patch_memory", "handle_insertion_memory"] -) - -# Compile the graph -graph = builder.compile() __all__ = ["graph"] diff --git a/src/memory_graph/state.py b/src/memory_graph/state.py deleted file mode 100644 index a204c1b..0000000 --- a/src/memory_graph/state.py +++ /dev/null @@ -1,30 +0,0 @@ -"""Define the shared values.""" - -from __future__ import annotations - -from dataclasses import dataclass - -from langchain_core.messages import AnyMessage -from langgraph.graph import add_messages -from typing_extensions import Annotated - - -@dataclass(kw_only=True) -class State: - """Main graph state.""" - - messages: Annotated[list[AnyMessage], add_messages] - """The messages in the conversation.""" - - -@dataclass(kw_only=True) -class ProcessorState(State): - """Extractor state.""" - - function_name: str - - -__all__ = [ - "State", - "ProcessorState", -] diff --git a/src/memory_graph/utils.py b/src/memory_graph/utils.py index 7b547a2..813edf2 100644 --- a/src/memory_graph/utils.py +++ b/src/memory_graph/utils.py @@ -2,14 +2,12 @@ from typing import Sequence -from langchain.chat_models import init_chat_model -from langchain_core.language_models import BaseChatModel -from langchain_core.messages import BaseMessage, merge_message_runs +from langchain_core.messages import AnyMessage, merge_message_runs def prepare_messages( - messages: Sequence[BaseMessage], system_prompt: str -) -> list[BaseMessage]: + messages: Sequence[AnyMessage], system_prompt: str +) -> list[AnyMessage]: """Merge message runs and add instructions before and after to stay on task.""" sys = { "role": "system", @@ -26,13 +24,3 @@ def prepare_messages( " What memories ought to be retained or updated?", } return list(merge_message_runs(messages=[sys] + list(messages) + [m])) - - -def init_model(fully_specified_name: str) -> BaseChatModel: - """Initialize the configured chat model.""" - if "/" in fully_specified_name: - provider, model = fully_specified_name.split("/", maxsplit=1) - else: - provider = None - model = fully_specified_name - return init_chat_model(model, model_provider=provider) diff --git a/tests/unit_tests/test_configuration.py b/tests/unit_tests/test_configuration.py index 038a4be..9fc859c 100644 --- a/tests/unit_tests/test_configuration.py +++ b/tests/unit_tests/test_configuration.py @@ -2,4 +2,4 @@ from memory_graph.configuration import Configuration def test_configuration_from_none() -> None: - Configuration.from_runnable_config() + Configuration.from_context()