mirror of
https://github.com/run-llama/rags.git
synced 2026-07-01 20:54:00 -04:00
308 lines
10 KiB
Python
308 lines
10 KiB
Python
from llama_index.llms import OpenAI, ChatMessage
|
|
from llama_index.llms.utils import resolve_llm
|
|
from pydantic import BaseModel, Field
|
|
import os
|
|
from llama_index.tools.query_engine import QueryEngineTool
|
|
from llama_index.agent import OpenAIAgent
|
|
from llama_index import (
|
|
VectorStoreIndex,
|
|
SummaryIndex,
|
|
ServiceContext,
|
|
Document
|
|
)
|
|
from llama_index.prompts import ChatPromptTemplate
|
|
from typing import List, cast, Optional
|
|
from llama_index import SimpleDirectoryReader
|
|
from llama_index.embeddings.utils import resolve_embed_model
|
|
from llama_index.tools import QueryEngineTool, ToolMetadata, FunctionTool
|
|
from typing import Dict, Tuple
|
|
import streamlit as st
|
|
|
|
|
|
####################
|
|
#### META TOOLS ####
|
|
####################
|
|
|
|
|
|
# System prompt tool
|
|
GEN_SYS_PROMPT_STR = """\
|
|
Task information is given below.
|
|
|
|
Given the task, please generate a system prompt for an OpenAI-powered bot to solve this task:
|
|
{task} \
|
|
|
|
Make sure the system prompt obeys the following requirements:
|
|
- Tells the bot to ALWAYS use tools given to solve the task. NEVER give an answer without using a tool.
|
|
- Does not reference a specific data source. The data source is implicit in any queries to the bot,
|
|
and telling the bot to analyze a specific data source might confuse it given a
|
|
user query.
|
|
|
|
"""
|
|
|
|
gen_sys_prompt_messages = [
|
|
ChatMessage(
|
|
role="system",
|
|
content="You are helping to build a system prompt for another bot.",
|
|
),
|
|
ChatMessage(role="user", content=GEN_SYS_PROMPT_STR),
|
|
]
|
|
|
|
GEN_SYS_PROMPT_TMPL = ChatPromptTemplate(gen_sys_prompt_messages)
|
|
|
|
|
|
class RAGParams(BaseModel):
|
|
"""RAG parameters.
|
|
|
|
Parameters used to configure a RAG pipeline.
|
|
|
|
"""
|
|
include_summarization: bool = Field(default=False, description="Whether to include summarization in the RAG pipeline.")
|
|
top_k: int = Field(default=2, description="Number of documents to retrieve from vector store.")
|
|
chunk_size: int = Field(default=1024, description="Chunk size for vector store.")
|
|
embed_model: str = Field(
|
|
default="default", description="Embedding model to use (default is OpenAI)"
|
|
)
|
|
llm: str = Field(default="gpt-4-1106-preview", description="LLM to use for summarization.")
|
|
|
|
|
|
class ParamCache(BaseModel):
|
|
"""Cache for RAG agent builder.
|
|
|
|
Created a wrapper class around a dict in case we wanted to more explicitly
|
|
type different items in the cache.
|
|
|
|
"""
|
|
|
|
# arbitrary types
|
|
class Config:
|
|
arbitrary_types_allowed = True
|
|
|
|
system_prompt: Optional[str] = Field(default=None, description="System prompt for RAG agent.")
|
|
docs: List[Document] = Field(default_factory=list, description="Documents for RAG agent.")
|
|
tools: List = Field(default_factory=list, description="Additional tools for RAG agent (e.g. web)")
|
|
rag_params: RAGParams = Field(default_factory=RAGParams, description="RAG parameters for RAG agent.")
|
|
agent: Optional[OpenAIAgent] = Field(default=None, description="RAG agent.")
|
|
|
|
|
|
|
|
class RAGAgentBuilder:
|
|
"""RAG Agent builder.
|
|
|
|
Contains a set of functions to construct a RAG agent, including:
|
|
- setting system prompts
|
|
- loading data
|
|
- adding web search
|
|
- setting parameters (e.g. top-k)
|
|
|
|
Must pass in a cache. This cache will be modified as the agent is built.
|
|
|
|
"""
|
|
def __init__(self, cache: Optional[ParamCache] = None) -> None:
|
|
"""Init params."""
|
|
self._cache = cache or ParamCache()
|
|
|
|
@property
|
|
def cache(self) -> ParamCache:
|
|
"""Cache."""
|
|
return self._cache
|
|
|
|
def create_system_prompt(self, task: str) -> str:
|
|
"""Create system prompt for another agent given an input task."""
|
|
llm = OpenAI(model="gpt-4-1106-preview")
|
|
fmt_messages = GEN_SYS_PROMPT_TMPL.format_messages(task=task)
|
|
response = llm.chat(fmt_messages)
|
|
self._cache.system_prompt = response.message.content
|
|
|
|
return f"System prompt created: {response.message.content}"
|
|
|
|
|
|
def load_data(
|
|
self,
|
|
file_names: Optional[List[str]] = None,
|
|
urls: Optional[List[str]] = None
|
|
) -> str:
|
|
"""Load data for a given task.
|
|
|
|
Only ONE of file_names or urls should be specified.
|
|
|
|
Args:
|
|
file_names (Optional[List[str]]): List of file names to load.
|
|
Defaults to None.
|
|
urls (Optional[List[str]]): List of urls to load.
|
|
Defaults to None.
|
|
|
|
"""
|
|
if file_names is None and urls is None:
|
|
raise ValueError("Must specify either file_names or urls.")
|
|
elif file_names is not None and urls is not None:
|
|
raise ValueError("Must specify only one of file_names or urls.")
|
|
elif file_names is not None:
|
|
reader = SimpleDirectoryReader(input_files=file_names)
|
|
docs = reader.load_data()
|
|
elif urls is not None:
|
|
from llama_hub.web.simple_web.base import SimpleWebPageReader
|
|
# use simple web page reader from llamahub
|
|
loader = SimpleWebPageReader()
|
|
docs = loader.load_data(urls=urls)
|
|
else:
|
|
raise ValueError("Must specify either file_names or urls.")
|
|
|
|
self._cache.docs = docs
|
|
return "Data loaded successfully."
|
|
|
|
|
|
# NOTE: unused
|
|
def add_web_tool(self) -> None:
|
|
"""Add a web tool to enable agent to solve a task."""
|
|
# TODO: make this not hardcoded to a web tool
|
|
# Set up Metaphor tool
|
|
from llama_hub.tools.metaphor.base import MetaphorToolSpec
|
|
|
|
# TODO: set metaphor API key
|
|
metaphor_tool = MetaphorToolSpec(
|
|
api_key=os.environ["METAPHOR_API_KEY"],
|
|
)
|
|
metaphor_tool_list = metaphor_tool.to_tool_list()
|
|
|
|
self._cache.tools.extend(metaphor_tool_list)
|
|
return "Web tool added successfully."
|
|
|
|
def get_rag_params(self) -> Dict:
|
|
"""Get parameters used to configure the RAG pipeline.
|
|
|
|
Should be called before `set_rag_params` so that the agent is aware of the
|
|
schema.
|
|
|
|
"""
|
|
rag_params = self._cache.rag_params
|
|
return rag_params.model_dump()
|
|
|
|
|
|
def set_rag_params(self, **rag_params: Dict):
|
|
"""Set RAG parameters.
|
|
|
|
These parameters will then be used to actually initialize the agent.
|
|
Should call `get_rag_params` first to get the schema of the input dictionary.
|
|
|
|
Args:
|
|
**rag_params (Dict): dictionary of RAG parameters.
|
|
|
|
"""
|
|
new_dict = self._cache.rag_params.model_dump()
|
|
new_dict.update(rag_params)
|
|
rag_params_obj = RAGParams(**new_dict)
|
|
self._cache.rag_params = rag_params_obj
|
|
return "RAG parameters set successfully."
|
|
|
|
|
|
def create_agent(self) -> None:
|
|
"""Create an agent.
|
|
|
|
There are no parameters for this function because all the
|
|
functions should have already been called to set up the agent.
|
|
|
|
"""
|
|
rag_params = cast(RAGParams, self._cache.rag_params)
|
|
docs = self._cache.docs
|
|
|
|
# first resolve llm and embedding model
|
|
embed_model = resolve_embed_model(rag_params.embed_model)
|
|
# llm = resolve_llm(rag_params.llm)
|
|
# TODO: use OpenAI for now
|
|
llm = OpenAI(model=rag_params.llm)
|
|
|
|
# first let's index the data with the right parameters
|
|
service_context = ServiceContext.from_defaults(
|
|
chunk_size=rag_params.chunk_size,
|
|
llm=llm,
|
|
embed_model=embed_model,
|
|
)
|
|
vector_index = VectorStoreIndex.from_documents(docs, service_context=service_context)
|
|
vector_query_engine = vector_index.as_query_engine(similarity_top_k=rag_params.top_k)
|
|
all_tools = []
|
|
vector_tool = QueryEngineTool(
|
|
query_engine=vector_query_engine,
|
|
metadata=ToolMetadata(
|
|
name="vector_tool",
|
|
description=("Use this tool to answer any user question over any data."),
|
|
),
|
|
)
|
|
all_tools.append(vector_tool)
|
|
if rag_params.include_summarization:
|
|
summary_index = SummaryIndex.from_documents(docs, service_context=service_context)
|
|
summary_query_engine = summary_index.as_query_engine()
|
|
summary_tool = QueryEngineTool(
|
|
query_engine=summary_query_engine,
|
|
metadata=ToolMetadata(
|
|
name="summary_tool",
|
|
description=("Use this tool for any user questions that ask for a summarization of content"),
|
|
),
|
|
)
|
|
all_tools.append(summary_tool)
|
|
|
|
|
|
# then we add tools
|
|
all_tools.extend(self._cache.tools)
|
|
|
|
# build agent
|
|
if self._cache.system_prompt is None:
|
|
return "System prompt not set yet. Please set system prompt first."
|
|
|
|
agent = OpenAIAgent.from_tools(
|
|
tools=all_tools,
|
|
system_prompt=self._cache.system_prompt,
|
|
llm=llm,
|
|
verbose=True
|
|
)
|
|
self._cache.agent = agent
|
|
return "Agent created successfully."
|
|
|
|
|
|
####################
|
|
#### META Agent ####
|
|
####################
|
|
|
|
RAG_BUILDER_SYS_STR = """\
|
|
You are helping to construct an agent given a user-specified task.
|
|
You should generally use the tools in this rough order to build the agent.
|
|
|
|
1) Create system prompt tool: to create the system prompt for the agent.
|
|
2) Load in user-specified data (based on file paths they specify).
|
|
3) Decide whether or not to add additional tools.
|
|
4) Set parameters for the RAG pipeline.
|
|
|
|
This will be a back and forth conversation with the user. You should
|
|
continue asking users if there's anything else they want to do until
|
|
they say they're done. To help guide them on the process,
|
|
you can give suggestions on parameters they can set based on the tools they
|
|
have available (e.g. "Do you want to set the number of documents to retrieve?")
|
|
|
|
"""
|
|
|
|
|
|
# define agent
|
|
@st.cache_resource
|
|
def load_meta_agent_and_tools() -> Tuple[OpenAIAgent, RAGAgentBuilder]:
|
|
prefix_msgs = [ChatMessage(role="system", content=RAG_BUILDER_SYS_STR)]
|
|
|
|
# think of this as tools for the agent to use
|
|
agent_builder = RAGAgentBuilder()
|
|
|
|
fns = [
|
|
agent_builder.create_system_prompt,
|
|
agent_builder.load_data,
|
|
# add_web_tool,
|
|
agent_builder.get_rag_params,
|
|
agent_builder.set_rag_params,
|
|
agent_builder.create_agent
|
|
]
|
|
fn_tools = [FunctionTool.from_defaults(fn=fn) for fn in fns]
|
|
|
|
builder_agent = OpenAIAgent.from_tools(
|
|
tools=fn_tools,
|
|
llm=OpenAI(llm="gpt-4-1106-preview"),
|
|
prefix_messages=prefix_msgs,
|
|
verbose=True,
|
|
)
|
|
return builder_agent, agent_builder
|
|
|