This commit is contained in:
William Fu-Hinthorn
2024-09-23 11:01:14 -07:00
parent 738f1a3f77
commit 4f395547bb
10 changed files with 1669 additions and 166 deletions
+5 -5
View File
@@ -2,16 +2,16 @@
evals:
LANGCHAIN_TEST_CACHE=tests/evals/cassettes poetry run python -m pytest -p no:asyncio --max-asyncio-tasks 4 tests/evals
LANGCHAIN_TEST_CACHE=tests/evals/cassettes python -m python -m pytest -p no:asyncio --max-asyncio-tasks 4 tests/evals
lint:
poetry run ruff check .
poetry run mypy .
python -m ruff check .
python -m mypy .
format:
ruff check --select I --fix
poetry run ruff format .
poetry run ruff check . --fix
python -m ruff format .
python -m ruff check . --fix
build:
poetry build
+1 -1
View File
@@ -1,7 +1,7 @@
{
"dependencies": ["."],
"graphs": {
"memory": "./memory_service/graph.py:memgraph"
"memory": "./src/memory_service/graph.py:memgraph"
},
"env": ".env"
}
+1 -2
View File
@@ -9,13 +9,12 @@ from langchain.chat_models import init_chat_model
from langchain_core.messages import AnyMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import tool
from langgraph.graph import START, StateGraph, add_messages
from langgraph_sdk import get_client
from typing_extensions import Annotated, TypedDict
from memory_service import _constants as constants
from memory_service import _configuration as settings
from memory_service import _constants as constants
from memory_service import _utils as utils
+60 -8
View File
@@ -1,22 +1,74 @@
import os
from dataclasses import dataclass, fields
from dataclasses import dataclass, field, fields
from typing import Literal, Optional
from langchain_core.runnables import RunnableConfig
@dataclass(kw_only=True)
class FunctionSchema:
name: str
"""Name of the function."""
description: str
"""A description of the function."""
parameters: dict
"""The JSON Schema for the memory."""
@dataclass(kw_only=True)
class MemoryConfig:
function: FunctionSchema
"""The function to use for the memory assistant."""
system_prompt: Optional[str] = ""
"""The system prompt to use for the memory assistant."""
update_mode: Literal["patch", "insert"] = field(default="patch")
"""Whether to continuously patch the memory, or treat each new
generation as a new memory.
Patching is useful for maintaining a structured profile or core list
of memories. Inserting is useful for maintaining all interactions and
not losing any information.
For patched memories, you can GET the current state at any given time.
For inserted memories, you can query the full history of interactions.
"""
def __post_init__(self):
if isinstance(self.function, dict):
self.function = FunctionSchema(**self.function)
@dataclass(kv_only=True)
class Settings:
class Configuration:
pinecone_api_key: str = ""
pinecone_index_name: str = ""
pinecone_namespace: str = "ns1"
model: str = "accounts/fireworks/models/firefunction-v2"
delay: float = 60 # seconds
"""The delay in seconds to wait before considering a conversation complete.
Default is 60 seconds.
"""
model: str
"""The model to use for generating memories.
Defaults to Fireworks's "accounts/fireworks/models/firefunction-v2"
"""
schemas: dict[str, MemoryConfig] = field(default_factory=dict)
"""The schemas for the memory assistant."""
thread_id: str
"""The thread ID of the conversation."""
user_id: str
"""The ID of the user to remember in the conversation."""
@classmethod
def from_runnable_config(cls, config: RunnableConfig):
configurable = config.get("configurable") or {}
_env_env_defaults = {
f.name: os.environ.get(f.name.upper(), "") for f in fields(cls) if f.init
configurable = config["configurable"]
values = {
f.name: os.environ.get(f.name.upper(), configurable.get(f.name))
for f in fields(cls)
if f.init
}
return cls(
**{**_env_env_defaults, **{k: v for k, v in configurable.items() if k in _env_env_defaults}}
)
values["schemas"] = {k: MemoryConfig(**v) for k, v in values["schemas"].items()}
return cls(**values)
-76
View File
@@ -1,76 +0,0 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional
from langchain_core.messages import AnyMessage
from langchain_core.pydantic_v1 import BaseModel
from langgraph.graph import add_messages
from typing_extensions import Annotated, Literal, TypedDict
class FunctionSchema(TypedDict):
name: str
"""Name of the function."""
description: str
"""A description of the function."""
parameters: dict
"""The JSON Schema for the memory."""
class MemoryConfig(TypedDict, total=False):
function: FunctionSchema
"""The function to use for the memory assistant."""
system_prompt: Optional[str]
"""The system prompt to use for the memory assistant."""
update_mode: Literal["patch", "insert"]
"""Whether to continuously patch the memory, or treat each new
generation as a new memory.
Patching is useful for maintaining a structured profile or core list
of memories. Inserting is useful for maintaining all interactions and
not losing any information.
For patched memories, you can GET the current state at any given time.
For inserted memories, you can query the full history of interactions.
"""
class GraphConfig(TypedDict, total=False):
delay: float
"""The delay in seconds to wait before considering a conversation complete.
Default is 60 seconds.
"""
model: str
"""The model to use for generating memories.
Defaults to Fireworks's "accounts/fireworks/models/firefunction-v2"
"""
schemas: dict[str, MemoryConfig]
"""The schemas for the memory assistant."""
thread_id: str
"""The thread ID of the conversation."""
user_id: str
"""The ID of the user to remember in the conversation."""
class State(TypedDict):
messages: Annotated[List[AnyMessage], add_messages]
"""The messages in the conversation."""
eager: bool
class SingleExtractorState(State):
function_name: str
responses: list[BaseModel]
user_state: Optional[Dict[str, Any]]
__all__ = [
"State",
"GraphConfig",
"SingleExtractorState",
"FunctionSchema",
"MemoryConfig",
]
+5 -37
View File
@@ -3,56 +3,24 @@ from __future__ import annotations
from functools import lru_cache
from typing import Sequence
import langsmith
from langchain_core.embeddings import Embeddings
from langchain_core.messages import (
AnyMessage,
HumanMessage,
SystemMessage,
merge_message_runs,
)
from langchain_fireworks import FireworksEmbeddings
from langchain_openai import OpenAIEmbeddings
from pinecone import Pinecone
from memory_service import _schemas as schemas
from memory_service import _configuration as settings
_DEFAULT_DELAY = 60 # seconds
def get_index():
pc = Pinecone(api_key=settings.SETTINGS.pinecone_api_key)
return pc.Index(settings.SETTINGS.pinecone_index_name)
@langsmith.traceable
def ensure_memory_config(config: dict) -> schemas.MemoryConfig:
"""Merge the user-provided config with default values."""
return {
**config,
**schemas.MemoryConfig(
function=config.get("function", {}),
system_prompt=config.get("system_prompt"),
update_mode=config.get("update_mode", "patch"),
),
}
@langsmith.traceable
def ensure_configurable(config: dict) -> schemas.GraphConfig:
"""Merge the user-provided config with default values."""
function_schemas = config.get("schemas") or {}
return {
**config,
**schemas.GraphConfig(
delay=config.get("delay", _DEFAULT_DELAY),
model=config.get("model", settings.SETTINGS.model),
schemas={k: ensure_memory_config(v) for k, v in function_schemas.items()},
thread_id=config["thread_id"],
user_id=config["user_id"],
),
}
def prepare_messages(
messages: Sequence[AnyMessage], system_prompt: str
) -> list[AnyMessage]:
@@ -74,8 +42,8 @@ def prepare_messages(
@lru_cache
def get_embeddings():
return FireworksEmbeddings(model="nomic-ai/nomic-embed-text-v1.5")
def get_embeddings() -> Embeddings:
return OpenAIEmbeddings(model="text-embedding-3-small")
__all__ = ["ensure_configurable", "prepare_messages"]
__all__ = ["prepare_messages"]
+40 -36
View File
@@ -14,10 +14,10 @@ from langgraph.graph import END, START, StateGraph
from trustcall import create_extractor
from typing_extensions import Literal
from memory_service import _configuration as configuration
from memory_service import _constants as constants
from memory_service import _schemas as schemas
from memory_service import _configuration as settings
from memory_service import _utils as utils
from memory_service import state as schemas
logger = logging.getLogger("memory")
# Handle patch memory, where we update a single document in the database.
@@ -33,13 +33,13 @@ async def fetch_patched_state(
This is a placeholder function. You should replace this with a function
that fetches the user's state from the database.
"""
configurable = utils.ensure_configurable(config["configurable"])
configurable = configuration.Configuration.from_runnable_config(config)
path = constants.PATCH_PATH.format(
user_id=configurable["user_id"], function_name=state["function_name"]
user_id=configurable.user_id, function_name=state.function_name
)
# TODO: does pinecone have an async api in their SDK...?
response = utils.get_index().fetch(
ids=[path], namespace=settings.SETTINGS.pinecone_namespace
ids=[path], namespace=configurable.pinecone_namespace
)
if vectors := response.get("vectors"):
document = vectors[path]
@@ -52,12 +52,12 @@ async def extract_patch_memories(
state: schemas.SingleExtractorState, config: RunnableConfig
) -> dict:
"""Extract the user's state from the conversation."""
configurable = utils.ensure_configurable(config["configurable"])
schemas = configurable["schemas"]
memory_config = schemas[state["function_name"]]
llm = init_chat_model(model=configurable["model"])
configurable = configuration.Configuration.from_runnable_config(config)
schemas = configurable.schemas
memory_config = schemas[state.function_name]
llm = init_chat_model(model=configurable.model)
messages = utils.prepare_messages(
state["messages"], memory_config.get("system_prompt") or ""
state.messages, memory_config.get("system_prompt") or ""
)
extractor = create_extractor(
llm,
@@ -67,7 +67,7 @@ async def extract_patch_memories(
inputs = {
"messages": messages,
}
if existing := state["user_state"]:
if existing := state.user_state:
inputs["existing"] = {memory_config["function"]["name"]: existing}
result = await extractor.ainvoke(inputs, config)
return {"responses": result["responses"]}
@@ -77,11 +77,11 @@ async def upsert_patched_state(
state: schemas.SingleExtractorState, config: RunnableConfig
) -> dict:
"""Upsert the user's state to the database."""
configurable = utils.ensure_configurable(config["configurable"])
configurable = configuration.Configuration.from_runnable_config(config)
path = constants.PATCH_PATH.format(
user_id=configurable["user_id"], function_name=state["function_name"]
user_id=configurable.user_id, function_name=state.function_name
)
serialized = state["responses"][0].model_dump_json()
serialized = state.responses[0].model_dump_json()
embeddings = utils.get_embeddings()
vector = await embeddings.aembed_query(serialized)
utils.get_index().upsert(
@@ -93,16 +93,18 @@ async def upsert_patched_state(
constants.PAYLOAD_KEY: serialized,
constants.PATH_KEY: path,
constants.TIMESTAMP_KEY: datetime.now(tz=timezone.utc),
"user_id": configurable["user_id"],
"user_id": configurable.user_id,
},
}
],
namespace=settings.SETTINGS.pinecone_namespace,
namespace=configurable.pinecone_namespace,
)
return {"user_state": {}}
patch_builder = StateGraph(schemas.SingleExtractorState, schemas.GraphConfig)
patch_builder = StateGraph(
schemas.SingleExtractorState, config_schema=configuration.Configuration
)
patch_builder.add_node(fetch_patched_state)
patch_builder.add_node(extract_patch_memories)
patch_builder.add_node(upsert_patched_state)
@@ -114,7 +116,7 @@ def should_commit_patch(
state: schemas.SingleExtractorState, config: RunnableConfig
) -> Literal["upsert_patched_state", "__end__"]:
"""Whether there are things extracted to commit to the DB."""
return "upsert_patched_state" if state["responses"] else END
return "upsert_patched_state" if state.responses else END
patch_builder.add_conditional_edges("extract_patch_memories", should_commit_patch)
@@ -128,11 +130,11 @@ async def extract_semantic_memories(
state: schemas.SingleExtractorState, config: RunnableConfig
) -> dict:
"""Extract embeddable "events"."""
configurable = utils.ensure_configurable(config["configurable"])
llm = init_chat_model(model=configurable["model"])
memory_config = configurable["schemas"][state["function_name"]]
configurable = configuration.Configuration.from_runnable_config(config)
llm = init_chat_model(model=configurable.model)
memory_config = configurable.schemas[state.function_name]
messages = utils.prepare_messages(
state["messages"], memory_config.get("system_prompt") or ""
state.messages, memory_config.get("system_prompt") or ""
)
extractor = create_extractor(llm, tools=[memory_config["function"]])
@@ -146,16 +148,16 @@ async def insert_memories(
state: schemas.SingleExtractorState, config: RunnableConfig
) -> dict:
"""Insert the user's state to the database."""
configurable = utils.ensure_configurable(config["configurable"])
configurable = configuration.Configuration.from_runnable_config(config)
embeddings = utils.get_embeddings()
serialized = [r.model_dump_json() for r in state["responses"]]
serialized = [r.model_dump_json() for r in state.responses]
# You could alternatively do multi-vector lookup based on the schema.
vectors = await embeddings.aembed_documents(serialized)
current_time = datetime.now(tz=timezone.utc)
paths = [
constants.INSERT_PATH.format(
user_id=configurable["user_id"],
function_name=state["function_name"],
user_id=configurable.user_id,
function_name=configurable.function_name,
event_id=str(uuid.uuid4()),
)
for _ in range(len(vectors))
@@ -168,19 +170,21 @@ async def insert_memories(
constants.PAYLOAD_KEY: serialized,
constants.PATH_KEY: path,
constants.TIMESTAMP_KEY: current_time,
"user_id": configurable["user_id"],
"user_id": configurable.user_id,
},
}
for path, vector, serialized in zip(paths, vectors, serialized)
]
utils.get_index().upsert(
vectors=documents,
namespace=settings.SETTINGS.pinecone_namespace,
namespace=configurable.pinecone_namespace,
)
return {"user_state": {}}
semantic_builder = StateGraph(schemas.SingleExtractorState, schemas.GraphConfig)
semantic_builder = StateGraph(
schemas.SingleExtractorState, config_schema=configuration.GraphConfig
)
# Lots of quality improvements can be made here, such as:
# - Fetch similar memories and prompt model to combine or extrapolate
# - Adding advanced indexing by the memory schema (like importance, relevance, etc.)
@@ -193,7 +197,7 @@ def should_insert(
state: schemas.SingleExtractorState, config: RunnableConfig
) -> Literal["insert_memories", "__end__"]:
"""Whether there are things extracted to commit to the DB."""
return "insert_memories" if state["responses"] else END
return "insert_memories" if state.responses else END
semantic_builder.add_conditional_edges("extract_semantic_memories", should_insert)
@@ -218,10 +222,10 @@ async def schedule(state: schemas.State, config: RunnableConfig) -> dict:
and a new one scheduled.
"""
if state.get("eager", False):
return {"messages": state["messages"]}
configurable = utils.ensure_configurable(config["configurable"])
if configurable["delay"]:
await asyncio.sleep(configurable["delay"])
return {"messages": state.messages}
configurable = configuration.Configuration.from_runnable_config(config)
if configurable.delay:
await asyncio.sleep(configurable.delay)
return {"messages": []}
@@ -240,9 +244,9 @@ def scatter_schemas(state: schemas.State, config: RunnableConfig) -> list[Send]:
These will be executed in parallel.
"""
configuration = utils.ensure_configurable(config["configurable"])
configurable = configuration.Configuration.from_runnable_config(config)
sends = []
for k, v in configuration["schemas"].items():
for k, v in configurable["schemas"].items():
update_mode = v["update_mode"]
match update_mode:
case "patch":
+35
View File
@@ -0,0 +1,35 @@
"""Define the shared values."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from langchain_core.messages import AnyMessage
from langchain_core.pydantic_v1 import BaseModel
from langgraph.graph import add_messages
from typing_extensions import Annotated
@dataclass(kw_only=True)
class State:
"""Main graph state."""
messages: Annotated[List[AnyMessage], add_messages]
"""The messages in the conversation."""
eager: bool
@dataclass(kw_only=True)
class SingleExtractorState(State):
"""Extractor state."""
function_name: str
responses: list[BaseModel]
user_state: Optional[Dict[str, Any]]
__all__ = [
"State",
"SingleExtractorState",
]
File diff suppressed because it is too large Load Diff
+1 -1
View File
@@ -7,8 +7,8 @@ from langchain_core.pydantic_v1 import BaseModel, Field
from langsmith import expect, get_current_run_tree, test
from memory_service._constants import PATCH_PATH
from memory_service._schemas import GraphConfig, MemoryConfig
from memory_service.graph import memgraph
from memory_service.state import GraphConfig, MemoryConfig
# To test the patch-based memory