diff --git a/.gitignore b/.gitignore index c0c056d..0691cc9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ .env .DS_Store -.venv/ \ No newline at end of file +.venv/ +workbook.ipynb \ No newline at end of file diff --git a/stateful_deepagent.ipynb b/stateful_deepagent.ipynb new file mode 100644 index 0000000..5572551 --- /dev/null +++ b/stateful_deepagent.ipynb @@ -0,0 +1,213 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "409d0ca9", + "metadata": {}, + "source": [ + "## Setup: Agent Config and Tool Definition" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "74fd8f7e", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Literal\n", + "from tavily import TavilyClient\n", + "import os\n", + "\n", + "from langchain.agents import create_agent\n", + "from langchain_openai import ChatOpenAI\n", + "from langchain.agents.middleware.types import wrap_tool_call\n", + "from langchain_core.messages import ToolMessage\n", + "from langchain_core.tools import tool\n", + "\n", + "from langgraph.checkpoint.memory import InMemorySaver\n", + "from langgraph.errors import GraphInterrupt\n", + "from langgraph.types import Command, Interrupt\n", + "\n", + "from deepagents.graph import create_deep_agent\n", + "from deepagents.middleware.subagents import CompiledSubAgent\n", + "from dotenv import load_dotenv\n", + "\n", + "load_dotenv()\n", + "\n", + "tavily_client = TavilyClient(os.getenv(\"TAVILY_API_KEY\"))\n", + "\n", + "model = ChatOpenAI(model=\"gpt-5-nano\", temperature=0)\n", + "\n", + "@tool\n", + "def get_current_weather(query: str,\n", + " max_results: int = 5,\n", + " topic: Literal[\"general\", \"news\", \"finance\"] = \"general\",\n", + " include_raw_content: bool = False,\n", + "):\n", + " \"\"\"\n", + " A tool to search the web for current weather conditions.\n", + " Args:\n", + " query: The query about weather to search for\n", + " max_results: The maximum number of results to return.\n", + " topic: The topic of the search.\n", + " include_raw_content: Whether to include the raw content of the search results.\n", + " \"\"\"\n", + " return tavily_client.search(\n", + " query,\n", + " max_results=max_results,\n", + " include_raw_content=include_raw_content,\n", + " topic=topic,\n", + " )\n", + "\n", + "@wrap_tool_call\n", + "def post_tool_approval(request, handler):\n", + " result = handler(request)\n", + " is_resuming = request.runtime.config.get('configurable', {}).get('__pregel_resuming')\n", + " if hasattr(result, 'content') and not is_resuming:\n", + " raise GraphInterrupt((Interrupt(value={\n", + " 'tool_call_id': request.tool_call['id'],\n", + " 'summary': result.content,\n", + " 'prompt': 'Approve the forecast summary before delivery.',\n", + " }),))\n", + " return result\n", + "\n", + "approval_subagent_ckpt = InMemorySaver()\n", + "weather_agent_runnable = create_agent(\n", + " model,\n", + " system_prompt='You are a weather expert. Use the get_current_weather tool to get the current weather conditions for a given location.',\n", + " tools=[get_current_weather],\n", + " middleware=[post_tool_approval],\n", + " checkpointer=approval_subagent_ckpt,\n", + ")\n", + "\n", + "weather_agent = CompiledSubAgent(\n", + " name='weather-agent',\n", + " description='A weather expert that uses the get_current_weather tool to get the current weather conditions for a given location.',\n", + " runnable=weather_agent_runnable,\n", + ")\n", + "\n", + "super_agent_ckpt = InMemorySaver()\n", + "super_agent = create_deep_agent(\n", + " model=model,\n", + " system_prompt='You are a weather expert. You have access to a specialist weather agent that can get the current weather conditions for a given location. Route to the weather agent when you need to get the current weather conditions for a given location.',\n", + " subagents=[weather_agent],\n", + " checkpointer=super_agent_ckpt,\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "id": "71db0c2f", + "metadata": {}, + "source": [ + "## Invoke Agent and Handle Interrupt\n", + "\n", + "This cell demonstrates the interrupt workflow:\n", + "\n", + "- **Thread Configuration**: Sets up a thread ID for stateful conversation tracking\n", + "- **Agent Invocation**: Invokes the super agent with a weather query\n", + "- **Interrupt Handling**: The agent execution is interrupted by the approval middleware, which raises a `GraphInterrupt` containing the tool call ID, summary, and approval prompt\n", + "- **Interrupt Inspection**: Extracts and prints the interrupt payload to show what information is available for approval\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "58ab943c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Super agent interrupt payload: {'tool_call_id': 'call_roTTUTTR8ZrLJRM5gDHn2gCM', 'summary': '{\"query\": \"current weather in San Francisco, CA\", \"follow_up_questions\": null, \"answer\": null, \"images\": [], \"results\": [{\"title\": \"Weather in San Francisco, CA\", \"url\": \"https://www.weatherapi.com/\", \"content\": \"{\\'location\\': {\\'name\\': \\'San Francisco\\', \\'region\\': \\'California\\', \\'country\\': \\'United States of America\\', \\'lat\\': 37.775, \\'lon\\': -122.4183, \\'tz_id\\': \\'America/Los_Angeles\\', \\'localtime_epoch\\': 1762474659, \\'localtime\\': \\'2025-11-06 16:17\\'}, \\'current\\': {\\'last_updated_epoch\\': 1762474500, \\'last_updated\\': \\'2025-11-06 16:15\\', \\'temp_c\\': 18.9, \\'temp_f\\': 66.0, \\'is_day\\': 1, \\'condition\\': {\\'text\\': \\'Partly Cloudy\\', \\'icon\\': \\'//cdn.weatherapi.com/weather/64x64/day/116.png\\', \\'code\\': 1003}, \\'wind_mph\\': 7.6, \\'wind_kph\\': 12.2, \\'wind_degree\\': 268, \\'wind_dir\\': \\'W\\', \\'pressure_mb\\': 1023.0, \\'pressure_in\\': 30.2, \\'precip_mm\\': 0.0, \\'precip_in\\': 0.0, \\'humidity\\': 65, \\'cloud\\': 25, \\'feelslike_c\\': 18.9, \\'feelslike_f\\': 66.0, \\'windchill_c\\': 16.6, \\'windchill_f\\': 61.8, \\'heatindex_c\\': 16.5, \\'heatindex_f\\': 61.8, \\'dewpoint_c\\': 14.6, \\'dewpoint_f\\': 58.2, \\'vis_km\\': 16.0, \\'vis_miles\\': 9.0, \\'uv\\': 0.5, \\'gust_mph\\': 11.2, \\'gust_kph\\': 18.1}}\", \"score\": 0.9062231, \"raw_content\": null}], \"response_time\": 1.99, \"request_id\": \"8b4f7b01-dac0-48d8-bfa2-401e47cddfdc\"}', 'prompt': 'Approve the forecast summary before delivery.'}\n" + ] + } + ], + "source": [ + "thread_config = {'configurable': {'thread_id': '1'}}\n", + "\n", + "initial_result = super_agent.invoke(\n", + " {'messages': [{'role': 'user', 'content': 'What is the current weather in San Francisco?'}]},\n", + " config=thread_config,\n", + ")\n", + "interrupt = initial_result['__interrupt__'][0]\n", + "print('Super agent interrupt payload:', interrupt.value)" + ] + }, + { + "cell_type": "markdown", + "id": "4d9dbb49", + "metadata": {}, + "source": [ + "## Resume Agent After Approval\n", + "\n", + "This cell demonstrates resuming the agent after approval:\n", + "\n", + "- **Resume Command**: Creates a `Command` object that resumes the interrupted tool call with an \"Approved\" status\n", + "- **Agent Resume**: Invokes the super agent again with the resume command, continuing from where it was interrupted\n", + "- **Result Extraction**: Extracts tool messages and the final agent response to show the complete workflow result\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "2e2f2b58", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tool messages after super-agent resume: ['Current weather in San Francisco, CA (local time: 16:17)\\n\\n- Condition: Partly cloudy\\n- Temperature: 18.9°C (66.0°F)\\n- Feels like: 18.9°C (66.0°F)\\n- Wind: 7.6 mph from the West (gusts up to 11.2 mph)\\n- Humidity: 65%\\n- Dew point: 14.6°C\\n- Pressure: 1023 mb\\n- Visibility: 16 km\\n- Cloud cover: ~25%\\n- Precipitation: 0.0 mm in the last hour\\n- UV index: 0.5 (low)\\n\\nNotes: Pleasantly mild for this time of day. A light jacket might be comfortable, especially in the evening when it cools off. No rain expected soon.']\n", + "Final super agent reply: Here’s the current weather for San Francisco, CA (local time 16:17):\n", + "\n", + "- Condition: Partly cloudy\n", + "- Temperature: 18.9°C (66.0°F); feels like 18.9°C\n", + "- Wind: 7.6 mph from the West; gusts up to 11.2 mph\n", + "- Humidity: 65%\n", + "- Dew point: 14.6°C\n", + "- Pressure: 1023 mb\n", + "- Visibility: 16 km\n", + "- Cloud cover: ~25%\n", + "- Precipitation (last hour): 0.0 mm\n", + "- UV index: 0.5 (low)\n", + "\n", + "Notes: Pleasantly mild for this time of day. A light jacket might be comfortable, especially this evening. No rain expected soon.\n", + "\n", + "Would you like an hourly forecast or a short-term outlook for the rest of the day?\n" + ] + } + ], + "source": [ + "resume_cmd = Command(resume={interrupt.value['tool_call_id']: 'Approved'})\n", + "resume_result = super_agent.invoke(resume_cmd, config=thread_config)\n", + "tool_messages = [m.content for m in resume_result['messages'] if isinstance(m, ToolMessage)]\n", + "print('Tool messages after super-agent resume:', tool_messages)\n", + "print('Final super agent reply:', resume_result['messages'][-1].content)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/workbook.ipynb b/workbook.ipynb index b2fca56..083d073 100644 --- a/workbook.ipynb +++ b/workbook.ipynb @@ -2,21 +2,10 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "0054ac84", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "import os\n", "from typing import Literal\n", @@ -57,7 +46,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "202783f9", "metadata": {}, "outputs": [], @@ -164,7 +153,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "id": "9b24d02d", "metadata": {}, "outputs": [], @@ -198,31 +187,10 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "id": "000cc37c", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "==================================\u001b[1m Ai Message \u001b[0m==================================\n", - "\n", - "To provide you with the most accurate information, I need a bit more detail:\n", - "\n", - "1. **Region**: Which region would you like to see? (e.g., company-wide/all regions, North America, Europe, Asia Pacific, etc.)\n", - "\n", - "2. **Time Period**: What time frame are you interested in? (e.g., latest month, latest quarter, trailing 12 months, year-to-date, etc.)\n", - "\n", - "3. **Metric Type**: Which growth rate are you looking for?\n", - " - Revenue growth rate\n", - " - Sales growth rate\n", - " - Other specific metrics (e.g., units growth, gross margin growth)\n", - "\n", - "Alternatively, I can pull all available growth-related metrics for a general overview. What would be most helpful for you?\n" - ] - } - ], + "outputs": [], "source": [ "config = {\"configurable\": {\"thread_id\": \"1\"}}\n", "result = sales_deep_agent.invoke({\n", @@ -266,20 +234,10 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "7704f7b0", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "First run (tool message): Stateful tracker has run 1 time(s).\n", - "Second run (tool message): Stateful tracker has run 2 time(s).\n", - "Persisted subagent state: {'runs': 2, 'log': ['run 1', 'run 2']}\n" - ] - } - ], + "outputs": [], "source": [ "from typing import TypedDict\n", "\n", @@ -410,22 +368,10 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "id": "f429544b", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Initial tool messages: ['Subagent paused: Which region should I query?']\n", - "Pending interrupt id: None\n", - "Tool message after resume: Region captured: North America\n", - "Persisted region state: {'region': 'North America'}\n", - "Deep agent final reply: Thanks, continuing with the specified region.\n" - ] - } - ], + "outputs": [], "source": [ "from typing import TypedDict\n", "\n", @@ -558,22 +504,10 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "id": "38411ecc", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Initial tool messages: ['Draft summary for North America: North America revenue forecast indicates +12% YoY growth.\\nAwaiting approval.']\n", - "Initial final reply: Draft generated; awaiting approval.\n", - "Tool messages after approval: ['Draft summary for North America: North America revenue forecast indicates +12% YoY growth.\\nAwaiting approval.', 'Approved summary for North America: North America revenue forecast indicates +12% YoY growth. (approval: Approved).']\n", - "Final orchestrator reply: Report approved and delivered to the stakeholder.\n", - "Persisted report state: {'region': 'North America', 'summary': 'North America revenue forecast indicates +12% YoY growth.', 'approval': 'Approved'}\n" - ] - } - ], + "outputs": [], "source": [ "from typing import TypedDict\n", "\n", @@ -712,6 +646,141 @@ "print(\"Final orchestrator reply:\", second_run[\"messages\"][-1].content)\n", "print(\"Persisted report state:\", report_graph.get_state(config={\"configurable\": {\"thread_id\": \"report-subagent\"}}).values)\n" ] + }, + { + "cell_type": "markdown", + "id": "55671ea0", + "metadata": {}, + "source": [ + "## Example 4: Super agent handles resume\n", + "\n", + "Bubble the post-tool interrupt up to the super agent so the parent thread pauses, collects human input, and resumes the subagent when `Command(resume=...)` arrives on the main graph.\n" + ] + }, + { + "cell_type": "markdown", + "id": "27359897", + "metadata": {}, + "source": [ + "### **Expected output**: \n", + "- First invocation returns a payload describing the drafted summary and approval prompt.\n", + "- Sending `Command(resume='Approved')` to the super agent continues execution and returns an approved summary tool message plus a final orchestrator reply.\n", + "- Console output shows the interrupt payload followed by the resumed tool message and final reply.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "3b37b52e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Super agent interrupt payload: {'tool_call_id': 'approval_tool_call', 'summary': 'Forecast for North America: +12% YoY growth.', 'prompt': 'Approve the forecast summary before delivery.'}\n", + "Tool messages after super-agent resume: ['Summary approved and delivered to the stakeholder.']\n", + "Final super agent reply: Approval noted. The forecast summary has been sent to the stakeholder.\n" + ] + } + ], + "source": [ + "from typing import TypedDict\n", + "\n", + "from langchain.agents import create_agent\n", + "from langchain.agents.middleware.types import wrap_tool_call\n", + "from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel\n", + "from langchain_core.messages import AIMessage, ToolCall, ToolMessage\n", + "from langchain_core.tools import tool\n", + "\n", + "from langgraph.checkpoint.memory import InMemorySaver\n", + "from langgraph.errors import GraphInterrupt\n", + "from langgraph.types import Command, Interrupt\n", + "\n", + "from deepagents.graph import create_deep_agent\n", + "from deepagents.middleware.subagents import CompiledSubAgent\n", + "\n", + "@tool\n", + "def fetch_forecast_summary(region: str = \"North America\") -> str:\n", + " '\"Return a short revenue forecast summary for the region.\"'\n", + " return f\"Forecast for {region}: +12% YoY growth.\"\n", + "\n", + "class ApprovalSubModel(FakeMessagesListChatModel):\n", + " def bind_tools(self, tools, **kwargs):\n", + " return self\n", + "\n", + "approval_sub_model = ApprovalSubModel(responses=[\n", + " AIMessage(content='', tool_calls=[ToolCall(name='fetch_forecast_summary', args={}, id='approval_tool_call')]),\n", + " AIMessage(content='Summary approved and delivered to the stakeholder.'),\n", + "])\n", + "\n", + "@wrap_tool_call\n", + "def post_tool_approval(request, handler):\n", + " result = handler(request)\n", + " is_resuming = request.runtime.config.get('configurable', {}).get('__pregel_resuming')\n", + " if hasattr(result, 'content') and not is_resuming:\n", + " raise GraphInterrupt((Interrupt(value={\n", + " 'tool_call_id': request.tool_call['id'],\n", + " 'summary': result.content,\n", + " 'prompt': 'Approve the forecast summary before delivery.',\n", + " }),))\n", + " return result\n", + "\n", + "approval_subagent_ckpt = InMemorySaver()\n", + "approval_subagent_runnable = create_agent(\n", + " approval_sub_model,\n", + " system_prompt='Generate a revenue forecast summary by calling the fetch_forecast_summary tool.',\n", + " tools=[fetch_forecast_summary],\n", + " middleware=[post_tool_approval],\n", + " checkpointer=approval_subagent_ckpt,\n", + ")\n", + "\n", + "approval_subagent = CompiledSubAgent(\n", + " name='approval-agent',\n", + " description='Produces a forecast summary and pauses until the summary is approved.',\n", + " runnable=approval_subagent_runnable,\n", + ")\n", + "\n", + "class SuperAgentModel(FakeMessagesListChatModel):\n", + " def bind_tools(self, tools, **kwargs):\n", + " return self\n", + "\n", + "super_agent_model = SuperAgentModel(responses=[\n", + " AIMessage(content='', tool_calls=[ToolCall(name='task', args={'subagent_type': 'approval-agent', 'description': 'Run the forecast approval workflow'}, id='task_call_1')]),\n", + " AIMessage(content='Approval noted. The forecast summary has been sent to the stakeholder.'),\n", + "])\n", + "\n", + "super_agent_ckpt = InMemorySaver()\n", + "super_agent = create_deep_agent(\n", + " model=super_agent_model,\n", + " subagents=[approval_subagent],\n", + " checkpointer=super_agent_ckpt,\n", + ")\n", + "\n", + "thread_config = {'configurable': {'thread_id': 'example-4-thread'}}\n", + "\n", + "initial_result = super_agent.invoke(\n", + " {'messages': [{'role': 'user', 'content': 'Prepare the forecast summary and wait for my approval.'}]},\n", + " config=thread_config,\n", + ")\n", + "interrupt = initial_result['__interrupt__'][0]\n", + "print('Super agent interrupt payload:', interrupt.value)\n", + "\n", + "resume_cmd = Command(resume={interrupt.value['tool_call_id']: 'Approved'})\n", + "resume_result = super_agent.invoke(resume_cmd, config=thread_config)\n", + "tool_messages = [m.content for m in resume_result['messages'] if isinstance(m, ToolMessage)]\n", + "print('Tool messages after super-agent resume:', tool_messages)\n", + "print('Final super agent reply:', resume_result['messages'][-1].content)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "46513550", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": {