mirror of
https://github.com/run-llama/multi-agents-workflow.git
synced 2026-06-30 21:27:55 -04:00
feat: add bubbling of events
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
MODEL_PROVIDER=openai
|
||||
|
||||
# The name of LLM model to use.
|
||||
MODEL=gpt-4o
|
||||
MODEL=gpt-4o-mini
|
||||
|
||||
# Name of the embedding model to use.
|
||||
EMBEDDING_MODEL=text-embedding-3-large
|
||||
|
||||
Vendored
+16
@@ -0,0 +1,16 @@
|
||||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python Debugger: Current File",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "main.py",
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false
|
||||
}
|
||||
]
|
||||
}
|
||||
+48
-39
@@ -1,13 +1,54 @@
|
||||
import asyncio
|
||||
from typing import Any, List
|
||||
|
||||
from llama_index.core.tools import FunctionTool
|
||||
from llama_index.core.tools.types import ToolMetadata, ToolOutput
|
||||
from llama_index.core.tools.utils import create_schema_from_function
|
||||
from llama_index.core.workflow import Context, Workflow
|
||||
|
||||
from app.core.function_call import (
|
||||
AgentRunResult,
|
||||
ContextAwareTool,
|
||||
FunctionCallingAgent,
|
||||
)
|
||||
from app.core.planner_agent import StructuredPlannerAgent
|
||||
from app.core.prefix import PrintPrefix
|
||||
from app.core.function_call import AgentRunResult, FunctionCallingAgent
|
||||
|
||||
import textwrap
|
||||
|
||||
class AgentCallTool(ContextAwareTool):
|
||||
def __init__(self, agent: Workflow) -> None:
|
||||
self.agent = agent
|
||||
# create the schema without the context
|
||||
name = f"call_{agent.name}"
|
||||
|
||||
async def schema_call(input: str) -> str:
|
||||
pass
|
||||
|
||||
# create the schema without the Context
|
||||
fn_schema = create_schema_from_function(name, schema_call)
|
||||
self._metadata = ToolMetadata(
|
||||
name=name,
|
||||
description=(
|
||||
f"Use this tool to delegate a sub task to the {agent.name} agent."
|
||||
+ (f" The agent is an {agent.role}." if agent.role else "")
|
||||
),
|
||||
fn_schema=fn_schema,
|
||||
)
|
||||
|
||||
# overload the acall function with the ctx argument as it's needed for bubbling the events
|
||||
async def acall(self, ctx: Context, input: str) -> ToolOutput:
|
||||
# FIXME: reset contexts, not needed after https://github.com/run-llama/llama_index/pull/15776
|
||||
self.agent._contexts = set()
|
||||
task = asyncio.create_task(self.agent.run(input=input))
|
||||
# bubble all events while running the agent to the calling agent
|
||||
async for ev in self.agent.stream_events():
|
||||
ctx.write_event_to_stream(ev)
|
||||
ret: AgentRunResult = await task
|
||||
response = ret.response.message.content
|
||||
return ToolOutput(
|
||||
content=str(response),
|
||||
tool_name=self.metadata.name,
|
||||
raw_input={"args": input, "kwargs": {}},
|
||||
raw_output=response,
|
||||
)
|
||||
|
||||
|
||||
class AgentCallingAgent(FunctionCallingAgent):
|
||||
@@ -19,7 +60,7 @@ class AgentCallingAgent(FunctionCallingAgent):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
agents = agents or []
|
||||
tools = [_create_call_workflow_fn(self, agent) for agent in agents]
|
||||
tools = [AgentCallTool(agent=agent) for agent in agents]
|
||||
super().__init__(*args, name=name, tools=tools, **kwargs)
|
||||
# call add_workflows so agents will get detected by llama agents automatically
|
||||
self.add_workflows(**{agent.name: agent for agent in agents})
|
||||
@@ -29,12 +70,12 @@ class AgentOrchestrator(StructuredPlannerAgent):
|
||||
def __init__(
|
||||
self,
|
||||
*args: Any,
|
||||
agents: List[FunctionCallingAgent] | None = None,
|
||||
name: str = "orchestrator",
|
||||
agents: List[FunctionCallingAgent] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
agents = agents or []
|
||||
tools = [_create_call_workflow_fn(self, agent) for agent in agents]
|
||||
tools = [AgentCallTool(agent=agent) for agent in agents]
|
||||
super().__init__(
|
||||
*args,
|
||||
name=name,
|
||||
@@ -43,35 +84,3 @@ class AgentOrchestrator(StructuredPlannerAgent):
|
||||
)
|
||||
# call add_workflows so agents will get detected by llama agents automatically
|
||||
self.add_workflows(**{agent.name: agent for agent in agents})
|
||||
|
||||
|
||||
def _create_call_workflow_fn(
|
||||
caller: FunctionCallingAgent, agent: FunctionCallingAgent
|
||||
) -> FunctionTool:
|
||||
def info(prefix: str, text: str) -> None:
|
||||
truncated = textwrap.shorten(text, width=255, placeholder="...")
|
||||
print(f"{prefix}: '{truncated}'")
|
||||
|
||||
async def acall_workflow_fn(input: str) -> str:
|
||||
# info(f"[{caller_name}->{agent.name}]", input)
|
||||
task = asyncio.create_task(agent.run(input=input))
|
||||
# bubble all events while running the agent to the calling agent
|
||||
if len(caller._sessions) > 1:
|
||||
print("XXX: Bubbling events only works with single-session agents")
|
||||
else:
|
||||
session = next(iter(caller._sessions))
|
||||
async for ev in agent.stream_events():
|
||||
session.write_event_to_stream(ev)
|
||||
ret: AgentRunResult = await task
|
||||
response = ret.response.message.content
|
||||
# info(f"[{caller_name}<-{agent.name}]", response)
|
||||
return response
|
||||
|
||||
return FunctionTool.from_defaults(
|
||||
async_fn=acall_workflow_fn,
|
||||
name=f"call_{agent.name}",
|
||||
description=(
|
||||
f"Use this tool to delegate a sub task to the {agent.name} agent."
|
||||
+ (f" The agent is an {agent.role}." if agent.role else "")
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from llama_index.core.llms import ChatMessage, ChatResponse
|
||||
@@ -6,6 +7,8 @@ from llama_index.core.memory import ChatMemoryBuffer
|
||||
from llama_index.core.settings import Settings
|
||||
from llama_index.core.tools import ToolOutput, ToolSelection
|
||||
from llama_index.core.tools.types import BaseTool
|
||||
from llama_index.core.tools import FunctionTool
|
||||
|
||||
from llama_index.core.workflow import (
|
||||
Context,
|
||||
Event,
|
||||
@@ -25,11 +28,30 @@ class ToolCallEvent(Event):
|
||||
tool_calls: list[ToolSelection]
|
||||
|
||||
|
||||
class AgentRunEvent(Event):
|
||||
name: str
|
||||
_msg: str
|
||||
|
||||
@property
|
||||
def msg(self):
|
||||
return self._msg
|
||||
|
||||
@msg.setter
|
||||
def msg(self, value):
|
||||
self._msg = value
|
||||
|
||||
|
||||
class AgentRunResult(BaseModel):
|
||||
response: ChatResponse
|
||||
sources: list[ToolOutput]
|
||||
|
||||
|
||||
class ContextAwareTool(FunctionTool):
|
||||
@abstractmethod
|
||||
async def acall(self, ctx: Context, input: Any) -> ToolOutput:
|
||||
pass
|
||||
|
||||
|
||||
class FunctionCallingAgent(Workflow):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -40,6 +62,7 @@ class FunctionCallingAgent(Workflow):
|
||||
verbose: bool = False,
|
||||
timeout: float = 360.0,
|
||||
name: str,
|
||||
write_events: bool = True,
|
||||
role: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
@@ -47,6 +70,7 @@ class FunctionCallingAgent(Workflow):
|
||||
self.tools = tools or []
|
||||
self.name = name
|
||||
self.role = role
|
||||
self.write_events = write_events
|
||||
|
||||
if llm is None:
|
||||
llm = Settings.llm
|
||||
@@ -72,9 +96,10 @@ class FunctionCallingAgent(Workflow):
|
||||
user_input = ev.input
|
||||
user_msg = ChatMessage(role="user", content=user_input)
|
||||
self.memory.put(user_msg)
|
||||
ctx.session.write_event_to_stream(
|
||||
Event(msg=f"[{self.name}] Start to work on: {user_input}")
|
||||
)
|
||||
if self.write_events:
|
||||
ctx.write_event_to_stream(
|
||||
AgentRunEvent(name=self.name, msg=f"Start to work on: {user_input}")
|
||||
)
|
||||
|
||||
# get chat history
|
||||
chat_history = self.memory.get()
|
||||
@@ -96,7 +121,10 @@ class FunctionCallingAgent(Workflow):
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
ctx.session.write_event_to_stream(Event(msg=f"[{self.name}] Finished task"))
|
||||
if self.write_events:
|
||||
ctx.write_event_to_stream(
|
||||
AgentRunEvent(name=self.name, msg="Finished task")
|
||||
)
|
||||
return StopEvent(
|
||||
result=AgentRunResult(response=response, sources=[*self.sources])
|
||||
)
|
||||
@@ -128,7 +156,11 @@ class FunctionCallingAgent(Workflow):
|
||||
continue
|
||||
|
||||
try:
|
||||
tool_output = await tool.acall(**tool_call.tool_kwargs)
|
||||
if isinstance(tool, ContextAwareTool):
|
||||
# inject context for calling an context aware tool
|
||||
tool_output = await tool.acall(ctx=ctx, **tool_call.tool_kwargs)
|
||||
else:
|
||||
tool_output = await tool.acall(**tool_call.tool_kwargs)
|
||||
self.sources.append(tool_output)
|
||||
tool_msgs.append(
|
||||
ChatMessage(
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
from typing import Any, List
|
||||
|
||||
from llama_index.core.llms.function_calling import FunctionCallingLLM
|
||||
@@ -10,7 +11,7 @@ from llama_index.core.workflow import (
|
||||
step,
|
||||
)
|
||||
|
||||
from app.core.function_call import AgentRunResult, FunctionCallingAgent
|
||||
from app.core.function_call import AgentRunEvent, AgentRunResult, FunctionCallingAgent
|
||||
from app.core.planner import Planner, SubTask, Plan
|
||||
from llama_index.core.tools import BaseTool
|
||||
|
||||
@@ -35,8 +36,7 @@ class PlanEventType(Enum):
|
||||
REFINED = "refined"
|
||||
|
||||
|
||||
class PlanEvent(Event):
|
||||
|
||||
class PlanEvent(AgentRunEvent):
|
||||
event_type: PlanEventType
|
||||
plan: Plan
|
||||
|
||||
@@ -68,9 +68,11 @@ class StructuredPlannerAgent(Workflow):
|
||||
name="executor",
|
||||
llm=llm,
|
||||
tools=self.tools,
|
||||
write_events=False,
|
||||
# it's important to instruct to just return the tool call, otherwise the executor will interpret and change the result
|
||||
system_prompt="You are an expert in completing given tasks by calling the right tool for the task. Just return the result of the tool call. Don't add any information yourself",
|
||||
)
|
||||
self.add_workflows(executor=self.executor)
|
||||
|
||||
@step()
|
||||
async def create_plan(
|
||||
@@ -80,8 +82,8 @@ class StructuredPlannerAgent(Workflow):
|
||||
ctx.data["task"] = ev.input
|
||||
ctx.data["act_plan_id"] = plan_id
|
||||
# inform about the new plan
|
||||
ctx.session.write_event_to_stream(
|
||||
PlanEvent(event_type=PlanEventType.CREATED, plan=plan)
|
||||
ctx.write_event_to_stream(
|
||||
PlanEvent(name=self.name, event_type=PlanEventType.CREATED, plan=plan)
|
||||
)
|
||||
if self._verbose:
|
||||
print("=== Executing plan ===\n")
|
||||
@@ -97,7 +99,7 @@ class StructuredPlannerAgent(Workflow):
|
||||
# send an event per sub task
|
||||
events = [SubTaskEvent(sub_task=sub_task) for sub_task in upcoming_sub_tasks]
|
||||
for event in events:
|
||||
ctx.session.send_event(event)
|
||||
ctx.send_event(event)
|
||||
|
||||
return None
|
||||
|
||||
@@ -107,7 +109,13 @@ class StructuredPlannerAgent(Workflow):
|
||||
) -> SubTaskResultEvent:
|
||||
if self._verbose:
|
||||
print(f"=== Executing sub task: {ev.sub_task.name} ===")
|
||||
result: AgentRunResult = await self.executor.run(input=ev.sub_task.input)
|
||||
# FIXME: reset contexts, not needed after https://github.com/run-llama/llama_index/pull/15776
|
||||
self.executor._contexts = set()
|
||||
task = asyncio.create_task(self.executor.run(input=ev.sub_task.input))
|
||||
# bubble all events while running the executor to the planner
|
||||
async for event in self.executor.stream_events():
|
||||
ctx.write_event_to_stream(event)
|
||||
result: AgentRunResult = await task
|
||||
if self._verbose:
|
||||
print("=== Done executing sub task ===\n")
|
||||
self.planner.state.add_completed_sub_task(ctx.data["act_plan_id"], ev.sub_task)
|
||||
@@ -141,8 +149,10 @@ class StructuredPlannerAgent(Workflow):
|
||||
)
|
||||
# inform about the new plan
|
||||
if new_plan is not None:
|
||||
ctx.session.write_event_to_stream(
|
||||
PlanEvent(event_type=PlanEventType.REFINED, plan=new_plan)
|
||||
ctx.write_event_to_stream(
|
||||
PlanEvent(
|
||||
name=self.name, event_type=PlanEventType.REFINED, plan=new_plan
|
||||
)
|
||||
)
|
||||
|
||||
# continue executing plan
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
# flake8: noqa: E402
|
||||
import asyncio
|
||||
import os
|
||||
import textwrap
|
||||
from dotenv import load_dotenv
|
||||
from app.core.agent_call import AgentCallingAgent, AgentOrchestrator
|
||||
from app.core.function_call import AgentRunResult, FunctionCallingAgent
|
||||
from app.core.function_call import AgentRunEvent, AgentRunResult, FunctionCallingAgent
|
||||
from app.engine.index import get_index
|
||||
from app.settings import init_settings
|
||||
from llama_index.core.tools import QueryEngineTool, ToolMetadata
|
||||
@@ -80,6 +81,11 @@ def create_orchestrator():
|
||||
)
|
||||
|
||||
|
||||
def info(prefix: str, text: str) -> None:
|
||||
truncated = textwrap.shorten(text, width=255, placeholder="...")
|
||||
print(f"[{prefix}] {truncated}")
|
||||
|
||||
|
||||
async def main():
|
||||
# agent = create_choreography()
|
||||
agent = create_orchestrator()
|
||||
@@ -88,10 +94,11 @@ async def main():
|
||||
)
|
||||
|
||||
async for ev in agent.stream_events():
|
||||
print(ev.msg)
|
||||
if isinstance(ev, AgentRunEvent):
|
||||
info(ev.name, ev.msg)
|
||||
|
||||
ret: AgentRunResult = await task
|
||||
print(ret.response.message.content)
|
||||
print(f"\n\nResult:\n\n{ret.response.message.content}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Generated
+9
-9
@@ -767,19 +767,19 @@ pydantic = ">=1.10"
|
||||
|
||||
[[package]]
|
||||
name = "llama-index"
|
||||
version = "0.11.2"
|
||||
version = "0.11.3"
|
||||
description = "Interface between LLMs and your data"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.8.1"
|
||||
files = [
|
||||
{file = "llama_index-0.11.2-py3-none-any.whl", hash = "sha256:3e70d09a48d8aaf479679c3de0598fe7b3276613a6927a5612fcafb2ecef60f0"},
|
||||
{file = "llama_index-0.11.2.tar.gz", hash = "sha256:8430b589e372c2b1614da259c4a8e4c2790d9278cd82f3a3b9e19972e8c2d834"},
|
||||
{file = "llama_index-0.11.3-py3-none-any.whl", hash = "sha256:f307a29b07536bca26cc39f955ec767d10c25becf167d409075bfaba9e2c654f"},
|
||||
{file = "llama_index-0.11.3.tar.gz", hash = "sha256:4133b06931b3dc15b9a66bf127a06e2dbb19fe991e10afe58b7ef44a14085d29"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
llama-index-agent-openai = ">=0.3.0,<0.4.0"
|
||||
llama-index-cli = ">=0.3.0,<0.4.0"
|
||||
llama-index-core = ">=0.11.2,<0.12.0"
|
||||
llama-index-core = ">=0.11.3,<0.12.0"
|
||||
llama-index-embeddings-openai = ">=0.2.0,<0.3.0"
|
||||
llama-index-indices-managed-llama-cloud = ">=0.3.0"
|
||||
llama-index-legacy = ">=0.9.48,<0.10.0"
|
||||
@@ -825,13 +825,13 @@ llama-index-llms-openai = ">=0.2.0,<0.3.0"
|
||||
|
||||
[[package]]
|
||||
name = "llama-index-core"
|
||||
version = "0.11.2"
|
||||
version = "0.11.3"
|
||||
description = "Interface between LLMs and your data"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.8.1"
|
||||
files = [
|
||||
{file = "llama_index_core-0.11.2-py3-none-any.whl", hash = "sha256:6c55667c4943ba197199e21e9b0e4641449f5e5dca662b0c91f5306f8c114e4f"},
|
||||
{file = "llama_index_core-0.11.2.tar.gz", hash = "sha256:eec37976fe3b1baa3bb31bd3c5f6ea821555c7065ac6a55b71b5601a7e097977"},
|
||||
{file = "llama_index_core-0.11.3-py3-none-any.whl", hash = "sha256:061aaf3892707bff6a34fb264ccda30b9a920890379e04c3ae9895edddde5e22"},
|
||||
{file = "llama_index_core-0.11.3.tar.gz", hash = "sha256:6b042e531797ccd755496570d563f3b1d9b42c292d3a311f8c21c8eb6aa57706"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -846,7 +846,7 @@ networkx = ">=3.0"
|
||||
nltk = ">3.8.1"
|
||||
numpy = "<2.0.0"
|
||||
pillow = ">=9.0.0"
|
||||
pydantic = ">=2.0.0,<3.0.0"
|
||||
pydantic = ">=2.7.0,<3.0.0"
|
||||
PyYAML = ">=6.0.1"
|
||||
requests = ">=2.31.0"
|
||||
SQLAlchemy = {version = ">=1.4.49", extras = ["asyncio"]}
|
||||
@@ -2261,4 +2261,4 @@ multidict = ">=4.0"
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.11"
|
||||
content-hash = "6335838be9094e90081212430ea065f2fe7e953904fa59dc4d6a04c1c3684fe1"
|
||||
content-hash = "a9b74298a247273a330e73a485b9a50bb2f2ba19b55c74f5b437288af60f4d79"
|
||||
|
||||
+1
-1
@@ -12,7 +12,7 @@ generate = "app.engine.generate:generate_datasource"
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.11"
|
||||
llama-index-agent-openai = ">=0.3.0,<0.4.0"
|
||||
llama-index = "^0.11.2"
|
||||
llama-index = "^0.11.3"
|
||||
|
||||
[tool.poetry.dependencies.docx2txt]
|
||||
version = "^0.8"
|
||||
|
||||
Reference in New Issue
Block a user