add state_deepagent notebook

This commit is contained in:
j-broekhuizen
2025-11-06 16:00:38 -08:00
parent 776fa4dc1a
commit 67c94e1de5
3 changed files with 362 additions and 79 deletions
+2 -1
View File
@@ -1,3 +1,4 @@
.env
.DS_Store
.venv/
.venv/
workbook.ipynb
+213
View File
@@ -0,0 +1,213 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "409d0ca9",
"metadata": {},
"source": [
"## Setup: Agent Config and Tool Definition"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "74fd8f7e",
"metadata": {},
"outputs": [],
"source": [
"from typing import Literal\n",
"from tavily import TavilyClient\n",
"import os\n",
"\n",
"from langchain.agents import create_agent\n",
"from langchain_openai import ChatOpenAI\n",
"from langchain.agents.middleware.types import wrap_tool_call\n",
"from langchain_core.messages import ToolMessage\n",
"from langchain_core.tools import tool\n",
"\n",
"from langgraph.checkpoint.memory import InMemorySaver\n",
"from langgraph.errors import GraphInterrupt\n",
"from langgraph.types import Command, Interrupt\n",
"\n",
"from deepagents.graph import create_deep_agent\n",
"from deepagents.middleware.subagents import CompiledSubAgent\n",
"from dotenv import load_dotenv\n",
"\n",
"load_dotenv()\n",
"\n",
"tavily_client = TavilyClient(os.getenv(\"TAVILY_API_KEY\"))\n",
"\n",
"model = ChatOpenAI(model=\"gpt-5-nano\", temperature=0)\n",
"\n",
"@tool\n",
"def get_current_weather(query: str,\n",
" max_results: int = 5,\n",
" topic: Literal[\"general\", \"news\", \"finance\"] = \"general\",\n",
" include_raw_content: bool = False,\n",
"):\n",
" \"\"\"\n",
" A tool to search the web for current weather conditions.\n",
" Args:\n",
" query: The query about weather to search for\n",
" max_results: The maximum number of results to return.\n",
" topic: The topic of the search.\n",
" include_raw_content: Whether to include the raw content of the search results.\n",
" \"\"\"\n",
" return tavily_client.search(\n",
" query,\n",
" max_results=max_results,\n",
" include_raw_content=include_raw_content,\n",
" topic=topic,\n",
" )\n",
"\n",
"@wrap_tool_call\n",
"def post_tool_approval(request, handler):\n",
" result = handler(request)\n",
" is_resuming = request.runtime.config.get('configurable', {}).get('__pregel_resuming')\n",
" if hasattr(result, 'content') and not is_resuming:\n",
" raise GraphInterrupt((Interrupt(value={\n",
" 'tool_call_id': request.tool_call['id'],\n",
" 'summary': result.content,\n",
" 'prompt': 'Approve the forecast summary before delivery.',\n",
" }),))\n",
" return result\n",
"\n",
"approval_subagent_ckpt = InMemorySaver()\n",
"weather_agent_runnable = create_agent(\n",
" model,\n",
" system_prompt='You are a weather expert. Use the get_current_weather tool to get the current weather conditions for a given location.',\n",
" tools=[get_current_weather],\n",
" middleware=[post_tool_approval],\n",
" checkpointer=approval_subagent_ckpt,\n",
")\n",
"\n",
"weather_agent = CompiledSubAgent(\n",
" name='weather-agent',\n",
" description='A weather expert that uses the get_current_weather tool to get the current weather conditions for a given location.',\n",
" runnable=weather_agent_runnable,\n",
")\n",
"\n",
"super_agent_ckpt = InMemorySaver()\n",
"super_agent = create_deep_agent(\n",
" model=model,\n",
" system_prompt='You are a weather expert. You have access to a specialist weather agent that can get the current weather conditions for a given location. Route to the weather agent when you need to get the current weather conditions for a given location.',\n",
" subagents=[weather_agent],\n",
" checkpointer=super_agent_ckpt,\n",
")\n"
]
},
{
"cell_type": "markdown",
"id": "71db0c2f",
"metadata": {},
"source": [
"## Invoke Agent and Handle Interrupt\n",
"\n",
"This cell demonstrates the interrupt workflow:\n",
"\n",
"- **Thread Configuration**: Sets up a thread ID for stateful conversation tracking\n",
"- **Agent Invocation**: Invokes the super agent with a weather query\n",
"- **Interrupt Handling**: The agent execution is interrupted by the approval middleware, which raises a `GraphInterrupt` containing the tool call ID, summary, and approval prompt\n",
"- **Interrupt Inspection**: Extracts and prints the interrupt payload to show what information is available for approval\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "58ab943c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Super agent interrupt payload: {'tool_call_id': 'call_roTTUTTR8ZrLJRM5gDHn2gCM', 'summary': '{\"query\": \"current weather in San Francisco, CA\", \"follow_up_questions\": null, \"answer\": null, \"images\": [], \"results\": [{\"title\": \"Weather in San Francisco, CA\", \"url\": \"https://www.weatherapi.com/\", \"content\": \"{\\'location\\': {\\'name\\': \\'San Francisco\\', \\'region\\': \\'California\\', \\'country\\': \\'United States of America\\', \\'lat\\': 37.775, \\'lon\\': -122.4183, \\'tz_id\\': \\'America/Los_Angeles\\', \\'localtime_epoch\\': 1762474659, \\'localtime\\': \\'2025-11-06 16:17\\'}, \\'current\\': {\\'last_updated_epoch\\': 1762474500, \\'last_updated\\': \\'2025-11-06 16:15\\', \\'temp_c\\': 18.9, \\'temp_f\\': 66.0, \\'is_day\\': 1, \\'condition\\': {\\'text\\': \\'Partly Cloudy\\', \\'icon\\': \\'//cdn.weatherapi.com/weather/64x64/day/116.png\\', \\'code\\': 1003}, \\'wind_mph\\': 7.6, \\'wind_kph\\': 12.2, \\'wind_degree\\': 268, \\'wind_dir\\': \\'W\\', \\'pressure_mb\\': 1023.0, \\'pressure_in\\': 30.2, \\'precip_mm\\': 0.0, \\'precip_in\\': 0.0, \\'humidity\\': 65, \\'cloud\\': 25, \\'feelslike_c\\': 18.9, \\'feelslike_f\\': 66.0, \\'windchill_c\\': 16.6, \\'windchill_f\\': 61.8, \\'heatindex_c\\': 16.5, \\'heatindex_f\\': 61.8, \\'dewpoint_c\\': 14.6, \\'dewpoint_f\\': 58.2, \\'vis_km\\': 16.0, \\'vis_miles\\': 9.0, \\'uv\\': 0.5, \\'gust_mph\\': 11.2, \\'gust_kph\\': 18.1}}\", \"score\": 0.9062231, \"raw_content\": null}], \"response_time\": 1.99, \"request_id\": \"8b4f7b01-dac0-48d8-bfa2-401e47cddfdc\"}', 'prompt': 'Approve the forecast summary before delivery.'}\n"
]
}
],
"source": [
"thread_config = {'configurable': {'thread_id': '1'}}\n",
"\n",
"initial_result = super_agent.invoke(\n",
" {'messages': [{'role': 'user', 'content': 'What is the current weather in San Francisco?'}]},\n",
" config=thread_config,\n",
")\n",
"interrupt = initial_result['__interrupt__'][0]\n",
"print('Super agent interrupt payload:', interrupt.value)"
]
},
{
"cell_type": "markdown",
"id": "4d9dbb49",
"metadata": {},
"source": [
"## Resume Agent After Approval\n",
"\n",
"This cell demonstrates resuming the agent after approval:\n",
"\n",
"- **Resume Command**: Creates a `Command` object that resumes the interrupted tool call with an \"Approved\" status\n",
"- **Agent Resume**: Invokes the super agent again with the resume command, continuing from where it was interrupted\n",
"- **Result Extraction**: Extracts tool messages and the final agent response to show the complete workflow result\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "2e2f2b58",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tool messages after super-agent resume: ['Current weather in San Francisco, CA (local time: 16:17)\\n\\n- Condition: Partly cloudy\\n- Temperature: 18.9°C (66.0°F)\\n- Feels like: 18.9°C (66.0°F)\\n- Wind: 7.6 mph from the West (gusts up to 11.2 mph)\\n- Humidity: 65%\\n- Dew point: 14.6°C\\n- Pressure: 1023 mb\\n- Visibility: 16 km\\n- Cloud cover: ~25%\\n- Precipitation: 0.0 mm in the last hour\\n- UV index: 0.5 (low)\\n\\nNotes: Pleasantly mild for this time of day. A light jacket might be comfortable, especially in the evening when it cools off. No rain expected soon.']\n",
"Final super agent reply: Heres the current weather for San Francisco, CA (local time 16:17):\n",
"\n",
"- Condition: Partly cloudy\n",
"- Temperature: 18.9°C (66.0°F); feels like 18.9°C\n",
"- Wind: 7.6 mph from the West; gusts up to 11.2 mph\n",
"- Humidity: 65%\n",
"- Dew point: 14.6°C\n",
"- Pressure: 1023 mb\n",
"- Visibility: 16 km\n",
"- Cloud cover: ~25%\n",
"- Precipitation (last hour): 0.0 mm\n",
"- UV index: 0.5 (low)\n",
"\n",
"Notes: Pleasantly mild for this time of day. A light jacket might be comfortable, especially this evening. No rain expected soon.\n",
"\n",
"Would you like an hourly forecast or a short-term outlook for the rest of the day?\n"
]
}
],
"source": [
"resume_cmd = Command(resume={interrupt.value['tool_call_id']: 'Approved'})\n",
"resume_result = super_agent.invoke(resume_cmd, config=thread_config)\n",
"tool_messages = [m.content for m in resume_result['messages'] if isinstance(m, ToolMessage)]\n",
"print('Tool messages after super-agent resume:', tool_messages)\n",
"print('Final super agent reply:', resume_result['messages'][-1].content)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
+147 -78
View File
@@ -2,21 +2,10 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"id": "0054ac84",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"import os\n",
"from typing import Literal\n",
@@ -57,7 +46,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"id": "202783f9",
"metadata": {},
"outputs": [],
@@ -164,7 +153,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"id": "9b24d02d",
"metadata": {},
"outputs": [],
@@ -198,31 +187,10 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"id": "000cc37c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
"\n",
"To provide you with the most accurate information, I need a bit more detail:\n",
"\n",
"1. **Region**: Which region would you like to see? (e.g., company-wide/all regions, North America, Europe, Asia Pacific, etc.)\n",
"\n",
"2. **Time Period**: What time frame are you interested in? (e.g., latest month, latest quarter, trailing 12 months, year-to-date, etc.)\n",
"\n",
"3. **Metric Type**: Which growth rate are you looking for?\n",
" - Revenue growth rate\n",
" - Sales growth rate\n",
" - Other specific metrics (e.g., units growth, gross margin growth)\n",
"\n",
"Alternatively, I can pull all available growth-related metrics for a general overview. What would be most helpful for you?\n"
]
}
],
"outputs": [],
"source": [
"config = {\"configurable\": {\"thread_id\": \"1\"}}\n",
"result = sales_deep_agent.invoke({\n",
@@ -266,20 +234,10 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"id": "7704f7b0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"First run (tool message): Stateful tracker has run 1 time(s).\n",
"Second run (tool message): Stateful tracker has run 2 time(s).\n",
"Persisted subagent state: {'runs': 2, 'log': ['run 1', 'run 2']}\n"
]
}
],
"outputs": [],
"source": [
"from typing import TypedDict\n",
"\n",
@@ -410,22 +368,10 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"id": "f429544b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Initial tool messages: ['Subagent paused: Which region should I query?']\n",
"Pending interrupt id: None\n",
"Tool message after resume: Region captured: North America\n",
"Persisted region state: {'region': 'North America'}\n",
"Deep agent final reply: Thanks, continuing with the specified region.\n"
]
}
],
"outputs": [],
"source": [
"from typing import TypedDict\n",
"\n",
@@ -558,22 +504,10 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"id": "38411ecc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Initial tool messages: ['Draft summary for North America: North America revenue forecast indicates +12% YoY growth.\\nAwaiting approval.']\n",
"Initial final reply: Draft generated; awaiting approval.\n",
"Tool messages after approval: ['Draft summary for North America: North America revenue forecast indicates +12% YoY growth.\\nAwaiting approval.', 'Approved summary for North America: North America revenue forecast indicates +12% YoY growth. (approval: Approved).']\n",
"Final orchestrator reply: Report approved and delivered to the stakeholder.\n",
"Persisted report state: {'region': 'North America', 'summary': 'North America revenue forecast indicates +12% YoY growth.', 'approval': 'Approved'}\n"
]
}
],
"outputs": [],
"source": [
"from typing import TypedDict\n",
"\n",
@@ -712,6 +646,141 @@
"print(\"Final orchestrator reply:\", second_run[\"messages\"][-1].content)\n",
"print(\"Persisted report state:\", report_graph.get_state(config={\"configurable\": {\"thread_id\": \"report-subagent\"}}).values)\n"
]
},
{
"cell_type": "markdown",
"id": "55671ea0",
"metadata": {},
"source": [
"## Example 4: Super agent handles resume\n",
"\n",
"Bubble the post-tool interrupt up to the super agent so the parent thread pauses, collects human input, and resumes the subagent when `Command(resume=...)` arrives on the main graph.\n"
]
},
{
"cell_type": "markdown",
"id": "27359897",
"metadata": {},
"source": [
"### **Expected output**: \n",
"- First invocation returns a payload describing the drafted summary and approval prompt.\n",
"- Sending `Command(resume='Approved')` to the super agent continues execution and returns an approved summary tool message plus a final orchestrator reply.\n",
"- Console output shows the interrupt payload followed by the resumed tool message and final reply.\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "3b37b52e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Super agent interrupt payload: {'tool_call_id': 'approval_tool_call', 'summary': 'Forecast for North America: +12% YoY growth.', 'prompt': 'Approve the forecast summary before delivery.'}\n",
"Tool messages after super-agent resume: ['Summary approved and delivered to the stakeholder.']\n",
"Final super agent reply: Approval noted. The forecast summary has been sent to the stakeholder.\n"
]
}
],
"source": [
"from typing import TypedDict\n",
"\n",
"from langchain.agents import create_agent\n",
"from langchain.agents.middleware.types import wrap_tool_call\n",
"from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel\n",
"from langchain_core.messages import AIMessage, ToolCall, ToolMessage\n",
"from langchain_core.tools import tool\n",
"\n",
"from langgraph.checkpoint.memory import InMemorySaver\n",
"from langgraph.errors import GraphInterrupt\n",
"from langgraph.types import Command, Interrupt\n",
"\n",
"from deepagents.graph import create_deep_agent\n",
"from deepagents.middleware.subagents import CompiledSubAgent\n",
"\n",
"@tool\n",
"def fetch_forecast_summary(region: str = \"North America\") -> str:\n",
" '\"Return a short revenue forecast summary for the region.\"'\n",
" return f\"Forecast for {region}: +12% YoY growth.\"\n",
"\n",
"class ApprovalSubModel(FakeMessagesListChatModel):\n",
" def bind_tools(self, tools, **kwargs):\n",
" return self\n",
"\n",
"approval_sub_model = ApprovalSubModel(responses=[\n",
" AIMessage(content='', tool_calls=[ToolCall(name='fetch_forecast_summary', args={}, id='approval_tool_call')]),\n",
" AIMessage(content='Summary approved and delivered to the stakeholder.'),\n",
"])\n",
"\n",
"@wrap_tool_call\n",
"def post_tool_approval(request, handler):\n",
" result = handler(request)\n",
" is_resuming = request.runtime.config.get('configurable', {}).get('__pregel_resuming')\n",
" if hasattr(result, 'content') and not is_resuming:\n",
" raise GraphInterrupt((Interrupt(value={\n",
" 'tool_call_id': request.tool_call['id'],\n",
" 'summary': result.content,\n",
" 'prompt': 'Approve the forecast summary before delivery.',\n",
" }),))\n",
" return result\n",
"\n",
"approval_subagent_ckpt = InMemorySaver()\n",
"approval_subagent_runnable = create_agent(\n",
" approval_sub_model,\n",
" system_prompt='Generate a revenue forecast summary by calling the fetch_forecast_summary tool.',\n",
" tools=[fetch_forecast_summary],\n",
" middleware=[post_tool_approval],\n",
" checkpointer=approval_subagent_ckpt,\n",
")\n",
"\n",
"approval_subagent = CompiledSubAgent(\n",
" name='approval-agent',\n",
" description='Produces a forecast summary and pauses until the summary is approved.',\n",
" runnable=approval_subagent_runnable,\n",
")\n",
"\n",
"class SuperAgentModel(FakeMessagesListChatModel):\n",
" def bind_tools(self, tools, **kwargs):\n",
" return self\n",
"\n",
"super_agent_model = SuperAgentModel(responses=[\n",
" AIMessage(content='', tool_calls=[ToolCall(name='task', args={'subagent_type': 'approval-agent', 'description': 'Run the forecast approval workflow'}, id='task_call_1')]),\n",
" AIMessage(content='Approval noted. The forecast summary has been sent to the stakeholder.'),\n",
"])\n",
"\n",
"super_agent_ckpt = InMemorySaver()\n",
"super_agent = create_deep_agent(\n",
" model=super_agent_model,\n",
" subagents=[approval_subagent],\n",
" checkpointer=super_agent_ckpt,\n",
")\n",
"\n",
"thread_config = {'configurable': {'thread_id': 'example-4-thread'}}\n",
"\n",
"initial_result = super_agent.invoke(\n",
" {'messages': [{'role': 'user', 'content': 'Prepare the forecast summary and wait for my approval.'}]},\n",
" config=thread_config,\n",
")\n",
"interrupt = initial_result['__interrupt__'][0]\n",
"print('Super agent interrupt payload:', interrupt.value)\n",
"\n",
"resume_cmd = Command(resume={interrupt.value['tool_call_id']: 'Approved'})\n",
"resume_result = super_agent.invoke(resume_cmd, config=thread_config)\n",
"tool_messages = [m.content for m in resume_result['messages'] if isinstance(m, ToolMessage)]\n",
"print('Tool messages after super-agent resume:', tool_messages)\n",
"print('Final super agent reply:', resume_result['messages'][-1].content)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "46513550",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {