mirror of
https://github.com/langchain-ai/langgraph-memory.git
synced 2026-07-01 23:44:01 -04:00
shuffle
This commit is contained in:
@@ -2,16 +2,16 @@
|
||||
|
||||
|
||||
evals:
|
||||
LANGCHAIN_TEST_CACHE=tests/evals/cassettes poetry run python -m pytest -p no:asyncio --max-asyncio-tasks 4 tests/evals
|
||||
LANGCHAIN_TEST_CACHE=tests/evals/cassettes python -m python -m pytest -p no:asyncio --max-asyncio-tasks 4 tests/evals
|
||||
|
||||
lint:
|
||||
poetry run ruff check .
|
||||
poetry run mypy .
|
||||
python -m ruff check .
|
||||
python -m mypy .
|
||||
|
||||
format:
|
||||
ruff check --select I --fix
|
||||
poetry run ruff format .
|
||||
poetry run ruff check . --fix
|
||||
python -m ruff format .
|
||||
python -m ruff check . --fix
|
||||
|
||||
build:
|
||||
poetry build
|
||||
|
||||
+1
-1
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"dependencies": ["."],
|
||||
"graphs": {
|
||||
"memory": "./memory_service/graph.py:memgraph"
|
||||
"memory": "./src/memory_service/graph.py:memgraph"
|
||||
},
|
||||
"env": ".env"
|
||||
}
|
||||
|
||||
@@ -9,13 +9,12 @@ from langchain.chat_models import init_chat_model
|
||||
from langchain_core.messages import AnyMessage
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.graph import START, StateGraph, add_messages
|
||||
from langgraph_sdk import get_client
|
||||
from typing_extensions import Annotated, TypedDict
|
||||
|
||||
from memory_service import _constants as constants
|
||||
from memory_service import _configuration as settings
|
||||
from memory_service import _constants as constants
|
||||
from memory_service import _utils as utils
|
||||
|
||||
|
||||
|
||||
@@ -1,22 +1,74 @@
|
||||
import os
|
||||
from dataclasses import dataclass, fields
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import Literal, Optional
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class FunctionSchema:
|
||||
name: str
|
||||
"""Name of the function."""
|
||||
description: str
|
||||
"""A description of the function."""
|
||||
parameters: dict
|
||||
"""The JSON Schema for the memory."""
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class MemoryConfig:
|
||||
function: FunctionSchema
|
||||
"""The function to use for the memory assistant."""
|
||||
system_prompt: Optional[str] = ""
|
||||
"""The system prompt to use for the memory assistant."""
|
||||
update_mode: Literal["patch", "insert"] = field(default="patch")
|
||||
"""Whether to continuously patch the memory, or treat each new
|
||||
|
||||
generation as a new memory.
|
||||
|
||||
Patching is useful for maintaining a structured profile or core list
|
||||
of memories. Inserting is useful for maintaining all interactions and
|
||||
not losing any information.
|
||||
|
||||
For patched memories, you can GET the current state at any given time.
|
||||
For inserted memories, you can query the full history of interactions.
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.function, dict):
|
||||
self.function = FunctionSchema(**self.function)
|
||||
|
||||
|
||||
@dataclass(kv_only=True)
|
||||
class Settings:
|
||||
class Configuration:
|
||||
pinecone_api_key: str = ""
|
||||
pinecone_index_name: str = ""
|
||||
pinecone_namespace: str = "ns1"
|
||||
model: str = "accounts/fireworks/models/firefunction-v2"
|
||||
delay: float = 60 # seconds
|
||||
"""The delay in seconds to wait before considering a conversation complete.
|
||||
|
||||
Default is 60 seconds.
|
||||
"""
|
||||
model: str
|
||||
"""The model to use for generating memories.
|
||||
|
||||
Defaults to Fireworks's "accounts/fireworks/models/firefunction-v2"
|
||||
"""
|
||||
schemas: dict[str, MemoryConfig] = field(default_factory=dict)
|
||||
"""The schemas for the memory assistant."""
|
||||
thread_id: str
|
||||
"""The thread ID of the conversation."""
|
||||
user_id: str
|
||||
"""The ID of the user to remember in the conversation."""
|
||||
|
||||
@classmethod
|
||||
def from_runnable_config(cls, config: RunnableConfig):
|
||||
configurable = config.get("configurable") or {}
|
||||
_env_env_defaults = {
|
||||
f.name: os.environ.get(f.name.upper(), "") for f in fields(cls) if f.init
|
||||
configurable = config["configurable"]
|
||||
values = {
|
||||
f.name: os.environ.get(f.name.upper(), configurable.get(f.name))
|
||||
for f in fields(cls)
|
||||
if f.init
|
||||
}
|
||||
return cls(
|
||||
**{**_env_env_defaults, **{k: v for k, v in configurable.items() if k in _env_env_defaults}}
|
||||
)
|
||||
values["schemas"] = {k: MemoryConfig(**v) for k, v in values["schemas"].items()}
|
||||
return cls(**values)
|
||||
|
||||
@@ -1,76 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.messages import AnyMessage
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langgraph.graph import add_messages
|
||||
from typing_extensions import Annotated, Literal, TypedDict
|
||||
|
||||
|
||||
class FunctionSchema(TypedDict):
|
||||
name: str
|
||||
"""Name of the function."""
|
||||
description: str
|
||||
"""A description of the function."""
|
||||
parameters: dict
|
||||
"""The JSON Schema for the memory."""
|
||||
|
||||
|
||||
class MemoryConfig(TypedDict, total=False):
|
||||
function: FunctionSchema
|
||||
"""The function to use for the memory assistant."""
|
||||
system_prompt: Optional[str]
|
||||
"""The system prompt to use for the memory assistant."""
|
||||
update_mode: Literal["patch", "insert"]
|
||||
"""Whether to continuously patch the memory, or treat each new
|
||||
|
||||
generation as a new memory.
|
||||
|
||||
Patching is useful for maintaining a structured profile or core list
|
||||
of memories. Inserting is useful for maintaining all interactions and
|
||||
not losing any information.
|
||||
|
||||
For patched memories, you can GET the current state at any given time.
|
||||
For inserted memories, you can query the full history of interactions.
|
||||
"""
|
||||
|
||||
|
||||
class GraphConfig(TypedDict, total=False):
|
||||
delay: float
|
||||
"""The delay in seconds to wait before considering a conversation complete.
|
||||
|
||||
Default is 60 seconds.
|
||||
"""
|
||||
model: str
|
||||
"""The model to use for generating memories.
|
||||
|
||||
Defaults to Fireworks's "accounts/fireworks/models/firefunction-v2"
|
||||
"""
|
||||
schemas: dict[str, MemoryConfig]
|
||||
"""The schemas for the memory assistant."""
|
||||
thread_id: str
|
||||
"""The thread ID of the conversation."""
|
||||
user_id: str
|
||||
"""The ID of the user to remember in the conversation."""
|
||||
|
||||
|
||||
class State(TypedDict):
|
||||
messages: Annotated[List[AnyMessage], add_messages]
|
||||
"""The messages in the conversation."""
|
||||
eager: bool
|
||||
|
||||
|
||||
class SingleExtractorState(State):
|
||||
function_name: str
|
||||
responses: list[BaseModel]
|
||||
user_state: Optional[Dict[str, Any]]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"State",
|
||||
"GraphConfig",
|
||||
"SingleExtractorState",
|
||||
"FunctionSchema",
|
||||
"MemoryConfig",
|
||||
]
|
||||
@@ -3,56 +3,24 @@ from __future__ import annotations
|
||||
from functools import lru_cache
|
||||
from typing import Sequence
|
||||
|
||||
import langsmith
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.messages import (
|
||||
AnyMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
merge_message_runs,
|
||||
)
|
||||
from langchain_fireworks import FireworksEmbeddings
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from pinecone import Pinecone
|
||||
|
||||
from memory_service import _schemas as schemas
|
||||
from memory_service import _configuration as settings
|
||||
|
||||
_DEFAULT_DELAY = 60 # seconds
|
||||
|
||||
|
||||
def get_index():
|
||||
pc = Pinecone(api_key=settings.SETTINGS.pinecone_api_key)
|
||||
return pc.Index(settings.SETTINGS.pinecone_index_name)
|
||||
|
||||
|
||||
@langsmith.traceable
|
||||
def ensure_memory_config(config: dict) -> schemas.MemoryConfig:
|
||||
"""Merge the user-provided config with default values."""
|
||||
return {
|
||||
**config,
|
||||
**schemas.MemoryConfig(
|
||||
function=config.get("function", {}),
|
||||
system_prompt=config.get("system_prompt"),
|
||||
update_mode=config.get("update_mode", "patch"),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@langsmith.traceable
|
||||
def ensure_configurable(config: dict) -> schemas.GraphConfig:
|
||||
"""Merge the user-provided config with default values."""
|
||||
function_schemas = config.get("schemas") or {}
|
||||
return {
|
||||
**config,
|
||||
**schemas.GraphConfig(
|
||||
delay=config.get("delay", _DEFAULT_DELAY),
|
||||
model=config.get("model", settings.SETTINGS.model),
|
||||
schemas={k: ensure_memory_config(v) for k, v in function_schemas.items()},
|
||||
thread_id=config["thread_id"],
|
||||
user_id=config["user_id"],
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def prepare_messages(
|
||||
messages: Sequence[AnyMessage], system_prompt: str
|
||||
) -> list[AnyMessage]:
|
||||
@@ -74,8 +42,8 @@ def prepare_messages(
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_embeddings():
|
||||
return FireworksEmbeddings(model="nomic-ai/nomic-embed-text-v1.5")
|
||||
def get_embeddings() -> Embeddings:
|
||||
return OpenAIEmbeddings(model="text-embedding-3-small")
|
||||
|
||||
|
||||
__all__ = ["ensure_configurable", "prepare_messages"]
|
||||
__all__ = ["prepare_messages"]
|
||||
|
||||
+40
-36
@@ -14,10 +14,10 @@ from langgraph.graph import END, START, StateGraph
|
||||
from trustcall import create_extractor
|
||||
from typing_extensions import Literal
|
||||
|
||||
from memory_service import _configuration as configuration
|
||||
from memory_service import _constants as constants
|
||||
from memory_service import _schemas as schemas
|
||||
from memory_service import _configuration as settings
|
||||
from memory_service import _utils as utils
|
||||
from memory_service import state as schemas
|
||||
|
||||
logger = logging.getLogger("memory")
|
||||
# Handle patch memory, where we update a single document in the database.
|
||||
@@ -33,13 +33,13 @@ async def fetch_patched_state(
|
||||
This is a placeholder function. You should replace this with a function
|
||||
that fetches the user's state from the database.
|
||||
"""
|
||||
configurable = utils.ensure_configurable(config["configurable"])
|
||||
configurable = configuration.Configuration.from_runnable_config(config)
|
||||
path = constants.PATCH_PATH.format(
|
||||
user_id=configurable["user_id"], function_name=state["function_name"]
|
||||
user_id=configurable.user_id, function_name=state.function_name
|
||||
)
|
||||
# TODO: does pinecone have an async api in their SDK...?
|
||||
response = utils.get_index().fetch(
|
||||
ids=[path], namespace=settings.SETTINGS.pinecone_namespace
|
||||
ids=[path], namespace=configurable.pinecone_namespace
|
||||
)
|
||||
if vectors := response.get("vectors"):
|
||||
document = vectors[path]
|
||||
@@ -52,12 +52,12 @@ async def extract_patch_memories(
|
||||
state: schemas.SingleExtractorState, config: RunnableConfig
|
||||
) -> dict:
|
||||
"""Extract the user's state from the conversation."""
|
||||
configurable = utils.ensure_configurable(config["configurable"])
|
||||
schemas = configurable["schemas"]
|
||||
memory_config = schemas[state["function_name"]]
|
||||
llm = init_chat_model(model=configurable["model"])
|
||||
configurable = configuration.Configuration.from_runnable_config(config)
|
||||
schemas = configurable.schemas
|
||||
memory_config = schemas[state.function_name]
|
||||
llm = init_chat_model(model=configurable.model)
|
||||
messages = utils.prepare_messages(
|
||||
state["messages"], memory_config.get("system_prompt") or ""
|
||||
state.messages, memory_config.get("system_prompt") or ""
|
||||
)
|
||||
extractor = create_extractor(
|
||||
llm,
|
||||
@@ -67,7 +67,7 @@ async def extract_patch_memories(
|
||||
inputs = {
|
||||
"messages": messages,
|
||||
}
|
||||
if existing := state["user_state"]:
|
||||
if existing := state.user_state:
|
||||
inputs["existing"] = {memory_config["function"]["name"]: existing}
|
||||
result = await extractor.ainvoke(inputs, config)
|
||||
return {"responses": result["responses"]}
|
||||
@@ -77,11 +77,11 @@ async def upsert_patched_state(
|
||||
state: schemas.SingleExtractorState, config: RunnableConfig
|
||||
) -> dict:
|
||||
"""Upsert the user's state to the database."""
|
||||
configurable = utils.ensure_configurable(config["configurable"])
|
||||
configurable = configuration.Configuration.from_runnable_config(config)
|
||||
path = constants.PATCH_PATH.format(
|
||||
user_id=configurable["user_id"], function_name=state["function_name"]
|
||||
user_id=configurable.user_id, function_name=state.function_name
|
||||
)
|
||||
serialized = state["responses"][0].model_dump_json()
|
||||
serialized = state.responses[0].model_dump_json()
|
||||
embeddings = utils.get_embeddings()
|
||||
vector = await embeddings.aembed_query(serialized)
|
||||
utils.get_index().upsert(
|
||||
@@ -93,16 +93,18 @@ async def upsert_patched_state(
|
||||
constants.PAYLOAD_KEY: serialized,
|
||||
constants.PATH_KEY: path,
|
||||
constants.TIMESTAMP_KEY: datetime.now(tz=timezone.utc),
|
||||
"user_id": configurable["user_id"],
|
||||
"user_id": configurable.user_id,
|
||||
},
|
||||
}
|
||||
],
|
||||
namespace=settings.SETTINGS.pinecone_namespace,
|
||||
namespace=configurable.pinecone_namespace,
|
||||
)
|
||||
return {"user_state": {}}
|
||||
|
||||
|
||||
patch_builder = StateGraph(schemas.SingleExtractorState, schemas.GraphConfig)
|
||||
patch_builder = StateGraph(
|
||||
schemas.SingleExtractorState, config_schema=configuration.Configuration
|
||||
)
|
||||
patch_builder.add_node(fetch_patched_state)
|
||||
patch_builder.add_node(extract_patch_memories)
|
||||
patch_builder.add_node(upsert_patched_state)
|
||||
@@ -114,7 +116,7 @@ def should_commit_patch(
|
||||
state: schemas.SingleExtractorState, config: RunnableConfig
|
||||
) -> Literal["upsert_patched_state", "__end__"]:
|
||||
"""Whether there are things extracted to commit to the DB."""
|
||||
return "upsert_patched_state" if state["responses"] else END
|
||||
return "upsert_patched_state" if state.responses else END
|
||||
|
||||
|
||||
patch_builder.add_conditional_edges("extract_patch_memories", should_commit_patch)
|
||||
@@ -128,11 +130,11 @@ async def extract_semantic_memories(
|
||||
state: schemas.SingleExtractorState, config: RunnableConfig
|
||||
) -> dict:
|
||||
"""Extract embeddable "events"."""
|
||||
configurable = utils.ensure_configurable(config["configurable"])
|
||||
llm = init_chat_model(model=configurable["model"])
|
||||
memory_config = configurable["schemas"][state["function_name"]]
|
||||
configurable = configuration.Configuration.from_runnable_config(config)
|
||||
llm = init_chat_model(model=configurable.model)
|
||||
memory_config = configurable.schemas[state.function_name]
|
||||
messages = utils.prepare_messages(
|
||||
state["messages"], memory_config.get("system_prompt") or ""
|
||||
state.messages, memory_config.get("system_prompt") or ""
|
||||
)
|
||||
|
||||
extractor = create_extractor(llm, tools=[memory_config["function"]])
|
||||
@@ -146,16 +148,16 @@ async def insert_memories(
|
||||
state: schemas.SingleExtractorState, config: RunnableConfig
|
||||
) -> dict:
|
||||
"""Insert the user's state to the database."""
|
||||
configurable = utils.ensure_configurable(config["configurable"])
|
||||
configurable = configuration.Configuration.from_runnable_config(config)
|
||||
embeddings = utils.get_embeddings()
|
||||
serialized = [r.model_dump_json() for r in state["responses"]]
|
||||
serialized = [r.model_dump_json() for r in state.responses]
|
||||
# You could alternatively do multi-vector lookup based on the schema.
|
||||
vectors = await embeddings.aembed_documents(serialized)
|
||||
current_time = datetime.now(tz=timezone.utc)
|
||||
paths = [
|
||||
constants.INSERT_PATH.format(
|
||||
user_id=configurable["user_id"],
|
||||
function_name=state["function_name"],
|
||||
user_id=configurable.user_id,
|
||||
function_name=configurable.function_name,
|
||||
event_id=str(uuid.uuid4()),
|
||||
)
|
||||
for _ in range(len(vectors))
|
||||
@@ -168,19 +170,21 @@ async def insert_memories(
|
||||
constants.PAYLOAD_KEY: serialized,
|
||||
constants.PATH_KEY: path,
|
||||
constants.TIMESTAMP_KEY: current_time,
|
||||
"user_id": configurable["user_id"],
|
||||
"user_id": configurable.user_id,
|
||||
},
|
||||
}
|
||||
for path, vector, serialized in zip(paths, vectors, serialized)
|
||||
]
|
||||
utils.get_index().upsert(
|
||||
vectors=documents,
|
||||
namespace=settings.SETTINGS.pinecone_namespace,
|
||||
namespace=configurable.pinecone_namespace,
|
||||
)
|
||||
return {"user_state": {}}
|
||||
|
||||
|
||||
semantic_builder = StateGraph(schemas.SingleExtractorState, schemas.GraphConfig)
|
||||
semantic_builder = StateGraph(
|
||||
schemas.SingleExtractorState, config_schema=configuration.GraphConfig
|
||||
)
|
||||
# Lots of quality improvements can be made here, such as:
|
||||
# - Fetch similar memories and prompt model to combine or extrapolate
|
||||
# - Adding advanced indexing by the memory schema (like importance, relevance, etc.)
|
||||
@@ -193,7 +197,7 @@ def should_insert(
|
||||
state: schemas.SingleExtractorState, config: RunnableConfig
|
||||
) -> Literal["insert_memories", "__end__"]:
|
||||
"""Whether there are things extracted to commit to the DB."""
|
||||
return "insert_memories" if state["responses"] else END
|
||||
return "insert_memories" if state.responses else END
|
||||
|
||||
|
||||
semantic_builder.add_conditional_edges("extract_semantic_memories", should_insert)
|
||||
@@ -218,10 +222,10 @@ async def schedule(state: schemas.State, config: RunnableConfig) -> dict:
|
||||
and a new one scheduled.
|
||||
"""
|
||||
if state.get("eager", False):
|
||||
return {"messages": state["messages"]}
|
||||
configurable = utils.ensure_configurable(config["configurable"])
|
||||
if configurable["delay"]:
|
||||
await asyncio.sleep(configurable["delay"])
|
||||
return {"messages": state.messages}
|
||||
configurable = configuration.Configuration.from_runnable_config(config)
|
||||
if configurable.delay:
|
||||
await asyncio.sleep(configurable.delay)
|
||||
return {"messages": []}
|
||||
|
||||
|
||||
@@ -240,9 +244,9 @@ def scatter_schemas(state: schemas.State, config: RunnableConfig) -> list[Send]:
|
||||
|
||||
These will be executed in parallel.
|
||||
"""
|
||||
configuration = utils.ensure_configurable(config["configurable"])
|
||||
configurable = configuration.Configuration.from_runnable_config(config)
|
||||
sends = []
|
||||
for k, v in configuration["schemas"].items():
|
||||
for k, v in configurable["schemas"].items():
|
||||
update_mode = v["update_mode"]
|
||||
match update_mode:
|
||||
case "patch":
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
"""Define the shared values."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.messages import AnyMessage
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langgraph.graph import add_messages
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class State:
|
||||
"""Main graph state."""
|
||||
|
||||
messages: Annotated[List[AnyMessage], add_messages]
|
||||
"""The messages in the conversation."""
|
||||
eager: bool
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class SingleExtractorState(State):
|
||||
"""Extractor state."""
|
||||
|
||||
function_name: str
|
||||
responses: list[BaseModel]
|
||||
user_state: Optional[Dict[str, Any]]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"State",
|
||||
"SingleExtractorState",
|
||||
]
|
||||
File diff suppressed because it is too large
Load Diff
@@ -7,8 +7,8 @@ from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langsmith import expect, get_current_run_tree, test
|
||||
|
||||
from memory_service._constants import PATCH_PATH
|
||||
from memory_service._schemas import GraphConfig, MemoryConfig
|
||||
from memory_service.graph import memgraph
|
||||
from memory_service.state import GraphConfig, MemoryConfig
|
||||
|
||||
|
||||
# To test the patch-based memory
|
||||
|
||||
Reference in New Issue
Block a user