Upgrade Opengpts (#361)

* Migrate pydantic

* Upgrade poetry

* Adapt to manage checkpoint using an AstncSaver

* Adjust Tools model

* Add checkpoint

* Update poetry

* Format

* Fix tests

* Modify tables

* Fix gpt4o

* Fix bots

* Fix retrieval

* Adding eugenes suggestions

* Fix state handling inconsistency between different agent types

* Improve doc

* Update backend/pyproject.toml

* lint fix

* lint

---------

Co-authored-by: “lgesuellip” <“lgesuellipinto@uade.edu.ar”>
Co-authored-by: Eugene Yurtsev <eugene@langchain.dev>
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
lgesuellip
2025-01-29 18:51:41 -03:00
committed by GitHub
parent 541ae6f7f1
commit 2cf3bf75e1
27 changed files with 3408 additions and 2720 deletions
+13 -3
View File
@@ -227,6 +227,16 @@ Navigate to [http://localhost:5173/](http://localhost:5173/) and enjoy!
Refer to this [guide](tools/redis_to_postgres/README.md) for migrating data from Redis to Postgres.
## Breaking Changes
### Migration 5 - Checkpoint Management Update
Version 5 of the database migrations introduces a significant change to how thread checkpoints are managed:
- Transitions from a pickle-based checkpointing system to a new multi-table checkpoint management system (breaking change)
- Aligns with LangGraph's new checkpoint architecture for better state management and persistence
- **Important**: Historical threads/checkpoints (created before this migration) will not be accessible in the UI
- Previous checkpoint data is preserved in the `old_checkpoints` table but cannot be accessed by the new system
- This architectural change improves how thread state is stored and managed, enabling more reliable state persistence in LangGraph-based agents.
## Features
As much as possible, we are striving for feature parity with OpenAI.
@@ -309,14 +319,14 @@ Then, those documents are passed in the system message to a separate call to the
Compared to assistants, it is more structured (but less powerful). It ALWAYS looks up something - which is good if you
know you want to look things up, but potentially wasteful if the user is just trying to have a normal conversation.
Also importantly, this only looks up things once - so if it doesnt find the right results then it will yield a bad
Also importantly, this only looks up things once - so if it doesn't find the right results then it will yield a bad
result (compared to an assistant, which could decide to look things up again).
![](_static/rag.png)
Despite this being a more simple architecture, it is good for a few reasons. First, because it is simpler it can work
pretty well with a wider variety of models (including lots of open source models). Second, if you have a use case where
you dont NEED the flexibility of an assistant (eg you know users will be looking up information every time) then it
you don't NEED the flexibility of an assistant (eg you know users will be looking up information every time) then it
can be more focused. And third, compared to the final architecture below it can use external knowledge.
RAGBot is implemented with [LangGraph](https://github.com/langchain-ai/langgraph) `StateGraph`. A `StateGraph` is a generalized graph that can model arbitrary state (i.e. `dict`), not just a `list` of messages.
@@ -325,7 +335,7 @@ RAGBot is implemented with [LangGraph](https://github.com/langchain-ai/langgraph
The final architecture is dead simple - just a call to a language model, parameterized by a system message. This allows
the GPT to take on different personas and characters. This is clearly far less powerful than Assistants or RAGBots
(which have access to external sources of data/computation) - but its still valuable! A lot of popular GPTs are just
(which have access to external sources of data/computation) - but it's still valuable! A lot of popular GPTs are just
system messages at the end of the day, and CharacterAI is crushing it despite largely just being system messages as
well.
+25
View File
@@ -1 +1,26 @@
# backend
## Database Migrations
### Migration 5 - Checkpoint Management Update
This migration introduces a significant change to thread checkpoint management:
#### Changes
- Transitions from single-table pickle storage to a robust multi-table checkpoint management system
- Implements LangGraph's latest checkpoint architecture for improved state persistence
- Preserves existing checkpoint data by renaming `checkpoints` table to `old_checkpoints`
- Introduces three new tables for better checkpoint management:
- `checkpoints`: Core checkpoint metadata
- `checkpoint_blobs`: Actual checkpoint data storage (compatible with LangGraph state serialization)
- `checkpoint_writes`: Tracks checkpoint write operations
- Adds runtime initialization via `ensure_setup()` in the lifespan event
#### Impact
- **Breaking Change**: Historical threads/checkpoints (pre-migration) will not be accessible in the UI
- Previous checkpoint data remains preserved but inaccessible in the new system
- Designed to work seamlessly with LangGraph's state persistence requirements
#### Migration Details
- **Up Migration**: Safely preserves existing data by renaming the table
- **Down Migration**: Restores original table structure if needed
- New checkpoint management tables are automatically created at application startup
+16 -13
View File
@@ -1,4 +1,3 @@
import pickle
from enum import Enum
from typing import Any, Dict, Mapping, Optional, Sequence, Union
@@ -7,14 +6,13 @@ from langchain_core.runnables import (
ConfigurableField,
RunnableBinding,
)
from langgraph.checkpoint import CheckpointAt
from langgraph.graph.message import Messages
from langgraph.pregel import Pregel
from app.agent_types.tools_agent import get_tools_agent_executor
from app.agent_types.xml_agent import get_xml_agent_executor
from app.chatbot import get_chatbot_executor
from app.checkpoint import PostgresCheckpoint
from app.checkpoint import AsyncPostgresCheckpoint
from app.llms import (
get_anthropic_llm,
get_google_llm,
@@ -74,7 +72,7 @@ class AgentType(str, Enum):
DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant."
CHECKPOINTER = PostgresCheckpoint(serde=pickle, at=CheckpointAt.END_OF_STEP)
CHECKPOINTER = AsyncPostgresCheckpoint()
def get_agent_executor(
@@ -123,7 +121,6 @@ def get_agent_executor(
return get_tools_agent_executor(
tools, llm, system_message, interrupt_before_action, CHECKPOINTER
)
else:
raise ValueError("Unexpected agent type")
@@ -135,7 +132,7 @@ class ConfigurableAgent(RunnableBinding):
retrieval_description: str = RETRIEVAL_DESCRIPTION
interrupt_before_action: bool = False
assistant_id: Optional[str] = None
thread_id: Optional[str] = None
thread_id: Optional[str] = ""
user_id: Optional[str] = None
def __init__(
@@ -145,7 +142,7 @@ class ConfigurableAgent(RunnableBinding):
agent: AgentType = AgentType.GPT_35_TURBO,
system_message: str = DEFAULT_SYSTEM_MESSAGE,
assistant_id: Optional[str] = None,
thread_id: Optional[str] = None,
thread_id: Optional[str] = "",
retrieval_description: str = RETRIEVAL_DESCRIPTION,
interrupt_before_action: bool = False,
kwargs: Optional[Mapping[str, Any]] = None,
@@ -204,7 +201,9 @@ def get_chatbot(
if llm_type == LLMType.GPT_35_TURBO:
llm = get_openai_llm()
elif llm_type == LLMType.GPT_4:
llm = get_openai_llm(gpt_4=True)
llm = get_openai_llm(model="gpt-4")
elif llm_type == LLMType.GPT_4O:
llm = get_openai_llm(model="gpt-4o")
elif llm_type == LLMType.AZURE_OPENAI:
llm = get_openai_llm(azure=True)
elif llm_type == LLMType.CLAUDE2:
@@ -265,7 +264,7 @@ class ConfigurableRetrieval(RunnableBinding):
llm_type: LLMType
system_message: str = DEFAULT_SYSTEM_MESSAGE
assistant_id: Optional[str] = None
thread_id: Optional[str] = None
thread_id: Optional[str] = ""
user_id: Optional[str] = None
def __init__(
@@ -274,7 +273,7 @@ class ConfigurableRetrieval(RunnableBinding):
llm_type: LLMType = LLMType.GPT_35_TURBO,
system_message: str = DEFAULT_SYSTEM_MESSAGE,
assistant_id: Optional[str] = None,
thread_id: Optional[str] = None,
thread_id: Optional[str] = "",
kwargs: Optional[Mapping[str, Any]] = None,
config: Optional[Mapping[str, Any]] = None,
**others: Any,
@@ -319,7 +318,9 @@ chat_retrieval = (
assistant_id=ConfigurableField(
id="assistant_id", name="Assistant ID", is_shared=True
),
thread_id=ConfigurableField(id="thread_id", name="Thread ID", is_shared=True),
thread_id=ConfigurableField(
id="thread_id", name="Thread ID", annotation=str, is_shared=True
),
)
.with_types(
input_type=Dict[str, Any],
@@ -335,7 +336,7 @@ agent: Pregel = (
system_message=DEFAULT_SYSTEM_MESSAGE,
retrieval_description=RETRIEVAL_DESCRIPTION,
assistant_id=None,
thread_id=None,
thread_id="",
)
.configurable_fields(
agent=ConfigurableField(id="agent_type", name="Agent Type"),
@@ -348,7 +349,9 @@ agent: Pregel = (
assistant_id=ConfigurableField(
id="assistant_id", name="Assistant ID", is_shared=True
),
thread_id=ConfigurableField(id="thread_id", name="Thread ID", is_shared=True),
thread_id=ConfigurableField(
id="thread_id", name="Thread ID", annotation=str, is_shared=True
),
tools=ConfigurableField(id="tools", name="Tools"),
retrieval_description=ConfigurableField(
id="retrieval_description", name="Retrieval Description"
+1 -1
View File
@@ -28,7 +28,7 @@ def get_tools_agent_executor(
msgs = []
for m in messages:
if isinstance(m, LiberalToolMessage):
_dict = m.dict()
_dict = m.model_dump()
_dict["content"] = str(_dict["content"])
m_c = ToolMessage(**_dict)
msgs.append(m_c)
+1 -1
View File
@@ -45,7 +45,7 @@ def construct_chat_history(messages):
temp_messages = []
collapsed_messages.append(message)
elif isinstance(message, LiberalFunctionMessage):
_dict = message.dict()
_dict = message.model_dump()
_dict["content"] = str(_dict["content"])
m_c = FunctionMessage(**_dict)
temp_messages.append(m_c)
+10 -8
View File
@@ -14,9 +14,11 @@ router = APIRouter()
class AssistantPayload(BaseModel):
"""Payload for creating an assistant."""
name: str = Field(..., description="The name of the assistant.")
config: dict = Field(..., description="The assistant config.")
public: bool = Field(default=False, description="Whether the assistant is public.")
name: Annotated[str, Field(description="The name of the assistant.")]
config: Annotated[dict, Field(description="The assistant config.")]
public: Annotated[
bool, Field(default=False, description="Whether the assistant is public.")
]
AssistantID = Annotated[str, Path(description="The ID of the assistant.")]
@@ -25,7 +27,7 @@ AssistantID = Annotated[str, Path(description="The ID of the assistant.")]
@router.get("/")
async def list_assistants(user: AuthedUser) -> List[Assistant]:
"""List all assistants for the current user."""
return await storage.list_assistants(user["user_id"])
return await storage.list_assistants(user.user_id)
@router.get("/public/")
@@ -40,7 +42,7 @@ async def get_assistant(
aid: AssistantID,
) -> Assistant:
"""Get an assistant by ID."""
assistant = await storage.get_assistant(user["user_id"], aid)
assistant = await storage.get_assistant(user.user_id, aid)
if not assistant:
raise HTTPException(status_code=404, detail="Assistant not found")
return assistant
@@ -53,7 +55,7 @@ async def create_assistant(
) -> Assistant:
"""Create an assistant."""
return await storage.put_assistant(
user["user_id"],
user.user_id,
str(uuid4()),
name=payload.name,
config=payload.config,
@@ -69,7 +71,7 @@ async def upsert_assistant(
) -> Assistant:
"""Create or update an assistant."""
return await storage.put_assistant(
user["user_id"],
user.user_id,
aid,
name=payload.name,
config=payload.config,
@@ -83,5 +85,5 @@ async def delete_assistant(
aid: AssistantID,
):
"""Delete an assistant by ID."""
await storage.delete_assistant(user["user_id"], aid)
await storage.delete_assistant(user.user_id, aid)
return {"status": "ok"}
+23 -14
View File
@@ -4,14 +4,13 @@ from uuid import UUID
import langsmith.client
from fastapi import APIRouter, BackgroundTasks, HTTPException
from fastapi.exceptions import RequestValidationError
from langchain.pydantic_v1 import ValidationError
from langchain_core.messages import AnyMessage
from langchain_core.runnables import RunnableConfig
from langsmith.utils import tracing_is_enabled
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, ValidationError
from sse_starlette import EventSourceResponse
from app.agent import agent
from app.agent import agent, chat_retrieval, chatbot
from app.auth.handlers import AuthedUser
from app.storage import get_assistant, get_thread
from app.stream import astream_state, to_sse
@@ -34,24 +33,34 @@ async def _run_input_and_config(payload: CreateRunPayload, user_id: str):
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
assistant = await get_assistant(user_id, str(thread["assistant_id"]))
assistant = await get_assistant(user_id, str(thread.assistant_id))
if not assistant:
raise HTTPException(status_code=404, detail="Assistant not found")
config: RunnableConfig = {
**assistant["config"],
**assistant.config,
"configurable": {
**assistant["config"]["configurable"],
**assistant.config["configurable"],
**((payload.config or {}).get("configurable") or {}),
"user_id": user_id,
"thread_id": str(thread["thread_id"]),
"assistant_id": str(assistant["assistant_id"]),
"thread_id": str(thread.thread_id),
"assistant_id": str(assistant.assistant_id),
},
}
try:
if payload.input is not None:
agent.get_input_schema(config).validate(payload.input)
# Get the bot type from config
bot_type = config["configurable"].get("type", "agent")
# Get the correct schema based on bot type
if bot_type == "chat_retrieval":
schema = chat_retrieval.get_input_schema()
elif bot_type == "chatbot":
schema = chatbot.get_input_schema()
else: # default to agent
schema = agent.get_input_schema()
# Validate against the correct schema
schema.model_validate(payload.input)
except ValidationError as e:
raise RequestValidationError(e.errors(), body=payload)
@@ -65,7 +74,7 @@ async def create_run(
background_tasks: BackgroundTasks,
):
"""Create a run."""
input_, config = await _run_input_and_config(payload, user["user_id"])
input_, config = await _run_input_and_config(payload, user.user_id)
background_tasks.add_task(agent.ainvoke, input_, config)
return {"status": "ok"} # TODO add a run id
@@ -76,7 +85,7 @@ async def stream_run(
user: AuthedUser,
):
"""Create a run."""
input_, config = await _run_input_and_config(payload, user["user_id"])
input_, config = await _run_input_and_config(payload, user.user_id)
return EventSourceResponse(to_sse(astream_state(agent, input_, config)))
@@ -84,19 +93,19 @@ async def stream_run(
@router.get("/input_schema")
async def input_schema() -> dict:
"""Return the input schema of the runnable."""
return agent.get_input_schema().schema()
return agent.get_input_schema().model_json_schema()
@router.get("/output_schema")
async def output_schema() -> dict:
"""Return the output schema of the runnable."""
return agent.get_output_schema().schema()
return agent.get_output_schema().model_json_schema()
@router.get("/config_schema")
async def config_schema() -> dict:
"""Return the config schema of the runnable."""
return agent.config_schema().schema()
return agent.config_schema().model_json_schema()
if tracing_is_enabled():
+16 -16
View File
@@ -18,8 +18,8 @@ ThreadID = Annotated[str, Path(description="The ID of the thread.")]
class ThreadPutRequest(BaseModel):
"""Payload for creating a thread."""
name: str = Field(..., description="The name of the thread.")
assistant_id: str = Field(..., description="The ID of the assistant to use.")
name: Annotated[str, Field(description="The name of the thread.")]
assistant_id: Annotated[str, Field(description="The ID of the assistant to use.")]
class ThreadPostRequest(BaseModel):
@@ -32,7 +32,7 @@ class ThreadPostRequest(BaseModel):
@router.get("/")
async def list_threads(user: AuthedUser) -> List[Thread]:
"""List all threads for the current user."""
return await storage.list_threads(user["user_id"])
return await storage.list_threads(user.user_id)
@router.get("/{tid}/state")
@@ -41,14 +41,14 @@ async def get_thread_state(
tid: ThreadID,
):
"""Get state for a thread."""
thread = await storage.get_thread(user["user_id"], tid)
thread = await storage.get_thread(user.user_id, tid)
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
assistant = await storage.get_assistant(user["user_id"], thread["assistant_id"])
assistant = await storage.get_assistant(user.user_id, thread.assistant_id)
if not assistant:
raise HTTPException(status_code=400, detail="Thread has no assistant")
return await storage.get_thread_state(
user_id=user["user_id"],
user_id=user.user_id,
thread_id=tid,
assistant=assistant,
)
@@ -61,16 +61,16 @@ async def add_thread_state(
payload: ThreadPostRequest,
):
"""Add state to a thread."""
thread = await storage.get_thread(user["user_id"], tid)
thread = await storage.get_thread(user.user_id, tid)
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
assistant = await storage.get_assistant(user["user_id"], thread["assistant_id"])
assistant = await storage.get_assistant(user.user_id, thread.assistant_id)
if not assistant:
raise HTTPException(status_code=400, detail="Thread has no assistant")
return await storage.update_thread_state(
payload.config or {"configurable": {"thread_id": tid}},
payload.values,
user_id=user["user_id"],
user_id=user.user_id,
assistant=assistant,
)
@@ -81,14 +81,14 @@ async def get_thread_history(
tid: ThreadID,
):
"""Get all past states for a thread."""
thread = await storage.get_thread(user["user_id"], tid)
thread = await storage.get_thread(user.user_id, tid)
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
assistant = await storage.get_assistant(user["user_id"], thread["assistant_id"])
assistant = await storage.get_assistant(user.user_id, thread.assistant_id)
if not assistant:
raise HTTPException(status_code=400, detail="Thread has no assistant")
return await storage.get_thread_history(
user_id=user["user_id"],
user_id=user.user_id,
thread_id=tid,
assistant=assistant,
)
@@ -100,7 +100,7 @@ async def get_thread(
tid: ThreadID,
) -> Thread:
"""Get a thread by ID."""
thread = await storage.get_thread(user["user_id"], tid)
thread = await storage.get_thread(user.user_id, tid)
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
return thread
@@ -113,7 +113,7 @@ async def create_thread(
) -> Thread:
"""Create a thread."""
return await storage.put_thread(
user["user_id"],
user.user_id,
str(uuid4()),
assistant_id=thread_put_request.assistant_id,
name=thread_put_request.name,
@@ -128,7 +128,7 @@ async def upsert_thread(
) -> Thread:
"""Update a thread."""
return await storage.put_thread(
user["user_id"],
user.user_id,
tid,
assistant_id=thread_put_request.assistant_id,
name=thread_put_request.name,
@@ -141,5 +141,5 @@ async def delete_thread(
tid: ThreadID,
):
"""Delete a thread by ID."""
await storage.delete_thread(user["user_id"], tid)
await storage.delete_thread(user.user_id, tid)
return {"status": "ok"}
+21 -11
View File
@@ -1,9 +1,10 @@
import os
from base64 import b64decode
from enum import Enum
from typing import Optional, Union
from typing import List, Optional, Union
from pydantic import BaseSettings, root_validator, validator
from pydantic import ConfigDict, field_validator, model_validator
from pydantic_settings import BaseSettings
class AuthType(Enum):
@@ -16,12 +17,16 @@ class JWTSettingsBase(BaseSettings):
iss: str
aud: Union[str, list[str]]
@validator("aud", pre=True, always=True)
def set_aud(cls, v, values) -> Union[str, list[str]]:
return v.split(",") if "," in v else v
@field_validator("aud", mode="before")
@classmethod
def set_aud(cls, v) -> Union[str, List[str]]:
if isinstance(v, str) and "," in v:
return v.split(",")
return v
class Config:
env_prefix = "jwt_"
model_config = ConfigDict(
env_prefix="jwt_",
)
class JWTSettingsLocal(JWTSettingsBase):
@@ -29,14 +34,18 @@ class JWTSettingsLocal(JWTSettingsBase):
decode_key: str = None
alg: str
@validator("decode_key", pre=True, always=True)
def set_decode_key(cls, v, values):
@field_validator("decode_key", mode="before")
@classmethod
def set_decode_key(cls, v, info):
"""
Key may be a multiline string (e.g. in the case of a public key), so to
be able to set it from env, we set it as a base64 encoded string and
decode it here.
"""
return b64decode(values["decode_key_b64"]).decode("utf-8")
decode_key_b64 = info.data.get("decode_key_b64")
if decode_key_b64:
return b64decode(decode_key_b64).decode("utf-8")
return v
class JWTSettingsOIDC(JWTSettingsBase):
@@ -48,7 +57,8 @@ class Settings(BaseSettings):
jwt_local: Optional[JWTSettingsLocal] = None
jwt_oidc: Optional[JWTSettingsOIDC] = None
@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def check_jwt_settings(cls, values):
auth_type = values.get("auth_type")
if auth_type == AuthType.JWT_LOCAL and values.get("jwt_local") is None:
+96 -126
View File
@@ -1,147 +1,117 @@
import pickle
from datetime import datetime
from typing import AsyncIterator, Optional
import os
from typing import Any, AsyncIterator, Optional, Sequence
from langchain_core.messages import BaseMessage
from langchain_core.runnables import ConfigurableFieldSpec, RunnableConfig
from langgraph.checkpoint.base import BaseCheckpointSaver
import structlog
from langgraph.checkpoint.base import (
ChannelVersions,
Checkpoint,
CheckpointAt,
CheckpointThreadTs,
CheckpointMetadata,
CheckpointTuple,
SerializerProtocol,
RunnableConfig,
)
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from langgraph.checkpoint.postgres.base import BasePostgresSaver
from langgraph.checkpoint.serde.base import SerializerProtocol
from psycopg import AsyncPipeline
from psycopg_pool import AsyncConnectionPool
from app.lifespan import get_pg_pool
logger = structlog.get_logger(__name__)
def loads(value: bytes) -> Checkpoint:
loaded: Checkpoint = pickle.loads(value)
for key, value in loaded["channel_values"].items():
if isinstance(value, list) and all(isinstance(v, BaseMessage) for v in value):
loaded["channel_values"][key] = [v.__class__(**v.__dict__) for v in value]
return loaded
class AsyncPostgresCheckpoint(BasePostgresSaver):
"""A singleton implementation of AsyncPostgresSaver with separate setup."""
_instance = None
def __new__(cls, *args, **kwargs):
if not cls._instance:
cls._instance = super().__new__(cls)
return cls._instance
class PostgresCheckpoint(BaseCheckpointSaver):
def __init__(
self,
*,
pipe: Optional[AsyncPipeline] = None,
serde: Optional[SerializerProtocol] = None,
at: Optional[CheckpointAt] = None,
) -> None:
super().__init__(serde=serde, at=at)
if not hasattr(self, "_initialized"):
super().__init__(serde=serde)
# Initialize basic attributes
self.pipe = pipe
self.serde = serde
self._initialized = True
self._setup_complete = False
self.async_postgres_saver = None
@property
def config_specs(self) -> list[ConfigurableFieldSpec]:
return [
ConfigurableFieldSpec(
id="thread_id",
annotation=Optional[str],
name="Thread ID",
description=None,
default=None,
is_shared=True,
),
CheckpointThreadTs,
]
async def ensure_setup(self) -> None:
"""Ensure the instance is set up before use."""
if not self._setup_complete:
await self.setup()
self._setup_complete = True
def get(self, config: RunnableConfig) -> Optional[Checkpoint]:
raise NotImplementedError
async def setup(self) -> None:
"""Internal setup method."""
try:
conninfo = (
f"postgresql://{os.environ['POSTGRES_USER']}:"
f"{os.environ['POSTGRES_PASSWORD']}@"
f"{os.environ['POSTGRES_HOST']}:"
f"{os.environ['POSTGRES_PORT']}/"
f"{os.environ['POSTGRES_DB']}"
)
def put(self, config: RunnableConfig, checkpoint: Checkpoint) -> RunnableConfig:
raise NotImplementedError
pool = AsyncConnectionPool(
conninfo=conninfo,
kwargs={"autocommit": True, "prepare_threshold": 0},
open=False, # Don't open in constructor
)
await pool.open()
async def alist(self, config: RunnableConfig) -> AsyncIterator[CheckpointTuple]:
async with get_pg_pool().acquire() as db, db.transaction():
thread_id = config["configurable"]["thread_id"]
async for value in db.cursor(
"SELECT checkpoint, thread_ts, parent_ts FROM checkpoints WHERE thread_id = $1 ORDER BY thread_ts DESC",
thread_id,
):
yield CheckpointTuple(
{
"configurable": {
"thread_id": thread_id,
"thread_ts": value[1],
}
},
loads(value[0]),
{
"configurable": {
"thread_id": thread_id,
"thread_ts": value[2],
}
}
if value[2]
else None,
)
self.async_postgres_saver = AsyncPostgresSaver(
conn=pool, pipe=self.pipe, serde=self.serde
)
# Setup will create/migrate the tables if they don't exist
await self.async_postgres_saver.setup()
logger.warning("Checkpoint setup complete.")
except Exception as e:
logger.error(f"Failed to set up AsyncPostgresCheckpoint: {e}")
raise
async def alist(
self,
config: Optional[RunnableConfig],
*,
filter: Optional[dict[str, Any]] = None,
before: Optional[RunnableConfig] = None,
limit: Optional[int] = None,
) -> AsyncIterator[CheckpointTuple]:
"""List checkpoints from the database asynchronously."""
return self.async_postgres_saver.alist(
config, filter=filter, before=before, limit=limit
)
async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
thread_id = config["configurable"]["thread_id"]
thread_ts = config["configurable"].get("thread_ts")
async with get_pg_pool().acquire() as conn:
if thread_ts:
if value := await conn.fetchrow(
"SELECT checkpoint, parent_ts FROM checkpoints WHERE thread_id = $1 AND thread_ts = $2",
thread_id,
datetime.fromisoformat(thread_ts),
):
return CheckpointTuple(
config,
loads(value[0]),
{
"configurable": {
"thread_id": thread_id,
"thread_ts": value[1],
}
}
if value[1]
else None,
)
else:
if value := await conn.fetchrow(
"SELECT checkpoint, thread_ts, parent_ts FROM checkpoints WHERE thread_id = $1 ORDER BY thread_ts DESC LIMIT 1",
thread_id,
):
return CheckpointTuple(
{
"configurable": {
"thread_id": thread_id,
"thread_ts": value[1],
}
},
loads(value[0]),
{
"configurable": {
"thread_id": thread_id,
"thread_ts": value[2],
}
}
if value[2]
else None,
)
"""Get a checkpoint tuple from the database asynchronously."""
return await self.async_postgres_saver.aget_tuple(config)
async def aput(self, config: RunnableConfig, checkpoint: Checkpoint) -> None:
thread_id = config["configurable"]["thread_id"]
async with get_pg_pool().acquire() as conn:
await conn.execute(
"""
INSERT INTO checkpoints (thread_id, thread_ts, parent_ts, checkpoint)
VALUES ($1, $2, $3, $4)
ON CONFLICT (thread_id, thread_ts)
DO UPDATE SET checkpoint = EXCLUDED.checkpoint;""",
thread_id,
datetime.fromisoformat(checkpoint["ts"]),
datetime.fromisoformat(checkpoint.get("parent_ts"))
if checkpoint.get("parent_ts")
else None,
pickle.dumps(checkpoint),
)
return {
"configurable": {
"thread_id": thread_id,
"thread_ts": checkpoint["ts"],
}
}
async def aput(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""Save a checkpoint to the database asynchronously."""
return await self.async_postgres_saver.aput(
config, checkpoint, metadata, new_versions
)
async def aput_writes(
self,
config: RunnableConfig,
writes: Sequence[tuple[str, Any]],
task_id: str,
) -> None:
"""Store intermediate writes linked to a checkpoint asynchronously."""
await self.async_postgres_saver.aput_writes(config, writes, task_id)
+3
View File
@@ -6,6 +6,8 @@ import orjson
import structlog
from fastapi import FastAPI
from app.checkpoint import AsyncPostgresCheckpoint
_pg_pool = None
@@ -56,6 +58,7 @@ async def lifespan(app: FastAPI):
port=os.environ["POSTGRES_PORT"],
init=_init_connection,
)
await AsyncPostgresCheckpoint().ensure_setup()
yield
await _pg_pool.close()
_pg_pool = None
+9 -8
View File
@@ -1,33 +1,34 @@
from typing import Any, get_args
from typing import Any
from langchain_core.messages import (
AnyMessage,
FunctionMessage,
MessageLikeRepresentation,
ToolMessage,
_message_from_dict,
)
from langgraph.graph.message import Messages, add_messages
from pydantic import Field
class LiberalFunctionMessage(FunctionMessage):
content: Any
content: Any = Field(default="")
class LiberalToolMessage(ToolMessage):
content: Any
content: Any = Field(default="")
def _convert_pydantic_dict_to_message(
data: MessageLikeRepresentation
data: MessageLikeRepresentation,
) -> MessageLikeRepresentation:
"""Convert a dictionary to a message object if it matches message format."""
if (
isinstance(data, dict)
and "content" in data
and isinstance(data.get("type"), str)
):
for cls in get_args(AnyMessage):
if data["type"] == cls(content="").type:
return cls(**data)
_type = data.pop("type")
return _message_from_dict({"data": data, "type": _type})
return data
+3 -2
View File
@@ -54,7 +54,7 @@ def get_retrieval_executor(
if isinstance(m, HumanMessage):
chat_history.append(m)
response = messages[-1].content
content = "\n".join([d.page_content for d in response])
content = "\n".join([d["page_content"] for d in response])
return [
SystemMessage(
content=response_prompt_template.format(
@@ -80,7 +80,7 @@ def get_retrieval_executor(
async def invoke_retrieval(state: AgentState):
messages = state["messages"]
if len(messages) == 1:
human_input = messages[-1]["content"]
human_input = messages[-1].content
return {
"messages": [
AIMessage(
@@ -118,6 +118,7 @@ def get_retrieval_executor(
params = messages[-1].tool_calls[0]
query = params["args"]["query"]
response = await retriever.ainvoke(query)
response = [doc.model_dump() for doc in response]
msg = LiberalToolMessage(
name="retrieval", content=response, tool_call_id=params["id"]
)
+7 -9
View File
@@ -1,10 +1,10 @@
from datetime import datetime
from typing import Optional
from typing_extensions import TypedDict
from pydantic import BaseModel
class User(TypedDict):
class User(BaseModel):
user_id: str
"""The ID of the user."""
sub: str
@@ -13,9 +13,7 @@ class User(TypedDict):
"""The time the user was created."""
class Assistant(TypedDict):
"""Assistant model."""
class Assistant(BaseModel):
assistant_id: str
"""The ID of the assistant."""
user_id: str
@@ -26,19 +24,19 @@ class Assistant(TypedDict):
"""The assistant config."""
updated_at: datetime
"""The last time the assistant was updated."""
public: bool
public: bool = False
"""Whether the assistant is public."""
class Thread(TypedDict):
class Thread(BaseModel):
thread_id: str
"""The ID of the thread."""
user_id: str
"""The ID of the user that owns the thread."""
assistant_id: Optional[str]
assistant_id: Optional[str] = None
"""The assistant that was used in conjunction with this thread."""
name: str
"""The name of the thread."""
updated_at: datetime
"""The last time the thread was updated."""
metadata: Optional[dict]
metadata: Optional[dict] = None
+2 -2
View File
@@ -34,13 +34,13 @@ async def ingest_files(
assistant_id = config["configurable"].get("assistant_id")
if assistant_id is not None:
assistant = await storage.get_assistant(user["user_id"], assistant_id)
assistant = await storage.get_assistant(user.user_id, assistant_id)
if assistant is None:
raise HTTPException(status_code=404, detail="Assistant not found.")
thread_id = config["configurable"].get("thread_id")
if thread_id is not None:
thread = await storage.get_thread(user["user_id"], thread_id)
thread = await storage.get_thread(user.user_id, thread_id)
if thread is None:
raise HTTPException(status_code=404, detail="Thread not found.")
+84 -36
View File
@@ -12,23 +12,30 @@ from app.schema import Assistant, Thread, User
async def list_assistants(user_id: str) -> List[Assistant]:
"""List all assistants for the current user."""
async with get_pg_pool().acquire() as conn:
return await conn.fetch("SELECT * FROM assistant WHERE user_id = $1", user_id)
records = await conn.fetch(
"SELECT * FROM assistant WHERE user_id = $1", user_id
)
return [Assistant(**record) for record in records]
async def get_assistant(user_id: str, assistant_id: str) -> Optional[Assistant]:
"""Get an assistant by ID."""
async with get_pg_pool().acquire() as conn:
return await conn.fetchrow(
record = await conn.fetchrow(
"SELECT * FROM assistant WHERE assistant_id = $1 AND (user_id = $2 OR public IS true)",
assistant_id,
user_id,
)
if record is None:
return None
return Assistant(**record)
async def list_public_assistants() -> List[Assistant]:
"""List all the public assistants."""
async with get_pg_pool().acquire() as conn:
return await conn.fetch(("SELECT * FROM assistant WHERE public IS true;"))
records = await conn.fetch("SELECT * FROM assistant WHERE public IS true")
return [Assistant(**record) for record in records]
async def put_assistant(
@@ -66,14 +73,14 @@ async def put_assistant(
updated_at,
public,
)
return {
"assistant_id": assistant_id,
"user_id": user_id,
"name": name,
"config": config,
"updated_at": updated_at,
"public": public,
}
return Assistant(
assistant_id=assistant_id,
user_id=user_id,
name=name,
config=config,
updated_at=updated_at,
public=public,
)
async def delete_assistant(user_id: str, assistant_id: str) -> None:
@@ -89,17 +96,21 @@ async def delete_assistant(user_id: str, assistant_id: str) -> None:
async def list_threads(user_id: str) -> List[Thread]:
"""List all threads for the current user."""
async with get_pg_pool().acquire() as conn:
return await conn.fetch("SELECT * FROM thread WHERE user_id = $1", user_id)
records = await conn.fetch("SELECT * FROM thread WHERE user_id = $1", user_id)
return [Thread(**record) for record in records]
async def get_thread(user_id: str, thread_id: str) -> Optional[Thread]:
"""Get a thread by ID."""
async with get_pg_pool().acquire() as conn:
return await conn.fetchrow(
record = await conn.fetchrow(
"SELECT * FROM thread WHERE thread_id = $1 AND user_id = $2",
thread_id,
user_id,
)
if record is None:
return None
return Thread(**record)
async def get_thread_state(*, user_id: str, thread_id: str, assistant: Assistant):
@@ -107,14 +118,17 @@ async def get_thread_state(*, user_id: str, thread_id: str, assistant: Assistant
state = await agent.aget_state(
{
"configurable": {
**assistant["config"]["configurable"],
**assistant.config["configurable"],
"thread_id": thread_id,
"assistant_id": assistant["assistant_id"],
"assistant_id": assistant.assistant_id,
}
}
)
# Keep original format - return values as is
values = state.values if state.values else None
return {
"values": state.values,
"values": values,
"next": state.next,
}
@@ -127,15 +141,39 @@ async def update_thread_state(
assistant: Assistant,
):
"""Add state to a thread."""
# Get the current state to determine the format
current_state = await agent.aget_state(
{
"configurable": {
**assistant.config["configurable"],
**config["configurable"],
"assistant_id": assistant.assistant_id,
}
}
)
# If current state is a dict (retrieval agent), maintain dict structure
if current_state.values and isinstance(current_state.values, dict):
if isinstance(values, dict):
state_values = values
else:
# Update just the messages in the existing state
state_values = {**current_state.values, "messages": values}
else:
# For message-only states (tools_agent, chatbot), just use the messages
state_values = (
values if isinstance(values, dict) and "messages" in values else values
)
await agent.aupdate_state(
{
"configurable": {
**assistant["config"]["configurable"],
**assistant.config["configurable"],
**config["configurable"],
"assistant_id": assistant["assistant_id"],
"assistant_id": assistant.assistant_id,
}
},
values,
state_values,
)
@@ -151,15 +189,27 @@ async def get_thread_history(*, user_id: str, thread_id: str, assistant: Assista
async for c in agent.aget_state_history(
{
"configurable": {
**assistant["config"]["configurable"],
**assistant.config["configurable"],
"thread_id": thread_id,
"assistant_id": assistant["assistant_id"],
"assistant_id": assistant.assistant_id,
}
}
)
]
def get_assistant_type(config: dict) -> str:
"""Extract assistant type from config, handling both old and new formats."""
configurable = config.get("configurable", {})
# First try direct type key (old format)
if "type" in configurable:
return configurable["type"]
# Default fallback
return "chatbot"
async def put_thread(
user_id: str, thread_id: str, *, assistant_id: str, name: str
) -> Thread:
@@ -167,9 +217,7 @@ async def put_thread(
updated_at = datetime.now(timezone.utc)
assistant = await get_assistant(user_id, assistant_id)
metadata = (
{"assistant_type": assistant["config"]["configurable"]["type"]}
if assistant
else None
{"assistant_type": get_assistant_type(assistant.config)} if assistant else None
)
async with get_pg_pool().acquire() as conn:
await conn.execute(
@@ -189,14 +237,14 @@ async def put_thread(
updated_at,
metadata,
)
return {
"thread_id": thread_id,
"user_id": user_id,
"assistant_id": assistant_id,
"name": name,
"updated_at": updated_at,
"metadata": metadata,
}
return Thread(
thread_id=thread_id,
user_id=user_id,
assistant_id=assistant_id,
name=name,
updated_at=updated_at,
metadata=metadata,
)
async def delete_thread(user_id: str, thread_id: str):
@@ -212,9 +260,9 @@ async def delete_thread(user_id: str, thread_id: str):
async def get_or_create_user(sub: str) -> tuple[User, bool]:
"""Returns a tuple of the user and a boolean indicating whether the user was created."""
async with get_pg_pool().acquire() as conn:
if user := await conn.fetchrow('SELECT * FROM "user" WHERE sub = $1', sub):
return user, False
user = await conn.fetchrow(
if record := await conn.fetchrow('SELECT * FROM "user" WHERE sub = $1', sub):
return User(**record), False
record = await conn.fetchrow(
'INSERT INTO "user" (sub) VALUES ($1) RETURNING *', sub
)
return user, True
return User(**record), True
+84 -88
View File
@@ -1,8 +1,7 @@
from enum import Enum
from functools import lru_cache
from typing import Optional
from typing import Annotated, Literal
from langchain.pydantic_v1 import BaseModel, Field
from langchain.tools.retriever import create_retriever_tool
from langchain_community.agent_toolkits.connery import ConneryToolkit
from langchain_community.retrievers.kay import KayAiRetriever
@@ -22,26 +21,26 @@ from langchain_community.utilities.arxiv import ArxivAPIWrapper
from langchain_community.utilities.dalle_image_generator import DallEAPIWrapper
from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper
from langchain_core.tools import Tool
from langchain_robocorp import ActionServerToolkit
from pydantic import BaseModel, Field
from typing_extensions import TypedDict
from app.upload import vstore
class DDGInput(BaseModel):
query: str = Field(description="search query to look up")
query: Annotated[str, Field(description="search query to look up")]
class ArxivInput(BaseModel):
query: str = Field(description="search query to look up")
query: Annotated[str, Field(description="search query to look up")]
class PythonREPLInput(BaseModel):
query: str = Field(description="python command to run")
query: Annotated[str, Field(description="python command to run")]
class DallEInput(BaseModel):
query: str = Field(description="image description to generate image from")
query: Annotated[str, Field(description="image description to generate image from")]
class AvailableTools(str, Enum):
@@ -66,10 +65,10 @@ class ToolConfig(TypedDict):
class BaseTool(BaseModel):
type: AvailableTools
name: Optional[str]
description: Optional[str]
config: Optional[ToolConfig]
multi_use: Optional[bool] = False
name: str
description: str
config: ToolConfig = Field(default_factory=dict)
multi_use: bool = False
class ActionServerConfig(ToolConfig):
@@ -78,125 +77,133 @@ class ActionServerConfig(ToolConfig):
class ActionServer(BaseTool):
type: AvailableTools = Field(AvailableTools.ACTION_SERVER, const=True)
name: str = Field("Action Server by Sema4.ai", const=True)
description: str = Field(
type: Literal[AvailableTools.ACTION_SERVER] = AvailableTools.ACTION_SERVER
name: Literal["Action Server by Sema4.ai"] = "Action Server by Sema4.ai"
description: Literal[
(
"Run AI actions with "
"[Sema4.ai Action Server](https://github.com/Sema4AI/actions)."
),
const=True,
)
] = (
"Run AI actions with "
"[Sema4.ai Action Server](https://github.com/Sema4AI/actions)."
)
config: ActionServerConfig
multi_use: bool = Field(True, const=True)
multi_use: Literal[True] = True
class Connery(BaseTool):
type: AvailableTools = Field(AvailableTools.CONNERY, const=True)
name: str = Field("AI Action Runner by Connery", const=True)
description: str = Field(
type: Literal[AvailableTools.CONNERY] = AvailableTools.CONNERY
name: Literal["AI Action Runner by Connery"] = "AI Action Runner by Connery"
description: Literal[
(
"Connect OpenGPTs to the real world with "
"[Connery](https://github.com/connery-io/connery)."
),
const=True,
)
] = (
"Connect OpenGPTs to the real world with "
"[Connery](https://github.com/connery-io/connery)."
)
class DDGSearch(BaseTool):
type: AvailableTools = Field(AvailableTools.DDG_SEARCH, const=True)
name: str = Field("DuckDuckGo Search", const=True)
description: str = Field(
"Search the web with [DuckDuckGo](https://pypi.org/project/duckduckgo-search/).",
const=True,
)
type: Literal[AvailableTools.DDG_SEARCH] = AvailableTools.DDG_SEARCH
name: Literal["DuckDuckGo Search"] = "DuckDuckGo Search"
description: Literal[
"Search the web with [DuckDuckGo](https://pypi.org/project/duckduckgo-search/)."
] = "Search the web with [DuckDuckGo](https://pypi.org/project/duckduckgo-search/)."
class Arxiv(BaseTool):
type: AvailableTools = Field(AvailableTools.ARXIV, const=True)
name: str = Field("Arxiv", const=True)
description: str = Field("Searches [Arxiv](https://arxiv.org/).", const=True)
type: Literal[AvailableTools.ARXIV] = AvailableTools.ARXIV
name: Literal["Arxiv"] = "Arxiv"
description: Literal[
"Searches [Arxiv](https://arxiv.org/)."
] = "Searches [Arxiv](https://arxiv.org/)."
class YouSearch(BaseTool):
type: AvailableTools = Field(AvailableTools.YOU_SEARCH, const=True)
name: str = Field("You.com Search", const=True)
description: str = Field(
"Uses [You.com](https://you.com/) search, optimized responses for LLMs.",
const=True,
)
type: Literal[AvailableTools.YOU_SEARCH] = AvailableTools.YOU_SEARCH
name: Literal["You.com Search"] = "You.com Search"
description: Literal[
"Uses [You.com](https://you.com/) search, optimized responses for LLMs."
] = "Uses [You.com](https://you.com/) search, optimized responses for LLMs."
class SecFilings(BaseTool):
type: AvailableTools = Field(AvailableTools.SEC_FILINGS, const=True)
name: str = Field("SEC Filings (Kay.ai)", const=True)
description: str = Field(
"Searches through SEC filings using [Kay.ai](https://www.kay.ai/).", const=True
)
type: Literal[AvailableTools.SEC_FILINGS] = AvailableTools.SEC_FILINGS
name: Literal["SEC Filings (Kay.ai)"] = "SEC Filings (Kay.ai)"
description: Literal[
"Searches through SEC filings using [Kay.ai](https://www.kay.ai/)."
] = "Searches through SEC filings using [Kay.ai](https://www.kay.ai/)."
class PressReleases(BaseTool):
type: AvailableTools = Field(AvailableTools.PRESS_RELEASES, const=True)
name: str = Field("Press Releases (Kay.ai)", const=True)
description: str = Field(
"Searches through press releases using [Kay.ai](https://www.kay.ai/).",
const=True,
)
type: Literal[AvailableTools.PRESS_RELEASES] = AvailableTools.PRESS_RELEASES
name: Literal["Press Releases (Kay.ai)"] = "Press Releases (Kay.ai)"
description: Literal[
"Searches through press releases using [Kay.ai](https://www.kay.ai/)."
] = "Searches through press releases using [Kay.ai](https://www.kay.ai/)."
class PubMed(BaseTool):
type: AvailableTools = Field(AvailableTools.PUBMED, const=True)
name: str = Field("PubMed", const=True)
description: str = Field(
"Searches [PubMed](https://pubmed.ncbi.nlm.nih.gov/).", const=True
)
type: Literal[AvailableTools.PUBMED] = AvailableTools.PUBMED
name: Literal["PubMed"] = "PubMed"
description: Literal[
"Searches [PubMed](https://pubmed.ncbi.nlm.nih.gov/)."
] = "Searches [PubMed](https://pubmed.ncbi.nlm.nih.gov/)."
class Wikipedia(BaseTool):
type: AvailableTools = Field(AvailableTools.WIKIPEDIA, const=True)
name: str = Field("Wikipedia", const=True)
description: str = Field(
"Searches [Wikipedia](https://pypi.org/project/wikipedia/).", const=True
)
type: Literal[AvailableTools.WIKIPEDIA] = AvailableTools.WIKIPEDIA
name: Literal["Wikipedia"] = "Wikipedia"
description: Literal[
"Searches [Wikipedia](https://pypi.org/project/wikipedia/)."
] = "Searches [Wikipedia](https://pypi.org/project/wikipedia/)."
class Tavily(BaseTool):
type: AvailableTools = Field(AvailableTools.TAVILY, const=True)
name: str = Field("Search (Tavily)", const=True)
description: str = Field(
type: Literal[AvailableTools.TAVILY] = AvailableTools.TAVILY
name: Literal["Search (Tavily)"] = "Search (Tavily)"
description: Literal[
(
"Uses the [Tavily](https://app.tavily.com/) search engine. "
"Includes sources in the response."
),
const=True,
)
] = (
"Uses the [Tavily](https://app.tavily.com/) search engine. "
"Includes sources in the response."
)
class TavilyAnswer(BaseTool):
type: AvailableTools = Field(AvailableTools.TAVILY_ANSWER, const=True)
name: str = Field("Search (short answer, Tavily)", const=True)
description: str = Field(
type: Literal[AvailableTools.TAVILY_ANSWER] = AvailableTools.TAVILY_ANSWER
name: Literal["Search (short answer, Tavily)"] = "Search (short answer, Tavily)"
description: Literal[
(
"Uses the [Tavily](https://app.tavily.com/) search engine. "
"This returns only the answer, no supporting evidence."
),
const=True,
)
] = (
"Uses the [Tavily](https://app.tavily.com/) search engine. "
"This returns only the answer, no supporting evidence."
)
class Retrieval(BaseTool):
type: AvailableTools = Field(AvailableTools.RETRIEVAL, const=True)
name: str = Field("Retrieval", const=True)
description: str = Field("Look up information in uploaded files.", const=True)
type: Literal[AvailableTools.RETRIEVAL] = AvailableTools.RETRIEVAL
name: Literal["Retrieval"] = "Retrieval"
description: Literal[
"Look up information in uploaded files."
] = "Look up information in uploaded files."
class DallE(BaseTool):
type: AvailableTools = Field(AvailableTools.DALL_E, const=True)
name: str = Field("Generate Image (Dall-E)", const=True)
description: str = Field(
"Generates images from a text description using OpenAI's DALL-E model.",
const=True,
)
type: Literal[AvailableTools.DALL_E] = AvailableTools.DALL_E
name: Literal["Generate Image (Dall-E)"] = "Generate Image (Dall-E)"
description: Literal[
"Generates images from a text description using OpenAI's DALL-E model."
] = "Generates images from a text description using OpenAI's DALL-E model."
RETRIEVAL_DESCRIPTION = """Can be used to look up information that was uploaded to this assistant.
@@ -286,16 +293,6 @@ def _get_tavily_answer():
return _TavilyAnswer(api_wrapper=tavily_search, name="search_tavily_answer")
def _get_action_server(**kwargs: ActionServerConfig):
toolkit = ActionServerToolkit(
url=kwargs["url"],
api_key=kwargs["api_key"],
additional_headers=kwargs.get("additional_headers", {}),
)
tools = toolkit.get_tools()
return tools
@lru_cache(maxsize=1)
def _get_connery_actions():
connery_service = ConneryService()
@@ -314,7 +311,6 @@ def _get_dalle_tools():
TOOLS = {
AvailableTools.ACTION_SERVER: _get_action_server,
AvailableTools.CONNERY: _get_connery_actions,
AvailableTools.DDG_SEARCH: _get_duck_duck_go,
AvailableTools.ARXIV: _get_arxiv,
+6 -8
View File
@@ -24,6 +24,7 @@ from langchain_core.runnables import (
from langchain_core.vectorstores import VectorStore
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
from pydantic import ConfigDict
from app.ingest import ingest_blob
from app.parsing import MIMETYPE_BASED_PARSER
@@ -110,18 +111,15 @@ class IngestRunnable(RunnableSerializable[BinaryIO, List[str]]):
"""Runnable for ingesting files into a vectorstore."""
text_splitter: TextSplitter
"""Text splitter to use for splitting the text into chunks."""
vectorstore: VectorStore
"""Vectorstore to ingest into."""
assistant_id: Optional[str]
thread_id: Optional[str]
assistant_id: Optional[str] = None
thread_id: Optional[str] = None
"""Ingested documents will be associated with assistant_id or thread_id.
ID is used as the namespace, and is filtered on at query time.
"""
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)
@property
def namespace(self) -> str:
@@ -161,12 +159,12 @@ ingest_runnable = IngestRunnable(
).configurable_fields(
assistant_id=ConfigurableField(
id="assistant_id",
annotation=str,
annotation=Optional[str],
name="Assistant ID",
),
thread_id=ConfigurableField(
id="thread_id",
annotation=str,
annotation=Optional[str],
name="Thread ID",
),
)
@@ -0,0 +1,12 @@
-- Drop the blob storage table
DROP TABLE IF EXISTS checkpoint_blobs;
-- Drop the writes tracking table
DROP TABLE IF EXISTS checkpoint_writes;
-- Drop the new checkpoints table that was created by the application
DROP TABLE IF EXISTS checkpoints;
-- Restore the original checkpoints table by renaming old_checkpoints back
-- This preserves the original data that was saved before the migration
ALTER TABLE old_checkpoints RENAME TO checkpoints;
@@ -0,0 +1,9 @@
-- BREAKING CHANGE WARNING:
-- This migration represents a transition from pickle-based checkpointing to a new checkpoint system.
-- As a result, any threads created before this migration will not be usable/clickable in the UI.
-- old thread data remains in old_checkpoints table but cannot be accessed by the new version.
-- Rename existing checkpoints table to preserve current data
-- This is necessary because the application will create a new checkpoints table
-- with an updated schema during runtime initialization.
ALTER TABLE checkpoints RENAME TO old_checkpoints;
+2928 -2352
View File
File diff suppressed because it is too large Load Diff
+12 -11
View File
@@ -17,11 +17,12 @@ fastapi = "^0.103.2"
# langchain = { git = "git@github.com:langchain-ai/langchain.git/", branch = "nc/subclass-runnable-binding" , subdirectory = "libs/langchain"}
orjson = "^3.9.10"
python-multipart = "^0.0.6"
tiktoken = "^0.5.1"
langchain = ">=0.0.338"
langgraph = "^0.0.38"
pydantic = "<2.0"
langchain-openai = "^0.1.3"
tiktoken = "^0"
langchain = "^0.3"
langgraph = "0.2.45"
langgraph-checkpoint-postgres = "^2.0.2"
pydantic = "^2"
langchain-openai = "^0.2"
beautifulsoup4 = "^4.12.3"
boto3 = "^1.34.28"
duckduckgo-search = "^5.3.0"
@@ -29,19 +30,19 @@ arxiv = "^2.1.0"
kay = "^0.1.2"
xmltodict = "^0.13.0"
wikipedia = "^1.4.0"
langchain-google-vertexai = "^1.0.1"
langchain-google-vertexai = "^2.0"
langchain-google-community = "^2.0.1"
setuptools = "^69.0.3"
pdfminer-six = "^20231228"
langchain-robocorp = "^0.0.8"
fireworks-ai = "^0.11.2"
httpx = { version = "0.25.2", extras = ["socks"] }
unstructured = {extras = ["doc", "docx"], version = "^0.12.5"}
httpx = { version = "^0", extras = ["socks"] }
unstructured = {extras = ["doc", "docx"], version = "^0"}
pgvector = "^0.2.5"
psycopg2-binary = "^2.9.9"
asyncpg = "^0.29.0"
langchain-core = "^0.1.44"
langchain-core = "^0.3"
pyjwt = {extras = ["crypto"], version = "^2.8.0"}
langchain-anthropic = "^0.1.8"
langchain-anthropic = "^0.2"
structlog = "^24.1.0"
python-json-logger = "^2.0.7"
@@ -1,7 +1,8 @@
from io import BytesIO
from langchain.text_splitter import RecursiveCharacterTextSplitter
from fastapi import UploadFile
from langchain.text_splitter import RecursiveCharacterTextSplitter
from app.upload import IngestRunnable, _guess_mimetype, convert_ingestion_input_to_blob
from tests.unit_tests.fixtures import get_sample_paths
from tests.unit_tests.utils import InMemoryVectorStore
+14 -6
View File
@@ -4,12 +4,15 @@ from typing import Optional, Sequence
from uuid import uuid4
import asyncpg
from pydantic import BaseModel
from app.schema import Assistant, Thread
from tests.unit_tests.app.helpers import get_client
def _project(d: dict, *, exclude_keys: Optional[Sequence[str]]) -> dict:
def _project(model: BaseModel, *, exclude_keys: Optional[Sequence[str]] = None) -> dict:
"""Return a dict with only the keys specified."""
d = model.model_dump()
_exclude = set(exclude_keys) if exclude_keys else set()
return {k: v for k, v in d.items() if k not in _exclude}
@@ -38,7 +41,8 @@ async def test_list_and_create_assistants(pool: asyncpg.pool.Pool) -> None:
headers=headers,
)
assert response.status_code == 200
assert _project(response.json(), exclude_keys=["updated_at", "user_id"]) == {
assistant = Assistant.model_validate(response.json())
assert _project(assistant, exclude_keys=["updated_at", "user_id"]) == {
"assistant_id": aid,
"config": {},
"name": "bobby",
@@ -48,8 +52,9 @@ async def test_list_and_create_assistants(pool: asyncpg.pool.Pool) -> None:
assert len(await conn.fetch("SELECT * FROM assistant;")) == 1
response = await client.get("/assistants/", headers=headers)
assistants = [Assistant.model_validate(d) for d in response.json()]
assert [
_project(d, exclude_keys=["updated_at", "user_id"]) for d in response.json()
_project(d, exclude_keys=["updated_at", "user_id"]) for d in assistants
] == [
{
"assistant_id": aid,
@@ -65,7 +70,8 @@ async def test_list_and_create_assistants(pool: asyncpg.pool.Pool) -> None:
headers=headers,
)
assert _project(response.json(), exclude_keys=["updated_at", "user_id"]) == {
assistant = Assistant.model_validate(response.json())
assert _project(assistant, exclude_keys=["updated_at", "user_id"]) == {
"assistant_id": aid,
"config": {},
"name": "bobby",
@@ -79,7 +85,7 @@ async def test_list_and_create_assistants(pool: asyncpg.pool.Pool) -> None:
assert response.json() == []
async def test_threads() -> None:
async def test_threads(pool: asyncpg.pool.Pool) -> None:
"""Test put thread."""
headers = {"Cookie": "opengpts_user_id=1"}
aid = str(uuid4())
@@ -102,6 +108,7 @@ async def test_threads() -> None:
headers=headers,
)
assert response.status_code == 200, response.text
_ = Thread.model_validate(response.json())
response = await client.get(f"/threads/{tid}/state", headers=headers)
assert response.status_code == 200
@@ -110,8 +117,9 @@ async def test_threads() -> None:
response = await client.get("/threads/", headers=headers)
assert response.status_code == 200
threads = [Thread.model_validate(d) for d in response.json()]
assert [
_project(d, exclude_keys=["updated_at", "user_id"]) for d in response.json()
_project(d, exclude_keys=["updated_at", "user_id"]) for d in threads
] == [
{
"assistant_id": aid,
+1 -1
View File
@@ -20,7 +20,7 @@ from tests.unit_tests.app.helpers import get_client
@app.get("/me")
async def me(user: AuthedUser) -> dict:
return user
return user.model_dump()
def _create_jwt(
+7 -1
View File
@@ -53,10 +53,16 @@ def _migrate_test_db() -> None:
@pytest.fixture(scope="session")
async def pool():
async def _init_db():
"""Initialize the test database."""
await _drop_test_db() # In case previous test session was abruptly terminated
await _create_test_db()
_migrate_test_db()
@pytest.fixture(scope="session")
async def pool(_init_db):
"""Initialize database pool with checkpointer."""
async with lifespan(app):
yield get_pg_pool()
await _drop_test_db()
+3 -2
View File
@@ -23,6 +23,8 @@ from redis.client import Redis as RedisType
from app.checkpoint import PostgresCheckpoint
from app.lifespan import get_pg_pool, lifespan
from app.server import app
from pydantic import ConfigDict
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
@@ -52,8 +54,7 @@ def load(keys: list[str], values: list[bytes]) -> dict:
class RedisCheckpoint(BaseCheckpointSaver):
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True,)
@property
def config_specs(self) -> list[ConfigurableFieldSpec]: