Files
multi-agents-workflow/app/core/function_call.py
T
2024-08-26 16:49:40 +07:00

137 lines
4.2 KiB
Python

from typing import Any, List
from llama_index.core.llms.function_calling import FunctionCallingLLM
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.tools.types import BaseTool
from llama_index.core.workflow import Workflow, StartEvent, StopEvent, step
from llama_index.core.llms import ChatMessage
from llama_index.core.tools import ToolSelection, ToolOutput
from llama_index.core.workflow import Event
from llama_index.core.settings import Settings
class InputEvent(Event):
input: list[ChatMessage]
class ToolCallEvent(Event):
tool_calls: list[ToolSelection]
class FunctionOutputEvent(Event):
output: ToolOutput
class FunctionCallingAgent(Workflow):
def __init__(
self,
*args: Any,
llm: FunctionCallingLLM | None = None,
tools: List[BaseTool] | None = None,
system_prompt: str | None = None,
verbose: bool = True,
timeout: float = 120.0,
name: str,
**kwargs: Any,
) -> None:
super().__init__(*args, verbose=verbose, timeout=timeout, **kwargs)
self.tools = tools or []
self.name = name
if llm is None:
llm = Settings.llm
self.llm = llm
assert self.llm.metadata.is_function_calling_model
self.system_prompt = system_prompt
self.memory = ChatMemoryBuffer.from_defaults(llm=self.llm)
self.sources = []
@step()
async def prepare_chat_history(self, ev: StartEvent) -> InputEvent:
# clear sources
self.sources = []
# set system prompt
if self.system_prompt is not None:
system_msg = ChatMessage(role="system", content=self.system_prompt)
self.memory.put(system_msg)
# get user input
user_input = ev.input
user_msg = ChatMessage(role="user", content=user_input)
self.memory.put(user_msg)
# get chat history
chat_history = self.memory.get()
return InputEvent(input=chat_history)
@step()
async def handle_llm_input(self, ev: InputEvent) -> ToolCallEvent | StopEvent:
chat_history = ev.input
response = await self.llm.achat_with_tools(
self.tools, chat_history=chat_history
)
self.memory.put(response.message)
tool_calls = self.llm.get_tool_calls_from_response(
response, error_on_no_tool_call=False
)
if not tool_calls:
return StopEvent(result={"response": response, "sources": [*self.sources]})
else:
return ToolCallEvent(tool_calls=tool_calls)
@step()
async def handle_tool_calls(self, ev: ToolCallEvent) -> InputEvent:
tool_calls = ev.tool_calls
tools_by_name = {tool.metadata.get_name(): tool for tool in self.tools}
tool_msgs = []
# call tools -- safely!
for tool_call in tool_calls:
tool = tools_by_name.get(tool_call.tool_name)
additional_kwargs = {
"tool_call_id": tool_call.tool_id,
"name": tool.metadata.get_name(),
}
if not tool:
tool_msgs.append(
ChatMessage(
role="tool",
content=f"Tool {tool_call.tool_name} does not exist",
additional_kwargs=additional_kwargs,
)
)
continue
try:
tool_output = await tool.acall(**tool_call.tool_kwargs)
self.sources.append(tool_output)
tool_msgs.append(
ChatMessage(
role="tool",
content=tool_output.content,
additional_kwargs=additional_kwargs,
)
)
except Exception as e:
tool_msgs.append(
ChatMessage(
role="tool",
content=f"Encountered error in tool call: {e}",
additional_kwargs=additional_kwargs,
)
)
for msg in tool_msgs:
self.memory.put(msg)
chat_history = self.memory.get()
return InputEvent(input=chat_history)