feat: add streaming for orchestrator

This commit is contained in:
Marcus Schiesser
2024-09-04 15:53:32 +07:00
committed by Marcus Schiesser
parent 6cce05fa27
commit 6d72f65f0c
3 changed files with 39 additions and 17 deletions
+1 -1
View File
@@ -32,4 +32,4 @@ TOP_K=3
EXAMPLE_TYPE=workflow
# Set it to true to start FastAPI endpoint
FAST_API=true
FAST_API=false
+36 -14
View File
@@ -1,7 +1,7 @@
import asyncio
import uuid
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
from llama_index.core.agent.runner.planner import (
DEFAULT_INITIAL_PLAN_PROMPT,
@@ -37,7 +37,7 @@ class SubTaskEvent(Event):
class SubTaskResultEvent(Event):
sub_task: SubTask
result: AgentRunResult
result: AgentRunResult | AsyncGenerator
class PlanEventType(Enum):
@@ -87,9 +87,13 @@ class StructuredPlannerAgent(Workflow):
async def create_plan(
self, ctx: Context, ev: StartEvent
) -> ExecutePlanEvent | StopEvent:
plan_id, plan = await self.planner.create_plan(input=ev.input)
# set streaming
ctx.data["streaming"] = getattr(ev, "streaming", False)
ctx.data["task"] = ev.input
plan_id, plan = await self.planner.create_plan(input=ev.input)
ctx.data["act_plan_id"] = plan_id
# inform about the new plan
ctx.write_event_to_stream(
PlanEvent(name=self.name, event_type=PlanEventType.CREATED, plan=plan)
@@ -118,11 +122,19 @@ class StructuredPlannerAgent(Workflow):
) -> SubTaskResultEvent:
if self._verbose:
print(f"=== Executing sub task: {ev.sub_task.name} ===")
task = asyncio.create_task(self.executor.run(input=ev.sub_task.input))
is_last_tasks = ctx.data["num_sub_tasks"] == self.get_remaining_subtasks(ctx)
# TODO: streaming only works without plan refining
streaming = is_last_tasks and ctx.data["streaming"] and not self.refine_plan
task = asyncio.create_task(
self.executor.run(
input=ev.sub_task.input,
streaming=streaming,
)
)
# bubble all events while running the executor to the planner
async for event in self.executor.stream_events():
ctx.write_event_to_stream(event)
result: AgentRunResult = await task
result = await task
if self._verbose:
print("=== Done executing sub task ===\n")
self.planner.state.add_completed_sub_task(ctx.data["act_plan_id"], ev.sub_task)
@@ -138,19 +150,17 @@ class StructuredPlannerAgent(Workflow):
if results is None:
return None
# store all results for refining the plan
ctx.data["results"] = ctx.data.get("results", {})
for result in results:
ctx.data["results"][result.sub_task.name] = result.result
upcoming_sub_tasks = self.planner.state.get_next_sub_tasks(
ctx.data["act_plan_id"]
)
upcoming_sub_tasks = self.get_upcoming_sub_tasks(ctx)
# if no more tasks to do, stop workflow and send result of last step
if len(upcoming_sub_tasks) == 0:
if upcoming_sub_tasks == 0:
return StopEvent(result=results[-1].result)
if self.refine_plan:
# store all results for refining the plan
ctx.data["results"] = ctx.data.get("results", {})
for result in results:
ctx.data["results"][result.sub_task.name] = result.result
new_plan = await self.planner.refine_plan(
ctx.data["task"], ctx.data["act_plan_id"], ctx.data["results"]
)
@@ -165,6 +175,18 @@ class StructuredPlannerAgent(Workflow):
# continue executing plan
return ExecutePlanEvent()
def get_upcoming_sub_tasks(self, ctx: Context):
upcoming_sub_tasks = self.planner.state.get_next_sub_tasks(
ctx.data["act_plan_id"]
)
return len(upcoming_sub_tasks)
def get_remaining_subtasks(self, ctx: Context):
remaining_subtasks = self.planner.state.get_remaining_subtasks(
ctx.data["act_plan_id"]
)
return len(remaining_subtasks)
# Concern dealing with creating and refining a plan, extracted from https://github.com/run-llama/llama_index/blob/main/llama-index-core/llama_index/core/agent/runner/planner.py#L138
class Planner:
+2 -2
View File
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Optional
from app.agents.single import FunctionCallingAgent
from app.agents.multi import AgentOrchestrator
from app.examples.researcher import create_researcher
@@ -6,7 +6,7 @@ from app.examples.researcher import create_researcher
from llama_index.core.chat_engine.types import ChatMessage
def create_orchestrator(chat_history: List[ChatMessage]):
def create_orchestrator(chat_history: Optional[List[ChatMessage]] = None):
researcher = create_researcher(chat_history)
writer = FunctionCallingAgent(
name="writer",