mirror of
https://github.com/langchain-ai/react-agent-tool-server.git
synced 2026-07-01 18:28:34 -04:00
use prebuilt agent for simplicity
This commit is contained in:
@@ -1,38 +1,18 @@
|
||||
"""Define a custom Reasoning and Action agent.
|
||||
|
||||
Works with a chat model with tool calling support.
|
||||
"""
|
||||
"""Create a ReAct agent with access to tools defined in a tool server."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from typing import Dict, List, Literal, cast
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.graph import StateGraph
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
from langgraph.prebuilt import ToolNode
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
|
||||
from react_agent.configuration import APP_STATE
|
||||
from react_agent.state import InputState, State
|
||||
from react_agent.tools import TOOLBOX
|
||||
from react_agent.utils import load_chat_model
|
||||
|
||||
|
||||
# Define the function that calls the model
|
||||
async def call_model(
|
||||
state: State, config: RunnableConfig
|
||||
) -> Dict[str, List[AIMessage]]:
|
||||
"""Call the LLM powering our "agent".
|
||||
|
||||
This function prepares the prompt, initializes the model, and processes the response.
|
||||
|
||||
Args:
|
||||
state (State): The current state of the conversation.
|
||||
config (RunnableConfig): Configuration for the model run.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the model's response message.
|
||||
"""
|
||||
async def make_graph(config: RunnableConfig) -> CompiledStateGraph:
|
||||
"""Create a custom state graph for the Reasoning and Action agent."""
|
||||
configuration = APP_STATE.configurable.from_runnable_config(config)
|
||||
# Add logic to select tools
|
||||
if configuration.selected_tools:
|
||||
@@ -52,79 +32,8 @@ async def call_model(
|
||||
system_time=datetime.now(tz=UTC).isoformat()
|
||||
)
|
||||
|
||||
# Get the model's response
|
||||
response = cast(
|
||||
AIMessage,
|
||||
await model.ainvoke(
|
||||
[{"role": "system", "content": system_message}, *state.messages], config
|
||||
),
|
||||
graph = create_react_agent(
|
||||
model, system_message, config_schema=APP_STATE.configurable
|
||||
)
|
||||
|
||||
# Handle the case when it's the last step and the model still wants to use a tool
|
||||
if state.is_last_step and response.tool_calls:
|
||||
return {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
id=response.id,
|
||||
content="Sorry, I could not find an answer to your question in the specified number of steps.",
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
# Return the model's response as a list to be added to existing messages
|
||||
return {"messages": [response]}
|
||||
|
||||
|
||||
def route_model_output(state: State) -> Literal["__end__", "tools"]:
|
||||
"""Determine the next node based on the model's output.
|
||||
|
||||
This function checks if the model's last message contains tool calls.
|
||||
|
||||
Args:
|
||||
state (State): The current state of the conversation.
|
||||
|
||||
Returns:
|
||||
str: The name of the next node to call ("__end__" or "tools").
|
||||
"""
|
||||
last_message = state.messages[-1]
|
||||
if not isinstance(last_message, AIMessage):
|
||||
raise ValueError(
|
||||
f"Expected AIMessage in output edges, but got {type(last_message).__name__}"
|
||||
)
|
||||
# If there is no tool call, then we finish
|
||||
if not last_message.tool_calls:
|
||||
return "__end__"
|
||||
# Otherwise we execute the requested actions
|
||||
return "tools"
|
||||
|
||||
|
||||
async def make_graph(config: RunnableConfig) -> CompiledStateGraph:
|
||||
"""Create a custom state graph for the Reasoning and Action agent."""
|
||||
tools = TOOLBOX.get_tools()
|
||||
builder = StateGraph(State, input=InputState, config_schema=APP_STATE.configurable)
|
||||
|
||||
# Define the two nodes we will cycle between
|
||||
builder.add_node(call_model)
|
||||
builder.add_node("tools", ToolNode(tools))
|
||||
|
||||
# Set the entrypoint as `call_model`
|
||||
# This means that this node is the first one called
|
||||
builder.add_edge("__start__", "call_model")
|
||||
|
||||
# Add a conditional edge to determine the next step after `call_model`
|
||||
builder.add_conditional_edges(
|
||||
"call_model",
|
||||
# After call_model finishes running, the next node(s) are scheduled
|
||||
# based on the output from route_model_output
|
||||
route_model_output,
|
||||
)
|
||||
|
||||
# Add a normal edge from `tools` to `call_model`
|
||||
# This creates a cycle: after using tools, we always return to the model
|
||||
builder.add_edge("tools", "call_model")
|
||||
|
||||
# Compile the builder into an executable graph
|
||||
# You can customize this by adding interrupt points for state updates
|
||||
graph = builder.compile()
|
||||
graph.name = "ReAct Agent" # This customizes the name in LangSmith
|
||||
return graph
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
"""Define the state structures for the agent."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Sequence
|
||||
|
||||
from langchain_core.messages import AnyMessage
|
||||
from langgraph.graph import add_messages
|
||||
from langgraph.managed import IsLastStep
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputState:
|
||||
"""Defines the input state for the agent, representing a narrower interface to the outside world.
|
||||
|
||||
This class is used to define the initial state and structure of incoming data.
|
||||
"""
|
||||
|
||||
messages: Annotated[Sequence[AnyMessage], add_messages] = field(
|
||||
default_factory=list
|
||||
)
|
||||
"""
|
||||
Messages tracking the primary execution state of the agent.
|
||||
|
||||
Typically accumulates a pattern of:
|
||||
1. HumanMessage - user input
|
||||
2. AIMessage with .tool_calls - agent picking tool(s) to use to collect information
|
||||
3. ToolMessage(s) - the responses (or errors) from the executed tools
|
||||
4. AIMessage without .tool_calls - agent responding in unstructured format to the user
|
||||
5. HumanMessage - user responds with the next conversational turn
|
||||
|
||||
Steps 2-5 may repeat as needed.
|
||||
|
||||
The `add_messages` annotation ensures that new messages are merged with existing ones,
|
||||
updating by ID to maintain an "append-only" state unless a message with the same ID is provided.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class State(InputState):
|
||||
"""Represents the complete state of the agent, extending InputState with additional attributes.
|
||||
|
||||
This class can be used to store any information needed throughout the agent's lifecycle.
|
||||
"""
|
||||
|
||||
is_last_step: IsLastStep = field(default=False)
|
||||
"""
|
||||
Indicates whether the current step is the last one before the graph raises an error.
|
||||
|
||||
This is a 'managed' variable, controlled by the state machine rather than user code.
|
||||
It is set to 'True' when the step count reaches recursion_limit - 1.
|
||||
"""
|
||||
|
||||
# Additional attributes can be added here as needed.
|
||||
# Common examples include:
|
||||
# retrieved_documents: List[Document] = field(default_factory=list)
|
||||
# extracted_entities: Dict[str, Any] = field(default_factory=dict)
|
||||
# api_connections: Dict[str, Any] = field(default_factory=dict)
|
||||
Reference in New Issue
Block a user