use prebuilt agent for simplicity

This commit is contained in:
Eugene Yurtsev
2025-03-06 15:35:48 -05:00
parent 9569f9ef86
commit 2cc238be9c
2 changed files with 6 additions and 157 deletions
+6 -97
View File
@@ -1,38 +1,18 @@
"""Define a custom Reasoning and Action agent.
Works with a chat model with tool calling support.
"""
"""Create a ReAct agent with access to tools defined in a tool server."""
from datetime import UTC, datetime
from typing import Dict, List, Literal, cast
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig
from langgraph.graph import StateGraph
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt import ToolNode
from langgraph.prebuilt import create_react_agent
from react_agent.configuration import APP_STATE
from react_agent.state import InputState, State
from react_agent.tools import TOOLBOX
from react_agent.utils import load_chat_model
# Define the function that calls the model
async def call_model(
state: State, config: RunnableConfig
) -> Dict[str, List[AIMessage]]:
"""Call the LLM powering our "agent".
This function prepares the prompt, initializes the model, and processes the response.
Args:
state (State): The current state of the conversation.
config (RunnableConfig): Configuration for the model run.
Returns:
dict: A dictionary containing the model's response message.
"""
async def make_graph(config: RunnableConfig) -> CompiledStateGraph:
"""Create a custom state graph for the Reasoning and Action agent."""
configuration = APP_STATE.configurable.from_runnable_config(config)
# Add logic to select tools
if configuration.selected_tools:
@@ -52,79 +32,8 @@ async def call_model(
system_time=datetime.now(tz=UTC).isoformat()
)
# Get the model's response
response = cast(
AIMessage,
await model.ainvoke(
[{"role": "system", "content": system_message}, *state.messages], config
),
graph = create_react_agent(
model, system_message, config_schema=APP_STATE.configurable
)
# Handle the case when it's the last step and the model still wants to use a tool
if state.is_last_step and response.tool_calls:
return {
"messages": [
AIMessage(
id=response.id,
content="Sorry, I could not find an answer to your question in the specified number of steps.",
)
]
}
# Return the model's response as a list to be added to existing messages
return {"messages": [response]}
def route_model_output(state: State) -> Literal["__end__", "tools"]:
"""Determine the next node based on the model's output.
This function checks if the model's last message contains tool calls.
Args:
state (State): The current state of the conversation.
Returns:
str: The name of the next node to call ("__end__" or "tools").
"""
last_message = state.messages[-1]
if not isinstance(last_message, AIMessage):
raise ValueError(
f"Expected AIMessage in output edges, but got {type(last_message).__name__}"
)
# If there is no tool call, then we finish
if not last_message.tool_calls:
return "__end__"
# Otherwise we execute the requested actions
return "tools"
async def make_graph(config: RunnableConfig) -> CompiledStateGraph:
"""Create a custom state graph for the Reasoning and Action agent."""
tools = TOOLBOX.get_tools()
builder = StateGraph(State, input=InputState, config_schema=APP_STATE.configurable)
# Define the two nodes we will cycle between
builder.add_node(call_model)
builder.add_node("tools", ToolNode(tools))
# Set the entrypoint as `call_model`
# This means that this node is the first one called
builder.add_edge("__start__", "call_model")
# Add a conditional edge to determine the next step after `call_model`
builder.add_conditional_edges(
"call_model",
# After call_model finishes running, the next node(s) are scheduled
# based on the output from route_model_output
route_model_output,
)
# Add a normal edge from `tools` to `call_model`
# This creates a cycle: after using tools, we always return to the model
builder.add_edge("tools", "call_model")
# Compile the builder into an executable graph
# You can customize this by adding interrupt points for state updates
graph = builder.compile()
graph.name = "ReAct Agent" # This customizes the name in LangSmith
return graph
-60
View File
@@ -1,60 +0,0 @@
"""Define the state structures for the agent."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Sequence
from langchain_core.messages import AnyMessage
from langgraph.graph import add_messages
from langgraph.managed import IsLastStep
from typing_extensions import Annotated
@dataclass
class InputState:
"""Defines the input state for the agent, representing a narrower interface to the outside world.
This class is used to define the initial state and structure of incoming data.
"""
messages: Annotated[Sequence[AnyMessage], add_messages] = field(
default_factory=list
)
"""
Messages tracking the primary execution state of the agent.
Typically accumulates a pattern of:
1. HumanMessage - user input
2. AIMessage with .tool_calls - agent picking tool(s) to use to collect information
3. ToolMessage(s) - the responses (or errors) from the executed tools
4. AIMessage without .tool_calls - agent responding in unstructured format to the user
5. HumanMessage - user responds with the next conversational turn
Steps 2-5 may repeat as needed.
The `add_messages` annotation ensures that new messages are merged with existing ones,
updating by ID to maintain an "append-only" state unless a message with the same ID is provided.
"""
@dataclass
class State(InputState):
"""Represents the complete state of the agent, extending InputState with additional attributes.
This class can be used to store any information needed throughout the agent's lifecycle.
"""
is_last_step: IsLastStep = field(default=False)
"""
Indicates whether the current step is the last one before the graph raises an error.
This is a 'managed' variable, controlled by the state machine rather than user code.
It is set to 'True' when the step count reaches recursion_limit - 1.
"""
# Additional attributes can be added here as needed.
# Common examples include:
# retrieved_documents: List[Document] = field(default_factory=list)
# extracted_entities: Dict[str, Any] = field(default_factory=dict)
# api_connections: Dict[str, Any] = field(default_factory=dict)