From b6201a8db93b3b3152d58b1c94da3215fe7e4bb9 Mon Sep 17 00:00:00 2001 From: Marcus Schiesser Date: Mon, 2 Sep 2024 13:17:57 +0700 Subject: [PATCH] feat: add bubbling of events --- .env | 2 +- .vscode/launch.json | 16 +++++++ app/core/agent_call.py | 87 +++++++++++++++++++++------------------ app/core/function_call.py | 42 ++++++++++++++++--- app/core/planner_agent.py | 28 +++++++++---- main.py | 13 ++++-- poetry.lock | 18 ++++---- pyproject.toml | 2 +- 8 files changed, 141 insertions(+), 67 deletions(-) create mode 100644 .vscode/launch.json diff --git a/.env b/.env index fd60231..63eb2da 100644 --- a/.env +++ b/.env @@ -5,7 +5,7 @@ MODEL_PROVIDER=openai # The name of LLM model to use. -MODEL=gpt-4o +MODEL=gpt-4o-mini # Name of the embedding model to use. EMBEDDING_MODEL=text-embedding-3-large diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..21fef9a --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,16 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python Debugger: Current File", + "type": "debugpy", + "request": "launch", + "program": "main.py", + "console": "integratedTerminal", + "justMyCode": false + } + ] +} \ No newline at end of file diff --git a/app/core/agent_call.py b/app/core/agent_call.py index 4bd356f..44f99bf 100644 --- a/app/core/agent_call.py +++ b/app/core/agent_call.py @@ -1,13 +1,54 @@ import asyncio from typing import Any, List -from llama_index.core.tools import FunctionTool +from llama_index.core.tools.types import ToolMetadata, ToolOutput +from llama_index.core.tools.utils import create_schema_from_function +from llama_index.core.workflow import Context, Workflow +from app.core.function_call import ( + AgentRunResult, + ContextAwareTool, + FunctionCallingAgent, +) from app.core.planner_agent import StructuredPlannerAgent -from app.core.prefix import PrintPrefix -from app.core.function_call import AgentRunResult, FunctionCallingAgent -import textwrap + +class AgentCallTool(ContextAwareTool): + def __init__(self, agent: Workflow) -> None: + self.agent = agent + # create the schema without the context + name = f"call_{agent.name}" + + async def schema_call(input: str) -> str: + pass + + # create the schema without the Context + fn_schema = create_schema_from_function(name, schema_call) + self._metadata = ToolMetadata( + name=name, + description=( + f"Use this tool to delegate a sub task to the {agent.name} agent." + + (f" The agent is an {agent.role}." if agent.role else "") + ), + fn_schema=fn_schema, + ) + + # overload the acall function with the ctx argument as it's needed for bubbling the events + async def acall(self, ctx: Context, input: str) -> ToolOutput: + # FIXME: reset contexts, not needed after https://github.com/run-llama/llama_index/pull/15776 + self.agent._contexts = set() + task = asyncio.create_task(self.agent.run(input=input)) + # bubble all events while running the agent to the calling agent + async for ev in self.agent.stream_events(): + ctx.write_event_to_stream(ev) + ret: AgentRunResult = await task + response = ret.response.message.content + return ToolOutput( + content=str(response), + tool_name=self.metadata.name, + raw_input={"args": input, "kwargs": {}}, + raw_output=response, + ) class AgentCallingAgent(FunctionCallingAgent): @@ -19,7 +60,7 @@ class AgentCallingAgent(FunctionCallingAgent): **kwargs: Any, ) -> None: agents = agents or [] - tools = [_create_call_workflow_fn(self, agent) for agent in agents] + tools = [AgentCallTool(agent=agent) for agent in agents] super().__init__(*args, name=name, tools=tools, **kwargs) # call add_workflows so agents will get detected by llama agents automatically self.add_workflows(**{agent.name: agent for agent in agents}) @@ -29,12 +70,12 @@ class AgentOrchestrator(StructuredPlannerAgent): def __init__( self, *args: Any, - agents: List[FunctionCallingAgent] | None = None, name: str = "orchestrator", + agents: List[FunctionCallingAgent] | None = None, **kwargs: Any, ) -> None: agents = agents or [] - tools = [_create_call_workflow_fn(self, agent) for agent in agents] + tools = [AgentCallTool(agent=agent) for agent in agents] super().__init__( *args, name=name, @@ -43,35 +84,3 @@ class AgentOrchestrator(StructuredPlannerAgent): ) # call add_workflows so agents will get detected by llama agents automatically self.add_workflows(**{agent.name: agent for agent in agents}) - - -def _create_call_workflow_fn( - caller: FunctionCallingAgent, agent: FunctionCallingAgent -) -> FunctionTool: - def info(prefix: str, text: str) -> None: - truncated = textwrap.shorten(text, width=255, placeholder="...") - print(f"{prefix}: '{truncated}'") - - async def acall_workflow_fn(input: str) -> str: - # info(f"[{caller_name}->{agent.name}]", input) - task = asyncio.create_task(agent.run(input=input)) - # bubble all events while running the agent to the calling agent - if len(caller._sessions) > 1: - print("XXX: Bubbling events only works with single-session agents") - else: - session = next(iter(caller._sessions)) - async for ev in agent.stream_events(): - session.write_event_to_stream(ev) - ret: AgentRunResult = await task - response = ret.response.message.content - # info(f"[{caller_name}<-{agent.name}]", response) - return response - - return FunctionTool.from_defaults( - async_fn=acall_workflow_fn, - name=f"call_{agent.name}", - description=( - f"Use this tool to delegate a sub task to the {agent.name} agent." - + (f" The agent is an {agent.role}." if agent.role else "") - ), - ) diff --git a/app/core/function_call.py b/app/core/function_call.py index e7256ad..9c4384b 100644 --- a/app/core/function_call.py +++ b/app/core/function_call.py @@ -1,3 +1,4 @@ +from abc import abstractmethod from typing import Any, List, Optional from llama_index.core.llms import ChatMessage, ChatResponse @@ -6,6 +7,8 @@ from llama_index.core.memory import ChatMemoryBuffer from llama_index.core.settings import Settings from llama_index.core.tools import ToolOutput, ToolSelection from llama_index.core.tools.types import BaseTool +from llama_index.core.tools import FunctionTool + from llama_index.core.workflow import ( Context, Event, @@ -25,11 +28,30 @@ class ToolCallEvent(Event): tool_calls: list[ToolSelection] +class AgentRunEvent(Event): + name: str + _msg: str + + @property + def msg(self): + return self._msg + + @msg.setter + def msg(self, value): + self._msg = value + + class AgentRunResult(BaseModel): response: ChatResponse sources: list[ToolOutput] +class ContextAwareTool(FunctionTool): + @abstractmethod + async def acall(self, ctx: Context, input: Any) -> ToolOutput: + pass + + class FunctionCallingAgent(Workflow): def __init__( self, @@ -40,6 +62,7 @@ class FunctionCallingAgent(Workflow): verbose: bool = False, timeout: float = 360.0, name: str, + write_events: bool = True, role: Optional[str] = None, **kwargs: Any, ) -> None: @@ -47,6 +70,7 @@ class FunctionCallingAgent(Workflow): self.tools = tools or [] self.name = name self.role = role + self.write_events = write_events if llm is None: llm = Settings.llm @@ -72,9 +96,10 @@ class FunctionCallingAgent(Workflow): user_input = ev.input user_msg = ChatMessage(role="user", content=user_input) self.memory.put(user_msg) - ctx.session.write_event_to_stream( - Event(msg=f"[{self.name}] Start to work on: {user_input}") - ) + if self.write_events: + ctx.write_event_to_stream( + AgentRunEvent(name=self.name, msg=f"Start to work on: {user_input}") + ) # get chat history chat_history = self.memory.get() @@ -96,7 +121,10 @@ class FunctionCallingAgent(Workflow): ) if not tool_calls: - ctx.session.write_event_to_stream(Event(msg=f"[{self.name}] Finished task")) + if self.write_events: + ctx.write_event_to_stream( + AgentRunEvent(name=self.name, msg="Finished task") + ) return StopEvent( result=AgentRunResult(response=response, sources=[*self.sources]) ) @@ -128,7 +156,11 @@ class FunctionCallingAgent(Workflow): continue try: - tool_output = await tool.acall(**tool_call.tool_kwargs) + if isinstance(tool, ContextAwareTool): + # inject context for calling an context aware tool + tool_output = await tool.acall(ctx=ctx, **tool_call.tool_kwargs) + else: + tool_output = await tool.acall(**tool_call.tool_kwargs) self.sources.append(tool_output) tool_msgs.append( ChatMessage( diff --git a/app/core/planner_agent.py b/app/core/planner_agent.py index a1d0333..15a00c0 100644 --- a/app/core/planner_agent.py +++ b/app/core/planner_agent.py @@ -1,3 +1,4 @@ +import asyncio from typing import Any, List from llama_index.core.llms.function_calling import FunctionCallingLLM @@ -10,7 +11,7 @@ from llama_index.core.workflow import ( step, ) -from app.core.function_call import AgentRunResult, FunctionCallingAgent +from app.core.function_call import AgentRunEvent, AgentRunResult, FunctionCallingAgent from app.core.planner import Planner, SubTask, Plan from llama_index.core.tools import BaseTool @@ -35,8 +36,7 @@ class PlanEventType(Enum): REFINED = "refined" -class PlanEvent(Event): - +class PlanEvent(AgentRunEvent): event_type: PlanEventType plan: Plan @@ -68,9 +68,11 @@ class StructuredPlannerAgent(Workflow): name="executor", llm=llm, tools=self.tools, + write_events=False, # it's important to instruct to just return the tool call, otherwise the executor will interpret and change the result system_prompt="You are an expert in completing given tasks by calling the right tool for the task. Just return the result of the tool call. Don't add any information yourself", ) + self.add_workflows(executor=self.executor) @step() async def create_plan( @@ -80,8 +82,8 @@ class StructuredPlannerAgent(Workflow): ctx.data["task"] = ev.input ctx.data["act_plan_id"] = plan_id # inform about the new plan - ctx.session.write_event_to_stream( - PlanEvent(event_type=PlanEventType.CREATED, plan=plan) + ctx.write_event_to_stream( + PlanEvent(name=self.name, event_type=PlanEventType.CREATED, plan=plan) ) if self._verbose: print("=== Executing plan ===\n") @@ -97,7 +99,7 @@ class StructuredPlannerAgent(Workflow): # send an event per sub task events = [SubTaskEvent(sub_task=sub_task) for sub_task in upcoming_sub_tasks] for event in events: - ctx.session.send_event(event) + ctx.send_event(event) return None @@ -107,7 +109,13 @@ class StructuredPlannerAgent(Workflow): ) -> SubTaskResultEvent: if self._verbose: print(f"=== Executing sub task: {ev.sub_task.name} ===") - result: AgentRunResult = await self.executor.run(input=ev.sub_task.input) + # FIXME: reset contexts, not needed after https://github.com/run-llama/llama_index/pull/15776 + self.executor._contexts = set() + task = asyncio.create_task(self.executor.run(input=ev.sub_task.input)) + # bubble all events while running the executor to the planner + async for event in self.executor.stream_events(): + ctx.write_event_to_stream(event) + result: AgentRunResult = await task if self._verbose: print("=== Done executing sub task ===\n") self.planner.state.add_completed_sub_task(ctx.data["act_plan_id"], ev.sub_task) @@ -141,8 +149,10 @@ class StructuredPlannerAgent(Workflow): ) # inform about the new plan if new_plan is not None: - ctx.session.write_event_to_stream( - PlanEvent(event_type=PlanEventType.REFINED, plan=new_plan) + ctx.write_event_to_stream( + PlanEvent( + name=self.name, event_type=PlanEventType.REFINED, plan=new_plan + ) ) # continue executing plan diff --git a/main.py b/main.py index f7c5677..39f9a84 100644 --- a/main.py +++ b/main.py @@ -1,9 +1,10 @@ # flake8: noqa: E402 import asyncio import os +import textwrap from dotenv import load_dotenv from app.core.agent_call import AgentCallingAgent, AgentOrchestrator -from app.core.function_call import AgentRunResult, FunctionCallingAgent +from app.core.function_call import AgentRunEvent, AgentRunResult, FunctionCallingAgent from app.engine.index import get_index from app.settings import init_settings from llama_index.core.tools import QueryEngineTool, ToolMetadata @@ -80,6 +81,11 @@ def create_orchestrator(): ) +def info(prefix: str, text: str) -> None: + truncated = textwrap.shorten(text, width=255, placeholder="...") + print(f"[{prefix}] {truncated}") + + async def main(): # agent = create_choreography() agent = create_orchestrator() @@ -88,10 +94,11 @@ async def main(): ) async for ev in agent.stream_events(): - print(ev.msg) + if isinstance(ev, AgentRunEvent): + info(ev.name, ev.msg) ret: AgentRunResult = await task - print(ret.response.message.content) + print(f"\n\nResult:\n\n{ret.response.message.content}") if __name__ == "__main__": diff --git a/poetry.lock b/poetry.lock index 92ae1cc..22872f1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -767,19 +767,19 @@ pydantic = ">=1.10" [[package]] name = "llama-index" -version = "0.11.2" +version = "0.11.3" description = "Interface between LLMs and your data" optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "llama_index-0.11.2-py3-none-any.whl", hash = "sha256:3e70d09a48d8aaf479679c3de0598fe7b3276613a6927a5612fcafb2ecef60f0"}, - {file = "llama_index-0.11.2.tar.gz", hash = "sha256:8430b589e372c2b1614da259c4a8e4c2790d9278cd82f3a3b9e19972e8c2d834"}, + {file = "llama_index-0.11.3-py3-none-any.whl", hash = "sha256:f307a29b07536bca26cc39f955ec767d10c25becf167d409075bfaba9e2c654f"}, + {file = "llama_index-0.11.3.tar.gz", hash = "sha256:4133b06931b3dc15b9a66bf127a06e2dbb19fe991e10afe58b7ef44a14085d29"}, ] [package.dependencies] llama-index-agent-openai = ">=0.3.0,<0.4.0" llama-index-cli = ">=0.3.0,<0.4.0" -llama-index-core = ">=0.11.2,<0.12.0" +llama-index-core = ">=0.11.3,<0.12.0" llama-index-embeddings-openai = ">=0.2.0,<0.3.0" llama-index-indices-managed-llama-cloud = ">=0.3.0" llama-index-legacy = ">=0.9.48,<0.10.0" @@ -825,13 +825,13 @@ llama-index-llms-openai = ">=0.2.0,<0.3.0" [[package]] name = "llama-index-core" -version = "0.11.2" +version = "0.11.3" description = "Interface between LLMs and your data" optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "llama_index_core-0.11.2-py3-none-any.whl", hash = "sha256:6c55667c4943ba197199e21e9b0e4641449f5e5dca662b0c91f5306f8c114e4f"}, - {file = "llama_index_core-0.11.2.tar.gz", hash = "sha256:eec37976fe3b1baa3bb31bd3c5f6ea821555c7065ac6a55b71b5601a7e097977"}, + {file = "llama_index_core-0.11.3-py3-none-any.whl", hash = "sha256:061aaf3892707bff6a34fb264ccda30b9a920890379e04c3ae9895edddde5e22"}, + {file = "llama_index_core-0.11.3.tar.gz", hash = "sha256:6b042e531797ccd755496570d563f3b1d9b42c292d3a311f8c21c8eb6aa57706"}, ] [package.dependencies] @@ -846,7 +846,7 @@ networkx = ">=3.0" nltk = ">3.8.1" numpy = "<2.0.0" pillow = ">=9.0.0" -pydantic = ">=2.0.0,<3.0.0" +pydantic = ">=2.7.0,<3.0.0" PyYAML = ">=6.0.1" requests = ">=2.31.0" SQLAlchemy = {version = ">=1.4.49", extras = ["asyncio"]} @@ -2261,4 +2261,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "6335838be9094e90081212430ea065f2fe7e953904fa59dc4d6a04c1c3684fe1" +content-hash = "a9b74298a247273a330e73a485b9a50bb2f2ba19b55c74f5b437288af60f4d79" diff --git a/pyproject.toml b/pyproject.toml index 00b9ac5..70b7b55 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ generate = "app.engine.generate:generate_datasource" [tool.poetry.dependencies] python = "^3.11" llama-index-agent-openai = ">=0.3.0,<0.4.0" -llama-index = "^0.11.2" +llama-index = "^0.11.3" [tool.poetry.dependencies.docx2txt] version = "^0.8"