mirror of
https://github.com/run-llama/multi-agents-workflow.git
synced 2026-06-30 21:27:55 -04:00
feat: add blog post multi-agent
This commit is contained in:
@@ -0,0 +1,2 @@
|
||||
**/__pycache__
|
||||
storage
|
||||
@@ -1,33 +0,0 @@
|
||||
from llama_agents import AgentService, SimpleMessageQueue
|
||||
from llama_index.core.agent import FunctionCallingAgentWorker
|
||||
from llama_index.core.tools import FunctionTool
|
||||
from llama_index.core.settings import Settings
|
||||
from app.utils import load_from_env
|
||||
|
||||
|
||||
DEFAULT_DUMMY_AGENT_DESCRIPTION = "I'm a dummy agent which does nothing."
|
||||
|
||||
|
||||
def dummy_function():
|
||||
"""
|
||||
This function does nothing.
|
||||
"""
|
||||
return ""
|
||||
|
||||
|
||||
def init_dummy_agent(message_queue: SimpleMessageQueue) -> AgentService:
|
||||
agent = FunctionCallingAgentWorker(
|
||||
tools=[FunctionTool.from_defaults(fn=dummy_function)],
|
||||
llm=Settings.llm,
|
||||
prefix_messages=[],
|
||||
).as_agent()
|
||||
|
||||
return AgentService(
|
||||
service_name="dummy_agent",
|
||||
agent=agent,
|
||||
message_queue=message_queue.client,
|
||||
description=load_from_env("AGENT_DUMMY_DESCRIPTION", throw_error=False)
|
||||
or DEFAULT_DUMMY_AGENT_DESCRIPTION,
|
||||
host=load_from_env("AGENT_DUMMY_HOST", throw_error=False) or "127.0.0.1",
|
||||
port=int(load_from_env("AGENT_DUMMY_PORT")),
|
||||
)
|
||||
@@ -1,52 +0,0 @@
|
||||
import os
|
||||
from llama_agents import AgentService, SimpleMessageQueue
|
||||
from llama_index.core.agent import FunctionCallingAgentWorker
|
||||
from llama_index.core.tools import QueryEngineTool, ToolMetadata
|
||||
from llama_index.core.settings import Settings
|
||||
from app.engine.index import get_index
|
||||
from app.utils import load_from_env
|
||||
|
||||
|
||||
DEFAULT_QUERY_ENGINE_AGENT_DESCRIPTION = (
|
||||
"Used to answer the questions using the provided context data."
|
||||
)
|
||||
|
||||
|
||||
def get_query_engine_tool() -> QueryEngineTool:
|
||||
"""
|
||||
Provide an agent worker that can be used to query the index.
|
||||
"""
|
||||
index = get_index()
|
||||
if index is None:
|
||||
raise ValueError("Index not found. Please create an index first.")
|
||||
query_engine = index.as_query_engine(similarity_top_k=int(os.getenv("TOP_K", 3)))
|
||||
return QueryEngineTool(
|
||||
query_engine=query_engine,
|
||||
metadata=ToolMetadata(
|
||||
name="context_data",
|
||||
description="""
|
||||
Provide the provided context information.
|
||||
Use a detailed plain text question as input to the tool.
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def init_query_engine_agent(
|
||||
message_queue: SimpleMessageQueue,
|
||||
) -> AgentService:
|
||||
"""
|
||||
Initialize the agent service.
|
||||
"""
|
||||
agent = FunctionCallingAgentWorker(
|
||||
tools=[get_query_engine_tool()], llm=Settings.llm, prefix_messages=[]
|
||||
).as_agent()
|
||||
return AgentService(
|
||||
service_name="context_query_agent",
|
||||
agent=agent,
|
||||
message_queue=message_queue.client,
|
||||
description=load_from_env("AGENT_QUERY_ENGINE_DESCRIPTION", throw_error=False)
|
||||
or DEFAULT_QUERY_ENGINE_AGENT_DESCRIPTION,
|
||||
host=load_from_env("AGENT_QUERY_ENGINE_HOST", throw_error=False) or "127.0.0.1",
|
||||
port=int(load_from_env("AGENT_QUERY_ENGINE_PORT")),
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
DATA_DIR = "data"
|
||||
@@ -0,0 +1,38 @@
|
||||
from typing import Any, List
|
||||
|
||||
from llama_index.core.tools import FunctionTool
|
||||
from llama_index.core.workflow import Workflow
|
||||
|
||||
from app.core.prefix import PrintPrefix
|
||||
from app.core.function_call import FunctionCallingAgent
|
||||
|
||||
|
||||
def create_call_workflow_fn(agent: Workflow) -> FunctionTool:
|
||||
def call_workflow_fn(input: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
async def acall_workflow_fn(input: str) -> str:
|
||||
print(f"Calling agent {agent.name} with input: '{input}'")
|
||||
with PrintPrefix(f"[{agent.name}]"):
|
||||
ret = await agent.run(input=input)
|
||||
print(f"Finished calling agent {agent.name}")
|
||||
return ret["response"]
|
||||
|
||||
return FunctionTool.from_defaults(
|
||||
fn=call_workflow_fn, # not necessary with https://github.com/run-llama/llama_index/pull/15638/files
|
||||
async_fn=acall_workflow_fn,
|
||||
name=f"call_{agent.name}",
|
||||
description=f"Use this tool to delegate a task to the agent {agent.name}",
|
||||
)
|
||||
|
||||
|
||||
class AgentCallingAgent(FunctionCallingAgent):
|
||||
def __init__(
|
||||
self,
|
||||
*args: Any,
|
||||
agents: List[Workflow] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
agents = agents or []
|
||||
tools = [create_call_workflow_fn(agent) for agent in agents]
|
||||
super().__init__(*args, tools=tools, **kwargs)
|
||||
@@ -1,19 +0,0 @@
|
||||
from llama_index.llms.openai import OpenAI
|
||||
from llama_agents import AgentOrchestrator, ControlPlaneServer
|
||||
from app.core.message_queue import message_queue
|
||||
from app.utils import load_from_env
|
||||
|
||||
|
||||
control_plane_host = (
|
||||
load_from_env("CONTROL_PLANE_HOST", throw_error=False) or "127.0.0.1"
|
||||
)
|
||||
control_plane_port = load_from_env("CONTROL_PLANE_PORT", throw_error=False) or "8001"
|
||||
|
||||
|
||||
# setup control plane
|
||||
control_plane = ControlPlaneServer(
|
||||
message_queue=message_queue,
|
||||
orchestrator=AgentOrchestrator(llm=OpenAI()),
|
||||
host=control_plane_host,
|
||||
port=int(control_plane_port) if control_plane_port else None,
|
||||
)
|
||||
@@ -0,0 +1,136 @@
|
||||
from typing import Any, List
|
||||
|
||||
from llama_index.core.llms.function_calling import FunctionCallingLLM
|
||||
from llama_index.core.memory import ChatMemoryBuffer
|
||||
from llama_index.core.tools.types import BaseTool
|
||||
from llama_index.core.workflow import Workflow, StartEvent, StopEvent, step
|
||||
|
||||
from llama_index.core.llms import ChatMessage
|
||||
from llama_index.core.tools import ToolSelection, ToolOutput
|
||||
from llama_index.core.workflow import Event
|
||||
from llama_index.core.settings import Settings
|
||||
|
||||
|
||||
class InputEvent(Event):
|
||||
input: list[ChatMessage]
|
||||
|
||||
|
||||
class ToolCallEvent(Event):
|
||||
tool_calls: list[ToolSelection]
|
||||
|
||||
|
||||
class FunctionOutputEvent(Event):
|
||||
output: ToolOutput
|
||||
|
||||
|
||||
class FunctionCallingAgent(Workflow):
|
||||
def __init__(
|
||||
self,
|
||||
*args: Any,
|
||||
llm: FunctionCallingLLM | None = None,
|
||||
tools: List[BaseTool] | None = None,
|
||||
system_prompt: str | None = None,
|
||||
verbose: bool = True,
|
||||
timeout: float = 120.0,
|
||||
name: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(*args, verbose=verbose, timeout=timeout, **kwargs)
|
||||
self.tools = tools or []
|
||||
self.name = name
|
||||
|
||||
if llm is None:
|
||||
llm = Settings.llm
|
||||
self.llm = llm
|
||||
assert self.llm.metadata.is_function_calling_model
|
||||
|
||||
self.system_prompt = system_prompt
|
||||
|
||||
self.memory = ChatMemoryBuffer.from_defaults(llm=self.llm)
|
||||
self.sources = []
|
||||
|
||||
@step()
|
||||
async def prepare_chat_history(self, ev: StartEvent) -> InputEvent:
|
||||
# clear sources
|
||||
self.sources = []
|
||||
|
||||
# set system prompt
|
||||
if self.system_prompt is not None:
|
||||
system_msg = ChatMessage(role="system", content=self.system_prompt)
|
||||
self.memory.put(system_msg)
|
||||
|
||||
# get user input
|
||||
user_input = ev.input
|
||||
user_msg = ChatMessage(role="user", content=user_input)
|
||||
self.memory.put(user_msg)
|
||||
|
||||
# get chat history
|
||||
chat_history = self.memory.get()
|
||||
return InputEvent(input=chat_history)
|
||||
|
||||
@step()
|
||||
async def handle_llm_input(self, ev: InputEvent) -> ToolCallEvent | StopEvent:
|
||||
chat_history = ev.input
|
||||
|
||||
response = await self.llm.achat_with_tools(
|
||||
self.tools, chat_history=chat_history
|
||||
)
|
||||
self.memory.put(response.message)
|
||||
|
||||
tool_calls = self.llm.get_tool_calls_from_response(
|
||||
response, error_on_no_tool_call=False
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
return StopEvent(result={"response": response, "sources": [*self.sources]})
|
||||
else:
|
||||
return ToolCallEvent(tool_calls=tool_calls)
|
||||
|
||||
@step()
|
||||
async def handle_tool_calls(self, ev: ToolCallEvent) -> InputEvent:
|
||||
tool_calls = ev.tool_calls
|
||||
tools_by_name = {tool.metadata.get_name(): tool for tool in self.tools}
|
||||
|
||||
tool_msgs = []
|
||||
|
||||
# call tools -- safely!
|
||||
for tool_call in tool_calls:
|
||||
tool = tools_by_name.get(tool_call.tool_name)
|
||||
additional_kwargs = {
|
||||
"tool_call_id": tool_call.tool_id,
|
||||
"name": tool.metadata.get_name(),
|
||||
}
|
||||
if not tool:
|
||||
tool_msgs.append(
|
||||
ChatMessage(
|
||||
role="tool",
|
||||
content=f"Tool {tool_call.tool_name} does not exist",
|
||||
additional_kwargs=additional_kwargs,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
tool_output = await tool.acall(**tool_call.tool_kwargs)
|
||||
self.sources.append(tool_output)
|
||||
tool_msgs.append(
|
||||
ChatMessage(
|
||||
role="tool",
|
||||
content=tool_output.content,
|
||||
additional_kwargs=additional_kwargs,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
tool_msgs.append(
|
||||
ChatMessage(
|
||||
role="tool",
|
||||
content=f"Encountered error in tool call: {e}",
|
||||
additional_kwargs=additional_kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
for msg in tool_msgs:
|
||||
self.memory.put(msg)
|
||||
|
||||
chat_history = self.memory.get()
|
||||
return InputEvent(input=chat_history)
|
||||
@@ -1,12 +0,0 @@
|
||||
from llama_agents import SimpleMessageQueue
|
||||
from app.utils import load_from_env
|
||||
|
||||
message_queue_host = (
|
||||
load_from_env("MESSAGE_QUEUE_HOST", throw_error=False) or "127.0.0.1"
|
||||
)
|
||||
message_queue_port = load_from_env("MESSAGE_QUEUE_PORT", throw_error=False) or "8000"
|
||||
|
||||
message_queue = SimpleMessageQueue(
|
||||
host=message_queue_host,
|
||||
port=int(message_queue_port) if message_queue_port else None,
|
||||
)
|
||||
@@ -0,0 +1,16 @@
|
||||
import builtins
|
||||
|
||||
|
||||
class PrintPrefix:
|
||||
def __init__(self, prefix):
|
||||
self.prefix = prefix
|
||||
self.original_print = builtins.print
|
||||
|
||||
def __enter__(self):
|
||||
builtins.print = self._print_with_prefix
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
builtins.print = self.original_print
|
||||
|
||||
def _print_with_prefix(self, *args, **kwargs):
|
||||
self.original_print(self.prefix, *args, **kwargs)
|
||||
@@ -1,88 +0,0 @@
|
||||
import json
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from fastapi import FastAPI
|
||||
from typing import Dict, Optional
|
||||
from llama_agents import CallableMessageConsumer, QueueMessage
|
||||
from llama_agents.message_queues.base import BaseMessageQueue
|
||||
from llama_agents.message_consumers.base import BaseMessageQueueConsumer
|
||||
from llama_agents.message_consumers.remote import RemoteMessageConsumer
|
||||
from app.utils import load_from_env
|
||||
from app.core.message_queue import message_queue
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class TaskResultService:
|
||||
def __init__(
|
||||
self,
|
||||
message_queue: BaseMessageQueue,
|
||||
name: str = "human",
|
||||
host: str = "127.0.0.1",
|
||||
port: Optional[int] = 8002,
|
||||
) -> None:
|
||||
self.name = name
|
||||
self.host = host
|
||||
self.port = port
|
||||
|
||||
self._message_queue = message_queue
|
||||
|
||||
# app
|
||||
self._app = FastAPI()
|
||||
self._app.add_api_route(
|
||||
"/", self.home, methods=["GET"], tags=["Human Consumer"]
|
||||
)
|
||||
self._app.add_api_route(
|
||||
"/process_message",
|
||||
self.process_message,
|
||||
methods=["POST"],
|
||||
tags=["Human Consumer"],
|
||||
)
|
||||
|
||||
@property
|
||||
def message_queue(self) -> BaseMessageQueue:
|
||||
return self._message_queue
|
||||
|
||||
def as_consumer(self, remote: bool = False) -> BaseMessageQueueConsumer:
|
||||
if remote:
|
||||
return RemoteMessageConsumer(
|
||||
url=(
|
||||
f"http://{self.host}:{self.port}/process_message"
|
||||
if self.port
|
||||
else f"http://{self.host}/process_message"
|
||||
),
|
||||
message_type=self.name,
|
||||
)
|
||||
|
||||
return CallableMessageConsumer(
|
||||
message_type=self.name,
|
||||
handler=self.process_message,
|
||||
)
|
||||
|
||||
async def process_message(self, message: QueueMessage) -> None:
|
||||
Path("task_results").mkdir(exist_ok=True)
|
||||
with open("task_results/task_results.json", "+a") as f:
|
||||
json.dump(message.model_dump(), f)
|
||||
f.write("\n")
|
||||
|
||||
async def home(self) -> Dict[str, str]:
|
||||
return {"message": "hello, human."}
|
||||
|
||||
async def register_to_message_queue(self) -> None:
|
||||
"""Register to the message queue."""
|
||||
await self.message_queue.register_consumer(self.as_consumer(remote=True))
|
||||
|
||||
|
||||
human_consumer_host = (
|
||||
load_from_env("HUMAN_CONSUMER_HOST", throw_error=False) or "127.0.0.1"
|
||||
)
|
||||
human_consumer_port = load_from_env("HUMAN_CONSUMER_PORT", throw_error=False) or "8002"
|
||||
|
||||
|
||||
human_consumer_server = TaskResultService(
|
||||
message_queue=message_queue,
|
||||
host=human_consumer_host,
|
||||
port=int(human_consumer_port) if human_consumer_port else None,
|
||||
name="human",
|
||||
)
|
||||
@@ -1,28 +1,36 @@
|
||||
# flake8: noqa: E402
|
||||
import asyncio
|
||||
from dotenv import load_dotenv
|
||||
from app.agents.researcher.agent import get_query_engine_tool
|
||||
from app.core.agent_call import AgentCallingAgent
|
||||
from app.core.function_call import FunctionCallingAgent
|
||||
from app.settings import init_settings
|
||||
|
||||
|
||||
load_dotenv()
|
||||
init_settings()
|
||||
|
||||
from llama_agents import ServerLauncher
|
||||
from app.core.message_queue import message_queue
|
||||
from app.core.control_plane import control_plane
|
||||
from app.core.task_result import human_consumer_server
|
||||
from app.agents.query_engine.agent import init_query_engine_agent
|
||||
from app.agents.dummy.agent import init_dummy_agent
|
||||
|
||||
agents = [
|
||||
init_query_engine_agent(message_queue),
|
||||
init_dummy_agent(message_queue),
|
||||
]
|
||||
async def main():
|
||||
researcher = FunctionCallingAgent(
|
||||
name="researcher",
|
||||
tools=[get_query_engine_tool()],
|
||||
system_prompt="You are a researcher agent. You are given a researching task. You must use your tools to complete the research.",
|
||||
)
|
||||
reviewer = AgentCallingAgent(
|
||||
name="reviewer",
|
||||
system_prompt="You are an expert in reviewing blog posts. You are given a task to write a blog post. Before starting to write the post, consult the researcher agent to get the information you need. Don't make up any information yourself.",
|
||||
)
|
||||
writer = AgentCallingAgent(
|
||||
name="writer",
|
||||
agents=[researcher, reviewer],
|
||||
system_prompt="""You are an expert in writing blog posts. You are given a task to write a blog post. Before starting to write the post, consult the researcher agent to get the information you need. Don't make up any information yourself.
|
||||
After creating a draft for the post, send it to the reviewer agent to receive some feedback and make sure to incorporate the feedback from the reviewer.
|
||||
You can consult the reviewer and researcher multiple times. Only finish the task once the reviewer is satisfied.""",
|
||||
)
|
||||
ret = await writer.run(input="Write a blog post about letter standards")
|
||||
print(ret["response"])
|
||||
|
||||
launcher = ServerLauncher(
|
||||
agents,
|
||||
control_plane,
|
||||
message_queue,
|
||||
additional_consumers=[human_consumer_server.as_consumer()],
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
launcher.launch_servers()
|
||||
asyncio.run(main())
|
||||
|
||||
Generated
+2264
File diff suppressed because it is too large
Load Diff
+5
-9
@@ -3,7 +3,7 @@
|
||||
name = "app"
|
||||
version = "0.1.0"
|
||||
description = ""
|
||||
authors = [ "Marcus Schiesser <mail@marcusschiesser.de>" ]
|
||||
authors = ["Marcus Schiesser <mail@marcusschiesser.de>"]
|
||||
readme = "README.md"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
@@ -11,16 +11,12 @@ generate = "app.engine.generate:generate_datasource"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.11"
|
||||
llama-agents = "^0.0.3"
|
||||
llama-index-embeddings-openai = "^0.1.10"
|
||||
llama-index-llms-openai = "^0.1.23"
|
||||
|
||||
[tool.poetry.dependencies.llama-index-agent-openai]
|
||||
version = "0.2.6"
|
||||
llama-index-agent-openai = ">=0.3.0,<0.4.0"
|
||||
llama-index = "^0.11.0"
|
||||
|
||||
[tool.poetry.dependencies.docx2txt]
|
||||
version = "^0.8"
|
||||
|
||||
[build-system]
|
||||
requires = [ "poetry-core" ]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
Reference in New Issue
Block a user