upgrade RAGs (#27)

This commit is contained in:
Jerry Liu
2023-11-26 01:27:44 -08:00
committed by GitHub
parent 55c95f1142
commit 49b867909c
12 changed files with 828 additions and 233 deletions
+32
View File
@@ -0,0 +1,32 @@
name: Linting
on:
push:
branches:
- main
pull_request:
jobs:
build:
runs-on: ubuntu-latest
strategy:
# You can use PyPy versions in python-version.
# For example, pypy-2.7 and pypy-3.8
matrix:
python-version: ["3.9"]
poetry-version: [1.5.1]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Run image
uses: abatilo/actions-poetry@v2.0.0
with:
poetry-version: ${{ matrix.poetry-version }}
- name: Install deps
run: |
poetry install --with dev
- name: Run Linting
run: poetry run make lint
+65 -26
View File
@@ -3,6 +3,11 @@ from streamlit_pills import pills
from agent_utils import (
load_meta_agent_and_tools,
load_agent_ids_from_directory,
)
from st_utils import add_sidebar
from constants import (
AGENT_CACHE_DIR,
)
@@ -11,23 +16,41 @@ from agent_utils import (
####################
st.set_page_config(page_title="Build a RAGs bot, powered by LlamaIndex", page_icon="🦙", layout="centered", initial_sidebar_state="auto", menu_items=None)
st.set_page_config(
page_title="Build a RAGs bot, powered by LlamaIndex",
page_icon="🦙",
layout="centered",
initial_sidebar_state="auto",
menu_items=None,
)
st.title("Build a RAGs bot, powered by LlamaIndex 💬🦙")
st.info(
"Use this page to build your RAG bot over your data! "
"Once the agent is finished creating, check out the `RAG Config` and `Generated RAG Agent` pages.",
icon=""
"Once the agent is finished creating, check out the `RAG Config` and "
"`Generated RAG Agent` pages.\n"
"To build a new agent, please make sure that 'Create a new agent' is selected.",
icon="",
)
# TODO: noodle on this
# with st.sidebar:
# openai_api_key_st = st.text_input("OpenAI API Key (optional, not needed if you filled in secrets.toml)", value="", type="password")
# if st.button("Save"):
# # save api key
# st.session_state.openai_api_key = openai_api_key_st
#### load builder agent and its tool spec (the agent_builder)
builder_agent, agent_builder = load_meta_agent_and_tools()
add_sidebar()
if (
"selected_cache" in st.session_state.keys()
and st.session_state.selected_cache is not None
):
# create builder agent / tools from selected cache
builder_agent, agent_builder = load_meta_agent_and_tools(
cache=st.session_state.selected_cache
)
else:
# create builder agent / tools from new cache
builder_agent, agent_builder = load_meta_agent_and_tools()
st.info(f"Currently building/editing agent: {agent_builder.cache.agent_id}", icon="")
if "builder_agent" not in st.session_state.keys():
st.session_state.builder_agent = builder_agent
@@ -36,27 +59,34 @@ if "agent_builder" not in st.session_state.keys():
# add pills
selected = pills(
"Outline your task!",
["I want to analyze this PDF file (data/invoices.pdf)",
"I want to search over my CSV documents."
], clearable=True, index=None
"Outline your task!",
[
"I want to analyze this PDF file (data/invoices.pdf)",
"I want to search over my CSV documents.",
],
clearable=True,
index=None,
)
if "messages" not in st.session_state.keys(): # Initialize the chat messages history
if "messages" not in st.session_state.keys(): # Initialize the chat messages history
st.session_state.messages = [
{"role": "assistant", "content": "What RAG bot do you want to build?"}
]
def add_to_message_history(role, content):
message = {"role": role, "content": str(content)}
st.session_state.messages.append(message) # Add response to message history
for message in st.session_state.messages: # Display the prior chat messages
def add_to_message_history(role: str, content: str) -> None:
message = {"role": role, "content": str(content)}
st.session_state.messages.append(message) # Add response to message history
for message in st.session_state.messages: # Display the prior chat messages
with st.chat_message(message["role"]):
st.write(message["content"])
# handle user input
if prompt := st.chat_input("Your question"): # Prompt for user input and save to chat history
if prompt := st.chat_input(
"Your question"
): # Prompt for user input and save to chat history
add_to_message_history("user", prompt)
with st.chat_message("user"):
st.write(prompt)
@@ -67,9 +97,18 @@ if st.session_state.messages[-1]["role"] != "assistant":
with st.spinner("Thinking..."):
response = st.session_state.builder_agent.chat(prompt)
st.write(str(response))
add_to_message_history("assistant", response)
add_to_message_history("assistant", str(response))
# # check cache
print(st.session_state.agent_builder.cache)
# if "agent" in cache:
# st.session_state.agent = cache["agent"]
# check agent_ids again, if it doesn't match, add to directory and refresh
agent_ids = load_agent_ids_from_directory(str(AGENT_CACHE_DIR))
# check diff between agent_ids and cur agent ids
diff_ids = list(set(agent_ids) - set(st.session_state.cur_agent_ids))
if len(diff_ids) > 0:
# clear streamlit cache, to allow you to generate a new agent
st.cache_resource.clear()
# trigger refresh
st.rerun()
else:
pass
+10
View File
@@ -0,0 +1,10 @@
.PHONY: format lint
format:
black .
lint:
mypy .
black --check .
ruff check .
test:
pytest tests
+2 -2
View File
@@ -11,10 +11,10 @@ This project is inspired by [GPTs](https://openai.com/blog/introducing-gpts), la
## Installation and Setup
Clone this project, go into the `rags` project folder.
Clone this project, go into the `rags` project folder. We recommend creating a virtual env for dependencies (`python3 -m venv .venv`).
```
pip install -r requirements.txt
poetry install --with dev
```
By default, we use OpenAI for both the builder agent as well as the generated RAG agent.
+435 -133
View File
@@ -3,14 +3,15 @@ from llama_index.llms.base import LLM
from llama_index.llms.utils import resolve_llm
from pydantic import BaseModel, Field
import os
from llama_index.tools.query_engine import QueryEngineTool
from llama_index.agent import OpenAIAgent, ReActAgent
from llama_index.agent.react.prompts import REACT_CHAT_SYSTEM_HEADER
from llama_index import (
VectorStoreIndex,
SummaryIndex,
ServiceContext,
Document
StorageContext,
Document,
load_index_from_storage,
)
from llama_index.prompts import ChatPromptTemplate
from typing import List, cast, Optional
@@ -18,28 +19,32 @@ from llama_index import SimpleDirectoryReader
from llama_index.embeddings.utils import resolve_embed_model
from llama_index.tools import QueryEngineTool, ToolMetadata, FunctionTool
from llama_index.agent.types import BaseAgent
from llama_index.chat_engine.types import BaseChatEngine
from llama_index.agent.react.formatter import ReActChatFormatter
from llama_index.llms.openai_utils import is_function_calling_model
from llama_index.chat_engine import CondensePlusContextChatEngine
from builder_config import BUILDER_LLM
from typing import Dict, Tuple, Any
from typing import Dict, Tuple, Any, Callable
import streamlit as st
from pathlib import Path
import json
import uuid
from constants import AGENT_CACHE_DIR
import shutil
def _resolve_llm(llm: str) -> LLM:
def _resolve_llm(llm_str: str) -> LLM:
"""Resolve LLM."""
# TODO: make this less hardcoded with if-else statements
# see if there's a prefix
# - if there isn't, assume it's an OpenAI model
# - if there is, resolve it
tokens = llm.split(":")
tokens = llm_str.split(":")
if len(tokens) == 1:
os.environ["OPENAI_API_KEY"] = st.secrets.openai_key
llm = OpenAI(model=llm)
llm: LLM = OpenAI(model=llm_str)
elif tokens[0] == "local":
llm = resolve_llm(llm)
llm = resolve_llm(llm_str)
elif tokens[0] == "openai":
os.environ["OPENAI_API_KEY"] = st.secrets.openai_key
llm = OpenAI(model=tokens[1])
@@ -50,7 +55,7 @@ def _resolve_llm(llm: str) -> LLM:
os.environ["REPLICATE_API_KEY"] = st.secrets.replicate_key
llm = Replicate(model=tokens[1])
else:
raise ValueError(f"LLM {llm} not recognized.")
raise ValueError(f"LLM {llm_str} not recognized.")
return llm
@@ -63,14 +68,17 @@ def _resolve_llm(llm: str) -> LLM:
GEN_SYS_PROMPT_STR = """\
Task information is given below.
Given the task, please generate a system prompt for an OpenAI-powered bot to solve this task:
Given the task, please generate a system prompt for an OpenAI-powered bot \
to solve this task:
{task} \
Make sure the system prompt obeys the following requirements:
- Tells the bot to ALWAYS use tools given to solve the task. NEVER give an answer without using a tool.
- Does not reference a specific data source. The data source is implicit in any queries to the bot,
and telling the bot to analyze a specific data source might confuse it given a
user query.
- Tells the bot to ALWAYS use tools given to solve the task. \
NEVER give an answer without using a tool.
- Does not reference a specific data source. \
The data source is implicit in any queries to the bot, \
and telling the bot to analyze a specific data source might confuse it given a \
user query.
"""
@@ -85,49 +93,196 @@ gen_sys_prompt_messages = [
GEN_SYS_PROMPT_TMPL = ChatPromptTemplate(gen_sys_prompt_messages)
class RAGParams(BaseModel):
"""RAG parameters.
Parameters used to configure a RAG pipeline.
"""
include_summarization: bool = Field(
default=False,
description=(
"Whether to include summarization in the RAG pipeline. (only for GPT-4)"
),
)
top_k: int = Field(
default=2, description="Number of documents to retrieve from vector store."
)
chunk_size: int = Field(default=1024, description="Chunk size for vector store.")
embed_model: str = Field(
default="default", description="Embedding model to use (default is OpenAI)"
)
llm: str = Field(
default="gpt-4-1106-preview", description="LLM to use for summarization."
)
def load_data(
file_names: Optional[List[str]] = None, urls: Optional[List[str]] = None
) -> List[Document]:
"""Load data."""
file_names = file_names or []
urls = urls or []
if not file_names and not urls:
raise ValueError("Must specify either file_names or urls.")
elif file_names and urls:
raise ValueError("Must specify only one of file_names or urls.")
elif file_names:
reader = SimpleDirectoryReader(input_files=file_names)
docs = reader.load_data()
elif urls:
from llama_hub.web.simple_web.base import SimpleWebPageReader
# use simple web page reader from llamahub
loader = SimpleWebPageReader()
docs = loader.load_data(urls=urls)
else:
raise ValueError("Must specify either file_names or urls.")
return docs
def load_agent(
tools: List,
llm: LLM,
tools: List,
llm: LLM,
system_prompt: str,
extra_kwargs: Optional[Dict] = None,
**kwargs: Any
) -> BaseAgent:
**kwargs: Any,
) -> BaseChatEngine:
"""Load agent."""
extra_kwargs = extra_kwargs or {}
if isinstance(llm, OpenAI) and is_function_calling_model(llm.model):
# get OpenAI Agent
agent = OpenAIAgent.from_tools(
tools=tools,
llm=llm,
system_prompt=system_prompt,
**kwargs
agent: BaseChatEngine = OpenAIAgent.from_tools(
tools=tools, llm=llm, system_prompt=system_prompt, **kwargs
)
else:
if "vector_index" not in extra_kwargs:
raise ValueError("Must pass in vector index for CondensePlusContextChatEngine.")
raise ValueError(
"Must pass in vector index for CondensePlusContextChatEngine."
)
vector_index = cast(VectorStoreIndex, extra_kwargs["vector_index"])
rag_params = cast(RAGParams, extra_kwargs["rag_params"])
# use condense + context chat engine
agent = CondensePlusContextChatEngine.from_defaults(
vector_index.as_retriever(similarity_top_k=rag_params.top_k),
)
return agent
class RAGParams(BaseModel):
"""RAG parameters.
def load_meta_agent(
tools: List,
llm: LLM,
system_prompt: str,
extra_kwargs: Optional[Dict] = None,
**kwargs: Any,
) -> BaseAgent:
"""Load meta agent.
TODO: consolidate with load_agent.
The meta-agent *has* to perform tool-use.
Parameters used to configure a RAG pipeline.
"""
include_summarization: bool = Field(default=False, description="Whether to include summarization in the RAG pipeline. (only for GPT-4)")
top_k: int = Field(default=2, description="Number of documents to retrieve from vector store.")
chunk_size: int = Field(default=1024, description="Chunk size for vector store.")
embed_model: str = Field(
default="default", description="Embedding model to use (default is OpenAI)"
extra_kwargs = extra_kwargs or {}
if isinstance(llm, OpenAI) and is_function_calling_model(llm.model):
# get OpenAI Agent
agent: BaseAgent = OpenAIAgent.from_tools(
tools=tools, llm=llm, system_prompt=system_prompt, **kwargs
)
else:
agent = ReActAgent.from_tools(
tools=tools,
llm=llm,
react_chat_formatter=ReActChatFormatter(
system_header=system_prompt + "\n" + REACT_CHAT_SYSTEM_HEADER,
),
**kwargs,
)
return agent
def construct_agent(
system_prompt: str,
rag_params: RAGParams,
docs: List[Document],
vector_index: Optional[VectorStoreIndex] = None,
additional_tools: Optional[List] = None,
) -> Tuple[BaseChatEngine, Dict]:
"""Construct agent from docs / parameters / indices."""
extra_info = {}
additional_tools = additional_tools or []
# first resolve llm and embedding model
embed_model = resolve_embed_model(rag_params.embed_model)
# llm = resolve_llm(rag_params.llm)
# TODO: use OpenAI for now
# llm = OpenAI(model=rag_params.llm)
llm = _resolve_llm(rag_params.llm)
# first let's index the data with the right parameters
service_context = ServiceContext.from_defaults(
chunk_size=rag_params.chunk_size,
llm=llm,
embed_model=embed_model,
)
llm: str = Field(default="gpt-4-1106-preview", description="LLM to use for summarization.")
if vector_index is None:
vector_index = VectorStoreIndex.from_documents(
docs, service_context=service_context
)
else:
pass
extra_info["vector_index"] = vector_index
vector_query_engine = vector_index.as_query_engine(
similarity_top_k=rag_params.top_k
)
all_tools = []
vector_tool = QueryEngineTool(
query_engine=vector_query_engine,
metadata=ToolMetadata(
name="vector_tool",
description=("Use this tool to answer any user question over any data."),
),
)
all_tools.append(vector_tool)
if rag_params.include_summarization:
summary_index = SummaryIndex.from_documents(
docs, service_context=service_context
)
summary_query_engine = summary_index.as_query_engine()
summary_tool = QueryEngineTool(
query_engine=summary_query_engine,
metadata=ToolMetadata(
name="summary_tool",
description=(
"Use this tool for any user questions that ask "
"for a summarization of content"
),
),
)
all_tools.append(summary_tool)
# then we add tools
all_tools.extend(additional_tools)
# build agent
if system_prompt is None:
return "System prompt not set yet. Please set system prompt first."
agent = load_agent(
all_tools,
llm=llm,
system_prompt=system_prompt,
verbose=True,
extra_kwargs={"vector_index": vector_index, "rag_params": rag_params},
)
return agent, extra_info
class ParamCache(BaseModel):
@@ -135,20 +290,163 @@ class ParamCache(BaseModel):
Created a wrapper class around a dict in case we wanted to more explicitly
type different items in the cache.
"""
# arbitrary types
class Config:
arbitrary_types_allowed = True
system_prompt: Optional[str] = Field(default=None, description="System prompt for RAG agent.")
file_paths: List[str] = Field(default_factory=list, description="File paths for RAG agent.")
docs: List[Document] = Field(default_factory=list, description="Documents for RAG agent.")
tools: List = Field(default_factory=list, description="Additional tools for RAG agent (e.g. web)")
rag_params: RAGParams = Field(default_factory=RAGParams, description="RAG parameters for RAG agent.")
agent: Optional[OpenAIAgent] = Field(default=None, description="RAG agent.")
# system prompt
system_prompt: Optional[str] = Field(
default=None, description="System prompt for RAG agent."
)
# data
file_names: List[str] = Field(
default_factory=list, description="File names as data source (if specified)"
)
urls: List[str] = Field(
default_factory=list, description="URLs as data source (if specified)"
)
docs: List = Field(default_factory=list, description="Documents for RAG agent.")
# tools
tools: List = Field(
default_factory=list, description="Additional tools for RAG agent (e.g. web)"
)
# RAG params
rag_params: RAGParams = Field(
default_factory=RAGParams, description="RAG parameters for RAG agent."
)
# agent params
vector_index: Optional[VectorStoreIndex] = Field(
default=None, description="Vector index for RAG agent."
)
agent_id: str = Field(
default_factory=lambda: f"Agent_{str(uuid.uuid4())}",
description="Agent ID for RAG agent.",
)
agent: Optional[BaseChatEngine] = Field(default=None, description="RAG agent.")
def save_to_disk(self, save_dir: str) -> None:
"""Save cache to disk."""
# NOTE: more complex than just calling dict() because we want to
# only store serializable fields and be space-efficient
dict_to_serialize = {
"system_prompt": self.system_prompt,
"file_names": self.file_names,
"urls": self.urls,
# TODO: figure out tools
# "tools": [],
"rag_params": self.rag_params.dict(),
"agent_id": self.agent_id,
}
# store the vector store within the agent
if self.vector_index is None:
raise ValueError("Must specify vector index in order to save.")
self.vector_index.storage_context.persist(Path(save_dir) / "storage")
# if save_path directories don't exist, create it
if not Path(save_dir).exists():
Path(save_dir).mkdir(parents=True)
with open(Path(save_dir) / "cache.json", "w") as f:
json.dump(dict_to_serialize, f)
@classmethod
def load_from_disk(
cls,
save_dir: str,
) -> "ParamCache":
"""Load cache from disk."""
storage_context = StorageContext.from_defaults(
persist_dir=str(Path(save_dir) / "storage")
)
vector_index = cast(VectorStoreIndex, load_index_from_storage(storage_context))
with open(Path(save_dir) / "cache.json", "r") as f:
cache_dict = json.load(f)
# replace rag params with RAGParams object
cache_dict["rag_params"] = RAGParams(**cache_dict["rag_params"])
# add in the missing fields
# load docs
cache_dict["docs"] = load_data(
file_names=cache_dict["file_names"], urls=cache_dict["urls"]
)
# load agent from index
agent, _ = construct_agent(
cache_dict["system_prompt"],
cache_dict["rag_params"],
cache_dict["docs"],
vector_index=vector_index,
# TODO: figure out tools
)
cache_dict["vector_index"] = vector_index
cache_dict["agent"] = agent
return cls(**cache_dict)
def add_agent_id_to_directory(dir: str, agent_id: str) -> None:
"""Save agent id to directory."""
full_path = Path(dir) / "agent_ids.json"
if not full_path.exists():
with open(full_path, "w") as f:
json.dump({"agent_ids": [agent_id]}, f)
else:
with open(full_path, "r") as f:
agent_ids = json.load(f)["agent_ids"]
if agent_id in agent_ids:
raise ValueError(f"Agent id {agent_id} already exists.")
agent_ids_set = set(agent_ids)
agent_ids_set.add(agent_id)
with open(full_path, "w") as f:
json.dump({"agent_ids": list(agent_ids_set)}, f)
def load_agent_ids_from_directory(dir: str) -> List[str]:
"""Load agent ids file."""
full_path = Path(dir) / "agent_ids.json"
if not full_path.exists():
return []
with open(full_path, "r") as f:
agent_ids = json.load(f)["agent_ids"]
return agent_ids
def load_cache_from_directory(
dir: str,
agent_id: str,
) -> ParamCache:
"""Load cache from directory."""
full_path = Path(dir) / f"{agent_id}"
if not full_path.exists():
raise ValueError(f"Cache for agent {agent_id} does not exist.")
cache = ParamCache.load_from_disk(str(full_path))
return cache
def remove_agent_from_directory(
dir: str,
agent_id: str,
) -> None:
"""Remove agent from directory."""
# modify / resave agent_ids
agent_ids = load_agent_ids_from_directory(dir)
new_agent_ids = [id for id in agent_ids if id != agent_id]
full_path = Path(dir) / "agent_ids.json"
with open(full_path, "w") as f:
json.dump({"agent_ids": new_agent_ids}, f)
# remove agent cache
full_path = Path(dir) / f"{agent_id}"
if full_path.exists():
# recursive delete
shutil.rmtree(full_path)
class RAGAgentBuilder:
@@ -161,11 +459,15 @@ class RAGAgentBuilder:
- setting parameters (e.g. top-k)
Must pass in a cache. This cache will be modified as the agent is built.
"""
def __init__(self, cache: Optional[ParamCache] = None) -> None:
def __init__(
self, cache: Optional[ParamCache] = None, cache_dir: Optional[str] = None
) -> None:
"""Init params."""
self._cache = cache or ParamCache()
self._cache_dir = cache_dir or AGENT_CACHE_DIR
@property
def cache(self) -> ParamCache:
@@ -181,11 +483,8 @@ class RAGAgentBuilder:
return f"System prompt created: {response.message.content}"
def load_data(
self,
file_names: Optional[List[str]] = None,
urls: Optional[List[str]] = None
self, file_names: Optional[List[str]] = None, urls: Optional[List[str]] = None
) -> str:
"""Load data for a given task.
@@ -196,32 +495,18 @@ class RAGAgentBuilder:
Defaults to None.
urls (Optional[List[str]]): List of urls to load.
Defaults to None.
"""
if file_names is None and urls is None:
raise ValueError("Must specify either file_names or urls.")
elif file_names is not None and urls is not None:
raise ValueError("Must specify only one of file_names or urls.")
elif file_names is not None:
reader = SimpleDirectoryReader(input_files=file_names)
docs = reader.load_data()
file_paths = file_names
elif urls is not None:
from llama_hub.web.simple_web.base import SimpleWebPageReader
# use simple web page reader from llamahub
loader = SimpleWebPageReader()
docs = loader.load_data(urls=urls)
file_paths = urls
else:
raise ValueError("Must specify either file_names or urls.")
file_names = file_names or []
urls = urls or []
docs = load_data(file_names=file_names, urls=urls)
self._cache.docs = docs
self._cache.file_paths = file_paths
self._cache.file_names = file_names
self._cache.urls = urls
return "Data loaded successfully."
# NOTE: unused
def add_web_tool(self) -> None:
def add_web_tool(self) -> str:
"""Add a web tool to enable agent to solve a task."""
# TODO: make this not hardcoded to a web tool
# Set up Metaphor tool
@@ -241,21 +526,20 @@ class RAGAgentBuilder:
Should be called before `set_rag_params` so that the agent is aware of the
schema.
"""
rag_params = self._cache.rag_params
return rag_params.dict()
def set_rag_params(self, **rag_params: Dict):
def set_rag_params(self, **rag_params: Dict) -> str:
"""Set RAG parameters.
These parameters will then be used to actually initialize the agent.
Should call `get_rag_params` first to get the schema of the input dictionary.
Args:
**rag_params (Dict): dictionary of RAG parameters.
**rag_params (Dict): dictionary of RAG parameters.
"""
new_dict = self._cache.rag_params.dict()
new_dict.update(rag_params)
@@ -263,69 +547,85 @@ class RAGAgentBuilder:
self._cache.rag_params = rag_params_obj
return "RAG parameters set successfully."
def create_agent(self) -> None:
def create_agent(self, agent_id: Optional[str] = None) -> str:
"""Create an agent.
There are no parameters for this function because all the
functions should have already been called to set up the agent.
"""
rag_params = cast(RAGParams, self._cache.rag_params)
docs = self._cache.docs
# first resolve llm and embedding model
embed_model = resolve_embed_model(rag_params.embed_model)
# llm = resolve_llm(rag_params.llm)
# TODO: use OpenAI for now
# llm = OpenAI(model=rag_params.llm)
llm = _resolve_llm(rag_params.llm)
# first let's index the data with the right parameters
service_context = ServiceContext.from_defaults(
chunk_size=rag_params.chunk_size,
llm=llm,
embed_model=embed_model,
)
vector_index = VectorStoreIndex.from_documents(docs, service_context=service_context)
vector_query_engine = vector_index.as_query_engine(similarity_top_k=rag_params.top_k)
all_tools = []
vector_tool = QueryEngineTool(
query_engine=vector_query_engine,
metadata=ToolMetadata(
name="vector_tool",
description=("Use this tool to answer any user question over any data."),
),
)
all_tools.append(vector_tool)
if rag_params.include_summarization:
summary_index = SummaryIndex.from_documents(docs, service_context=service_context)
summary_query_engine = summary_index.as_query_engine()
summary_tool = QueryEngineTool(
query_engine=summary_query_engine,
metadata=ToolMetadata(
name="summary_tool",
description=("Use this tool for any user questions that ask for a summarization of content"),
),
)
all_tools.append(summary_tool)
# then we add tools
all_tools.extend(self._cache.tools)
# build agent
if self._cache.system_prompt is None:
return "System prompt not set yet. Please set system prompt first."
raise ValueError("Must set system prompt before creating agent.")
agent = load_agent(
all_tools, llm=llm, system_prompt=self._cache.system_prompt, verbose=True,
extra_kwargs={"vector_index": vector_index, "rag_params": rag_params}
agent, extra_info = construct_agent(
cast(str, self._cache.system_prompt),
cast(RAGParams, self._cache.rag_params),
self._cache.docs,
additional_tools=self._cache.tools,
)
# if agent_id not specified, randomly generate one
agent_id = agent_id or self._cache.agent_id or f"Agent_{str(uuid.uuid4())}"
self._cache.vector_index = extra_info["vector_index"]
self._cache.agent_id = agent_id
self._cache.agent = agent
# save the cache to disk
agent_cache_path = f"{self._cache_dir}/{agent_id}"
self._cache.save_to_disk(agent_cache_path)
# save to agent ids
add_agent_id_to_directory(str(self._cache_dir), agent_id)
return "Agent created successfully."
def update_agent(
self,
agent_id: str,
system_prompt: Optional[str] = None,
include_summarization: Optional[bool] = None,
top_k: Optional[int] = None,
chunk_size: Optional[int] = None,
embed_model: Optional[str] = None,
llm: Optional[str] = None,
) -> None:
"""Update agent.
Delete old agent by ID and create a new one.
Optionally update the system prompt and RAG parameters.
NOTE: Currently is manually called, not meant for agent use.
"""
# remove saved agent from directory, since we'll be re-saving
remove_agent_from_directory(str(AGENT_CACHE_DIR), self.cache.agent_id)
# set agent id
self.cache.agent_id = agent_id
# set system prompt
if system_prompt is not None:
self.cache.system_prompt = system_prompt
# get agent_builder
# We call set_rag_params and create_agent, which will
# update the cache
# TODO: decouple functions from tool functions exposed to the agent
rag_params_dict: Dict[str, Any] = {}
if include_summarization is not None:
rag_params_dict["include_summarization"] = include_summarization
if top_k is not None:
rag_params_dict["top_k"] = top_k
if chunk_size is not None:
rag_params_dict["chunk_size"] = chunk_size
if embed_model is not None:
rag_params_dict["embed_model"] = embed_model
if llm is not None:
rag_params_dict["llm"] = llm
self.set_rag_params(**rag_params_dict)
# this will update the agent in the cache
self.create_agent()
####################
#### META Agent ####
@@ -339,6 +639,7 @@ You should generally use the tools in this rough order to build the agent.
2) Load in user-specified data (based on file paths they specify).
3) Decide whether or not to add additional tools.
4) Set parameters for the RAG pipeline.
5) Build the agent
This will be a back and forth conversation with the user. You should
continue asking users if there's anything else they want to do until
@@ -355,25 +656,26 @@ have available (e.g. "Do you want to set the number of documents to retrieve?")
# define agent
@st.cache_resource
def load_meta_agent_and_tools() -> Tuple[OpenAIAgent, RAGAgentBuilder]:
# @st.cache_resource
def load_meta_agent_and_tools(
cache: Optional[ParamCache] = None,
) -> Tuple[BaseAgent, RAGAgentBuilder]:
# think of this as tools for the agent to use
agent_builder = RAGAgentBuilder()
agent_builder = RAGAgentBuilder(cache)
fns = [
agent_builder.create_system_prompt,
agent_builder.load_data,
fns: List[Callable] = [
agent_builder.create_system_prompt,
agent_builder.load_data,
# add_web_tool,
agent_builder.get_rag_params,
agent_builder.set_rag_params,
agent_builder.create_agent
agent_builder.create_agent,
]
fn_tools = [FunctionTool.from_defaults(fn=fn) for fn in fns]
builder_agent = load_agent(
builder_agent = load_meta_agent(
fn_tools, llm=BUILDER_LLM, system_prompt=RAG_BUILDER_SYS_STR, verbose=True
)
return builder_agent, agent_builder
+2 -1
View File
@@ -7,6 +7,7 @@ import os
## OpenAI
from llama_index.llms import OpenAI
# set OpenAI Key - use Streamlit secrets
os.environ["OPENAI_API_KEY"] = st.secrets.openai_key
# load LLM
@@ -16,4 +17,4 @@ BUILDER_LLM = OpenAI(model="gpt-4-1106-preview")
# from llama_index.llms import Anthropic
# # set Anthropic key
# os.environ["ANTHROPIC_API_KEY"] = st.secrets.anthropic_key
# BUILDER_LLM = Anthropic()
# BUILDER_LLM = Anthropic()
+4
View File
@@ -0,0 +1,4 @@
from pathlib import Path
AGENT_CACHE_DIR = Path(__file__).parent / "cache" / "agents"
MESSAGES_CACHE_DIR = Path(__file__).parent / "cache" / "messages"
+103 -37
View File
@@ -1,13 +1,16 @@
"""Streamlit page showing builder config."""
import streamlit as st
import openai
from streamlit_pills import pills
from typing import cast
from typing import cast, Optional
from agent_utils import (
RAGParams,
RAGAgentBuilder,
ParamCache,
remove_agent_from_directory,
)
from st_utils import update_selected_agent_with_id
from constants import AGENT_CACHE_DIR
from st_utils import add_sidebar
####################
@@ -15,53 +18,116 @@ from agent_utils import (
####################
def update_agent() -> None:
"""Update agent."""
if (
"config_agent_builder" in st.session_state.keys()
and st.session_state.config_agent_builder is not None
):
agent_builder = cast(RAGAgentBuilder, st.session_state.config_agent_builder)
### Update the agent
agent_builder.update_agent(
st.session_state.agent_id_st,
system_prompt=st.session_state.sys_prompt_st,
include_summarization=st.session_state.include_summarization_st,
top_k=st.session_state.top_k_st,
chunk_size=st.session_state.chunk_size_st,
embed_model=st.session_state.embed_model_st,
llm=st.session_state.llm_st,
)
st.set_page_config(page_title="RAG Pipeline Config", page_icon="🦙", layout="centered", initial_sidebar_state="auto", menu_items=None)
st.title("RAG Pipeline Config")
st.info(
"This is generated by the builder in the above section.", icon=""
# Update Radio Buttons: update selected agent to the new id
update_selected_agent_with_id(agent_builder.cache.agent_id)
else:
raise ValueError("Agent builder is None. Cannot update agent.")
def delete_agent() -> None:
"""Delete agent."""
if (
"config_agent_builder" in st.session_state.keys()
and st.session_state.config_agent_builder is not None
):
agent_builder = cast(RAGAgentBuilder, st.session_state.config_agent_builder)
### Delete agent
# remove saved agent from directory
remove_agent_from_directory(str(AGENT_CACHE_DIR), agent_builder.cache.agent_id)
# Update Radio Buttons: update selected agent to the new id
update_selected_agent_with_id(None)
else:
raise ValueError("Agent builder is None. Cannot delete agent.")
st.set_page_config(
page_title="RAG Pipeline Config",
page_icon="🦙",
layout="centered",
initial_sidebar_state="auto",
menu_items=None,
)
st.title("RAG Pipeline Config")
add_sidebar()
if "agent_builder" in st.session_state.keys():
# first, pick the cache: this is preloaded from an existing agent,
# or is part of the current one being created
if (
"selected_cache" in st.session_state.keys()
and st.session_state.selected_cache is not None
):
cache = cast(ParamCache, st.session_state.selected_cache)
agent_builder: Optional[RAGAgentBuilder] = RAGAgentBuilder(cache)
elif "agent_builder" in st.session_state.keys():
agent_builder = cast(RAGAgentBuilder, st.session_state.agent_builder)
else:
agent_builder = None
# set as session state
st.session_state.config_agent_builder = agent_builder
if agent_builder is not None:
st.info(f"Viewing config for agent: {agent_builder.cache.agent_id}", icon="")
agent_id_st = st.text_input(
"Agent ID", value=agent_builder.cache.agent_id, key="agent_id_st"
)
if agent_builder.cache.system_prompt is None:
system_prompt = ""
else:
system_prompt = agent_builder.cache.system_prompt
sys_prompt_st = st.text_area("System Prompt", value=system_prompt)
sys_prompt_st = st.text_area(
"System Prompt", value=system_prompt, key="sys_prompt_st"
)
rag_params = cast(RAGParams, agent_builder.cache.rag_params)
file_paths = st.text_input(
"File/URL paths (not editable)",
value=",".join(agent_builder.cache.file_paths),
disabled=True
file_names = st.text_input(
"File names (not editable)",
value=",".join(agent_builder.cache.file_names),
disabled=True,
)
include_summarization_st = st.checkbox("Include Summarization (only works for GPT-4)", value=rag_params.include_summarization)
top_k_st = st.number_input("Top K", value=rag_params.top_k)
chunk_size_st = st.number_input("Chunk Size", value=rag_params.chunk_size)
embed_model_st = st.text_input("Embed Model", value=rag_params.embed_model)
llm_st = st.text_input("LLM", value=rag_params.llm)
urls = st.text_input(
"URLs (not editable)", value=",".join(agent_builder.cache.urls), disabled=True
)
include_summarization_st = st.checkbox(
"Include Summarization (only works for GPT-4)",
value=rag_params.include_summarization,
key="include_summarization_st",
)
top_k_st = st.number_input("Top K", value=rag_params.top_k, key="top_k_st")
chunk_size_st = st.number_input(
"Chunk Size", value=rag_params.chunk_size, key="chunk_size_st"
)
embed_model_st = st.text_input(
"Embed Model", value=rag_params.embed_model, key="embed_model_st"
)
llm_st = st.text_input("LLM", value=rag_params.llm, key="llm_st")
if agent_builder.cache.agent is not None:
if st.button("Update Agent"):
# update the agent
agent_builder.cache.system_prompt = sys_prompt_st
# get agent_builder
# We call set_rag_params and create_agent, which will
# update the cache
agent_builder = cast(RAGAgentBuilder, st.session_state.agent_builder)
# TODO: decouple functions from tool functions exposed to the agent
agent_builder.set_rag_params(
include_summarization=include_summarization_st,
top_k=top_k_st,
chunk_size=chunk_size_st,
embed_model=embed_model_st,
llm=llm_st,
)
# this will update the agent in the cache
agent_builder.create_agent()
st.button("Update Agent", on_click=update_agent)
st.button(":red[Delete Agent]", on_click=delete_agent)
else:
# show text saying "agent not created"
st.info("Agent not created. Please create an agent in the above section.")
else:
st.info("agent builder not created yet. Please describe your task in the above section.")
st.info("No agent builder found. Please create an agent in the above section.")
+52 -34
View File
@@ -1,11 +1,8 @@
"""Streamlit page showing builder config."""
import streamlit as st
from typing import cast
from agent_utils import (
RAGAgentBuilder,
)
from streamlit_pills import pills
from typing import cast, Optional
from agent_utils import RAGAgentBuilder, ParamCache
from st_utils import add_sidebar
####################
@@ -13,46 +10,67 @@ from streamlit_pills import pills
####################
st.set_page_config(page_title="Generated RAG Agent", page_icon="🦙", layout="centered", initial_sidebar_state="auto", menu_items=None)
st.title("Generated RAG Agent")
st.info(
"This is generated by the builder in the above section.", icon=""
st.set_page_config(
page_title="Generated RAG Agent",
page_icon="🦙",
layout="centered",
initial_sidebar_state="auto",
menu_items=None,
)
st.title("Generated RAG Agent")
if "agent_messages" not in st.session_state.keys(): # Initialize the chat messages history
add_sidebar()
if (
"agent_messages" not in st.session_state.keys()
): # Initialize the chat messages history
st.session_state.agent_messages = [
{"role": "assistant", "content": "Ask me a question!"}
]
def add_to_message_history(role, content):
def add_to_message_history(role: str, content: str) -> None:
message = {"role": role, "content": str(content)}
st.session_state.agent_messages.append(message) # Add response to message history
st.session_state.agent_messages.append(message) # Add response to message history
# first, pick the cache: this is preloaded from an existing agent,
# or is part of the current one being created
agent = None
if "agent_builder" in st.session_state.keys():
if (
"selected_cache" in st.session_state.keys()
and st.session_state.selected_cache is not None
):
cache: Optional[ParamCache] = cast(ParamCache, st.session_state.selected_cache)
elif "agent_builder" in st.session_state.keys():
agent_builder = cast(RAGAgentBuilder, st.session_state.agent_builder)
if agent_builder.cache.agent is not None:
agent = agent_builder.cache.agent
for message in st.session_state.agent_messages: # Display the prior chat messages
with st.chat_message(message["role"]):
st.write(message["content"])
cache = agent_builder.cache
else:
cache = None
st.info("Agent not created. Please create an agent in the above section.")
# don't process selected for now
if prompt := st.chat_input("Your question"): # Prompt for user input and save to chat history
add_to_message_history("user", prompt)
with st.chat_message("user"):
st.write(prompt)
# if agent is created, then we can chat with it
if cache is not None and cache.agent is not None:
st.info(f"Viewing config for agent: {cache.agent_id}", icon="")
agent = cache.agent
for message in st.session_state.agent_messages: # Display the prior chat messages
with st.chat_message(message["role"]):
st.write(message["content"])
# If last message is not from assistant, generate a new response
if st.session_state.agent_messages[-1]["role"] != "assistant":
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
response = agent.chat(prompt)
st.write(str(response))
add_to_message_history("assistant", response)
else:
st.info("Agent not created. Please create an agent in the above section.")
# don't process selected for now
if prompt := st.chat_input(
"Your question"
): # Prompt for user input and save to chat history
add_to_message_history("user", prompt)
with st.chat_message("user"):
st.write(prompt)
# If last message is not from assistant, generate a new response
if st.session_state.agent_messages[-1]["role"] != "assistant":
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
response = agent.chat(str(prompt))
st.write(str(response))
add_to_message_history("assistant", str(response))
else:
st.info("Agent not created. Please create an agent in the above section.")
+65
View File
@@ -0,0 +1,65 @@
[tool.poetry]
name = "rags"
version = "0.0.2"
description = "Build RAG with natural language."
authors = ["Jerry Liu"]
# New attributes
license = "MIT"
readme = "README.md"
homepage = "https://docs.llamaindex.ai/en/latest/"
repository = "https://github.com/run-llama/rags"
keywords = ["llama-index", "rags"]
include = [
"LICENSE",
]
[tool.poetry.dependencies]
python = ">=3.8.1,<3.12,!=3.9.7"
streamlit = "1.28.0"
streamlit-pills = "0.3.0"
llama-index = "0.9.7"
llama-hub = "0.0.44"
# NOTE: this is due to a trivial dependency in the web tool, will refactor
langchain = "0.0.305"
pypdf = "3.17.1"
[tool.poetry.dev-dependencies]
# pytest = "7.2.1"
# pytest-dotenv = "0.5.2"
# pytest_httpserver = "1.0.8"
# pytest-mock = "3.11.1"
typing-inspect = "0.8.0"
typing_extensions = "^4.5.0"
types-requests = "2.28.11.8"
black = "22.12.0"
isort = "5.11.4"
pytest-asyncio = "^0.21.1"
ruff = "0.0.285"
mypy = "0.991"
[build-system]
requires = ["poetry>=0.12", "poetry-core>=1.0.0"]
build-backend = "poetry.masonry.api"
[tool.mypy]
disallow_untyped_defs = true
ignore_missing_imports = true
exclude = ["notebooks", "build", "examples"]
[tool.ruff]
# Allow lines to be as long as 80 characters.
# TODO: it should be removed, but we need to fix the entire code first.
line-length = 88
exclude = [
".venv",
"__pycache__",
".ipynb_checkpoints",
".mypy_cache",
".ruff_cache",
"examples",
"notebooks",
".git"
]
[tool.ruff.per-file-ignores]
"base.py" = ["E402", "F811", "E501"]
+58
View File
@@ -0,0 +1,58 @@
"""Streamlit utils."""
from agent_utils import (
load_agent_ids_from_directory,
load_cache_from_directory,
)
from constants import (
AGENT_CACHE_DIR,
)
from typing import Optional
import streamlit as st
def update_selected_agent_with_id(selected_id: Optional[str] = None) -> None:
"""Update selected agent with id."""
# set session state
st.session_state.selected_id = (
selected_id if selected_id != "Create a new agent" else None
)
if st.session_state.selected_id is None:
st.session_state.selected_cache = None
else:
# load agent from directory
agent_cache = load_cache_from_directory(
str(AGENT_CACHE_DIR), st.session_state.selected_id
)
st.session_state.selected_cache = agent_cache
## handler for sidebar specifically
def update_selected_agent() -> None:
"""Update selected agent."""
selected_id = st.session_state.agent_selector
update_selected_agent_with_id(selected_id)
def add_sidebar() -> None:
"""Add sidebar."""
with st.sidebar:
st.session_state.cur_agent_ids = load_agent_ids_from_directory(
str(AGENT_CACHE_DIR)
)
choices = ["Create a new agent"] + st.session_state.cur_agent_ids
# by default, set index to 0. if value is in selected_id, set index to that
index = 0
if "selected_id" in st.session_state.keys():
if st.session_state.selected_id is not None:
index = choices.index(st.session_state.selected_id)
# display buttons
st.radio(
"Agents",
choices,
index=index,
on_change=update_selected_agent,
key="agent_selector",
)
View File