mirror of
https://github.com/langchain-ai/opengpts.git
synced 2026-07-01 18:23:41 -04:00
Upgrade Opengpts (#361)
* Migrate pydantic * Upgrade poetry * Adapt to manage checkpoint using an AstncSaver * Adjust Tools model * Add checkpoint * Update poetry * Format * Fix tests * Modify tables * Fix gpt4o * Fix bots * Fix retrieval * Adding eugenes suggestions * Fix state handling inconsistency between different agent types * Improve doc * Update backend/pyproject.toml * lint fix * lint --------- Co-authored-by: “lgesuellip” <“lgesuellipinto@uade.edu.ar”> Co-authored-by: Eugene Yurtsev <eugene@langchain.dev> Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
@@ -227,6 +227,16 @@ Navigate to [http://localhost:5173/](http://localhost:5173/) and enjoy!
|
||||
|
||||
Refer to this [guide](tools/redis_to_postgres/README.md) for migrating data from Redis to Postgres.
|
||||
|
||||
## Breaking Changes
|
||||
|
||||
### Migration 5 - Checkpoint Management Update
|
||||
Version 5 of the database migrations introduces a significant change to how thread checkpoints are managed:
|
||||
- Transitions from a pickle-based checkpointing system to a new multi-table checkpoint management system (breaking change)
|
||||
- Aligns with LangGraph's new checkpoint architecture for better state management and persistence
|
||||
- **Important**: Historical threads/checkpoints (created before this migration) will not be accessible in the UI
|
||||
- Previous checkpoint data is preserved in the `old_checkpoints` table but cannot be accessed by the new system
|
||||
- This architectural change improves how thread state is stored and managed, enabling more reliable state persistence in LangGraph-based agents.
|
||||
|
||||
## Features
|
||||
|
||||
As much as possible, we are striving for feature parity with OpenAI.
|
||||
@@ -309,14 +319,14 @@ Then, those documents are passed in the system message to a separate call to the
|
||||
|
||||
Compared to assistants, it is more structured (but less powerful). It ALWAYS looks up something - which is good if you
|
||||
know you want to look things up, but potentially wasteful if the user is just trying to have a normal conversation.
|
||||
Also importantly, this only looks up things once - so if it doesn’t find the right results then it will yield a bad
|
||||
Also importantly, this only looks up things once - so if it doesn't find the right results then it will yield a bad
|
||||
result (compared to an assistant, which could decide to look things up again).
|
||||
|
||||

|
||||
|
||||
Despite this being a more simple architecture, it is good for a few reasons. First, because it is simpler it can work
|
||||
pretty well with a wider variety of models (including lots of open source models). Second, if you have a use case where
|
||||
you don’t NEED the flexibility of an assistant (eg you know users will be looking up information every time) then it
|
||||
you don't NEED the flexibility of an assistant (eg you know users will be looking up information every time) then it
|
||||
can be more focused. And third, compared to the final architecture below it can use external knowledge.
|
||||
|
||||
RAGBot is implemented with [LangGraph](https://github.com/langchain-ai/langgraph) `StateGraph`. A `StateGraph` is a generalized graph that can model arbitrary state (i.e. `dict`), not just a `list` of messages.
|
||||
@@ -325,7 +335,7 @@ RAGBot is implemented with [LangGraph](https://github.com/langchain-ai/langgraph
|
||||
|
||||
The final architecture is dead simple - just a call to a language model, parameterized by a system message. This allows
|
||||
the GPT to take on different personas and characters. This is clearly far less powerful than Assistants or RAGBots
|
||||
(which have access to external sources of data/computation) - but it’s still valuable! A lot of popular GPTs are just
|
||||
(which have access to external sources of data/computation) - but it's still valuable! A lot of popular GPTs are just
|
||||
system messages at the end of the day, and CharacterAI is crushing it despite largely just being system messages as
|
||||
well.
|
||||
|
||||
|
||||
@@ -1 +1,26 @@
|
||||
# backend
|
||||
|
||||
## Database Migrations
|
||||
|
||||
### Migration 5 - Checkpoint Management Update
|
||||
This migration introduces a significant change to thread checkpoint management:
|
||||
|
||||
#### Changes
|
||||
- Transitions from single-table pickle storage to a robust multi-table checkpoint management system
|
||||
- Implements LangGraph's latest checkpoint architecture for improved state persistence
|
||||
- Preserves existing checkpoint data by renaming `checkpoints` table to `old_checkpoints`
|
||||
- Introduces three new tables for better checkpoint management:
|
||||
- `checkpoints`: Core checkpoint metadata
|
||||
- `checkpoint_blobs`: Actual checkpoint data storage (compatible with LangGraph state serialization)
|
||||
- `checkpoint_writes`: Tracks checkpoint write operations
|
||||
- Adds runtime initialization via `ensure_setup()` in the lifespan event
|
||||
|
||||
#### Impact
|
||||
- **Breaking Change**: Historical threads/checkpoints (pre-migration) will not be accessible in the UI
|
||||
- Previous checkpoint data remains preserved but inaccessible in the new system
|
||||
- Designed to work seamlessly with LangGraph's state persistence requirements
|
||||
|
||||
#### Migration Details
|
||||
- **Up Migration**: Safely preserves existing data by renaming the table
|
||||
- **Down Migration**: Restores original table structure if needed
|
||||
- New checkpoint management tables are automatically created at application startup
|
||||
|
||||
+16
-13
@@ -1,4 +1,3 @@
|
||||
import pickle
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Mapping, Optional, Sequence, Union
|
||||
|
||||
@@ -7,14 +6,13 @@ from langchain_core.runnables import (
|
||||
ConfigurableField,
|
||||
RunnableBinding,
|
||||
)
|
||||
from langgraph.checkpoint import CheckpointAt
|
||||
from langgraph.graph.message import Messages
|
||||
from langgraph.pregel import Pregel
|
||||
|
||||
from app.agent_types.tools_agent import get_tools_agent_executor
|
||||
from app.agent_types.xml_agent import get_xml_agent_executor
|
||||
from app.chatbot import get_chatbot_executor
|
||||
from app.checkpoint import PostgresCheckpoint
|
||||
from app.checkpoint import AsyncPostgresCheckpoint
|
||||
from app.llms import (
|
||||
get_anthropic_llm,
|
||||
get_google_llm,
|
||||
@@ -74,7 +72,7 @@ class AgentType(str, Enum):
|
||||
|
||||
DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant."
|
||||
|
||||
CHECKPOINTER = PostgresCheckpoint(serde=pickle, at=CheckpointAt.END_OF_STEP)
|
||||
CHECKPOINTER = AsyncPostgresCheckpoint()
|
||||
|
||||
|
||||
def get_agent_executor(
|
||||
@@ -123,7 +121,6 @@ def get_agent_executor(
|
||||
return get_tools_agent_executor(
|
||||
tools, llm, system_message, interrupt_before_action, CHECKPOINTER
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError("Unexpected agent type")
|
||||
|
||||
@@ -135,7 +132,7 @@ class ConfigurableAgent(RunnableBinding):
|
||||
retrieval_description: str = RETRIEVAL_DESCRIPTION
|
||||
interrupt_before_action: bool = False
|
||||
assistant_id: Optional[str] = None
|
||||
thread_id: Optional[str] = None
|
||||
thread_id: Optional[str] = ""
|
||||
user_id: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
@@ -145,7 +142,7 @@ class ConfigurableAgent(RunnableBinding):
|
||||
agent: AgentType = AgentType.GPT_35_TURBO,
|
||||
system_message: str = DEFAULT_SYSTEM_MESSAGE,
|
||||
assistant_id: Optional[str] = None,
|
||||
thread_id: Optional[str] = None,
|
||||
thread_id: Optional[str] = "",
|
||||
retrieval_description: str = RETRIEVAL_DESCRIPTION,
|
||||
interrupt_before_action: bool = False,
|
||||
kwargs: Optional[Mapping[str, Any]] = None,
|
||||
@@ -204,7 +201,9 @@ def get_chatbot(
|
||||
if llm_type == LLMType.GPT_35_TURBO:
|
||||
llm = get_openai_llm()
|
||||
elif llm_type == LLMType.GPT_4:
|
||||
llm = get_openai_llm(gpt_4=True)
|
||||
llm = get_openai_llm(model="gpt-4")
|
||||
elif llm_type == LLMType.GPT_4O:
|
||||
llm = get_openai_llm(model="gpt-4o")
|
||||
elif llm_type == LLMType.AZURE_OPENAI:
|
||||
llm = get_openai_llm(azure=True)
|
||||
elif llm_type == LLMType.CLAUDE2:
|
||||
@@ -265,7 +264,7 @@ class ConfigurableRetrieval(RunnableBinding):
|
||||
llm_type: LLMType
|
||||
system_message: str = DEFAULT_SYSTEM_MESSAGE
|
||||
assistant_id: Optional[str] = None
|
||||
thread_id: Optional[str] = None
|
||||
thread_id: Optional[str] = ""
|
||||
user_id: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
@@ -274,7 +273,7 @@ class ConfigurableRetrieval(RunnableBinding):
|
||||
llm_type: LLMType = LLMType.GPT_35_TURBO,
|
||||
system_message: str = DEFAULT_SYSTEM_MESSAGE,
|
||||
assistant_id: Optional[str] = None,
|
||||
thread_id: Optional[str] = None,
|
||||
thread_id: Optional[str] = "",
|
||||
kwargs: Optional[Mapping[str, Any]] = None,
|
||||
config: Optional[Mapping[str, Any]] = None,
|
||||
**others: Any,
|
||||
@@ -319,7 +318,9 @@ chat_retrieval = (
|
||||
assistant_id=ConfigurableField(
|
||||
id="assistant_id", name="Assistant ID", is_shared=True
|
||||
),
|
||||
thread_id=ConfigurableField(id="thread_id", name="Thread ID", is_shared=True),
|
||||
thread_id=ConfigurableField(
|
||||
id="thread_id", name="Thread ID", annotation=str, is_shared=True
|
||||
),
|
||||
)
|
||||
.with_types(
|
||||
input_type=Dict[str, Any],
|
||||
@@ -335,7 +336,7 @@ agent: Pregel = (
|
||||
system_message=DEFAULT_SYSTEM_MESSAGE,
|
||||
retrieval_description=RETRIEVAL_DESCRIPTION,
|
||||
assistant_id=None,
|
||||
thread_id=None,
|
||||
thread_id="",
|
||||
)
|
||||
.configurable_fields(
|
||||
agent=ConfigurableField(id="agent_type", name="Agent Type"),
|
||||
@@ -348,7 +349,9 @@ agent: Pregel = (
|
||||
assistant_id=ConfigurableField(
|
||||
id="assistant_id", name="Assistant ID", is_shared=True
|
||||
),
|
||||
thread_id=ConfigurableField(id="thread_id", name="Thread ID", is_shared=True),
|
||||
thread_id=ConfigurableField(
|
||||
id="thread_id", name="Thread ID", annotation=str, is_shared=True
|
||||
),
|
||||
tools=ConfigurableField(id="tools", name="Tools"),
|
||||
retrieval_description=ConfigurableField(
|
||||
id="retrieval_description", name="Retrieval Description"
|
||||
|
||||
@@ -28,7 +28,7 @@ def get_tools_agent_executor(
|
||||
msgs = []
|
||||
for m in messages:
|
||||
if isinstance(m, LiberalToolMessage):
|
||||
_dict = m.dict()
|
||||
_dict = m.model_dump()
|
||||
_dict["content"] = str(_dict["content"])
|
||||
m_c = ToolMessage(**_dict)
|
||||
msgs.append(m_c)
|
||||
|
||||
@@ -45,7 +45,7 @@ def construct_chat_history(messages):
|
||||
temp_messages = []
|
||||
collapsed_messages.append(message)
|
||||
elif isinstance(message, LiberalFunctionMessage):
|
||||
_dict = message.dict()
|
||||
_dict = message.model_dump()
|
||||
_dict["content"] = str(_dict["content"])
|
||||
m_c = FunctionMessage(**_dict)
|
||||
temp_messages.append(m_c)
|
||||
|
||||
@@ -14,9 +14,11 @@ router = APIRouter()
|
||||
class AssistantPayload(BaseModel):
|
||||
"""Payload for creating an assistant."""
|
||||
|
||||
name: str = Field(..., description="The name of the assistant.")
|
||||
config: dict = Field(..., description="The assistant config.")
|
||||
public: bool = Field(default=False, description="Whether the assistant is public.")
|
||||
name: Annotated[str, Field(description="The name of the assistant.")]
|
||||
config: Annotated[dict, Field(description="The assistant config.")]
|
||||
public: Annotated[
|
||||
bool, Field(default=False, description="Whether the assistant is public.")
|
||||
]
|
||||
|
||||
|
||||
AssistantID = Annotated[str, Path(description="The ID of the assistant.")]
|
||||
@@ -25,7 +27,7 @@ AssistantID = Annotated[str, Path(description="The ID of the assistant.")]
|
||||
@router.get("/")
|
||||
async def list_assistants(user: AuthedUser) -> List[Assistant]:
|
||||
"""List all assistants for the current user."""
|
||||
return await storage.list_assistants(user["user_id"])
|
||||
return await storage.list_assistants(user.user_id)
|
||||
|
||||
|
||||
@router.get("/public/")
|
||||
@@ -40,7 +42,7 @@ async def get_assistant(
|
||||
aid: AssistantID,
|
||||
) -> Assistant:
|
||||
"""Get an assistant by ID."""
|
||||
assistant = await storage.get_assistant(user["user_id"], aid)
|
||||
assistant = await storage.get_assistant(user.user_id, aid)
|
||||
if not assistant:
|
||||
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||
return assistant
|
||||
@@ -53,7 +55,7 @@ async def create_assistant(
|
||||
) -> Assistant:
|
||||
"""Create an assistant."""
|
||||
return await storage.put_assistant(
|
||||
user["user_id"],
|
||||
user.user_id,
|
||||
str(uuid4()),
|
||||
name=payload.name,
|
||||
config=payload.config,
|
||||
@@ -69,7 +71,7 @@ async def upsert_assistant(
|
||||
) -> Assistant:
|
||||
"""Create or update an assistant."""
|
||||
return await storage.put_assistant(
|
||||
user["user_id"],
|
||||
user.user_id,
|
||||
aid,
|
||||
name=payload.name,
|
||||
config=payload.config,
|
||||
@@ -83,5 +85,5 @@ async def delete_assistant(
|
||||
aid: AssistantID,
|
||||
):
|
||||
"""Delete an assistant by ID."""
|
||||
await storage.delete_assistant(user["user_id"], aid)
|
||||
await storage.delete_assistant(user.user_id, aid)
|
||||
return {"status": "ok"}
|
||||
|
||||
+23
-14
@@ -4,14 +4,13 @@ from uuid import UUID
|
||||
import langsmith.client
|
||||
from fastapi import APIRouter, BackgroundTasks, HTTPException
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from langchain.pydantic_v1 import ValidationError
|
||||
from langchain_core.messages import AnyMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langsmith.utils import tracing_is_enabled
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sse_starlette import EventSourceResponse
|
||||
|
||||
from app.agent import agent
|
||||
from app.agent import agent, chat_retrieval, chatbot
|
||||
from app.auth.handlers import AuthedUser
|
||||
from app.storage import get_assistant, get_thread
|
||||
from app.stream import astream_state, to_sse
|
||||
@@ -34,24 +33,34 @@ async def _run_input_and_config(payload: CreateRunPayload, user_id: str):
|
||||
if not thread:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
assistant = await get_assistant(user_id, str(thread["assistant_id"]))
|
||||
assistant = await get_assistant(user_id, str(thread.assistant_id))
|
||||
if not assistant:
|
||||
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||
|
||||
config: RunnableConfig = {
|
||||
**assistant["config"],
|
||||
**assistant.config,
|
||||
"configurable": {
|
||||
**assistant["config"]["configurable"],
|
||||
**assistant.config["configurable"],
|
||||
**((payload.config or {}).get("configurable") or {}),
|
||||
"user_id": user_id,
|
||||
"thread_id": str(thread["thread_id"]),
|
||||
"assistant_id": str(assistant["assistant_id"]),
|
||||
"thread_id": str(thread.thread_id),
|
||||
"assistant_id": str(assistant.assistant_id),
|
||||
},
|
||||
}
|
||||
|
||||
try:
|
||||
if payload.input is not None:
|
||||
agent.get_input_schema(config).validate(payload.input)
|
||||
# Get the bot type from config
|
||||
bot_type = config["configurable"].get("type", "agent")
|
||||
# Get the correct schema based on bot type
|
||||
if bot_type == "chat_retrieval":
|
||||
schema = chat_retrieval.get_input_schema()
|
||||
elif bot_type == "chatbot":
|
||||
schema = chatbot.get_input_schema()
|
||||
else: # default to agent
|
||||
schema = agent.get_input_schema()
|
||||
# Validate against the correct schema
|
||||
schema.model_validate(payload.input)
|
||||
except ValidationError as e:
|
||||
raise RequestValidationError(e.errors(), body=payload)
|
||||
|
||||
@@ -65,7 +74,7 @@ async def create_run(
|
||||
background_tasks: BackgroundTasks,
|
||||
):
|
||||
"""Create a run."""
|
||||
input_, config = await _run_input_and_config(payload, user["user_id"])
|
||||
input_, config = await _run_input_and_config(payload, user.user_id)
|
||||
background_tasks.add_task(agent.ainvoke, input_, config)
|
||||
return {"status": "ok"} # TODO add a run id
|
||||
|
||||
@@ -76,7 +85,7 @@ async def stream_run(
|
||||
user: AuthedUser,
|
||||
):
|
||||
"""Create a run."""
|
||||
input_, config = await _run_input_and_config(payload, user["user_id"])
|
||||
input_, config = await _run_input_and_config(payload, user.user_id)
|
||||
|
||||
return EventSourceResponse(to_sse(astream_state(agent, input_, config)))
|
||||
|
||||
@@ -84,19 +93,19 @@ async def stream_run(
|
||||
@router.get("/input_schema")
|
||||
async def input_schema() -> dict:
|
||||
"""Return the input schema of the runnable."""
|
||||
return agent.get_input_schema().schema()
|
||||
return agent.get_input_schema().model_json_schema()
|
||||
|
||||
|
||||
@router.get("/output_schema")
|
||||
async def output_schema() -> dict:
|
||||
"""Return the output schema of the runnable."""
|
||||
return agent.get_output_schema().schema()
|
||||
return agent.get_output_schema().model_json_schema()
|
||||
|
||||
|
||||
@router.get("/config_schema")
|
||||
async def config_schema() -> dict:
|
||||
"""Return the config schema of the runnable."""
|
||||
return agent.config_schema().schema()
|
||||
return agent.config_schema().model_json_schema()
|
||||
|
||||
|
||||
if tracing_is_enabled():
|
||||
|
||||
+16
-16
@@ -18,8 +18,8 @@ ThreadID = Annotated[str, Path(description="The ID of the thread.")]
|
||||
class ThreadPutRequest(BaseModel):
|
||||
"""Payload for creating a thread."""
|
||||
|
||||
name: str = Field(..., description="The name of the thread.")
|
||||
assistant_id: str = Field(..., description="The ID of the assistant to use.")
|
||||
name: Annotated[str, Field(description="The name of the thread.")]
|
||||
assistant_id: Annotated[str, Field(description="The ID of the assistant to use.")]
|
||||
|
||||
|
||||
class ThreadPostRequest(BaseModel):
|
||||
@@ -32,7 +32,7 @@ class ThreadPostRequest(BaseModel):
|
||||
@router.get("/")
|
||||
async def list_threads(user: AuthedUser) -> List[Thread]:
|
||||
"""List all threads for the current user."""
|
||||
return await storage.list_threads(user["user_id"])
|
||||
return await storage.list_threads(user.user_id)
|
||||
|
||||
|
||||
@router.get("/{tid}/state")
|
||||
@@ -41,14 +41,14 @@ async def get_thread_state(
|
||||
tid: ThreadID,
|
||||
):
|
||||
"""Get state for a thread."""
|
||||
thread = await storage.get_thread(user["user_id"], tid)
|
||||
thread = await storage.get_thread(user.user_id, tid)
|
||||
if not thread:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
assistant = await storage.get_assistant(user["user_id"], thread["assistant_id"])
|
||||
assistant = await storage.get_assistant(user.user_id, thread.assistant_id)
|
||||
if not assistant:
|
||||
raise HTTPException(status_code=400, detail="Thread has no assistant")
|
||||
return await storage.get_thread_state(
|
||||
user_id=user["user_id"],
|
||||
user_id=user.user_id,
|
||||
thread_id=tid,
|
||||
assistant=assistant,
|
||||
)
|
||||
@@ -61,16 +61,16 @@ async def add_thread_state(
|
||||
payload: ThreadPostRequest,
|
||||
):
|
||||
"""Add state to a thread."""
|
||||
thread = await storage.get_thread(user["user_id"], tid)
|
||||
thread = await storage.get_thread(user.user_id, tid)
|
||||
if not thread:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
assistant = await storage.get_assistant(user["user_id"], thread["assistant_id"])
|
||||
assistant = await storage.get_assistant(user.user_id, thread.assistant_id)
|
||||
if not assistant:
|
||||
raise HTTPException(status_code=400, detail="Thread has no assistant")
|
||||
return await storage.update_thread_state(
|
||||
payload.config or {"configurable": {"thread_id": tid}},
|
||||
payload.values,
|
||||
user_id=user["user_id"],
|
||||
user_id=user.user_id,
|
||||
assistant=assistant,
|
||||
)
|
||||
|
||||
@@ -81,14 +81,14 @@ async def get_thread_history(
|
||||
tid: ThreadID,
|
||||
):
|
||||
"""Get all past states for a thread."""
|
||||
thread = await storage.get_thread(user["user_id"], tid)
|
||||
thread = await storage.get_thread(user.user_id, tid)
|
||||
if not thread:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
assistant = await storage.get_assistant(user["user_id"], thread["assistant_id"])
|
||||
assistant = await storage.get_assistant(user.user_id, thread.assistant_id)
|
||||
if not assistant:
|
||||
raise HTTPException(status_code=400, detail="Thread has no assistant")
|
||||
return await storage.get_thread_history(
|
||||
user_id=user["user_id"],
|
||||
user_id=user.user_id,
|
||||
thread_id=tid,
|
||||
assistant=assistant,
|
||||
)
|
||||
@@ -100,7 +100,7 @@ async def get_thread(
|
||||
tid: ThreadID,
|
||||
) -> Thread:
|
||||
"""Get a thread by ID."""
|
||||
thread = await storage.get_thread(user["user_id"], tid)
|
||||
thread = await storage.get_thread(user.user_id, tid)
|
||||
if not thread:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
return thread
|
||||
@@ -113,7 +113,7 @@ async def create_thread(
|
||||
) -> Thread:
|
||||
"""Create a thread."""
|
||||
return await storage.put_thread(
|
||||
user["user_id"],
|
||||
user.user_id,
|
||||
str(uuid4()),
|
||||
assistant_id=thread_put_request.assistant_id,
|
||||
name=thread_put_request.name,
|
||||
@@ -128,7 +128,7 @@ async def upsert_thread(
|
||||
) -> Thread:
|
||||
"""Update a thread."""
|
||||
return await storage.put_thread(
|
||||
user["user_id"],
|
||||
user.user_id,
|
||||
tid,
|
||||
assistant_id=thread_put_request.assistant_id,
|
||||
name=thread_put_request.name,
|
||||
@@ -141,5 +141,5 @@ async def delete_thread(
|
||||
tid: ThreadID,
|
||||
):
|
||||
"""Delete a thread by ID."""
|
||||
await storage.delete_thread(user["user_id"], tid)
|
||||
await storage.delete_thread(user.user_id, tid)
|
||||
return {"status": "ok"}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import os
|
||||
from base64 import b64decode
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from pydantic import BaseSettings, root_validator, validator
|
||||
from pydantic import ConfigDict, field_validator, model_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class AuthType(Enum):
|
||||
@@ -16,12 +17,16 @@ class JWTSettingsBase(BaseSettings):
|
||||
iss: str
|
||||
aud: Union[str, list[str]]
|
||||
|
||||
@validator("aud", pre=True, always=True)
|
||||
def set_aud(cls, v, values) -> Union[str, list[str]]:
|
||||
return v.split(",") if "," in v else v
|
||||
@field_validator("aud", mode="before")
|
||||
@classmethod
|
||||
def set_aud(cls, v) -> Union[str, List[str]]:
|
||||
if isinstance(v, str) and "," in v:
|
||||
return v.split(",")
|
||||
return v
|
||||
|
||||
class Config:
|
||||
env_prefix = "jwt_"
|
||||
model_config = ConfigDict(
|
||||
env_prefix="jwt_",
|
||||
)
|
||||
|
||||
|
||||
class JWTSettingsLocal(JWTSettingsBase):
|
||||
@@ -29,14 +34,18 @@ class JWTSettingsLocal(JWTSettingsBase):
|
||||
decode_key: str = None
|
||||
alg: str
|
||||
|
||||
@validator("decode_key", pre=True, always=True)
|
||||
def set_decode_key(cls, v, values):
|
||||
@field_validator("decode_key", mode="before")
|
||||
@classmethod
|
||||
def set_decode_key(cls, v, info):
|
||||
"""
|
||||
Key may be a multiline string (e.g. in the case of a public key), so to
|
||||
be able to set it from env, we set it as a base64 encoded string and
|
||||
decode it here.
|
||||
"""
|
||||
return b64decode(values["decode_key_b64"]).decode("utf-8")
|
||||
decode_key_b64 = info.data.get("decode_key_b64")
|
||||
if decode_key_b64:
|
||||
return b64decode(decode_key_b64).decode("utf-8")
|
||||
return v
|
||||
|
||||
|
||||
class JWTSettingsOIDC(JWTSettingsBase):
|
||||
@@ -48,7 +57,8 @@ class Settings(BaseSettings):
|
||||
jwt_local: Optional[JWTSettingsLocal] = None
|
||||
jwt_oidc: Optional[JWTSettingsOIDC] = None
|
||||
|
||||
@root_validator(pre=True)
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_jwt_settings(cls, values):
|
||||
auth_type = values.get("auth_type")
|
||||
if auth_type == AuthType.JWT_LOCAL and values.get("jwt_local") is None:
|
||||
|
||||
+96
-126
@@ -1,147 +1,117 @@
|
||||
import pickle
|
||||
from datetime import datetime
|
||||
from typing import AsyncIterator, Optional
|
||||
import os
|
||||
from typing import Any, AsyncIterator, Optional, Sequence
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.runnables import ConfigurableFieldSpec, RunnableConfig
|
||||
from langgraph.checkpoint.base import BaseCheckpointSaver
|
||||
import structlog
|
||||
from langgraph.checkpoint.base import (
|
||||
ChannelVersions,
|
||||
Checkpoint,
|
||||
CheckpointAt,
|
||||
CheckpointThreadTs,
|
||||
CheckpointMetadata,
|
||||
CheckpointTuple,
|
||||
SerializerProtocol,
|
||||
RunnableConfig,
|
||||
)
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
from langgraph.checkpoint.postgres.base import BasePostgresSaver
|
||||
from langgraph.checkpoint.serde.base import SerializerProtocol
|
||||
from psycopg import AsyncPipeline
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
|
||||
from app.lifespan import get_pg_pool
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
def loads(value: bytes) -> Checkpoint:
|
||||
loaded: Checkpoint = pickle.loads(value)
|
||||
for key, value in loaded["channel_values"].items():
|
||||
if isinstance(value, list) and all(isinstance(v, BaseMessage) for v in value):
|
||||
loaded["channel_values"][key] = [v.__class__(**v.__dict__) for v in value]
|
||||
return loaded
|
||||
class AsyncPostgresCheckpoint(BasePostgresSaver):
|
||||
"""A singleton implementation of AsyncPostgresSaver with separate setup."""
|
||||
|
||||
_instance = None
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if not cls._instance:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
class PostgresCheckpoint(BaseCheckpointSaver):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
pipe: Optional[AsyncPipeline] = None,
|
||||
serde: Optional[SerializerProtocol] = None,
|
||||
at: Optional[CheckpointAt] = None,
|
||||
) -> None:
|
||||
super().__init__(serde=serde, at=at)
|
||||
if not hasattr(self, "_initialized"):
|
||||
super().__init__(serde=serde)
|
||||
# Initialize basic attributes
|
||||
self.pipe = pipe
|
||||
self.serde = serde
|
||||
self._initialized = True
|
||||
self._setup_complete = False
|
||||
self.async_postgres_saver = None
|
||||
|
||||
@property
|
||||
def config_specs(self) -> list[ConfigurableFieldSpec]:
|
||||
return [
|
||||
ConfigurableFieldSpec(
|
||||
id="thread_id",
|
||||
annotation=Optional[str],
|
||||
name="Thread ID",
|
||||
description=None,
|
||||
default=None,
|
||||
is_shared=True,
|
||||
),
|
||||
CheckpointThreadTs,
|
||||
]
|
||||
async def ensure_setup(self) -> None:
|
||||
"""Ensure the instance is set up before use."""
|
||||
if not self._setup_complete:
|
||||
await self.setup()
|
||||
self._setup_complete = True
|
||||
|
||||
def get(self, config: RunnableConfig) -> Optional[Checkpoint]:
|
||||
raise NotImplementedError
|
||||
async def setup(self) -> None:
|
||||
"""Internal setup method."""
|
||||
try:
|
||||
conninfo = (
|
||||
f"postgresql://{os.environ['POSTGRES_USER']}:"
|
||||
f"{os.environ['POSTGRES_PASSWORD']}@"
|
||||
f"{os.environ['POSTGRES_HOST']}:"
|
||||
f"{os.environ['POSTGRES_PORT']}/"
|
||||
f"{os.environ['POSTGRES_DB']}"
|
||||
)
|
||||
|
||||
def put(self, config: RunnableConfig, checkpoint: Checkpoint) -> RunnableConfig:
|
||||
raise NotImplementedError
|
||||
pool = AsyncConnectionPool(
|
||||
conninfo=conninfo,
|
||||
kwargs={"autocommit": True, "prepare_threshold": 0},
|
||||
open=False, # Don't open in constructor
|
||||
)
|
||||
await pool.open()
|
||||
|
||||
async def alist(self, config: RunnableConfig) -> AsyncIterator[CheckpointTuple]:
|
||||
async with get_pg_pool().acquire() as db, db.transaction():
|
||||
thread_id = config["configurable"]["thread_id"]
|
||||
async for value in db.cursor(
|
||||
"SELECT checkpoint, thread_ts, parent_ts FROM checkpoints WHERE thread_id = $1 ORDER BY thread_ts DESC",
|
||||
thread_id,
|
||||
):
|
||||
yield CheckpointTuple(
|
||||
{
|
||||
"configurable": {
|
||||
"thread_id": thread_id,
|
||||
"thread_ts": value[1],
|
||||
}
|
||||
},
|
||||
loads(value[0]),
|
||||
{
|
||||
"configurable": {
|
||||
"thread_id": thread_id,
|
||||
"thread_ts": value[2],
|
||||
}
|
||||
}
|
||||
if value[2]
|
||||
else None,
|
||||
)
|
||||
self.async_postgres_saver = AsyncPostgresSaver(
|
||||
conn=pool, pipe=self.pipe, serde=self.serde
|
||||
)
|
||||
|
||||
# Setup will create/migrate the tables if they don't exist
|
||||
await self.async_postgres_saver.setup()
|
||||
|
||||
logger.warning("Checkpoint setup complete.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to set up AsyncPostgresCheckpoint: {e}")
|
||||
raise
|
||||
|
||||
async def alist(
|
||||
self,
|
||||
config: Optional[RunnableConfig],
|
||||
*,
|
||||
filter: Optional[dict[str, Any]] = None,
|
||||
before: Optional[RunnableConfig] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> AsyncIterator[CheckpointTuple]:
|
||||
"""List checkpoints from the database asynchronously."""
|
||||
return self.async_postgres_saver.alist(
|
||||
config, filter=filter, before=before, limit=limit
|
||||
)
|
||||
|
||||
async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
|
||||
thread_id = config["configurable"]["thread_id"]
|
||||
thread_ts = config["configurable"].get("thread_ts")
|
||||
async with get_pg_pool().acquire() as conn:
|
||||
if thread_ts:
|
||||
if value := await conn.fetchrow(
|
||||
"SELECT checkpoint, parent_ts FROM checkpoints WHERE thread_id = $1 AND thread_ts = $2",
|
||||
thread_id,
|
||||
datetime.fromisoformat(thread_ts),
|
||||
):
|
||||
return CheckpointTuple(
|
||||
config,
|
||||
loads(value[0]),
|
||||
{
|
||||
"configurable": {
|
||||
"thread_id": thread_id,
|
||||
"thread_ts": value[1],
|
||||
}
|
||||
}
|
||||
if value[1]
|
||||
else None,
|
||||
)
|
||||
else:
|
||||
if value := await conn.fetchrow(
|
||||
"SELECT checkpoint, thread_ts, parent_ts FROM checkpoints WHERE thread_id = $1 ORDER BY thread_ts DESC LIMIT 1",
|
||||
thread_id,
|
||||
):
|
||||
return CheckpointTuple(
|
||||
{
|
||||
"configurable": {
|
||||
"thread_id": thread_id,
|
||||
"thread_ts": value[1],
|
||||
}
|
||||
},
|
||||
loads(value[0]),
|
||||
{
|
||||
"configurable": {
|
||||
"thread_id": thread_id,
|
||||
"thread_ts": value[2],
|
||||
}
|
||||
}
|
||||
if value[2]
|
||||
else None,
|
||||
)
|
||||
"""Get a checkpoint tuple from the database asynchronously."""
|
||||
return await self.async_postgres_saver.aget_tuple(config)
|
||||
|
||||
async def aput(self, config: RunnableConfig, checkpoint: Checkpoint) -> None:
|
||||
thread_id = config["configurable"]["thread_id"]
|
||||
async with get_pg_pool().acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO checkpoints (thread_id, thread_ts, parent_ts, checkpoint)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (thread_id, thread_ts)
|
||||
DO UPDATE SET checkpoint = EXCLUDED.checkpoint;""",
|
||||
thread_id,
|
||||
datetime.fromisoformat(checkpoint["ts"]),
|
||||
datetime.fromisoformat(checkpoint.get("parent_ts"))
|
||||
if checkpoint.get("parent_ts")
|
||||
else None,
|
||||
pickle.dumps(checkpoint),
|
||||
)
|
||||
return {
|
||||
"configurable": {
|
||||
"thread_id": thread_id,
|
||||
"thread_ts": checkpoint["ts"],
|
||||
}
|
||||
}
|
||||
async def aput(
|
||||
self,
|
||||
config: RunnableConfig,
|
||||
checkpoint: Checkpoint,
|
||||
metadata: CheckpointMetadata,
|
||||
new_versions: ChannelVersions,
|
||||
) -> RunnableConfig:
|
||||
"""Save a checkpoint to the database asynchronously."""
|
||||
return await self.async_postgres_saver.aput(
|
||||
config, checkpoint, metadata, new_versions
|
||||
)
|
||||
|
||||
async def aput_writes(
|
||||
self,
|
||||
config: RunnableConfig,
|
||||
writes: Sequence[tuple[str, Any]],
|
||||
task_id: str,
|
||||
) -> None:
|
||||
"""Store intermediate writes linked to a checkpoint asynchronously."""
|
||||
await self.async_postgres_saver.aput_writes(config, writes, task_id)
|
||||
|
||||
@@ -6,6 +6,8 @@ import orjson
|
||||
import structlog
|
||||
from fastapi import FastAPI
|
||||
|
||||
from app.checkpoint import AsyncPostgresCheckpoint
|
||||
|
||||
_pg_pool = None
|
||||
|
||||
|
||||
@@ -56,6 +58,7 @@ async def lifespan(app: FastAPI):
|
||||
port=os.environ["POSTGRES_PORT"],
|
||||
init=_init_connection,
|
||||
)
|
||||
await AsyncPostgresCheckpoint().ensure_setup()
|
||||
yield
|
||||
await _pg_pool.close()
|
||||
_pg_pool = None
|
||||
|
||||
@@ -1,33 +1,34 @@
|
||||
from typing import Any, get_args
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import (
|
||||
AnyMessage,
|
||||
FunctionMessage,
|
||||
MessageLikeRepresentation,
|
||||
ToolMessage,
|
||||
_message_from_dict,
|
||||
)
|
||||
from langgraph.graph.message import Messages, add_messages
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class LiberalFunctionMessage(FunctionMessage):
|
||||
content: Any
|
||||
content: Any = Field(default="")
|
||||
|
||||
|
||||
class LiberalToolMessage(ToolMessage):
|
||||
content: Any
|
||||
content: Any = Field(default="")
|
||||
|
||||
|
||||
def _convert_pydantic_dict_to_message(
|
||||
data: MessageLikeRepresentation
|
||||
data: MessageLikeRepresentation,
|
||||
) -> MessageLikeRepresentation:
|
||||
"""Convert a dictionary to a message object if it matches message format."""
|
||||
if (
|
||||
isinstance(data, dict)
|
||||
and "content" in data
|
||||
and isinstance(data.get("type"), str)
|
||||
):
|
||||
for cls in get_args(AnyMessage):
|
||||
if data["type"] == cls(content="").type:
|
||||
return cls(**data)
|
||||
_type = data.pop("type")
|
||||
return _message_from_dict({"data": data, "type": _type})
|
||||
return data
|
||||
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ def get_retrieval_executor(
|
||||
if isinstance(m, HumanMessage):
|
||||
chat_history.append(m)
|
||||
response = messages[-1].content
|
||||
content = "\n".join([d.page_content for d in response])
|
||||
content = "\n".join([d["page_content"] for d in response])
|
||||
return [
|
||||
SystemMessage(
|
||||
content=response_prompt_template.format(
|
||||
@@ -80,7 +80,7 @@ def get_retrieval_executor(
|
||||
async def invoke_retrieval(state: AgentState):
|
||||
messages = state["messages"]
|
||||
if len(messages) == 1:
|
||||
human_input = messages[-1]["content"]
|
||||
human_input = messages[-1].content
|
||||
return {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
@@ -118,6 +118,7 @@ def get_retrieval_executor(
|
||||
params = messages[-1].tool_calls[0]
|
||||
query = params["args"]["query"]
|
||||
response = await retriever.ainvoke(query)
|
||||
response = [doc.model_dump() for doc in response]
|
||||
msg = LiberalToolMessage(
|
||||
name="retrieval", content=response, tool_call_id=params["id"]
|
||||
)
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class User(TypedDict):
|
||||
class User(BaseModel):
|
||||
user_id: str
|
||||
"""The ID of the user."""
|
||||
sub: str
|
||||
@@ -13,9 +13,7 @@ class User(TypedDict):
|
||||
"""The time the user was created."""
|
||||
|
||||
|
||||
class Assistant(TypedDict):
|
||||
"""Assistant model."""
|
||||
|
||||
class Assistant(BaseModel):
|
||||
assistant_id: str
|
||||
"""The ID of the assistant."""
|
||||
user_id: str
|
||||
@@ -26,19 +24,19 @@ class Assistant(TypedDict):
|
||||
"""The assistant config."""
|
||||
updated_at: datetime
|
||||
"""The last time the assistant was updated."""
|
||||
public: bool
|
||||
public: bool = False
|
||||
"""Whether the assistant is public."""
|
||||
|
||||
|
||||
class Thread(TypedDict):
|
||||
class Thread(BaseModel):
|
||||
thread_id: str
|
||||
"""The ID of the thread."""
|
||||
user_id: str
|
||||
"""The ID of the user that owns the thread."""
|
||||
assistant_id: Optional[str]
|
||||
assistant_id: Optional[str] = None
|
||||
"""The assistant that was used in conjunction with this thread."""
|
||||
name: str
|
||||
"""The name of the thread."""
|
||||
updated_at: datetime
|
||||
"""The last time the thread was updated."""
|
||||
metadata: Optional[dict]
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
@@ -34,13 +34,13 @@ async def ingest_files(
|
||||
|
||||
assistant_id = config["configurable"].get("assistant_id")
|
||||
if assistant_id is not None:
|
||||
assistant = await storage.get_assistant(user["user_id"], assistant_id)
|
||||
assistant = await storage.get_assistant(user.user_id, assistant_id)
|
||||
if assistant is None:
|
||||
raise HTTPException(status_code=404, detail="Assistant not found.")
|
||||
|
||||
thread_id = config["configurable"].get("thread_id")
|
||||
if thread_id is not None:
|
||||
thread = await storage.get_thread(user["user_id"], thread_id)
|
||||
thread = await storage.get_thread(user.user_id, thread_id)
|
||||
if thread is None:
|
||||
raise HTTPException(status_code=404, detail="Thread not found.")
|
||||
|
||||
|
||||
+84
-36
@@ -12,23 +12,30 @@ from app.schema import Assistant, Thread, User
|
||||
async def list_assistants(user_id: str) -> List[Assistant]:
|
||||
"""List all assistants for the current user."""
|
||||
async with get_pg_pool().acquire() as conn:
|
||||
return await conn.fetch("SELECT * FROM assistant WHERE user_id = $1", user_id)
|
||||
records = await conn.fetch(
|
||||
"SELECT * FROM assistant WHERE user_id = $1", user_id
|
||||
)
|
||||
return [Assistant(**record) for record in records]
|
||||
|
||||
|
||||
async def get_assistant(user_id: str, assistant_id: str) -> Optional[Assistant]:
|
||||
"""Get an assistant by ID."""
|
||||
async with get_pg_pool().acquire() as conn:
|
||||
return await conn.fetchrow(
|
||||
record = await conn.fetchrow(
|
||||
"SELECT * FROM assistant WHERE assistant_id = $1 AND (user_id = $2 OR public IS true)",
|
||||
assistant_id,
|
||||
user_id,
|
||||
)
|
||||
if record is None:
|
||||
return None
|
||||
return Assistant(**record)
|
||||
|
||||
|
||||
async def list_public_assistants() -> List[Assistant]:
|
||||
"""List all the public assistants."""
|
||||
async with get_pg_pool().acquire() as conn:
|
||||
return await conn.fetch(("SELECT * FROM assistant WHERE public IS true;"))
|
||||
records = await conn.fetch("SELECT * FROM assistant WHERE public IS true")
|
||||
return [Assistant(**record) for record in records]
|
||||
|
||||
|
||||
async def put_assistant(
|
||||
@@ -66,14 +73,14 @@ async def put_assistant(
|
||||
updated_at,
|
||||
public,
|
||||
)
|
||||
return {
|
||||
"assistant_id": assistant_id,
|
||||
"user_id": user_id,
|
||||
"name": name,
|
||||
"config": config,
|
||||
"updated_at": updated_at,
|
||||
"public": public,
|
||||
}
|
||||
return Assistant(
|
||||
assistant_id=assistant_id,
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
config=config,
|
||||
updated_at=updated_at,
|
||||
public=public,
|
||||
)
|
||||
|
||||
|
||||
async def delete_assistant(user_id: str, assistant_id: str) -> None:
|
||||
@@ -89,17 +96,21 @@ async def delete_assistant(user_id: str, assistant_id: str) -> None:
|
||||
async def list_threads(user_id: str) -> List[Thread]:
|
||||
"""List all threads for the current user."""
|
||||
async with get_pg_pool().acquire() as conn:
|
||||
return await conn.fetch("SELECT * FROM thread WHERE user_id = $1", user_id)
|
||||
records = await conn.fetch("SELECT * FROM thread WHERE user_id = $1", user_id)
|
||||
return [Thread(**record) for record in records]
|
||||
|
||||
|
||||
async def get_thread(user_id: str, thread_id: str) -> Optional[Thread]:
|
||||
"""Get a thread by ID."""
|
||||
async with get_pg_pool().acquire() as conn:
|
||||
return await conn.fetchrow(
|
||||
record = await conn.fetchrow(
|
||||
"SELECT * FROM thread WHERE thread_id = $1 AND user_id = $2",
|
||||
thread_id,
|
||||
user_id,
|
||||
)
|
||||
if record is None:
|
||||
return None
|
||||
return Thread(**record)
|
||||
|
||||
|
||||
async def get_thread_state(*, user_id: str, thread_id: str, assistant: Assistant):
|
||||
@@ -107,14 +118,17 @@ async def get_thread_state(*, user_id: str, thread_id: str, assistant: Assistant
|
||||
state = await agent.aget_state(
|
||||
{
|
||||
"configurable": {
|
||||
**assistant["config"]["configurable"],
|
||||
**assistant.config["configurable"],
|
||||
"thread_id": thread_id,
|
||||
"assistant_id": assistant["assistant_id"],
|
||||
"assistant_id": assistant.assistant_id,
|
||||
}
|
||||
}
|
||||
)
|
||||
# Keep original format - return values as is
|
||||
values = state.values if state.values else None
|
||||
|
||||
return {
|
||||
"values": state.values,
|
||||
"values": values,
|
||||
"next": state.next,
|
||||
}
|
||||
|
||||
@@ -127,15 +141,39 @@ async def update_thread_state(
|
||||
assistant: Assistant,
|
||||
):
|
||||
"""Add state to a thread."""
|
||||
# Get the current state to determine the format
|
||||
current_state = await agent.aget_state(
|
||||
{
|
||||
"configurable": {
|
||||
**assistant.config["configurable"],
|
||||
**config["configurable"],
|
||||
"assistant_id": assistant.assistant_id,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# If current state is a dict (retrieval agent), maintain dict structure
|
||||
if current_state.values and isinstance(current_state.values, dict):
|
||||
if isinstance(values, dict):
|
||||
state_values = values
|
||||
else:
|
||||
# Update just the messages in the existing state
|
||||
state_values = {**current_state.values, "messages": values}
|
||||
else:
|
||||
# For message-only states (tools_agent, chatbot), just use the messages
|
||||
state_values = (
|
||||
values if isinstance(values, dict) and "messages" in values else values
|
||||
)
|
||||
|
||||
await agent.aupdate_state(
|
||||
{
|
||||
"configurable": {
|
||||
**assistant["config"]["configurable"],
|
||||
**assistant.config["configurable"],
|
||||
**config["configurable"],
|
||||
"assistant_id": assistant["assistant_id"],
|
||||
"assistant_id": assistant.assistant_id,
|
||||
}
|
||||
},
|
||||
values,
|
||||
state_values,
|
||||
)
|
||||
|
||||
|
||||
@@ -151,15 +189,27 @@ async def get_thread_history(*, user_id: str, thread_id: str, assistant: Assista
|
||||
async for c in agent.aget_state_history(
|
||||
{
|
||||
"configurable": {
|
||||
**assistant["config"]["configurable"],
|
||||
**assistant.config["configurable"],
|
||||
"thread_id": thread_id,
|
||||
"assistant_id": assistant["assistant_id"],
|
||||
"assistant_id": assistant.assistant_id,
|
||||
}
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def get_assistant_type(config: dict) -> str:
|
||||
"""Extract assistant type from config, handling both old and new formats."""
|
||||
configurable = config.get("configurable", {})
|
||||
|
||||
# First try direct type key (old format)
|
||||
if "type" in configurable:
|
||||
return configurable["type"]
|
||||
|
||||
# Default fallback
|
||||
return "chatbot"
|
||||
|
||||
|
||||
async def put_thread(
|
||||
user_id: str, thread_id: str, *, assistant_id: str, name: str
|
||||
) -> Thread:
|
||||
@@ -167,9 +217,7 @@ async def put_thread(
|
||||
updated_at = datetime.now(timezone.utc)
|
||||
assistant = await get_assistant(user_id, assistant_id)
|
||||
metadata = (
|
||||
{"assistant_type": assistant["config"]["configurable"]["type"]}
|
||||
if assistant
|
||||
else None
|
||||
{"assistant_type": get_assistant_type(assistant.config)} if assistant else None
|
||||
)
|
||||
async with get_pg_pool().acquire() as conn:
|
||||
await conn.execute(
|
||||
@@ -189,14 +237,14 @@ async def put_thread(
|
||||
updated_at,
|
||||
metadata,
|
||||
)
|
||||
return {
|
||||
"thread_id": thread_id,
|
||||
"user_id": user_id,
|
||||
"assistant_id": assistant_id,
|
||||
"name": name,
|
||||
"updated_at": updated_at,
|
||||
"metadata": metadata,
|
||||
}
|
||||
return Thread(
|
||||
thread_id=thread_id,
|
||||
user_id=user_id,
|
||||
assistant_id=assistant_id,
|
||||
name=name,
|
||||
updated_at=updated_at,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
async def delete_thread(user_id: str, thread_id: str):
|
||||
@@ -212,9 +260,9 @@ async def delete_thread(user_id: str, thread_id: str):
|
||||
async def get_or_create_user(sub: str) -> tuple[User, bool]:
|
||||
"""Returns a tuple of the user and a boolean indicating whether the user was created."""
|
||||
async with get_pg_pool().acquire() as conn:
|
||||
if user := await conn.fetchrow('SELECT * FROM "user" WHERE sub = $1', sub):
|
||||
return user, False
|
||||
user = await conn.fetchrow(
|
||||
if record := await conn.fetchrow('SELECT * FROM "user" WHERE sub = $1', sub):
|
||||
return User(**record), False
|
||||
record = await conn.fetchrow(
|
||||
'INSERT INTO "user" (sub) VALUES ($1) RETURNING *', sub
|
||||
)
|
||||
return user, True
|
||||
return User(**record), True
|
||||
|
||||
+84
-88
@@ -1,8 +1,7 @@
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from langchain.pydantic_v1 import BaseModel, Field
|
||||
from langchain.tools.retriever import create_retriever_tool
|
||||
from langchain_community.agent_toolkits.connery import ConneryToolkit
|
||||
from langchain_community.retrievers.kay import KayAiRetriever
|
||||
@@ -22,26 +21,26 @@ from langchain_community.utilities.arxiv import ArxivAPIWrapper
|
||||
from langchain_community.utilities.dalle_image_generator import DallEAPIWrapper
|
||||
from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper
|
||||
from langchain_core.tools import Tool
|
||||
from langchain_robocorp import ActionServerToolkit
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from app.upload import vstore
|
||||
|
||||
|
||||
class DDGInput(BaseModel):
|
||||
query: str = Field(description="search query to look up")
|
||||
query: Annotated[str, Field(description="search query to look up")]
|
||||
|
||||
|
||||
class ArxivInput(BaseModel):
|
||||
query: str = Field(description="search query to look up")
|
||||
query: Annotated[str, Field(description="search query to look up")]
|
||||
|
||||
|
||||
class PythonREPLInput(BaseModel):
|
||||
query: str = Field(description="python command to run")
|
||||
query: Annotated[str, Field(description="python command to run")]
|
||||
|
||||
|
||||
class DallEInput(BaseModel):
|
||||
query: str = Field(description="image description to generate image from")
|
||||
query: Annotated[str, Field(description="image description to generate image from")]
|
||||
|
||||
|
||||
class AvailableTools(str, Enum):
|
||||
@@ -66,10 +65,10 @@ class ToolConfig(TypedDict):
|
||||
|
||||
class BaseTool(BaseModel):
|
||||
type: AvailableTools
|
||||
name: Optional[str]
|
||||
description: Optional[str]
|
||||
config: Optional[ToolConfig]
|
||||
multi_use: Optional[bool] = False
|
||||
name: str
|
||||
description: str
|
||||
config: ToolConfig = Field(default_factory=dict)
|
||||
multi_use: bool = False
|
||||
|
||||
|
||||
class ActionServerConfig(ToolConfig):
|
||||
@@ -78,125 +77,133 @@ class ActionServerConfig(ToolConfig):
|
||||
|
||||
|
||||
class ActionServer(BaseTool):
|
||||
type: AvailableTools = Field(AvailableTools.ACTION_SERVER, const=True)
|
||||
name: str = Field("Action Server by Sema4.ai", const=True)
|
||||
description: str = Field(
|
||||
type: Literal[AvailableTools.ACTION_SERVER] = AvailableTools.ACTION_SERVER
|
||||
name: Literal["Action Server by Sema4.ai"] = "Action Server by Sema4.ai"
|
||||
description: Literal[
|
||||
(
|
||||
"Run AI actions with "
|
||||
"[Sema4.ai Action Server](https://github.com/Sema4AI/actions)."
|
||||
),
|
||||
const=True,
|
||||
)
|
||||
] = (
|
||||
"Run AI actions with "
|
||||
"[Sema4.ai Action Server](https://github.com/Sema4AI/actions)."
|
||||
)
|
||||
config: ActionServerConfig
|
||||
multi_use: bool = Field(True, const=True)
|
||||
multi_use: Literal[True] = True
|
||||
|
||||
|
||||
class Connery(BaseTool):
|
||||
type: AvailableTools = Field(AvailableTools.CONNERY, const=True)
|
||||
name: str = Field("AI Action Runner by Connery", const=True)
|
||||
description: str = Field(
|
||||
type: Literal[AvailableTools.CONNERY] = AvailableTools.CONNERY
|
||||
name: Literal["AI Action Runner by Connery"] = "AI Action Runner by Connery"
|
||||
description: Literal[
|
||||
(
|
||||
"Connect OpenGPTs to the real world with "
|
||||
"[Connery](https://github.com/connery-io/connery)."
|
||||
),
|
||||
const=True,
|
||||
)
|
||||
] = (
|
||||
"Connect OpenGPTs to the real world with "
|
||||
"[Connery](https://github.com/connery-io/connery)."
|
||||
)
|
||||
|
||||
|
||||
class DDGSearch(BaseTool):
|
||||
type: AvailableTools = Field(AvailableTools.DDG_SEARCH, const=True)
|
||||
name: str = Field("DuckDuckGo Search", const=True)
|
||||
description: str = Field(
|
||||
"Search the web with [DuckDuckGo](https://pypi.org/project/duckduckgo-search/).",
|
||||
const=True,
|
||||
)
|
||||
type: Literal[AvailableTools.DDG_SEARCH] = AvailableTools.DDG_SEARCH
|
||||
name: Literal["DuckDuckGo Search"] = "DuckDuckGo Search"
|
||||
description: Literal[
|
||||
"Search the web with [DuckDuckGo](https://pypi.org/project/duckduckgo-search/)."
|
||||
] = "Search the web with [DuckDuckGo](https://pypi.org/project/duckduckgo-search/)."
|
||||
|
||||
|
||||
class Arxiv(BaseTool):
|
||||
type: AvailableTools = Field(AvailableTools.ARXIV, const=True)
|
||||
name: str = Field("Arxiv", const=True)
|
||||
description: str = Field("Searches [Arxiv](https://arxiv.org/).", const=True)
|
||||
type: Literal[AvailableTools.ARXIV] = AvailableTools.ARXIV
|
||||
name: Literal["Arxiv"] = "Arxiv"
|
||||
description: Literal[
|
||||
"Searches [Arxiv](https://arxiv.org/)."
|
||||
] = "Searches [Arxiv](https://arxiv.org/)."
|
||||
|
||||
|
||||
class YouSearch(BaseTool):
|
||||
type: AvailableTools = Field(AvailableTools.YOU_SEARCH, const=True)
|
||||
name: str = Field("You.com Search", const=True)
|
||||
description: str = Field(
|
||||
"Uses [You.com](https://you.com/) search, optimized responses for LLMs.",
|
||||
const=True,
|
||||
)
|
||||
type: Literal[AvailableTools.YOU_SEARCH] = AvailableTools.YOU_SEARCH
|
||||
name: Literal["You.com Search"] = "You.com Search"
|
||||
description: Literal[
|
||||
"Uses [You.com](https://you.com/) search, optimized responses for LLMs."
|
||||
] = "Uses [You.com](https://you.com/) search, optimized responses for LLMs."
|
||||
|
||||
|
||||
class SecFilings(BaseTool):
|
||||
type: AvailableTools = Field(AvailableTools.SEC_FILINGS, const=True)
|
||||
name: str = Field("SEC Filings (Kay.ai)", const=True)
|
||||
description: str = Field(
|
||||
"Searches through SEC filings using [Kay.ai](https://www.kay.ai/).", const=True
|
||||
)
|
||||
type: Literal[AvailableTools.SEC_FILINGS] = AvailableTools.SEC_FILINGS
|
||||
name: Literal["SEC Filings (Kay.ai)"] = "SEC Filings (Kay.ai)"
|
||||
description: Literal[
|
||||
"Searches through SEC filings using [Kay.ai](https://www.kay.ai/)."
|
||||
] = "Searches through SEC filings using [Kay.ai](https://www.kay.ai/)."
|
||||
|
||||
|
||||
class PressReleases(BaseTool):
|
||||
type: AvailableTools = Field(AvailableTools.PRESS_RELEASES, const=True)
|
||||
name: str = Field("Press Releases (Kay.ai)", const=True)
|
||||
description: str = Field(
|
||||
"Searches through press releases using [Kay.ai](https://www.kay.ai/).",
|
||||
const=True,
|
||||
)
|
||||
type: Literal[AvailableTools.PRESS_RELEASES] = AvailableTools.PRESS_RELEASES
|
||||
name: Literal["Press Releases (Kay.ai)"] = "Press Releases (Kay.ai)"
|
||||
description: Literal[
|
||||
"Searches through press releases using [Kay.ai](https://www.kay.ai/)."
|
||||
] = "Searches through press releases using [Kay.ai](https://www.kay.ai/)."
|
||||
|
||||
|
||||
class PubMed(BaseTool):
|
||||
type: AvailableTools = Field(AvailableTools.PUBMED, const=True)
|
||||
name: str = Field("PubMed", const=True)
|
||||
description: str = Field(
|
||||
"Searches [PubMed](https://pubmed.ncbi.nlm.nih.gov/).", const=True
|
||||
)
|
||||
type: Literal[AvailableTools.PUBMED] = AvailableTools.PUBMED
|
||||
name: Literal["PubMed"] = "PubMed"
|
||||
description: Literal[
|
||||
"Searches [PubMed](https://pubmed.ncbi.nlm.nih.gov/)."
|
||||
] = "Searches [PubMed](https://pubmed.ncbi.nlm.nih.gov/)."
|
||||
|
||||
|
||||
class Wikipedia(BaseTool):
|
||||
type: AvailableTools = Field(AvailableTools.WIKIPEDIA, const=True)
|
||||
name: str = Field("Wikipedia", const=True)
|
||||
description: str = Field(
|
||||
"Searches [Wikipedia](https://pypi.org/project/wikipedia/).", const=True
|
||||
)
|
||||
type: Literal[AvailableTools.WIKIPEDIA] = AvailableTools.WIKIPEDIA
|
||||
name: Literal["Wikipedia"] = "Wikipedia"
|
||||
description: Literal[
|
||||
"Searches [Wikipedia](https://pypi.org/project/wikipedia/)."
|
||||
] = "Searches [Wikipedia](https://pypi.org/project/wikipedia/)."
|
||||
|
||||
|
||||
class Tavily(BaseTool):
|
||||
type: AvailableTools = Field(AvailableTools.TAVILY, const=True)
|
||||
name: str = Field("Search (Tavily)", const=True)
|
||||
description: str = Field(
|
||||
type: Literal[AvailableTools.TAVILY] = AvailableTools.TAVILY
|
||||
name: Literal["Search (Tavily)"] = "Search (Tavily)"
|
||||
description: Literal[
|
||||
(
|
||||
"Uses the [Tavily](https://app.tavily.com/) search engine. "
|
||||
"Includes sources in the response."
|
||||
),
|
||||
const=True,
|
||||
)
|
||||
] = (
|
||||
"Uses the [Tavily](https://app.tavily.com/) search engine. "
|
||||
"Includes sources in the response."
|
||||
)
|
||||
|
||||
|
||||
class TavilyAnswer(BaseTool):
|
||||
type: AvailableTools = Field(AvailableTools.TAVILY_ANSWER, const=True)
|
||||
name: str = Field("Search (short answer, Tavily)", const=True)
|
||||
description: str = Field(
|
||||
type: Literal[AvailableTools.TAVILY_ANSWER] = AvailableTools.TAVILY_ANSWER
|
||||
name: Literal["Search (short answer, Tavily)"] = "Search (short answer, Tavily)"
|
||||
description: Literal[
|
||||
(
|
||||
"Uses the [Tavily](https://app.tavily.com/) search engine. "
|
||||
"This returns only the answer, no supporting evidence."
|
||||
),
|
||||
const=True,
|
||||
)
|
||||
] = (
|
||||
"Uses the [Tavily](https://app.tavily.com/) search engine. "
|
||||
"This returns only the answer, no supporting evidence."
|
||||
)
|
||||
|
||||
|
||||
class Retrieval(BaseTool):
|
||||
type: AvailableTools = Field(AvailableTools.RETRIEVAL, const=True)
|
||||
name: str = Field("Retrieval", const=True)
|
||||
description: str = Field("Look up information in uploaded files.", const=True)
|
||||
type: Literal[AvailableTools.RETRIEVAL] = AvailableTools.RETRIEVAL
|
||||
name: Literal["Retrieval"] = "Retrieval"
|
||||
description: Literal[
|
||||
"Look up information in uploaded files."
|
||||
] = "Look up information in uploaded files."
|
||||
|
||||
|
||||
class DallE(BaseTool):
|
||||
type: AvailableTools = Field(AvailableTools.DALL_E, const=True)
|
||||
name: str = Field("Generate Image (Dall-E)", const=True)
|
||||
description: str = Field(
|
||||
"Generates images from a text description using OpenAI's DALL-E model.",
|
||||
const=True,
|
||||
)
|
||||
type: Literal[AvailableTools.DALL_E] = AvailableTools.DALL_E
|
||||
name: Literal["Generate Image (Dall-E)"] = "Generate Image (Dall-E)"
|
||||
description: Literal[
|
||||
"Generates images from a text description using OpenAI's DALL-E model."
|
||||
] = "Generates images from a text description using OpenAI's DALL-E model."
|
||||
|
||||
|
||||
RETRIEVAL_DESCRIPTION = """Can be used to look up information that was uploaded to this assistant.
|
||||
@@ -286,16 +293,6 @@ def _get_tavily_answer():
|
||||
return _TavilyAnswer(api_wrapper=tavily_search, name="search_tavily_answer")
|
||||
|
||||
|
||||
def _get_action_server(**kwargs: ActionServerConfig):
|
||||
toolkit = ActionServerToolkit(
|
||||
url=kwargs["url"],
|
||||
api_key=kwargs["api_key"],
|
||||
additional_headers=kwargs.get("additional_headers", {}),
|
||||
)
|
||||
tools = toolkit.get_tools()
|
||||
return tools
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_connery_actions():
|
||||
connery_service = ConneryService()
|
||||
@@ -314,7 +311,6 @@ def _get_dalle_tools():
|
||||
|
||||
|
||||
TOOLS = {
|
||||
AvailableTools.ACTION_SERVER: _get_action_server,
|
||||
AvailableTools.CONNERY: _get_connery_actions,
|
||||
AvailableTools.DDG_SEARCH: _get_duck_duck_go,
|
||||
AvailableTools.ARXIV: _get_arxiv,
|
||||
|
||||
@@ -24,6 +24,7 @@ from langchain_core.runnables import (
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from app.ingest import ingest_blob
|
||||
from app.parsing import MIMETYPE_BASED_PARSER
|
||||
@@ -110,18 +111,15 @@ class IngestRunnable(RunnableSerializable[BinaryIO, List[str]]):
|
||||
"""Runnable for ingesting files into a vectorstore."""
|
||||
|
||||
text_splitter: TextSplitter
|
||||
"""Text splitter to use for splitting the text into chunks."""
|
||||
vectorstore: VectorStore
|
||||
"""Vectorstore to ingest into."""
|
||||
assistant_id: Optional[str]
|
||||
thread_id: Optional[str]
|
||||
assistant_id: Optional[str] = None
|
||||
thread_id: Optional[str] = None
|
||||
"""Ingested documents will be associated with assistant_id or thread_id.
|
||||
|
||||
ID is used as the namespace, and is filtered on at query time.
|
||||
"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@property
|
||||
def namespace(self) -> str:
|
||||
@@ -161,12 +159,12 @@ ingest_runnable = IngestRunnable(
|
||||
).configurable_fields(
|
||||
assistant_id=ConfigurableField(
|
||||
id="assistant_id",
|
||||
annotation=str,
|
||||
annotation=Optional[str],
|
||||
name="Assistant ID",
|
||||
),
|
||||
thread_id=ConfigurableField(
|
||||
id="thread_id",
|
||||
annotation=str,
|
||||
annotation=Optional[str],
|
||||
name="Thread ID",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
-- Drop the blob storage table
|
||||
DROP TABLE IF EXISTS checkpoint_blobs;
|
||||
|
||||
-- Drop the writes tracking table
|
||||
DROP TABLE IF EXISTS checkpoint_writes;
|
||||
|
||||
-- Drop the new checkpoints table that was created by the application
|
||||
DROP TABLE IF EXISTS checkpoints;
|
||||
|
||||
-- Restore the original checkpoints table by renaming old_checkpoints back
|
||||
-- This preserves the original data that was saved before the migration
|
||||
ALTER TABLE old_checkpoints RENAME TO checkpoints;
|
||||
@@ -0,0 +1,9 @@
|
||||
-- BREAKING CHANGE WARNING:
|
||||
-- This migration represents a transition from pickle-based checkpointing to a new checkpoint system.
|
||||
-- As a result, any threads created before this migration will not be usable/clickable in the UI.
|
||||
-- old thread data remains in old_checkpoints table but cannot be accessed by the new version.
|
||||
|
||||
-- Rename existing checkpoints table to preserve current data
|
||||
-- This is necessary because the application will create a new checkpoints table
|
||||
-- with an updated schema during runtime initialization.
|
||||
ALTER TABLE checkpoints RENAME TO old_checkpoints;
|
||||
Generated
+2928
-2352
File diff suppressed because it is too large
Load Diff
+12
-11
@@ -17,11 +17,12 @@ fastapi = "^0.103.2"
|
||||
# langchain = { git = "git@github.com:langchain-ai/langchain.git/", branch = "nc/subclass-runnable-binding" , subdirectory = "libs/langchain"}
|
||||
orjson = "^3.9.10"
|
||||
python-multipart = "^0.0.6"
|
||||
tiktoken = "^0.5.1"
|
||||
langchain = ">=0.0.338"
|
||||
langgraph = "^0.0.38"
|
||||
pydantic = "<2.0"
|
||||
langchain-openai = "^0.1.3"
|
||||
tiktoken = "^0"
|
||||
langchain = "^0.3"
|
||||
langgraph = "0.2.45"
|
||||
langgraph-checkpoint-postgres = "^2.0.2"
|
||||
pydantic = "^2"
|
||||
langchain-openai = "^0.2"
|
||||
beautifulsoup4 = "^4.12.3"
|
||||
boto3 = "^1.34.28"
|
||||
duckduckgo-search = "^5.3.0"
|
||||
@@ -29,19 +30,19 @@ arxiv = "^2.1.0"
|
||||
kay = "^0.1.2"
|
||||
xmltodict = "^0.13.0"
|
||||
wikipedia = "^1.4.0"
|
||||
langchain-google-vertexai = "^1.0.1"
|
||||
langchain-google-vertexai = "^2.0"
|
||||
langchain-google-community = "^2.0.1"
|
||||
setuptools = "^69.0.3"
|
||||
pdfminer-six = "^20231228"
|
||||
langchain-robocorp = "^0.0.8"
|
||||
fireworks-ai = "^0.11.2"
|
||||
httpx = { version = "0.25.2", extras = ["socks"] }
|
||||
unstructured = {extras = ["doc", "docx"], version = "^0.12.5"}
|
||||
httpx = { version = "^0", extras = ["socks"] }
|
||||
unstructured = {extras = ["doc", "docx"], version = "^0"}
|
||||
pgvector = "^0.2.5"
|
||||
psycopg2-binary = "^2.9.9"
|
||||
asyncpg = "^0.29.0"
|
||||
langchain-core = "^0.1.44"
|
||||
langchain-core = "^0.3"
|
||||
pyjwt = {extras = ["crypto"], version = "^2.8.0"}
|
||||
langchain-anthropic = "^0.1.8"
|
||||
langchain-anthropic = "^0.2"
|
||||
structlog = "^24.1.0"
|
||||
python-json-logger = "^2.0.7"
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from io import BytesIO
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from fastapi import UploadFile
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from app.upload import IngestRunnable, _guess_mimetype, convert_ingestion_input_to_blob
|
||||
from tests.unit_tests.fixtures import get_sample_paths
|
||||
from tests.unit_tests.utils import InMemoryVectorStore
|
||||
|
||||
@@ -4,12 +4,15 @@ from typing import Optional, Sequence
|
||||
from uuid import uuid4
|
||||
|
||||
import asyncpg
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.schema import Assistant, Thread
|
||||
from tests.unit_tests.app.helpers import get_client
|
||||
|
||||
|
||||
def _project(d: dict, *, exclude_keys: Optional[Sequence[str]]) -> dict:
|
||||
def _project(model: BaseModel, *, exclude_keys: Optional[Sequence[str]] = None) -> dict:
|
||||
"""Return a dict with only the keys specified."""
|
||||
d = model.model_dump()
|
||||
_exclude = set(exclude_keys) if exclude_keys else set()
|
||||
return {k: v for k, v in d.items() if k not in _exclude}
|
||||
|
||||
@@ -38,7 +41,8 @@ async def test_list_and_create_assistants(pool: asyncpg.pool.Pool) -> None:
|
||||
headers=headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert _project(response.json(), exclude_keys=["updated_at", "user_id"]) == {
|
||||
assistant = Assistant.model_validate(response.json())
|
||||
assert _project(assistant, exclude_keys=["updated_at", "user_id"]) == {
|
||||
"assistant_id": aid,
|
||||
"config": {},
|
||||
"name": "bobby",
|
||||
@@ -48,8 +52,9 @@ async def test_list_and_create_assistants(pool: asyncpg.pool.Pool) -> None:
|
||||
assert len(await conn.fetch("SELECT * FROM assistant;")) == 1
|
||||
|
||||
response = await client.get("/assistants/", headers=headers)
|
||||
assistants = [Assistant.model_validate(d) for d in response.json()]
|
||||
assert [
|
||||
_project(d, exclude_keys=["updated_at", "user_id"]) for d in response.json()
|
||||
_project(d, exclude_keys=["updated_at", "user_id"]) for d in assistants
|
||||
] == [
|
||||
{
|
||||
"assistant_id": aid,
|
||||
@@ -65,7 +70,8 @@ async def test_list_and_create_assistants(pool: asyncpg.pool.Pool) -> None:
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert _project(response.json(), exclude_keys=["updated_at", "user_id"]) == {
|
||||
assistant = Assistant.model_validate(response.json())
|
||||
assert _project(assistant, exclude_keys=["updated_at", "user_id"]) == {
|
||||
"assistant_id": aid,
|
||||
"config": {},
|
||||
"name": "bobby",
|
||||
@@ -79,7 +85,7 @@ async def test_list_and_create_assistants(pool: asyncpg.pool.Pool) -> None:
|
||||
assert response.json() == []
|
||||
|
||||
|
||||
async def test_threads() -> None:
|
||||
async def test_threads(pool: asyncpg.pool.Pool) -> None:
|
||||
"""Test put thread."""
|
||||
headers = {"Cookie": "opengpts_user_id=1"}
|
||||
aid = str(uuid4())
|
||||
@@ -102,6 +108,7 @@ async def test_threads() -> None:
|
||||
headers=headers,
|
||||
)
|
||||
assert response.status_code == 200, response.text
|
||||
_ = Thread.model_validate(response.json())
|
||||
|
||||
response = await client.get(f"/threads/{tid}/state", headers=headers)
|
||||
assert response.status_code == 200
|
||||
@@ -110,8 +117,9 @@ async def test_threads() -> None:
|
||||
response = await client.get("/threads/", headers=headers)
|
||||
|
||||
assert response.status_code == 200
|
||||
threads = [Thread.model_validate(d) for d in response.json()]
|
||||
assert [
|
||||
_project(d, exclude_keys=["updated_at", "user_id"]) for d in response.json()
|
||||
_project(d, exclude_keys=["updated_at", "user_id"]) for d in threads
|
||||
] == [
|
||||
{
|
||||
"assistant_id": aid,
|
||||
|
||||
@@ -20,7 +20,7 @@ from tests.unit_tests.app.helpers import get_client
|
||||
|
||||
@app.get("/me")
|
||||
async def me(user: AuthedUser) -> dict:
|
||||
return user
|
||||
return user.model_dump()
|
||||
|
||||
|
||||
def _create_jwt(
|
||||
|
||||
@@ -53,10 +53,16 @@ def _migrate_test_db() -> None:
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def pool():
|
||||
async def _init_db():
|
||||
"""Initialize the test database."""
|
||||
await _drop_test_db() # In case previous test session was abruptly terminated
|
||||
await _create_test_db()
|
||||
_migrate_test_db()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def pool(_init_db):
|
||||
"""Initialize database pool with checkpointer."""
|
||||
async with lifespan(app):
|
||||
yield get_pg_pool()
|
||||
await _drop_test_db()
|
||||
|
||||
@@ -23,6 +23,8 @@ from redis.client import Redis as RedisType
|
||||
from app.checkpoint import PostgresCheckpoint
|
||||
from app.lifespan import get_pg_pool, lifespan
|
||||
from app.server import app
|
||||
from pydantic import ConfigDict
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
@@ -52,8 +54,7 @@ def load(keys: list[str], values: list[bytes]) -> dict:
|
||||
|
||||
|
||||
class RedisCheckpoint(BaseCheckpointSaver):
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True,)
|
||||
|
||||
@property
|
||||
def config_specs(self) -> list[ConfigurableFieldSpec]:
|
||||
|
||||
Reference in New Issue
Block a user