feat: add bubbling of events

This commit is contained in:
Marcus Schiesser
2024-09-02 13:17:57 +07:00
parent aac9151880
commit b6201a8db9
8 changed files with 141 additions and 67 deletions
+1 -1
View File
@@ -5,7 +5,7 @@
MODEL_PROVIDER=openai
# The name of LLM model to use.
MODEL=gpt-4o
MODEL=gpt-4o-mini
# Name of the embedding model to use.
EMBEDDING_MODEL=text-embedding-3-large
+16
View File
@@ -0,0 +1,16 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python Debugger: Current File",
"type": "debugpy",
"request": "launch",
"program": "main.py",
"console": "integratedTerminal",
"justMyCode": false
}
]
}
+48 -39
View File
@@ -1,13 +1,54 @@
import asyncio
from typing import Any, List
from llama_index.core.tools import FunctionTool
from llama_index.core.tools.types import ToolMetadata, ToolOutput
from llama_index.core.tools.utils import create_schema_from_function
from llama_index.core.workflow import Context, Workflow
from app.core.function_call import (
AgentRunResult,
ContextAwareTool,
FunctionCallingAgent,
)
from app.core.planner_agent import StructuredPlannerAgent
from app.core.prefix import PrintPrefix
from app.core.function_call import AgentRunResult, FunctionCallingAgent
import textwrap
class AgentCallTool(ContextAwareTool):
def __init__(self, agent: Workflow) -> None:
self.agent = agent
# create the schema without the context
name = f"call_{agent.name}"
async def schema_call(input: str) -> str:
pass
# create the schema without the Context
fn_schema = create_schema_from_function(name, schema_call)
self._metadata = ToolMetadata(
name=name,
description=(
f"Use this tool to delegate a sub task to the {agent.name} agent."
+ (f" The agent is an {agent.role}." if agent.role else "")
),
fn_schema=fn_schema,
)
# overload the acall function with the ctx argument as it's needed for bubbling the events
async def acall(self, ctx: Context, input: str) -> ToolOutput:
# FIXME: reset contexts, not needed after https://github.com/run-llama/llama_index/pull/15776
self.agent._contexts = set()
task = asyncio.create_task(self.agent.run(input=input))
# bubble all events while running the agent to the calling agent
async for ev in self.agent.stream_events():
ctx.write_event_to_stream(ev)
ret: AgentRunResult = await task
response = ret.response.message.content
return ToolOutput(
content=str(response),
tool_name=self.metadata.name,
raw_input={"args": input, "kwargs": {}},
raw_output=response,
)
class AgentCallingAgent(FunctionCallingAgent):
@@ -19,7 +60,7 @@ class AgentCallingAgent(FunctionCallingAgent):
**kwargs: Any,
) -> None:
agents = agents or []
tools = [_create_call_workflow_fn(self, agent) for agent in agents]
tools = [AgentCallTool(agent=agent) for agent in agents]
super().__init__(*args, name=name, tools=tools, **kwargs)
# call add_workflows so agents will get detected by llama agents automatically
self.add_workflows(**{agent.name: agent for agent in agents})
@@ -29,12 +70,12 @@ class AgentOrchestrator(StructuredPlannerAgent):
def __init__(
self,
*args: Any,
agents: List[FunctionCallingAgent] | None = None,
name: str = "orchestrator",
agents: List[FunctionCallingAgent] | None = None,
**kwargs: Any,
) -> None:
agents = agents or []
tools = [_create_call_workflow_fn(self, agent) for agent in agents]
tools = [AgentCallTool(agent=agent) for agent in agents]
super().__init__(
*args,
name=name,
@@ -43,35 +84,3 @@ class AgentOrchestrator(StructuredPlannerAgent):
)
# call add_workflows so agents will get detected by llama agents automatically
self.add_workflows(**{agent.name: agent for agent in agents})
def _create_call_workflow_fn(
caller: FunctionCallingAgent, agent: FunctionCallingAgent
) -> FunctionTool:
def info(prefix: str, text: str) -> None:
truncated = textwrap.shorten(text, width=255, placeholder="...")
print(f"{prefix}: '{truncated}'")
async def acall_workflow_fn(input: str) -> str:
# info(f"[{caller_name}->{agent.name}]", input)
task = asyncio.create_task(agent.run(input=input))
# bubble all events while running the agent to the calling agent
if len(caller._sessions) > 1:
print("XXX: Bubbling events only works with single-session agents")
else:
session = next(iter(caller._sessions))
async for ev in agent.stream_events():
session.write_event_to_stream(ev)
ret: AgentRunResult = await task
response = ret.response.message.content
# info(f"[{caller_name}<-{agent.name}]", response)
return response
return FunctionTool.from_defaults(
async_fn=acall_workflow_fn,
name=f"call_{agent.name}",
description=(
f"Use this tool to delegate a sub task to the {agent.name} agent."
+ (f" The agent is an {agent.role}." if agent.role else "")
),
)
+37 -5
View File
@@ -1,3 +1,4 @@
from abc import abstractmethod
from typing import Any, List, Optional
from llama_index.core.llms import ChatMessage, ChatResponse
@@ -6,6 +7,8 @@ from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.settings import Settings
from llama_index.core.tools import ToolOutput, ToolSelection
from llama_index.core.tools.types import BaseTool
from llama_index.core.tools import FunctionTool
from llama_index.core.workflow import (
Context,
Event,
@@ -25,11 +28,30 @@ class ToolCallEvent(Event):
tool_calls: list[ToolSelection]
class AgentRunEvent(Event):
name: str
_msg: str
@property
def msg(self):
return self._msg
@msg.setter
def msg(self, value):
self._msg = value
class AgentRunResult(BaseModel):
response: ChatResponse
sources: list[ToolOutput]
class ContextAwareTool(FunctionTool):
@abstractmethod
async def acall(self, ctx: Context, input: Any) -> ToolOutput:
pass
class FunctionCallingAgent(Workflow):
def __init__(
self,
@@ -40,6 +62,7 @@ class FunctionCallingAgent(Workflow):
verbose: bool = False,
timeout: float = 360.0,
name: str,
write_events: bool = True,
role: Optional[str] = None,
**kwargs: Any,
) -> None:
@@ -47,6 +70,7 @@ class FunctionCallingAgent(Workflow):
self.tools = tools or []
self.name = name
self.role = role
self.write_events = write_events
if llm is None:
llm = Settings.llm
@@ -72,9 +96,10 @@ class FunctionCallingAgent(Workflow):
user_input = ev.input
user_msg = ChatMessage(role="user", content=user_input)
self.memory.put(user_msg)
ctx.session.write_event_to_stream(
Event(msg=f"[{self.name}] Start to work on: {user_input}")
)
if self.write_events:
ctx.write_event_to_stream(
AgentRunEvent(name=self.name, msg=f"Start to work on: {user_input}")
)
# get chat history
chat_history = self.memory.get()
@@ -96,7 +121,10 @@ class FunctionCallingAgent(Workflow):
)
if not tool_calls:
ctx.session.write_event_to_stream(Event(msg=f"[{self.name}] Finished task"))
if self.write_events:
ctx.write_event_to_stream(
AgentRunEvent(name=self.name, msg="Finished task")
)
return StopEvent(
result=AgentRunResult(response=response, sources=[*self.sources])
)
@@ -128,7 +156,11 @@ class FunctionCallingAgent(Workflow):
continue
try:
tool_output = await tool.acall(**tool_call.tool_kwargs)
if isinstance(tool, ContextAwareTool):
# inject context for calling an context aware tool
tool_output = await tool.acall(ctx=ctx, **tool_call.tool_kwargs)
else:
tool_output = await tool.acall(**tool_call.tool_kwargs)
self.sources.append(tool_output)
tool_msgs.append(
ChatMessage(
+19 -9
View File
@@ -1,3 +1,4 @@
import asyncio
from typing import Any, List
from llama_index.core.llms.function_calling import FunctionCallingLLM
@@ -10,7 +11,7 @@ from llama_index.core.workflow import (
step,
)
from app.core.function_call import AgentRunResult, FunctionCallingAgent
from app.core.function_call import AgentRunEvent, AgentRunResult, FunctionCallingAgent
from app.core.planner import Planner, SubTask, Plan
from llama_index.core.tools import BaseTool
@@ -35,8 +36,7 @@ class PlanEventType(Enum):
REFINED = "refined"
class PlanEvent(Event):
class PlanEvent(AgentRunEvent):
event_type: PlanEventType
plan: Plan
@@ -68,9 +68,11 @@ class StructuredPlannerAgent(Workflow):
name="executor",
llm=llm,
tools=self.tools,
write_events=False,
# it's important to instruct to just return the tool call, otherwise the executor will interpret and change the result
system_prompt="You are an expert in completing given tasks by calling the right tool for the task. Just return the result of the tool call. Don't add any information yourself",
)
self.add_workflows(executor=self.executor)
@step()
async def create_plan(
@@ -80,8 +82,8 @@ class StructuredPlannerAgent(Workflow):
ctx.data["task"] = ev.input
ctx.data["act_plan_id"] = plan_id
# inform about the new plan
ctx.session.write_event_to_stream(
PlanEvent(event_type=PlanEventType.CREATED, plan=plan)
ctx.write_event_to_stream(
PlanEvent(name=self.name, event_type=PlanEventType.CREATED, plan=plan)
)
if self._verbose:
print("=== Executing plan ===\n")
@@ -97,7 +99,7 @@ class StructuredPlannerAgent(Workflow):
# send an event per sub task
events = [SubTaskEvent(sub_task=sub_task) for sub_task in upcoming_sub_tasks]
for event in events:
ctx.session.send_event(event)
ctx.send_event(event)
return None
@@ -107,7 +109,13 @@ class StructuredPlannerAgent(Workflow):
) -> SubTaskResultEvent:
if self._verbose:
print(f"=== Executing sub task: {ev.sub_task.name} ===")
result: AgentRunResult = await self.executor.run(input=ev.sub_task.input)
# FIXME: reset contexts, not needed after https://github.com/run-llama/llama_index/pull/15776
self.executor._contexts = set()
task = asyncio.create_task(self.executor.run(input=ev.sub_task.input))
# bubble all events while running the executor to the planner
async for event in self.executor.stream_events():
ctx.write_event_to_stream(event)
result: AgentRunResult = await task
if self._verbose:
print("=== Done executing sub task ===\n")
self.planner.state.add_completed_sub_task(ctx.data["act_plan_id"], ev.sub_task)
@@ -141,8 +149,10 @@ class StructuredPlannerAgent(Workflow):
)
# inform about the new plan
if new_plan is not None:
ctx.session.write_event_to_stream(
PlanEvent(event_type=PlanEventType.REFINED, plan=new_plan)
ctx.write_event_to_stream(
PlanEvent(
name=self.name, event_type=PlanEventType.REFINED, plan=new_plan
)
)
# continue executing plan
+10 -3
View File
@@ -1,9 +1,10 @@
# flake8: noqa: E402
import asyncio
import os
import textwrap
from dotenv import load_dotenv
from app.core.agent_call import AgentCallingAgent, AgentOrchestrator
from app.core.function_call import AgentRunResult, FunctionCallingAgent
from app.core.function_call import AgentRunEvent, AgentRunResult, FunctionCallingAgent
from app.engine.index import get_index
from app.settings import init_settings
from llama_index.core.tools import QueryEngineTool, ToolMetadata
@@ -80,6 +81,11 @@ def create_orchestrator():
)
def info(prefix: str, text: str) -> None:
truncated = textwrap.shorten(text, width=255, placeholder="...")
print(f"[{prefix}] {truncated}")
async def main():
# agent = create_choreography()
agent = create_orchestrator()
@@ -88,10 +94,11 @@ async def main():
)
async for ev in agent.stream_events():
print(ev.msg)
if isinstance(ev, AgentRunEvent):
info(ev.name, ev.msg)
ret: AgentRunResult = await task
print(ret.response.message.content)
print(f"\n\nResult:\n\n{ret.response.message.content}")
if __name__ == "__main__":
Generated
+9 -9
View File
@@ -767,19 +767,19 @@ pydantic = ">=1.10"
[[package]]
name = "llama-index"
version = "0.11.2"
version = "0.11.3"
description = "Interface between LLMs and your data"
optional = false
python-versions = "<4.0,>=3.8.1"
files = [
{file = "llama_index-0.11.2-py3-none-any.whl", hash = "sha256:3e70d09a48d8aaf479679c3de0598fe7b3276613a6927a5612fcafb2ecef60f0"},
{file = "llama_index-0.11.2.tar.gz", hash = "sha256:8430b589e372c2b1614da259c4a8e4c2790d9278cd82f3a3b9e19972e8c2d834"},
{file = "llama_index-0.11.3-py3-none-any.whl", hash = "sha256:f307a29b07536bca26cc39f955ec767d10c25becf167d409075bfaba9e2c654f"},
{file = "llama_index-0.11.3.tar.gz", hash = "sha256:4133b06931b3dc15b9a66bf127a06e2dbb19fe991e10afe58b7ef44a14085d29"},
]
[package.dependencies]
llama-index-agent-openai = ">=0.3.0,<0.4.0"
llama-index-cli = ">=0.3.0,<0.4.0"
llama-index-core = ">=0.11.2,<0.12.0"
llama-index-core = ">=0.11.3,<0.12.0"
llama-index-embeddings-openai = ">=0.2.0,<0.3.0"
llama-index-indices-managed-llama-cloud = ">=0.3.0"
llama-index-legacy = ">=0.9.48,<0.10.0"
@@ -825,13 +825,13 @@ llama-index-llms-openai = ">=0.2.0,<0.3.0"
[[package]]
name = "llama-index-core"
version = "0.11.2"
version = "0.11.3"
description = "Interface between LLMs and your data"
optional = false
python-versions = "<4.0,>=3.8.1"
files = [
{file = "llama_index_core-0.11.2-py3-none-any.whl", hash = "sha256:6c55667c4943ba197199e21e9b0e4641449f5e5dca662b0c91f5306f8c114e4f"},
{file = "llama_index_core-0.11.2.tar.gz", hash = "sha256:eec37976fe3b1baa3bb31bd3c5f6ea821555c7065ac6a55b71b5601a7e097977"},
{file = "llama_index_core-0.11.3-py3-none-any.whl", hash = "sha256:061aaf3892707bff6a34fb264ccda30b9a920890379e04c3ae9895edddde5e22"},
{file = "llama_index_core-0.11.3.tar.gz", hash = "sha256:6b042e531797ccd755496570d563f3b1d9b42c292d3a311f8c21c8eb6aa57706"},
]
[package.dependencies]
@@ -846,7 +846,7 @@ networkx = ">=3.0"
nltk = ">3.8.1"
numpy = "<2.0.0"
pillow = ">=9.0.0"
pydantic = ">=2.0.0,<3.0.0"
pydantic = ">=2.7.0,<3.0.0"
PyYAML = ">=6.0.1"
requests = ">=2.31.0"
SQLAlchemy = {version = ">=1.4.49", extras = ["asyncio"]}
@@ -2261,4 +2261,4 @@ multidict = ">=4.0"
[metadata]
lock-version = "2.0"
python-versions = "^3.11"
content-hash = "6335838be9094e90081212430ea065f2fe7e953904fa59dc4d6a04c1c3684fe1"
content-hash = "a9b74298a247273a330e73a485b9a50bb2f2ba19b55c74f5b437288af60f4d79"
+1 -1
View File
@@ -12,7 +12,7 @@ generate = "app.engine.generate:generate_datasource"
[tool.poetry.dependencies]
python = "^3.11"
llama-index-agent-openai = ">=0.3.0,<0.4.0"
llama-index = "^0.11.2"
llama-index = "^0.11.3"
[tool.poetry.dependencies.docx2txt]
version = "^0.8"