Merge pull request #1 from langchain-ai/wfh/first_commit

Draft MemGPT Service
This commit is contained in:
William FH
2024-06-26 13:57:38 -07:00
committed by William Fu-Hinthorn
14 changed files with 3507 additions and 0 deletions
+7
View File
@@ -0,0 +1,7 @@
PINECONE_API_KEY=...
PINECONE_INDEX_NAME=...
PINECONE_NAMESPACE=...
FIREWORKS_API_KEY=...
# You can add other keys as appropriate, depending on
# the services you are using.
+20
View File
@@ -0,0 +1,20 @@
.PHONY: tests lint format evals
evals:
LANGCHAIN_TEST_CACHE=tests/evals/cassettes poetry run python -m pytest -p no:asyncio --max-asyncio-tasks 4 tests/evals
lint:
poetry run ruff check .
poetry run mypy .
format:
ruff check --select I --fix
poetry run ruff format .
poetry run ruff check . --fix
build:
poetry build
publish:
poetry publish --dry-run
+8
View File
@@ -0,0 +1,8 @@
{
"dependencies": ["."],
"graphs": {
"agent": "./memory_service/graph.py:memgraph",
"chat": "./memory_service/chatbot.py:chat_graph"
},
"env": ".env"
}
+5
View File
@@ -0,0 +1,5 @@
"""Simple example memory extraction service."""
from memory_service.graph import extraction_graph
__all__ = ["extraction_graph"]
+5
View File
@@ -0,0 +1,5 @@
PAYLOAD_KEY = "content"
PATH_KEY = "path"
PATCH_PATH = "user/{user_id}/patches/{function_name}"
INSERT_PATH = "user/{user_id}/inserts/{function_name}/{event_id}"
TIMESTAMP_KEY = "timestamp"
+75
View File
@@ -0,0 +1,75 @@
from __future__ import annotations
from typing import Any, Dict, Optional
from langchain_core.pydantic_v1 import BaseModel
from langgraph.graph import add_messages
from typing_extensions import Annotated, Literal, TypedDict
class FunctionSchema(TypedDict):
name: str
"""Name of the function."""
description: str
"""A description of the function."""
parameters: dict
"""The JSON Schema for the memory."""
class MemoryConfig(TypedDict, total=False):
function: FunctionSchema
"""The function to use for the memory assistant."""
system_prompt: Optional[str]
"""The system prompt to use for the memory assistant."""
update_mode: Literal["patch", "insert"]
"""Whether to continuously patch the memory, or treat each new
generation as a new memory.
Patching is useful for maintaining a structured profile or core list
of memories. Inserting is useful for maintaining all interactions and
not losing any information.
For patched memories, you can GET the current state at any given time.
For inserted memories, you can query the full history of interactions.
"""
class GraphConfig(TypedDict, total=False):
delay: float
"""The delay in seconds to wait before considering a conversation complete.
Default is 60 seconds.
"""
model: str
"""The model to use for generating memories.
Defaults to Fireworks's "accounts/fireworks/models/firefunction-v2"
"""
schemas: dict[str, MemoryConfig]
"""The schemas for the memory assistant."""
thread_id: str
"""The thread ID of the conversation."""
user_id: str
"""The ID of the user to remember in the conversation."""
class State(TypedDict):
messages: Annotated[list, add_messages]
"""The messages in the conversation."""
eager: bool
class SingleExtractorState(State):
function_name: str
responses: list[BaseModel]
user_state: Optional[Dict[str, Any]]
__all__ = [
"State",
"GraphConfig",
"SingleExtractorState",
"FunctionSchema",
"MemoryConfig",
]
+10
View File
@@ -0,0 +1,10 @@
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
pinecone_api_key: str = ""
pinecone_index_name: str = ""
pinecone_namespace: str = "ns1"
SETTINGS = Settings()
+58
View File
@@ -0,0 +1,58 @@
from __future__ import annotations
from typing import Sequence
from langchain_core.messages import (
AnyMessage,
HumanMessage,
SystemMessage,
merge_message_runs,
)
from pinecone import Pinecone
from memory_service import _schemas as schemas
from memory_service import _settings as settings
_DEFAULT_DELAY = 60 # seconds
def get_index():
pc = Pinecone(api_key=settings.SETTINGS.pinecone_api_key)
return pc.Index(settings.SETTINGS.pinecone_index_name)
def ensure_configurable(config: dict) -> schemas.GraphConfig:
"""Merge the user-provided config with default values."""
return {
**config,
**schemas.GraphConfig(
delay=config.get("delay", _DEFAULT_DELAY),
model=config.get("model", "accounts/fireworks/models/firefunction-v2"),
schemas=config.get("schemas", {}),
thread_id=config["thread_id"],
user_id=config["user_id"],
),
}
def prepare_messages(
messages: Sequence[AnyMessage], system_prompt: str
) -> list[AnyMessage]:
"""Merge message runs and add instructions before and after to stay on task."""
sys = SystemMessage(
content=system_prompt
+ """
<memory-system>Reflect on following interaction. Use the provided tools to \
retain any necessary memories about the user.</memory-system>
"""
)
m = HumanMessage(
content="## End of conversation\n\n"
"<memory-system>Reflect on the interaction above."
" What memories ought to be retained or updated?</memory-system>",
)
return merge_message_runs([sys] + list(messages) + [m])
__all__ = ["ensure_configurable", "prepare_messages"]
+205
View File
@@ -0,0 +1,205 @@
"""Example chatbot that incorporates user memories."""
import os
from datetime import datetime, timezone
from typing import List
import langsmith
from langchain.chat_models import init_chat_model
from langchain_core.messages import AnyMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableConfig
from langchain_nomic.embeddings import NomicEmbeddings
from langgraph.graph import START, StateGraph, add_messages
from langgraph_sdk import get_client
from typing_extensions import Annotated, TypedDict
from memory_service import _constants as constants
from memory_service import _settings as settings
from memory_service import _utils as utils
class ChatState(TypedDict):
"""The state of the chatbot."""
messages: Annotated[List[AnyMessage], add_messages]
user_memories: List[dict]
class ChatConfigurable(TypedDict):
"""The configurable fields for the chatbot."""
user_id: str
thread_id: str
memory_service_url: str = ""
def _ensure_configurable(config: RunnableConfig) -> ChatConfigurable:
"""Ensure the configuration is valid."""
return ChatConfigurable(
user_id=config["configurable"]["user_id"],
thread_id=config["configurable"]["thread_id"],
memory_service_url=config["configurable"].get(
"memory_service_url", os.environ.get("MEMORY_SERVICE_URL", "")
),
)
PROMPT = ChatPromptTemplate.from_messages(
[
(
"system",
"You are a helpful and friendly chatbot.{user_info}\n\nSystem Time: {time}",
)
]
).partial(
time=lambda: datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
)
@langsmith.traceable
def format_query(messages: List[AnyMessage]) -> str:
"""Format the query for the user's memories."""
# This is quite naive :)
return " ".join([str(m.content) for m in messages if m.type == "human"][-5:])
async def query_memories(state: ChatState, config: RunnableConfig) -> ChatState:
"""Query the user's memories."""
configurable: ChatConfigurable = config["configurable"]
user_id = configurable["user_id"]
index = utils.get_index()
embeddings = NomicEmbeddings(model="nomic-embed-text-v1.5")
query = format_query(state["messages"])
vec = await embeddings.embed_query(query)
# You can also filter by memory type, etc. here.
response = index.query(
vector=vec,
filter={"user": {"$eq": user_id}},
include_metadata=True,
top_k=10,
namespace=settings.SETTINGS.pinecone_namespace,
)
memories = []
if matches := response.get("matches"):
memories = [m["metadata"]["memory"][constants.PAYLOAD_KEY] for m in matches]
return {
"user_memories": memories,
}
@langsmith.traceable
def format_memories(memories: List[dict]) -> str:
"""Format the user's memories."""
if not memories:
return ""
# Note Bene: You can format better than this....
memories = "\n".join(str(m) for m in memories)
return f"""
## Memories
You have noted the following memorable events from previous interactions with the user.
<memories>
{memories}
</memories>
"""
async def bot(state: ChatState, config: RunnableConfig) -> ChatState:
"""Prompt the bot to resopnd to the user, incorporating memories (if provided)."""
model = init_chat_model("claude-3-5-sonnet-20240620")
chain = PROMPT | model
memories = format_memories(state["user_memories"])
m = await chain.ainvoke(
{
"messages": state["messages"],
"user_info": memories,
},
config,
)
return {
"messages": [m],
}
async def post_messages(state: ChatState, config: RunnableConfig) -> ChatState:
"""Query the user's memories."""
configurable = _ensure_configurable(config)
langgraph_client = get_client(url=configurable["memory_service_url"])
thread_id = config["configurable"]["thread_id"]
try:
thread = await langgraph_client.threads.create(thread_id=thread_id)
except Exception:
thread = await langgraph_client.threads.get(thread_id=thread_id)
await langgraph_client.runs.create(
thread["thread_id"],
input={
"messages": state["messages"], # the service dedupes messages
},
config={
"system_prompt": "You are a helpful and friendly chatbot.",
"configurable": {
"user_id": thread["user_id"],
"thread_id": thread["thread_id"],
"schemas": {
"system_prompt": "Extract any memorable events from the user's"
" messages that you would like to remember.",
"MemorableEvent": {
"function": {
"name": "memorable_event",
"description": "",
"parameters": {
"name": "memorable_event",
"description": "Any event, observation, insight, or "
"other detail that you may want to recall in "
"later interactions with the user.",
"parameters": {
"description": "Any event, observation, insight, or"
" other detail that you may want to recall in"
" later interactions with the user.",
"properties": {
"description": {
"title": "Description",
"type": "string",
},
"participants": {
"description": "Names of participants in"
" the event and their relationship to the "
"user.",
"items": {"type": "string"},
"title": "Participants",
"type": "array",
},
},
"required": ["description", "participants"],
"title": "memorable_event",
"type": "object",
},
},
}
},
},
},
},
multitask_strategy="interrupt",
)
return {
"messages": state["messages"],
"user_memories": [],
}
builder = StateGraph(ChatState, ChatConfigurable)
builder.add_node(query_memories)
builder.add_node(bot)
builder.add_node(post_messages)
builder.add_edge(START, "query_memories")
builder.add_edge("query_memories", "bot")
builder.add_edge("bot", "post_messages")
chat_graph = builder.compile()
+255
View File
@@ -0,0 +1,255 @@
"""Graphs that extract memories on a schedule."""
from __future__ import annotations
import asyncio
import json
import logging
import uuid
from datetime import datetime, timezone
from langchain.chat_models import init_chat_model
from langchain_core.runnables import RunnableConfig
from langchain_nomic.embeddings import NomicEmbeddings
from langgraph.constants import Send
from langgraph.graph import END, START, StateGraph
from trustcall import create_extractor
from typing_extensions import Literal
from memory_service import _constants as constants
from memory_service import _schemas as schemas
from memory_service import _settings as settings
from memory_service import _utils as utils
logger = logging.getLogger("memory")
async def fetch_patched_state(
state: schemas.SingleExtractorState, config: RunnableConfig
) -> dict:
"""Fetch the user's state from the database.
This is a placeholder function. You should replace this with a function
that fetches the user's state from the database.
"""
configurable = utils.ensure_configurable(config["configurable"])
path = constants.PATCH_PATH.format(
user_id=configurable["user_id"], function_name=state["function_name"]
)
# TODO: does pinecone have an async api in their SDK...?
response = utils.get_index().fetch(
ids=[path], namespace=settings.SETTINGS.pinecone_namespace
)
if vectors := response.get("vectors"):
document = vectors[path]
payload = document["metadata"][constants.PAYLOAD_KEY]
return {"user_state": payload}
return {"user_state": None}
async def extract_patch_memories(
state: schemas.SingleExtractorState, config: RunnableConfig
) -> dict:
"""Extract the user's state from the conversation."""
configurable = utils.ensure_configurable(config["configurable"])
schemas = configurable["schemas"]
memory_config = schemas[state["function_name"]]
llm = init_chat_model(model=configurable["model"])
messages = utils.prepare_messages(
state["messages"], memory_config.get("system_prompt") or ""
)
extractor = create_extractor(
llm,
tools=[memory_config["function"]],
tool_choice=memory_config["function"]["name"],
)
existing = state["user_state"]
result = await extractor.ainvoke(
{
"messages": messages,
"existing": {memory_config["function"]["name"]: existing},
}
)
return {"responses": result["responses"]}
async def upsert_patched_state(
state: schemas.SingleExtractorState, config: RunnableConfig
) -> dict:
"""Upsert the user's state to the database."""
configurable = utils.ensure_configurable(config["configurable"])
path = constants.PATCH_PATH.format(
user_id=configurable["user_id"], function_name=state["function_name"]
)
serialized = state["responses"][0].model_dump_json()
embeddings = NomicEmbeddings(model="nomic-embed-text-v1.5")
vector = await embeddings.aembed_query(serialized)
utils.get_index().upsert(
vectors=[
{
"id": path,
"values": vector,
"metadata": {
constants.PAYLOAD_KEY: json.loads(serialized),
constants.PATH_KEY: path,
constants.TIMESTAMP_KEY: datetime.now(tz=timezone.utc),
},
}
],
namespace=settings.SETTINGS.pinecone_namespace,
)
return {"user_state": {}}
async def extract_insertion_memories(
state: schemas.SingleExtractorState, config: RunnableConfig
) -> dict:
"""Extract embeddable "events"."""
configurable = utils.ensure_configurable(config["configurable"])
llm = init_chat_model(model=configurable["model"])
memory_config = configurable["schemas"][state["function_name"]]
messages = utils.prepare_messages(
state["messages"], memory_config.get("system_prompt") or ""
)
extractor = create_extractor(llm, tools=[memory_config["function"]])
# We don't have an "existing" value here since we are continuously inserting
# new memories.
result = await extractor.ainvoke({"messages": messages})
return {"responses": result["responses"]}
async def insert_memories(
state: schemas.SingleExtractorState, config: RunnableConfig
) -> dict:
"""Insert the user's state to the database."""
configurable = utils.ensure_configurable(config["configurable"])
embeddings = NomicEmbeddings(model="nomic-embed-text-v1.5")
serialized = [r.model_dump_json() for r in state["responses"]]
# You could alternatively do multi-vector lookup based on the schema.
vectors = await embeddings.aembed_documents(serialized)
current_time = datetime.now(tz=timezone.utc)
paths = [
constants.INSERT_PATH.format(
user_id=configurable["user_id"],
function_name=state["function_name"],
event_id=str(uuid.uuid4()),
)
for _ in range(len(vectors))
]
documents = [
{
"id": path,
"values": vector,
"metadata": {
constants.PAYLOAD_KEY: json.loads(serialized),
constants.PATH_KEY: path,
constants.TIMESTAMP_KEY: current_time,
},
}
for path, vector, serialized in zip(paths, vectors, serialized)
]
utils.get_index().upsert(
vectors=documents,
namespace=settings.SETTINGS.pinecone_namespace,
)
return {"user_state": {}}
def route_inbound(
state: schemas.SingleExtractorState, config: RunnableConfig
) -> Literal["fetch_patched_state", "extract_insertion_memories"]:
configurable = utils.ensure_configurable(config["configurable"])
update_mode = configurable["schemas"][state["function_name"]]["update_mode"]
match update_mode:
case "patch":
return "fetch_patched_state"
case "insert":
return "extract_insertion_memories"
case _:
raise ValueError(f"Unknown update mode: {update_mode}")
extraction_builder = StateGraph(schemas.SingleExtractorState, schemas.GraphConfig)
extraction_builder.add_node(fetch_patched_state)
extraction_builder.add_node(extract_patch_memories)
extraction_builder.add_node(upsert_patched_state)
extraction_builder.add_node(extract_insertion_memories)
extraction_builder.add_node(insert_memories)
extraction_builder.add_conditional_edges(START, route_inbound)
extraction_builder.add_edge("fetch_patched_state", "extract_patch_memories")
def should_commit(
state: schemas.SingleExtractorState, config: RunnableConfig
) -> Literal["insert_memories", "upsert_patched_state", "__end__"]:
"""Whether there are things extracted to commit to the DB."""
if not state["responses"]:
return END
configurable = utils.ensure_configurable(config["configurable"])
function = configurable["schemas"][state["function_name"]]
commit_node_name = (
"insert_memories"
if function["update_mode"] == "insert"
else "upsert_patched_state"
)
return commit_node_name
extraction_builder.add_conditional_edges("extract_patch_memories", should_commit)
extraction_builder.add_conditional_edges("extract_insertion_memories", should_commit)
extraction_graph = extraction_builder.compile()
# This graph is public facing. It receives conversations and distibutes them to the
# memory types as needed.
async def schedule(state: schemas.State, config: RunnableConfig) -> dict:
"""Delay the start of processing to simulate run scheduling.
We only really need to process a conversation after it is completed.
In general, we don't know when a conversation is completed, so we will
delay the processing of the conversation for a set amount of time.
This is configurable at the assistant and run level, and to bypass this,
you can set `eager` to True in the run inputs.
If a new message comes in before the delay is up, the run can be cancelled
and a new one scheduled.
"""
if state.get("eager", False):
return {"messages": state["messages"]}
configurable = utils.ensure_configurable(config["configurable"])
if configurable["delay"]:
await asyncio.sleep(configurable["delay"])
return {"messages": []}
# Create the graph + all nodes
builder = StateGraph(schemas.State, schemas.GraphConfig)
builder.add_node(schedule)
builder.add_node("extract", extraction_graph)
# Add edges
builder.add_edge(START, "schedule")
def scatter_schemas(state: schemas.State, config: RunnableConfig) -> list[Send]:
"""Route the schemas for the memory assistant.
These will be executed in parallel.
"""
configuration = utils.ensure_configurable(config["configurable"])
return [
Send("extract", {**state, "function_name": k}) for k in configuration["schemas"]
]
builder.add_conditional_edges("schedule", scatter_schemas)
memgraph = builder.compile()
__all__ = ["memgraph"]
Generated
+2617
View File
File diff suppressed because it is too large Load Diff
+56
View File
@@ -0,0 +1,56 @@
[tool.poetry]
name = "memory-service"
version = "0.0.1"
description = "A simple memory service (for agents) on LangGraph cloud."
authors = ["William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com>"]
readme = "README.md"
[tool.poetry.dependencies]
python = ">=3.9.0,<3.13"
langgraph = "^0.1.0"
langchain-fireworks = "^0.1.3"
# Feel free to swap out for postgres or your favorite database.
langchain-pinecone = "^0.1.1"
jsonpatch = "^1.33"
dydantic = "^0.0.6"
pytest-asyncio = "^0.23.7"
trustcall = "^0.0.4"
langchain = "^0.2.6"
langchain-openai = "^0.1.10"
langchain-anthropic = "^0.1.15"
langchain-nomic = "^0.1.2"
pydantic-settings = "^2.3.4"
langgraph-sdk = "^0.1.23"
[tool.poetry.group.dev.dependencies]
ruff = "^0.4.10"
mypy = "^1.10.0"
pytest = "^8.2.2"
langgraph-cli = "^0.1.43"
[tool.ruff]
lint.select = [
"E", # pycodestyle
"F", # pyflakes
"I", # isort
"D", # pydocstyle
"D401", # First line should be in imperative mood
]
[tool.ruff.format]
docstring-code-format = true
docstring-code-line-length = 80
[tool.ruff.lint.per-file-ignores]
"tests/*" = ["D", "E501"]
[tool.ruff.lint.pydocstyle]
convention = "google"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
[tool.pytest.ini_options]
asyncio_mode = "auto"
+10
View File
@@ -0,0 +1,10 @@
import os
import pytest
@pytest.fixture(scope="session", autouse=True)
def set_fake_env_vars():
os.environ["PINECONE_API_KEY"] = "fake_key"
os.environ["PINECONE_INDEX_NAME"] = "fake_index"
yield
+176
View File
@@ -0,0 +1,176 @@
from typing import List
from unittest.mock import MagicMock, patch
import pytest
from langchain_core.pydantic_v1 import BaseModel, Field
from langsmith import expect, test
from memory_service._constants import PATCH_PATH
from memory_service._schemas import GraphConfig, MemoryConfig
from memory_service.graph import memgraph
# To test the patch-based memory
class CoreMemories(BaseModel):
"""Core memories about the user."""
memories: List[str]
@pytest.fixture
def core_memory_func() -> MemoryConfig:
return {
"function": {
"name": "core_memories",
"description": "A list of core memories about the user.",
"parameters": CoreMemories.schema(),
},
"system_prompt": "You may add or remove memories that are core to the"
" user's identity or that will help you better interact with the user.",
"update_mode": "patch",
}
@test(output_keys="num_mems_expected")
@pytest.mark.parametrize(
"messages, existing, num_mems_expected",
[
([], {}, 0),
([("user", "When I was young, I had a dog named spot")], {}, 1),
(
[("user", "When I was young, I had a dog named spot.")],
{"memories": ["I am afraid of spiders."]},
2,
),
],
)
async def test_patch_memory(
core_memory_func: MemoryConfig,
messages: List[str],
num_mems_expected: int,
existing: dict,
):
# patch memory_service.graph.index with a mock
user_id = "4fddb3ef-fcc9-4ef7-91b6-89e4a3efd112"
thread_id = "e1d0b7f7-0a8b-4c5f-8c4b-8a6c9f6e5c7a"
function_name = "CoreMemories"
with patch("memory_service._utils.get_index") as get_index:
index = MagicMock()
get_index.return_value = index
# No existing memories
if existing:
path = PATCH_PATH.format(
user_id=user_id,
function_name=function_name,
)
index.fetch.return_value = {
"vectors": {path: {"metadata": {"content": existing}}}
}
else:
index.fetch.return_value = {}
# When the memories are patched
await memgraph.ainvoke(
{
"messages": messages,
},
{
"configurable": GraphConfig(
delay=0.1,
user_id=user_id,
thread_id=thread_id,
schemas={function_name: core_memory_func},
),
},
)
if num_mems_expected:
# Check if index.upsert was called
index.upsert.assert_called_once()
# Get named call args
vectors = index.upsert.call_args.kwargs["vectors"]
assert len(vectors) == 1
# Check if the memory was added
memories = vectors[0]["metadata"]["content"]["memories"]
expect(len(memories)).to_equal(num_mems_expected)
# To test the insertion memory
class MemorableEvent(BaseModel):
"""A memorable event."""
description: str
participants: List[str] = Field(
description="Names of participants in the event and their relationship to the user."
)
@pytest.fixture
def memorable_event_func() -> MemoryConfig:
return {
"function": {
"name": "memorable_event",
"description": "Any event, observation, insight, or other detail that you may"
" want to recall in later interactions with the user.",
"parameters": MemorableEvent.schema(),
},
"system_prompt": "Extract all events that are memorable and relevant to the user."
" using parallel tool calling. If nothing of interest occured in the diologue, simply reply 'None'.",
"update_mode": "insert",
}
@test(output_keys="num_events_expected")
@pytest.mark.parametrize(
"messages, num_events_expected",
[
([], 0),
(
[
("user", "I went to the beach with my friends."),
("assistant", "That sounds like a fun day."),
],
1,
),
(
[
("user", "I went to the beach with my friends."),
("assistant", "That sounds like a fun day."),
("user", "I also went to the park with my family - I like the park."),
],
2,
),
],
)
async def test_insert_memory(
memorable_event_func: MemoryConfig,
messages: List[str],
num_events_expected: int,
):
# patch memory_service.graph.index with a mock
user_id = "4fddb3ef-fcc9-4ef7-91b6-89e4a3efd112"
thread_id = "e1d0b7f7-0a8b-4c5f-8c4b-8a6c9f6e5c7a"
function_name = "MemorableEvent"
with patch("memory_service._utils.get_index") as get_index:
index = MagicMock()
get_index.return_value = index
index.fetch.return_value = {}
# When the events are inserted
await memgraph.ainvoke(
{
"messages": messages,
},
{
"configurable": GraphConfig(
delay=0.1,
user_id=user_id,
thread_id=thread_id,
schemas={function_name: memorable_event_func},
),
},
)
if num_events_expected:
# Check if index.upsert was called
index.upsert.assert_called_once()
# Get named call args
vectors = index.upsert.call_args.kwargs["vectors"]
assert len(vectors) == num_events_expected