mirror of
https://github.com/langchain-ai/langgraph-memory.git
synced 2026-07-01 09:25:02 -04:00
Merge pull request #1 from langchain-ai/wfh/first_commit
Draft MemGPT Service
This commit is contained in:
@@ -0,0 +1,7 @@
|
||||
PINECONE_API_KEY=...
|
||||
PINECONE_INDEX_NAME=...
|
||||
PINECONE_NAMESPACE=...
|
||||
FIREWORKS_API_KEY=...
|
||||
|
||||
# You can add other keys as appropriate, depending on
|
||||
# the services you are using.
|
||||
@@ -0,0 +1,20 @@
|
||||
.PHONY: tests lint format evals
|
||||
|
||||
|
||||
evals:
|
||||
LANGCHAIN_TEST_CACHE=tests/evals/cassettes poetry run python -m pytest -p no:asyncio --max-asyncio-tasks 4 tests/evals
|
||||
|
||||
lint:
|
||||
poetry run ruff check .
|
||||
poetry run mypy .
|
||||
|
||||
format:
|
||||
ruff check --select I --fix
|
||||
poetry run ruff format .
|
||||
poetry run ruff check . --fix
|
||||
|
||||
build:
|
||||
poetry build
|
||||
|
||||
publish:
|
||||
poetry publish --dry-run
|
||||
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"dependencies": ["."],
|
||||
"graphs": {
|
||||
"agent": "./memory_service/graph.py:memgraph",
|
||||
"chat": "./memory_service/chatbot.py:chat_graph"
|
||||
},
|
||||
"env": ".env"
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
"""Simple example memory extraction service."""
|
||||
|
||||
from memory_service.graph import extraction_graph
|
||||
|
||||
__all__ = ["extraction_graph"]
|
||||
@@ -0,0 +1,5 @@
|
||||
PAYLOAD_KEY = "content"
|
||||
PATH_KEY = "path"
|
||||
PATCH_PATH = "user/{user_id}/patches/{function_name}"
|
||||
INSERT_PATH = "user/{user_id}/inserts/{function_name}/{event_id}"
|
||||
TIMESTAMP_KEY = "timestamp"
|
||||
@@ -0,0 +1,75 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langgraph.graph import add_messages
|
||||
from typing_extensions import Annotated, Literal, TypedDict
|
||||
|
||||
|
||||
class FunctionSchema(TypedDict):
|
||||
name: str
|
||||
"""Name of the function."""
|
||||
description: str
|
||||
"""A description of the function."""
|
||||
parameters: dict
|
||||
"""The JSON Schema for the memory."""
|
||||
|
||||
|
||||
class MemoryConfig(TypedDict, total=False):
|
||||
function: FunctionSchema
|
||||
"""The function to use for the memory assistant."""
|
||||
system_prompt: Optional[str]
|
||||
"""The system prompt to use for the memory assistant."""
|
||||
update_mode: Literal["patch", "insert"]
|
||||
"""Whether to continuously patch the memory, or treat each new
|
||||
|
||||
generation as a new memory.
|
||||
|
||||
Patching is useful for maintaining a structured profile or core list
|
||||
of memories. Inserting is useful for maintaining all interactions and
|
||||
not losing any information.
|
||||
|
||||
For patched memories, you can GET the current state at any given time.
|
||||
For inserted memories, you can query the full history of interactions.
|
||||
"""
|
||||
|
||||
|
||||
class GraphConfig(TypedDict, total=False):
|
||||
delay: float
|
||||
"""The delay in seconds to wait before considering a conversation complete.
|
||||
|
||||
Default is 60 seconds.
|
||||
"""
|
||||
model: str
|
||||
"""The model to use for generating memories.
|
||||
|
||||
Defaults to Fireworks's "accounts/fireworks/models/firefunction-v2"
|
||||
"""
|
||||
schemas: dict[str, MemoryConfig]
|
||||
"""The schemas for the memory assistant."""
|
||||
thread_id: str
|
||||
"""The thread ID of the conversation."""
|
||||
user_id: str
|
||||
"""The ID of the user to remember in the conversation."""
|
||||
|
||||
|
||||
class State(TypedDict):
|
||||
messages: Annotated[list, add_messages]
|
||||
"""The messages in the conversation."""
|
||||
eager: bool
|
||||
|
||||
|
||||
class SingleExtractorState(State):
|
||||
function_name: str
|
||||
responses: list[BaseModel]
|
||||
user_state: Optional[Dict[str, Any]]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"State",
|
||||
"GraphConfig",
|
||||
"SingleExtractorState",
|
||||
"FunctionSchema",
|
||||
"MemoryConfig",
|
||||
]
|
||||
@@ -0,0 +1,10 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
pinecone_api_key: str = ""
|
||||
pinecone_index_name: str = ""
|
||||
pinecone_namespace: str = "ns1"
|
||||
|
||||
|
||||
SETTINGS = Settings()
|
||||
@@ -0,0 +1,58 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence
|
||||
|
||||
from langchain_core.messages import (
|
||||
AnyMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
merge_message_runs,
|
||||
)
|
||||
from pinecone import Pinecone
|
||||
|
||||
from memory_service import _schemas as schemas
|
||||
from memory_service import _settings as settings
|
||||
|
||||
_DEFAULT_DELAY = 60 # seconds
|
||||
|
||||
|
||||
def get_index():
|
||||
pc = Pinecone(api_key=settings.SETTINGS.pinecone_api_key)
|
||||
return pc.Index(settings.SETTINGS.pinecone_index_name)
|
||||
|
||||
|
||||
def ensure_configurable(config: dict) -> schemas.GraphConfig:
|
||||
"""Merge the user-provided config with default values."""
|
||||
return {
|
||||
**config,
|
||||
**schemas.GraphConfig(
|
||||
delay=config.get("delay", _DEFAULT_DELAY),
|
||||
model=config.get("model", "accounts/fireworks/models/firefunction-v2"),
|
||||
schemas=config.get("schemas", {}),
|
||||
thread_id=config["thread_id"],
|
||||
user_id=config["user_id"],
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def prepare_messages(
|
||||
messages: Sequence[AnyMessage], system_prompt: str
|
||||
) -> list[AnyMessage]:
|
||||
"""Merge message runs and add instructions before and after to stay on task."""
|
||||
sys = SystemMessage(
|
||||
content=system_prompt
|
||||
+ """
|
||||
|
||||
<memory-system>Reflect on following interaction. Use the provided tools to \
|
||||
retain any necessary memories about the user.</memory-system>
|
||||
"""
|
||||
)
|
||||
m = HumanMessage(
|
||||
content="## End of conversation\n\n"
|
||||
"<memory-system>Reflect on the interaction above."
|
||||
" What memories ought to be retained or updated?</memory-system>",
|
||||
)
|
||||
return merge_message_runs([sys] + list(messages) + [m])
|
||||
|
||||
|
||||
__all__ = ["ensure_configurable", "prepare_messages"]
|
||||
@@ -0,0 +1,205 @@
|
||||
"""Example chatbot that incorporates user memories."""
|
||||
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from typing import List
|
||||
|
||||
import langsmith
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langchain_core.messages import AnyMessage
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_nomic.embeddings import NomicEmbeddings
|
||||
from langgraph.graph import START, StateGraph, add_messages
|
||||
from langgraph_sdk import get_client
|
||||
from typing_extensions import Annotated, TypedDict
|
||||
|
||||
from memory_service import _constants as constants
|
||||
from memory_service import _settings as settings
|
||||
from memory_service import _utils as utils
|
||||
|
||||
|
||||
class ChatState(TypedDict):
|
||||
"""The state of the chatbot."""
|
||||
|
||||
messages: Annotated[List[AnyMessage], add_messages]
|
||||
user_memories: List[dict]
|
||||
|
||||
|
||||
class ChatConfigurable(TypedDict):
|
||||
"""The configurable fields for the chatbot."""
|
||||
|
||||
user_id: str
|
||||
thread_id: str
|
||||
memory_service_url: str = ""
|
||||
|
||||
|
||||
def _ensure_configurable(config: RunnableConfig) -> ChatConfigurable:
|
||||
"""Ensure the configuration is valid."""
|
||||
return ChatConfigurable(
|
||||
user_id=config["configurable"]["user_id"],
|
||||
thread_id=config["configurable"]["thread_id"],
|
||||
memory_service_url=config["configurable"].get(
|
||||
"memory_service_url", os.environ.get("MEMORY_SERVICE_URL", "")
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
PROMPT = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"You are a helpful and friendly chatbot.{user_info}\n\nSystem Time: {time}",
|
||||
)
|
||||
]
|
||||
).partial(
|
||||
time=lambda: datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
)
|
||||
|
||||
|
||||
@langsmith.traceable
|
||||
def format_query(messages: List[AnyMessage]) -> str:
|
||||
"""Format the query for the user's memories."""
|
||||
# This is quite naive :)
|
||||
return " ".join([str(m.content) for m in messages if m.type == "human"][-5:])
|
||||
|
||||
|
||||
async def query_memories(state: ChatState, config: RunnableConfig) -> ChatState:
|
||||
"""Query the user's memories."""
|
||||
configurable: ChatConfigurable = config["configurable"]
|
||||
user_id = configurable["user_id"]
|
||||
index = utils.get_index()
|
||||
embeddings = NomicEmbeddings(model="nomic-embed-text-v1.5")
|
||||
|
||||
query = format_query(state["messages"])
|
||||
vec = await embeddings.embed_query(query)
|
||||
# You can also filter by memory type, etc. here.
|
||||
response = index.query(
|
||||
vector=vec,
|
||||
filter={"user": {"$eq": user_id}},
|
||||
include_metadata=True,
|
||||
top_k=10,
|
||||
namespace=settings.SETTINGS.pinecone_namespace,
|
||||
)
|
||||
memories = []
|
||||
if matches := response.get("matches"):
|
||||
memories = [m["metadata"]["memory"][constants.PAYLOAD_KEY] for m in matches]
|
||||
return {
|
||||
"user_memories": memories,
|
||||
}
|
||||
|
||||
|
||||
@langsmith.traceable
|
||||
def format_memories(memories: List[dict]) -> str:
|
||||
"""Format the user's memories."""
|
||||
if not memories:
|
||||
return ""
|
||||
# Note Bene: You can format better than this....
|
||||
memories = "\n".join(str(m) for m in memories)
|
||||
return f"""
|
||||
|
||||
## Memories
|
||||
|
||||
You have noted the following memorable events from previous interactions with the user.
|
||||
<memories>
|
||||
{memories}
|
||||
</memories>
|
||||
"""
|
||||
|
||||
|
||||
async def bot(state: ChatState, config: RunnableConfig) -> ChatState:
|
||||
"""Prompt the bot to resopnd to the user, incorporating memories (if provided)."""
|
||||
model = init_chat_model("claude-3-5-sonnet-20240620")
|
||||
chain = PROMPT | model
|
||||
memories = format_memories(state["user_memories"])
|
||||
m = await chain.ainvoke(
|
||||
{
|
||||
"messages": state["messages"],
|
||||
"user_info": memories,
|
||||
},
|
||||
config,
|
||||
)
|
||||
|
||||
return {
|
||||
"messages": [m],
|
||||
}
|
||||
|
||||
|
||||
async def post_messages(state: ChatState, config: RunnableConfig) -> ChatState:
|
||||
"""Query the user's memories."""
|
||||
configurable = _ensure_configurable(config)
|
||||
langgraph_client = get_client(url=configurable["memory_service_url"])
|
||||
thread_id = config["configurable"]["thread_id"]
|
||||
try:
|
||||
thread = await langgraph_client.threads.create(thread_id=thread_id)
|
||||
except Exception:
|
||||
thread = await langgraph_client.threads.get(thread_id=thread_id)
|
||||
|
||||
await langgraph_client.runs.create(
|
||||
thread["thread_id"],
|
||||
input={
|
||||
"messages": state["messages"], # the service dedupes messages
|
||||
},
|
||||
config={
|
||||
"system_prompt": "You are a helpful and friendly chatbot.",
|
||||
"configurable": {
|
||||
"user_id": thread["user_id"],
|
||||
"thread_id": thread["thread_id"],
|
||||
"schemas": {
|
||||
"system_prompt": "Extract any memorable events from the user's"
|
||||
" messages that you would like to remember.",
|
||||
"MemorableEvent": {
|
||||
"function": {
|
||||
"name": "memorable_event",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"name": "memorable_event",
|
||||
"description": "Any event, observation, insight, or "
|
||||
"other detail that you may want to recall in "
|
||||
"later interactions with the user.",
|
||||
"parameters": {
|
||||
"description": "Any event, observation, insight, or"
|
||||
" other detail that you may want to recall in"
|
||||
" later interactions with the user.",
|
||||
"properties": {
|
||||
"description": {
|
||||
"title": "Description",
|
||||
"type": "string",
|
||||
},
|
||||
"participants": {
|
||||
"description": "Names of participants in"
|
||||
" the event and their relationship to the "
|
||||
"user.",
|
||||
"items": {"type": "string"},
|
||||
"title": "Participants",
|
||||
"type": "array",
|
||||
},
|
||||
},
|
||||
"required": ["description", "participants"],
|
||||
"title": "memorable_event",
|
||||
"type": "object",
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
multitask_strategy="interrupt",
|
||||
)
|
||||
|
||||
return {
|
||||
"messages": state["messages"],
|
||||
"user_memories": [],
|
||||
}
|
||||
|
||||
|
||||
builder = StateGraph(ChatState, ChatConfigurable)
|
||||
builder.add_node(query_memories)
|
||||
builder.add_node(bot)
|
||||
builder.add_node(post_messages)
|
||||
builder.add_edge(START, "query_memories")
|
||||
builder.add_edge("query_memories", "bot")
|
||||
builder.add_edge("bot", "post_messages")
|
||||
|
||||
chat_graph = builder.compile()
|
||||
@@ -0,0 +1,255 @@
|
||||
"""Graphs that extract memories on a schedule."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_nomic.embeddings import NomicEmbeddings
|
||||
from langgraph.constants import Send
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from trustcall import create_extractor
|
||||
from typing_extensions import Literal
|
||||
|
||||
from memory_service import _constants as constants
|
||||
from memory_service import _schemas as schemas
|
||||
from memory_service import _settings as settings
|
||||
from memory_service import _utils as utils
|
||||
|
||||
logger = logging.getLogger("memory")
|
||||
|
||||
|
||||
async def fetch_patched_state(
|
||||
state: schemas.SingleExtractorState, config: RunnableConfig
|
||||
) -> dict:
|
||||
"""Fetch the user's state from the database.
|
||||
|
||||
This is a placeholder function. You should replace this with a function
|
||||
that fetches the user's state from the database.
|
||||
"""
|
||||
configurable = utils.ensure_configurable(config["configurable"])
|
||||
path = constants.PATCH_PATH.format(
|
||||
user_id=configurable["user_id"], function_name=state["function_name"]
|
||||
)
|
||||
# TODO: does pinecone have an async api in their SDK...?
|
||||
response = utils.get_index().fetch(
|
||||
ids=[path], namespace=settings.SETTINGS.pinecone_namespace
|
||||
)
|
||||
if vectors := response.get("vectors"):
|
||||
document = vectors[path]
|
||||
payload = document["metadata"][constants.PAYLOAD_KEY]
|
||||
return {"user_state": payload}
|
||||
return {"user_state": None}
|
||||
|
||||
|
||||
async def extract_patch_memories(
|
||||
state: schemas.SingleExtractorState, config: RunnableConfig
|
||||
) -> dict:
|
||||
"""Extract the user's state from the conversation."""
|
||||
configurable = utils.ensure_configurable(config["configurable"])
|
||||
schemas = configurable["schemas"]
|
||||
memory_config = schemas[state["function_name"]]
|
||||
llm = init_chat_model(model=configurable["model"])
|
||||
messages = utils.prepare_messages(
|
||||
state["messages"], memory_config.get("system_prompt") or ""
|
||||
)
|
||||
extractor = create_extractor(
|
||||
llm,
|
||||
tools=[memory_config["function"]],
|
||||
tool_choice=memory_config["function"]["name"],
|
||||
)
|
||||
existing = state["user_state"]
|
||||
result = await extractor.ainvoke(
|
||||
{
|
||||
"messages": messages,
|
||||
"existing": {memory_config["function"]["name"]: existing},
|
||||
}
|
||||
)
|
||||
return {"responses": result["responses"]}
|
||||
|
||||
|
||||
async def upsert_patched_state(
|
||||
state: schemas.SingleExtractorState, config: RunnableConfig
|
||||
) -> dict:
|
||||
"""Upsert the user's state to the database."""
|
||||
configurable = utils.ensure_configurable(config["configurable"])
|
||||
path = constants.PATCH_PATH.format(
|
||||
user_id=configurable["user_id"], function_name=state["function_name"]
|
||||
)
|
||||
serialized = state["responses"][0].model_dump_json()
|
||||
embeddings = NomicEmbeddings(model="nomic-embed-text-v1.5")
|
||||
vector = await embeddings.aembed_query(serialized)
|
||||
utils.get_index().upsert(
|
||||
vectors=[
|
||||
{
|
||||
"id": path,
|
||||
"values": vector,
|
||||
"metadata": {
|
||||
constants.PAYLOAD_KEY: json.loads(serialized),
|
||||
constants.PATH_KEY: path,
|
||||
constants.TIMESTAMP_KEY: datetime.now(tz=timezone.utc),
|
||||
},
|
||||
}
|
||||
],
|
||||
namespace=settings.SETTINGS.pinecone_namespace,
|
||||
)
|
||||
return {"user_state": {}}
|
||||
|
||||
|
||||
async def extract_insertion_memories(
|
||||
state: schemas.SingleExtractorState, config: RunnableConfig
|
||||
) -> dict:
|
||||
"""Extract embeddable "events"."""
|
||||
configurable = utils.ensure_configurable(config["configurable"])
|
||||
llm = init_chat_model(model=configurable["model"])
|
||||
memory_config = configurable["schemas"][state["function_name"]]
|
||||
messages = utils.prepare_messages(
|
||||
state["messages"], memory_config.get("system_prompt") or ""
|
||||
)
|
||||
|
||||
extractor = create_extractor(llm, tools=[memory_config["function"]])
|
||||
# We don't have an "existing" value here since we are continuously inserting
|
||||
# new memories.
|
||||
result = await extractor.ainvoke({"messages": messages})
|
||||
return {"responses": result["responses"]}
|
||||
|
||||
|
||||
async def insert_memories(
|
||||
state: schemas.SingleExtractorState, config: RunnableConfig
|
||||
) -> dict:
|
||||
"""Insert the user's state to the database."""
|
||||
configurable = utils.ensure_configurable(config["configurable"])
|
||||
embeddings = NomicEmbeddings(model="nomic-embed-text-v1.5")
|
||||
serialized = [r.model_dump_json() for r in state["responses"]]
|
||||
# You could alternatively do multi-vector lookup based on the schema.
|
||||
vectors = await embeddings.aembed_documents(serialized)
|
||||
current_time = datetime.now(tz=timezone.utc)
|
||||
paths = [
|
||||
constants.INSERT_PATH.format(
|
||||
user_id=configurable["user_id"],
|
||||
function_name=state["function_name"],
|
||||
event_id=str(uuid.uuid4()),
|
||||
)
|
||||
for _ in range(len(vectors))
|
||||
]
|
||||
documents = [
|
||||
{
|
||||
"id": path,
|
||||
"values": vector,
|
||||
"metadata": {
|
||||
constants.PAYLOAD_KEY: json.loads(serialized),
|
||||
constants.PATH_KEY: path,
|
||||
constants.TIMESTAMP_KEY: current_time,
|
||||
},
|
||||
}
|
||||
for path, vector, serialized in zip(paths, vectors, serialized)
|
||||
]
|
||||
utils.get_index().upsert(
|
||||
vectors=documents,
|
||||
namespace=settings.SETTINGS.pinecone_namespace,
|
||||
)
|
||||
return {"user_state": {}}
|
||||
|
||||
|
||||
def route_inbound(
|
||||
state: schemas.SingleExtractorState, config: RunnableConfig
|
||||
) -> Literal["fetch_patched_state", "extract_insertion_memories"]:
|
||||
configurable = utils.ensure_configurable(config["configurable"])
|
||||
update_mode = configurable["schemas"][state["function_name"]]["update_mode"]
|
||||
match update_mode:
|
||||
case "patch":
|
||||
return "fetch_patched_state"
|
||||
case "insert":
|
||||
return "extract_insertion_memories"
|
||||
case _:
|
||||
raise ValueError(f"Unknown update mode: {update_mode}")
|
||||
|
||||
|
||||
extraction_builder = StateGraph(schemas.SingleExtractorState, schemas.GraphConfig)
|
||||
extraction_builder.add_node(fetch_patched_state)
|
||||
extraction_builder.add_node(extract_patch_memories)
|
||||
extraction_builder.add_node(upsert_patched_state)
|
||||
extraction_builder.add_node(extract_insertion_memories)
|
||||
extraction_builder.add_node(insert_memories)
|
||||
|
||||
extraction_builder.add_conditional_edges(START, route_inbound)
|
||||
extraction_builder.add_edge("fetch_patched_state", "extract_patch_memories")
|
||||
|
||||
|
||||
def should_commit(
|
||||
state: schemas.SingleExtractorState, config: RunnableConfig
|
||||
) -> Literal["insert_memories", "upsert_patched_state", "__end__"]:
|
||||
"""Whether there are things extracted to commit to the DB."""
|
||||
if not state["responses"]:
|
||||
return END
|
||||
configurable = utils.ensure_configurable(config["configurable"])
|
||||
function = configurable["schemas"][state["function_name"]]
|
||||
commit_node_name = (
|
||||
"insert_memories"
|
||||
if function["update_mode"] == "insert"
|
||||
else "upsert_patched_state"
|
||||
)
|
||||
return commit_node_name
|
||||
|
||||
|
||||
extraction_builder.add_conditional_edges("extract_patch_memories", should_commit)
|
||||
extraction_builder.add_conditional_edges("extract_insertion_memories", should_commit)
|
||||
extraction_graph = extraction_builder.compile()
|
||||
|
||||
|
||||
# This graph is public facing. It receives conversations and distibutes them to the
|
||||
# memory types as needed.
|
||||
|
||||
|
||||
async def schedule(state: schemas.State, config: RunnableConfig) -> dict:
|
||||
"""Delay the start of processing to simulate run scheduling.
|
||||
|
||||
We only really need to process a conversation after it is completed.
|
||||
In general, we don't know when a conversation is completed, so we will
|
||||
delay the processing of the conversation for a set amount of time.
|
||||
|
||||
This is configurable at the assistant and run level, and to bypass this,
|
||||
you can set `eager` to True in the run inputs.
|
||||
|
||||
If a new message comes in before the delay is up, the run can be cancelled
|
||||
and a new one scheduled.
|
||||
"""
|
||||
if state.get("eager", False):
|
||||
return {"messages": state["messages"]}
|
||||
configurable = utils.ensure_configurable(config["configurable"])
|
||||
if configurable["delay"]:
|
||||
await asyncio.sleep(configurable["delay"])
|
||||
return {"messages": []}
|
||||
|
||||
|
||||
# Create the graph + all nodes
|
||||
builder = StateGraph(schemas.State, schemas.GraphConfig)
|
||||
builder.add_node(schedule)
|
||||
builder.add_node("extract", extraction_graph)
|
||||
|
||||
# Add edges
|
||||
builder.add_edge(START, "schedule")
|
||||
|
||||
|
||||
def scatter_schemas(state: schemas.State, config: RunnableConfig) -> list[Send]:
|
||||
"""Route the schemas for the memory assistant.
|
||||
|
||||
These will be executed in parallel.
|
||||
"""
|
||||
configuration = utils.ensure_configurable(config["configurable"])
|
||||
return [
|
||||
Send("extract", {**state, "function_name": k}) for k in configuration["schemas"]
|
||||
]
|
||||
|
||||
|
||||
builder.add_conditional_edges("schedule", scatter_schemas)
|
||||
|
||||
memgraph = builder.compile()
|
||||
|
||||
|
||||
__all__ = ["memgraph"]
|
||||
Generated
+2617
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,56 @@
|
||||
[tool.poetry]
|
||||
name = "memory-service"
|
||||
version = "0.0.1"
|
||||
description = "A simple memory service (for agents) on LangGraph cloud."
|
||||
authors = ["William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com>"]
|
||||
readme = "README.md"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.9.0,<3.13"
|
||||
langgraph = "^0.1.0"
|
||||
langchain-fireworks = "^0.1.3"
|
||||
# Feel free to swap out for postgres or your favorite database.
|
||||
langchain-pinecone = "^0.1.1"
|
||||
jsonpatch = "^1.33"
|
||||
dydantic = "^0.0.6"
|
||||
pytest-asyncio = "^0.23.7"
|
||||
trustcall = "^0.0.4"
|
||||
langchain = "^0.2.6"
|
||||
langchain-openai = "^0.1.10"
|
||||
langchain-anthropic = "^0.1.15"
|
||||
langchain-nomic = "^0.1.2"
|
||||
pydantic-settings = "^2.3.4"
|
||||
langgraph-sdk = "^0.1.23"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
ruff = "^0.4.10"
|
||||
mypy = "^1.10.0"
|
||||
pytest = "^8.2.2"
|
||||
langgraph-cli = "^0.1.43"
|
||||
|
||||
[tool.ruff]
|
||||
lint.select = [
|
||||
"E", # pycodestyle
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
"D", # pydocstyle
|
||||
"D401", # First line should be in imperative mood
|
||||
]
|
||||
|
||||
[tool.ruff.format]
|
||||
docstring-code-format = true
|
||||
docstring-code-line-length = 80
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"tests/*" = ["D", "E501"]
|
||||
|
||||
[tool.ruff.lint.pydocstyle]
|
||||
convention = "google"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def set_fake_env_vars():
|
||||
os.environ["PINECONE_API_KEY"] = "fake_key"
|
||||
os.environ["PINECONE_INDEX_NAME"] = "fake_index"
|
||||
yield
|
||||
@@ -0,0 +1,176 @@
|
||||
from typing import List
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langsmith import expect, test
|
||||
|
||||
from memory_service._constants import PATCH_PATH
|
||||
from memory_service._schemas import GraphConfig, MemoryConfig
|
||||
from memory_service.graph import memgraph
|
||||
|
||||
|
||||
# To test the patch-based memory
|
||||
class CoreMemories(BaseModel):
|
||||
"""Core memories about the user."""
|
||||
|
||||
memories: List[str]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def core_memory_func() -> MemoryConfig:
|
||||
return {
|
||||
"function": {
|
||||
"name": "core_memories",
|
||||
"description": "A list of core memories about the user.",
|
||||
"parameters": CoreMemories.schema(),
|
||||
},
|
||||
"system_prompt": "You may add or remove memories that are core to the"
|
||||
" user's identity or that will help you better interact with the user.",
|
||||
"update_mode": "patch",
|
||||
}
|
||||
|
||||
|
||||
@test(output_keys="num_mems_expected")
|
||||
@pytest.mark.parametrize(
|
||||
"messages, existing, num_mems_expected",
|
||||
[
|
||||
([], {}, 0),
|
||||
([("user", "When I was young, I had a dog named spot")], {}, 1),
|
||||
(
|
||||
[("user", "When I was young, I had a dog named spot.")],
|
||||
{"memories": ["I am afraid of spiders."]},
|
||||
2,
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_patch_memory(
|
||||
core_memory_func: MemoryConfig,
|
||||
messages: List[str],
|
||||
num_mems_expected: int,
|
||||
existing: dict,
|
||||
):
|
||||
# patch memory_service.graph.index with a mock
|
||||
user_id = "4fddb3ef-fcc9-4ef7-91b6-89e4a3efd112"
|
||||
thread_id = "e1d0b7f7-0a8b-4c5f-8c4b-8a6c9f6e5c7a"
|
||||
function_name = "CoreMemories"
|
||||
with patch("memory_service._utils.get_index") as get_index:
|
||||
index = MagicMock()
|
||||
get_index.return_value = index
|
||||
# No existing memories
|
||||
if existing:
|
||||
path = PATCH_PATH.format(
|
||||
user_id=user_id,
|
||||
function_name=function_name,
|
||||
)
|
||||
index.fetch.return_value = {
|
||||
"vectors": {path: {"metadata": {"content": existing}}}
|
||||
}
|
||||
else:
|
||||
index.fetch.return_value = {}
|
||||
|
||||
# When the memories are patched
|
||||
await memgraph.ainvoke(
|
||||
{
|
||||
"messages": messages,
|
||||
},
|
||||
{
|
||||
"configurable": GraphConfig(
|
||||
delay=0.1,
|
||||
user_id=user_id,
|
||||
thread_id=thread_id,
|
||||
schemas={function_name: core_memory_func},
|
||||
),
|
||||
},
|
||||
)
|
||||
if num_mems_expected:
|
||||
# Check if index.upsert was called
|
||||
index.upsert.assert_called_once()
|
||||
# Get named call args
|
||||
vectors = index.upsert.call_args.kwargs["vectors"]
|
||||
assert len(vectors) == 1
|
||||
# Check if the memory was added
|
||||
memories = vectors[0]["metadata"]["content"]["memories"]
|
||||
expect(len(memories)).to_equal(num_mems_expected)
|
||||
|
||||
|
||||
# To test the insertion memory
|
||||
class MemorableEvent(BaseModel):
|
||||
"""A memorable event."""
|
||||
|
||||
description: str
|
||||
participants: List[str] = Field(
|
||||
description="Names of participants in the event and their relationship to the user."
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def memorable_event_func() -> MemoryConfig:
|
||||
return {
|
||||
"function": {
|
||||
"name": "memorable_event",
|
||||
"description": "Any event, observation, insight, or other detail that you may"
|
||||
" want to recall in later interactions with the user.",
|
||||
"parameters": MemorableEvent.schema(),
|
||||
},
|
||||
"system_prompt": "Extract all events that are memorable and relevant to the user."
|
||||
" using parallel tool calling. If nothing of interest occured in the diologue, simply reply 'None'.",
|
||||
"update_mode": "insert",
|
||||
}
|
||||
|
||||
|
||||
@test(output_keys="num_events_expected")
|
||||
@pytest.mark.parametrize(
|
||||
"messages, num_events_expected",
|
||||
[
|
||||
([], 0),
|
||||
(
|
||||
[
|
||||
("user", "I went to the beach with my friends."),
|
||||
("assistant", "That sounds like a fun day."),
|
||||
],
|
||||
1,
|
||||
),
|
||||
(
|
||||
[
|
||||
("user", "I went to the beach with my friends."),
|
||||
("assistant", "That sounds like a fun day."),
|
||||
("user", "I also went to the park with my family - I like the park."),
|
||||
],
|
||||
2,
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_insert_memory(
|
||||
memorable_event_func: MemoryConfig,
|
||||
messages: List[str],
|
||||
num_events_expected: int,
|
||||
):
|
||||
# patch memory_service.graph.index with a mock
|
||||
user_id = "4fddb3ef-fcc9-4ef7-91b6-89e4a3efd112"
|
||||
thread_id = "e1d0b7f7-0a8b-4c5f-8c4b-8a6c9f6e5c7a"
|
||||
function_name = "MemorableEvent"
|
||||
with patch("memory_service._utils.get_index") as get_index:
|
||||
index = MagicMock()
|
||||
get_index.return_value = index
|
||||
index.fetch.return_value = {}
|
||||
# When the events are inserted
|
||||
await memgraph.ainvoke(
|
||||
{
|
||||
"messages": messages,
|
||||
},
|
||||
{
|
||||
"configurable": GraphConfig(
|
||||
delay=0.1,
|
||||
user_id=user_id,
|
||||
thread_id=thread_id,
|
||||
schemas={function_name: memorable_event_func},
|
||||
),
|
||||
},
|
||||
)
|
||||
if num_events_expected:
|
||||
# Check if index.upsert was called
|
||||
index.upsert.assert_called_once()
|
||||
# Get named call args
|
||||
vectors = index.upsert.call_args.kwargs["vectors"]
|
||||
assert len(vectors) == num_events_expected
|
||||
Reference in New Issue
Block a user