From 6d72f65f0ccbb063e5a80e608a7ea9f9d55031b7 Mon Sep 17 00:00:00 2001 From: Marcus Schiesser Date: Wed, 4 Sep 2024 15:53:32 +0700 Subject: [PATCH] feat: add streaming for orchestrator --- .env | 2 +- app/agents/planner.py | 50 ++++++++++++++++++++++++++---------- app/examples/orchestrator.py | 4 +-- 3 files changed, 39 insertions(+), 17 deletions(-) diff --git a/.env b/.env index 24b77c0..071dfbf 100644 --- a/.env +++ b/.env @@ -32,4 +32,4 @@ TOP_K=3 EXAMPLE_TYPE=workflow # Set it to true to start FastAPI endpoint -FAST_API=true \ No newline at end of file +FAST_API=false \ No newline at end of file diff --git a/app/agents/planner.py b/app/agents/planner.py index 546957c..8a72def 100644 --- a/app/agents/planner.py +++ b/app/agents/planner.py @@ -1,7 +1,7 @@ import asyncio import uuid from enum import Enum -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union from llama_index.core.agent.runner.planner import ( DEFAULT_INITIAL_PLAN_PROMPT, @@ -37,7 +37,7 @@ class SubTaskEvent(Event): class SubTaskResultEvent(Event): sub_task: SubTask - result: AgentRunResult + result: AgentRunResult | AsyncGenerator class PlanEventType(Enum): @@ -87,9 +87,13 @@ class StructuredPlannerAgent(Workflow): async def create_plan( self, ctx: Context, ev: StartEvent ) -> ExecutePlanEvent | StopEvent: - plan_id, plan = await self.planner.create_plan(input=ev.input) + # set streaming + ctx.data["streaming"] = getattr(ev, "streaming", False) ctx.data["task"] = ev.input + + plan_id, plan = await self.planner.create_plan(input=ev.input) ctx.data["act_plan_id"] = plan_id + # inform about the new plan ctx.write_event_to_stream( PlanEvent(name=self.name, event_type=PlanEventType.CREATED, plan=plan) @@ -118,11 +122,19 @@ class StructuredPlannerAgent(Workflow): ) -> SubTaskResultEvent: if self._verbose: print(f"=== Executing sub task: {ev.sub_task.name} ===") - task = asyncio.create_task(self.executor.run(input=ev.sub_task.input)) + is_last_tasks = ctx.data["num_sub_tasks"] == self.get_remaining_subtasks(ctx) + # TODO: streaming only works without plan refining + streaming = is_last_tasks and ctx.data["streaming"] and not self.refine_plan + task = asyncio.create_task( + self.executor.run( + input=ev.sub_task.input, + streaming=streaming, + ) + ) # bubble all events while running the executor to the planner async for event in self.executor.stream_events(): ctx.write_event_to_stream(event) - result: AgentRunResult = await task + result = await task if self._verbose: print("=== Done executing sub task ===\n") self.planner.state.add_completed_sub_task(ctx.data["act_plan_id"], ev.sub_task) @@ -138,19 +150,17 @@ class StructuredPlannerAgent(Workflow): if results is None: return None - # store all results for refining the plan - ctx.data["results"] = ctx.data.get("results", {}) - for result in results: - ctx.data["results"][result.sub_task.name] = result.result - - upcoming_sub_tasks = self.planner.state.get_next_sub_tasks( - ctx.data["act_plan_id"] - ) + upcoming_sub_tasks = self.get_upcoming_sub_tasks(ctx) # if no more tasks to do, stop workflow and send result of last step - if len(upcoming_sub_tasks) == 0: + if upcoming_sub_tasks == 0: return StopEvent(result=results[-1].result) if self.refine_plan: + # store all results for refining the plan + ctx.data["results"] = ctx.data.get("results", {}) + for result in results: + ctx.data["results"][result.sub_task.name] = result.result + new_plan = await self.planner.refine_plan( ctx.data["task"], ctx.data["act_plan_id"], ctx.data["results"] ) @@ -165,6 +175,18 @@ class StructuredPlannerAgent(Workflow): # continue executing plan return ExecutePlanEvent() + def get_upcoming_sub_tasks(self, ctx: Context): + upcoming_sub_tasks = self.planner.state.get_next_sub_tasks( + ctx.data["act_plan_id"] + ) + return len(upcoming_sub_tasks) + + def get_remaining_subtasks(self, ctx: Context): + remaining_subtasks = self.planner.state.get_remaining_subtasks( + ctx.data["act_plan_id"] + ) + return len(remaining_subtasks) + # Concern dealing with creating and refining a plan, extracted from https://github.com/run-llama/llama_index/blob/main/llama-index-core/llama_index/core/agent/runner/planner.py#L138 class Planner: diff --git a/app/examples/orchestrator.py b/app/examples/orchestrator.py index ba346e0..9f91512 100644 --- a/app/examples/orchestrator.py +++ b/app/examples/orchestrator.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional from app.agents.single import FunctionCallingAgent from app.agents.multi import AgentOrchestrator from app.examples.researcher import create_researcher @@ -6,7 +6,7 @@ from app.examples.researcher import create_researcher from llama_index.core.chat_engine.types import ChatMessage -def create_orchestrator(chat_history: List[ChatMessage]): +def create_orchestrator(chat_history: Optional[List[ChatMessage]] = None): researcher = create_researcher(chat_history) writer = FunctionCallingAgent( name="writer",