feat: add blog post multi-agent

This commit is contained in:
Marcus Schiesser
2024-08-26 16:49:40 +07:00
parent 148b6c7dba
commit 61d14631d0
13 changed files with 2487 additions and 230 deletions
+2
View File
@@ -0,0 +1,2 @@
**/__pycache__
storage
-33
View File
@@ -1,33 +0,0 @@
from llama_agents import AgentService, SimpleMessageQueue
from llama_index.core.agent import FunctionCallingAgentWorker
from llama_index.core.tools import FunctionTool
from llama_index.core.settings import Settings
from app.utils import load_from_env
DEFAULT_DUMMY_AGENT_DESCRIPTION = "I'm a dummy agent which does nothing."
def dummy_function():
"""
This function does nothing.
"""
return ""
def init_dummy_agent(message_queue: SimpleMessageQueue) -> AgentService:
agent = FunctionCallingAgentWorker(
tools=[FunctionTool.from_defaults(fn=dummy_function)],
llm=Settings.llm,
prefix_messages=[],
).as_agent()
return AgentService(
service_name="dummy_agent",
agent=agent,
message_queue=message_queue.client,
description=load_from_env("AGENT_DUMMY_DESCRIPTION", throw_error=False)
or DEFAULT_DUMMY_AGENT_DESCRIPTION,
host=load_from_env("AGENT_DUMMY_HOST", throw_error=False) or "127.0.0.1",
port=int(load_from_env("AGENT_DUMMY_PORT")),
)
-52
View File
@@ -1,52 +0,0 @@
import os
from llama_agents import AgentService, SimpleMessageQueue
from llama_index.core.agent import FunctionCallingAgentWorker
from llama_index.core.tools import QueryEngineTool, ToolMetadata
from llama_index.core.settings import Settings
from app.engine.index import get_index
from app.utils import load_from_env
DEFAULT_QUERY_ENGINE_AGENT_DESCRIPTION = (
"Used to answer the questions using the provided context data."
)
def get_query_engine_tool() -> QueryEngineTool:
"""
Provide an agent worker that can be used to query the index.
"""
index = get_index()
if index is None:
raise ValueError("Index not found. Please create an index first.")
query_engine = index.as_query_engine(similarity_top_k=int(os.getenv("TOP_K", 3)))
return QueryEngineTool(
query_engine=query_engine,
metadata=ToolMetadata(
name="context_data",
description="""
Provide the provided context information.
Use a detailed plain text question as input to the tool.
""",
),
)
def init_query_engine_agent(
message_queue: SimpleMessageQueue,
) -> AgentService:
"""
Initialize the agent service.
"""
agent = FunctionCallingAgentWorker(
tools=[get_query_engine_tool()], llm=Settings.llm, prefix_messages=[]
).as_agent()
return AgentService(
service_name="context_query_agent",
agent=agent,
message_queue=message_queue.client,
description=load_from_env("AGENT_QUERY_ENGINE_DESCRIPTION", throw_error=False)
or DEFAULT_QUERY_ENGINE_AGENT_DESCRIPTION,
host=load_from_env("AGENT_QUERY_ENGINE_HOST", throw_error=False) or "127.0.0.1",
port=int(load_from_env("AGENT_QUERY_ENGINE_PORT")),
)
+1
View File
@@ -0,0 +1 @@
DATA_DIR = "data"
+38
View File
@@ -0,0 +1,38 @@
from typing import Any, List
from llama_index.core.tools import FunctionTool
from llama_index.core.workflow import Workflow
from app.core.prefix import PrintPrefix
from app.core.function_call import FunctionCallingAgent
def create_call_workflow_fn(agent: Workflow) -> FunctionTool:
def call_workflow_fn(input: str) -> str:
raise NotImplementedError
async def acall_workflow_fn(input: str) -> str:
print(f"Calling agent {agent.name} with input: '{input}'")
with PrintPrefix(f"[{agent.name}]"):
ret = await agent.run(input=input)
print(f"Finished calling agent {agent.name}")
return ret["response"]
return FunctionTool.from_defaults(
fn=call_workflow_fn, # not necessary with https://github.com/run-llama/llama_index/pull/15638/files
async_fn=acall_workflow_fn,
name=f"call_{agent.name}",
description=f"Use this tool to delegate a task to the agent {agent.name}",
)
class AgentCallingAgent(FunctionCallingAgent):
def __init__(
self,
*args: Any,
agents: List[Workflow] | None = None,
**kwargs: Any,
) -> None:
agents = agents or []
tools = [create_call_workflow_fn(agent) for agent in agents]
super().__init__(*args, tools=tools, **kwargs)
-19
View File
@@ -1,19 +0,0 @@
from llama_index.llms.openai import OpenAI
from llama_agents import AgentOrchestrator, ControlPlaneServer
from app.core.message_queue import message_queue
from app.utils import load_from_env
control_plane_host = (
load_from_env("CONTROL_PLANE_HOST", throw_error=False) or "127.0.0.1"
)
control_plane_port = load_from_env("CONTROL_PLANE_PORT", throw_error=False) or "8001"
# setup control plane
control_plane = ControlPlaneServer(
message_queue=message_queue,
orchestrator=AgentOrchestrator(llm=OpenAI()),
host=control_plane_host,
port=int(control_plane_port) if control_plane_port else None,
)
+136
View File
@@ -0,0 +1,136 @@
from typing import Any, List
from llama_index.core.llms.function_calling import FunctionCallingLLM
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.tools.types import BaseTool
from llama_index.core.workflow import Workflow, StartEvent, StopEvent, step
from llama_index.core.llms import ChatMessage
from llama_index.core.tools import ToolSelection, ToolOutput
from llama_index.core.workflow import Event
from llama_index.core.settings import Settings
class InputEvent(Event):
input: list[ChatMessage]
class ToolCallEvent(Event):
tool_calls: list[ToolSelection]
class FunctionOutputEvent(Event):
output: ToolOutput
class FunctionCallingAgent(Workflow):
def __init__(
self,
*args: Any,
llm: FunctionCallingLLM | None = None,
tools: List[BaseTool] | None = None,
system_prompt: str | None = None,
verbose: bool = True,
timeout: float = 120.0,
name: str,
**kwargs: Any,
) -> None:
super().__init__(*args, verbose=verbose, timeout=timeout, **kwargs)
self.tools = tools or []
self.name = name
if llm is None:
llm = Settings.llm
self.llm = llm
assert self.llm.metadata.is_function_calling_model
self.system_prompt = system_prompt
self.memory = ChatMemoryBuffer.from_defaults(llm=self.llm)
self.sources = []
@step()
async def prepare_chat_history(self, ev: StartEvent) -> InputEvent:
# clear sources
self.sources = []
# set system prompt
if self.system_prompt is not None:
system_msg = ChatMessage(role="system", content=self.system_prompt)
self.memory.put(system_msg)
# get user input
user_input = ev.input
user_msg = ChatMessage(role="user", content=user_input)
self.memory.put(user_msg)
# get chat history
chat_history = self.memory.get()
return InputEvent(input=chat_history)
@step()
async def handle_llm_input(self, ev: InputEvent) -> ToolCallEvent | StopEvent:
chat_history = ev.input
response = await self.llm.achat_with_tools(
self.tools, chat_history=chat_history
)
self.memory.put(response.message)
tool_calls = self.llm.get_tool_calls_from_response(
response, error_on_no_tool_call=False
)
if not tool_calls:
return StopEvent(result={"response": response, "sources": [*self.sources]})
else:
return ToolCallEvent(tool_calls=tool_calls)
@step()
async def handle_tool_calls(self, ev: ToolCallEvent) -> InputEvent:
tool_calls = ev.tool_calls
tools_by_name = {tool.metadata.get_name(): tool for tool in self.tools}
tool_msgs = []
# call tools -- safely!
for tool_call in tool_calls:
tool = tools_by_name.get(tool_call.tool_name)
additional_kwargs = {
"tool_call_id": tool_call.tool_id,
"name": tool.metadata.get_name(),
}
if not tool:
tool_msgs.append(
ChatMessage(
role="tool",
content=f"Tool {tool_call.tool_name} does not exist",
additional_kwargs=additional_kwargs,
)
)
continue
try:
tool_output = await tool.acall(**tool_call.tool_kwargs)
self.sources.append(tool_output)
tool_msgs.append(
ChatMessage(
role="tool",
content=tool_output.content,
additional_kwargs=additional_kwargs,
)
)
except Exception as e:
tool_msgs.append(
ChatMessage(
role="tool",
content=f"Encountered error in tool call: {e}",
additional_kwargs=additional_kwargs,
)
)
for msg in tool_msgs:
self.memory.put(msg)
chat_history = self.memory.get()
return InputEvent(input=chat_history)
-12
View File
@@ -1,12 +0,0 @@
from llama_agents import SimpleMessageQueue
from app.utils import load_from_env
message_queue_host = (
load_from_env("MESSAGE_QUEUE_HOST", throw_error=False) or "127.0.0.1"
)
message_queue_port = load_from_env("MESSAGE_QUEUE_PORT", throw_error=False) or "8000"
message_queue = SimpleMessageQueue(
host=message_queue_host,
port=int(message_queue_port) if message_queue_port else None,
)
+16
View File
@@ -0,0 +1,16 @@
import builtins
class PrintPrefix:
def __init__(self, prefix):
self.prefix = prefix
self.original_print = builtins.print
def __enter__(self):
builtins.print = self._print_with_prefix
def __exit__(self, exc_type, exc_val, exc_tb):
builtins.print = self.original_print
def _print_with_prefix(self, *args, **kwargs):
self.original_print(self.prefix, *args, **kwargs)
-88
View File
@@ -1,88 +0,0 @@
import json
from logging import getLogger
from pathlib import Path
from fastapi import FastAPI
from typing import Dict, Optional
from llama_agents import CallableMessageConsumer, QueueMessage
from llama_agents.message_queues.base import BaseMessageQueue
from llama_agents.message_consumers.base import BaseMessageQueueConsumer
from llama_agents.message_consumers.remote import RemoteMessageConsumer
from app.utils import load_from_env
from app.core.message_queue import message_queue
logger = getLogger(__name__)
class TaskResultService:
def __init__(
self,
message_queue: BaseMessageQueue,
name: str = "human",
host: str = "127.0.0.1",
port: Optional[int] = 8002,
) -> None:
self.name = name
self.host = host
self.port = port
self._message_queue = message_queue
# app
self._app = FastAPI()
self._app.add_api_route(
"/", self.home, methods=["GET"], tags=["Human Consumer"]
)
self._app.add_api_route(
"/process_message",
self.process_message,
methods=["POST"],
tags=["Human Consumer"],
)
@property
def message_queue(self) -> BaseMessageQueue:
return self._message_queue
def as_consumer(self, remote: bool = False) -> BaseMessageQueueConsumer:
if remote:
return RemoteMessageConsumer(
url=(
f"http://{self.host}:{self.port}/process_message"
if self.port
else f"http://{self.host}/process_message"
),
message_type=self.name,
)
return CallableMessageConsumer(
message_type=self.name,
handler=self.process_message,
)
async def process_message(self, message: QueueMessage) -> None:
Path("task_results").mkdir(exist_ok=True)
with open("task_results/task_results.json", "+a") as f:
json.dump(message.model_dump(), f)
f.write("\n")
async def home(self) -> Dict[str, str]:
return {"message": "hello, human."}
async def register_to_message_queue(self) -> None:
"""Register to the message queue."""
await self.message_queue.register_consumer(self.as_consumer(remote=True))
human_consumer_host = (
load_from_env("HUMAN_CONSUMER_HOST", throw_error=False) or "127.0.0.1"
)
human_consumer_port = load_from_env("HUMAN_CONSUMER_PORT", throw_error=False) or "8002"
human_consumer_server = TaskResultService(
message_queue=message_queue,
host=human_consumer_host,
port=int(human_consumer_port) if human_consumer_port else None,
name="human",
)
+25 -17
View File
@@ -1,28 +1,36 @@
# flake8: noqa: E402
import asyncio
from dotenv import load_dotenv
from app.agents.researcher.agent import get_query_engine_tool
from app.core.agent_call import AgentCallingAgent
from app.core.function_call import FunctionCallingAgent
from app.settings import init_settings
load_dotenv()
init_settings()
from llama_agents import ServerLauncher
from app.core.message_queue import message_queue
from app.core.control_plane import control_plane
from app.core.task_result import human_consumer_server
from app.agents.query_engine.agent import init_query_engine_agent
from app.agents.dummy.agent import init_dummy_agent
agents = [
init_query_engine_agent(message_queue),
init_dummy_agent(message_queue),
]
async def main():
researcher = FunctionCallingAgent(
name="researcher",
tools=[get_query_engine_tool()],
system_prompt="You are a researcher agent. You are given a researching task. You must use your tools to complete the research.",
)
reviewer = AgentCallingAgent(
name="reviewer",
system_prompt="You are an expert in reviewing blog posts. You are given a task to write a blog post. Before starting to write the post, consult the researcher agent to get the information you need. Don't make up any information yourself.",
)
writer = AgentCallingAgent(
name="writer",
agents=[researcher, reviewer],
system_prompt="""You are an expert in writing blog posts. You are given a task to write a blog post. Before starting to write the post, consult the researcher agent to get the information you need. Don't make up any information yourself.
After creating a draft for the post, send it to the reviewer agent to receive some feedback and make sure to incorporate the feedback from the reviewer.
You can consult the reviewer and researcher multiple times. Only finish the task once the reviewer is satisfied.""",
)
ret = await writer.run(input="Write a blog post about letter standards")
print(ret["response"])
launcher = ServerLauncher(
agents,
control_plane,
message_queue,
additional_consumers=[human_consumer_server.as_consumer()],
)
if __name__ == "__main__":
launcher.launch_servers()
asyncio.run(main())
Generated
+2264
View File
File diff suppressed because it is too large Load Diff
+5 -9
View File
@@ -3,7 +3,7 @@
name = "app"
version = "0.1.0"
description = ""
authors = [ "Marcus Schiesser <mail@marcusschiesser.de>" ]
authors = ["Marcus Schiesser <mail@marcusschiesser.de>"]
readme = "README.md"
[tool.poetry.scripts]
@@ -11,16 +11,12 @@ generate = "app.engine.generate:generate_datasource"
[tool.poetry.dependencies]
python = "^3.11"
llama-agents = "^0.0.3"
llama-index-embeddings-openai = "^0.1.10"
llama-index-llms-openai = "^0.1.23"
[tool.poetry.dependencies.llama-index-agent-openai]
version = "0.2.6"
llama-index-agent-openai = ">=0.3.0,<0.4.0"
llama-index = "^0.11.0"
[tool.poetry.dependencies.docx2txt]
version = "^0.8"
[build-system]
requires = [ "poetry-core" ]
build-backend = "poetry.core.masonry.api"
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"