mirror of
https://github.com/run-llama/rags.git
synced 2026-07-01 20:54:00 -04:00
upgrade RAGs (#27)
This commit is contained in:
@@ -0,0 +1,32 @@
|
||||
name: Linting
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
# You can use PyPy versions in python-version.
|
||||
# For example, pypy-2.7 and pypy-3.8
|
||||
matrix:
|
||||
python-version: ["3.9"]
|
||||
poetry-version: [1.5.1]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Run image
|
||||
uses: abatilo/actions-poetry@v2.0.0
|
||||
with:
|
||||
poetry-version: ${{ matrix.poetry-version }}
|
||||
- name: Install deps
|
||||
run: |
|
||||
poetry install --with dev
|
||||
- name: Run Linting
|
||||
run: poetry run make lint
|
||||
+65
-26
@@ -3,6 +3,11 @@ from streamlit_pills import pills
|
||||
|
||||
from agent_utils import (
|
||||
load_meta_agent_and_tools,
|
||||
load_agent_ids_from_directory,
|
||||
)
|
||||
from st_utils import add_sidebar
|
||||
from constants import (
|
||||
AGENT_CACHE_DIR,
|
||||
)
|
||||
|
||||
|
||||
@@ -11,23 +16,41 @@ from agent_utils import (
|
||||
####################
|
||||
|
||||
|
||||
st.set_page_config(page_title="Build a RAGs bot, powered by LlamaIndex", page_icon="🦙", layout="centered", initial_sidebar_state="auto", menu_items=None)
|
||||
st.set_page_config(
|
||||
page_title="Build a RAGs bot, powered by LlamaIndex",
|
||||
page_icon="🦙",
|
||||
layout="centered",
|
||||
initial_sidebar_state="auto",
|
||||
menu_items=None,
|
||||
)
|
||||
st.title("Build a RAGs bot, powered by LlamaIndex 💬🦙")
|
||||
st.info(
|
||||
"Use this page to build your RAG bot over your data! "
|
||||
"Once the agent is finished creating, check out the `RAG Config` and `Generated RAG Agent` pages.",
|
||||
icon="ℹ️"
|
||||
"Once the agent is finished creating, check out the `RAG Config` and "
|
||||
"`Generated RAG Agent` pages.\n"
|
||||
"To build a new agent, please make sure that 'Create a new agent' is selected.",
|
||||
icon="ℹ️",
|
||||
)
|
||||
|
||||
# TODO: noodle on this
|
||||
# with st.sidebar:
|
||||
# openai_api_key_st = st.text_input("OpenAI API Key (optional, not needed if you filled in secrets.toml)", value="", type="password")
|
||||
# if st.button("Save"):
|
||||
# # save api key
|
||||
# st.session_state.openai_api_key = openai_api_key_st
|
||||
|
||||
#### load builder agent and its tool spec (the agent_builder)
|
||||
builder_agent, agent_builder = load_meta_agent_and_tools()
|
||||
add_sidebar()
|
||||
|
||||
|
||||
if (
|
||||
"selected_cache" in st.session_state.keys()
|
||||
and st.session_state.selected_cache is not None
|
||||
):
|
||||
# create builder agent / tools from selected cache
|
||||
builder_agent, agent_builder = load_meta_agent_and_tools(
|
||||
cache=st.session_state.selected_cache
|
||||
)
|
||||
else:
|
||||
# create builder agent / tools from new cache
|
||||
builder_agent, agent_builder = load_meta_agent_and_tools()
|
||||
|
||||
|
||||
st.info(f"Currently building/editing agent: {agent_builder.cache.agent_id}", icon="ℹ️")
|
||||
|
||||
|
||||
if "builder_agent" not in st.session_state.keys():
|
||||
st.session_state.builder_agent = builder_agent
|
||||
@@ -36,27 +59,34 @@ if "agent_builder" not in st.session_state.keys():
|
||||
|
||||
# add pills
|
||||
selected = pills(
|
||||
"Outline your task!",
|
||||
["I want to analyze this PDF file (data/invoices.pdf)",
|
||||
"I want to search over my CSV documents."
|
||||
], clearable=True, index=None
|
||||
"Outline your task!",
|
||||
[
|
||||
"I want to analyze this PDF file (data/invoices.pdf)",
|
||||
"I want to search over my CSV documents.",
|
||||
],
|
||||
clearable=True,
|
||||
index=None,
|
||||
)
|
||||
|
||||
if "messages" not in st.session_state.keys(): # Initialize the chat messages history
|
||||
if "messages" not in st.session_state.keys(): # Initialize the chat messages history
|
||||
st.session_state.messages = [
|
||||
{"role": "assistant", "content": "What RAG bot do you want to build?"}
|
||||
]
|
||||
|
||||
def add_to_message_history(role, content):
|
||||
message = {"role": role, "content": str(content)}
|
||||
st.session_state.messages.append(message) # Add response to message history
|
||||
|
||||
for message in st.session_state.messages: # Display the prior chat messages
|
||||
def add_to_message_history(role: str, content: str) -> None:
|
||||
message = {"role": role, "content": str(content)}
|
||||
st.session_state.messages.append(message) # Add response to message history
|
||||
|
||||
|
||||
for message in st.session_state.messages: # Display the prior chat messages
|
||||
with st.chat_message(message["role"]):
|
||||
st.write(message["content"])
|
||||
|
||||
# handle user input
|
||||
if prompt := st.chat_input("Your question"): # Prompt for user input and save to chat history
|
||||
if prompt := st.chat_input(
|
||||
"Your question"
|
||||
): # Prompt for user input and save to chat history
|
||||
add_to_message_history("user", prompt)
|
||||
with st.chat_message("user"):
|
||||
st.write(prompt)
|
||||
@@ -67,9 +97,18 @@ if st.session_state.messages[-1]["role"] != "assistant":
|
||||
with st.spinner("Thinking..."):
|
||||
response = st.session_state.builder_agent.chat(prompt)
|
||||
st.write(str(response))
|
||||
add_to_message_history("assistant", response)
|
||||
add_to_message_history("assistant", str(response))
|
||||
|
||||
# # check cache
|
||||
print(st.session_state.agent_builder.cache)
|
||||
# if "agent" in cache:
|
||||
# st.session_state.agent = cache["agent"]
|
||||
# check agent_ids again, if it doesn't match, add to directory and refresh
|
||||
agent_ids = load_agent_ids_from_directory(str(AGENT_CACHE_DIR))
|
||||
# check diff between agent_ids and cur agent ids
|
||||
diff_ids = list(set(agent_ids) - set(st.session_state.cur_agent_ids))
|
||||
if len(diff_ids) > 0:
|
||||
# clear streamlit cache, to allow you to generate a new agent
|
||||
st.cache_resource.clear()
|
||||
|
||||
# trigger refresh
|
||||
st.rerun()
|
||||
|
||||
else:
|
||||
pass
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
.PHONY: format lint
|
||||
|
||||
format:
|
||||
black .
|
||||
lint:
|
||||
mypy .
|
||||
black --check .
|
||||
ruff check .
|
||||
test:
|
||||
pytest tests
|
||||
@@ -11,10 +11,10 @@ This project is inspired by [GPTs](https://openai.com/blog/introducing-gpts), la
|
||||
|
||||
## Installation and Setup
|
||||
|
||||
Clone this project, go into the `rags` project folder.
|
||||
Clone this project, go into the `rags` project folder. We recommend creating a virtual env for dependencies (`python3 -m venv .venv`).
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
poetry install --with dev
|
||||
```
|
||||
|
||||
By default, we use OpenAI for both the builder agent as well as the generated RAG agent.
|
||||
|
||||
+435
-133
@@ -3,14 +3,15 @@ from llama_index.llms.base import LLM
|
||||
from llama_index.llms.utils import resolve_llm
|
||||
from pydantic import BaseModel, Field
|
||||
import os
|
||||
from llama_index.tools.query_engine import QueryEngineTool
|
||||
from llama_index.agent import OpenAIAgent, ReActAgent
|
||||
from llama_index.agent.react.prompts import REACT_CHAT_SYSTEM_HEADER
|
||||
from llama_index import (
|
||||
VectorStoreIndex,
|
||||
SummaryIndex,
|
||||
ServiceContext,
|
||||
Document
|
||||
StorageContext,
|
||||
Document,
|
||||
load_index_from_storage,
|
||||
)
|
||||
from llama_index.prompts import ChatPromptTemplate
|
||||
from typing import List, cast, Optional
|
||||
@@ -18,28 +19,32 @@ from llama_index import SimpleDirectoryReader
|
||||
from llama_index.embeddings.utils import resolve_embed_model
|
||||
from llama_index.tools import QueryEngineTool, ToolMetadata, FunctionTool
|
||||
from llama_index.agent.types import BaseAgent
|
||||
from llama_index.chat_engine.types import BaseChatEngine
|
||||
from llama_index.agent.react.formatter import ReActChatFormatter
|
||||
from llama_index.llms.openai_utils import is_function_calling_model
|
||||
from llama_index.chat_engine import CondensePlusContextChatEngine
|
||||
from builder_config import BUILDER_LLM
|
||||
from typing import Dict, Tuple, Any
|
||||
from typing import Dict, Tuple, Any, Callable
|
||||
import streamlit as st
|
||||
from pathlib import Path
|
||||
import json
|
||||
import uuid
|
||||
from constants import AGENT_CACHE_DIR
|
||||
import shutil
|
||||
|
||||
|
||||
def _resolve_llm(llm: str) -> LLM:
|
||||
def _resolve_llm(llm_str: str) -> LLM:
|
||||
"""Resolve LLM."""
|
||||
# TODO: make this less hardcoded with if-else statements
|
||||
# see if there's a prefix
|
||||
# - if there isn't, assume it's an OpenAI model
|
||||
# - if there is, resolve it
|
||||
tokens = llm.split(":")
|
||||
tokens = llm_str.split(":")
|
||||
if len(tokens) == 1:
|
||||
os.environ["OPENAI_API_KEY"] = st.secrets.openai_key
|
||||
llm = OpenAI(model=llm)
|
||||
llm: LLM = OpenAI(model=llm_str)
|
||||
elif tokens[0] == "local":
|
||||
llm = resolve_llm(llm)
|
||||
llm = resolve_llm(llm_str)
|
||||
elif tokens[0] == "openai":
|
||||
os.environ["OPENAI_API_KEY"] = st.secrets.openai_key
|
||||
llm = OpenAI(model=tokens[1])
|
||||
@@ -50,7 +55,7 @@ def _resolve_llm(llm: str) -> LLM:
|
||||
os.environ["REPLICATE_API_KEY"] = st.secrets.replicate_key
|
||||
llm = Replicate(model=tokens[1])
|
||||
else:
|
||||
raise ValueError(f"LLM {llm} not recognized.")
|
||||
raise ValueError(f"LLM {llm_str} not recognized.")
|
||||
return llm
|
||||
|
||||
|
||||
@@ -63,14 +68,17 @@ def _resolve_llm(llm: str) -> LLM:
|
||||
GEN_SYS_PROMPT_STR = """\
|
||||
Task information is given below.
|
||||
|
||||
Given the task, please generate a system prompt for an OpenAI-powered bot to solve this task:
|
||||
Given the task, please generate a system prompt for an OpenAI-powered bot \
|
||||
to solve this task:
|
||||
{task} \
|
||||
|
||||
Make sure the system prompt obeys the following requirements:
|
||||
- Tells the bot to ALWAYS use tools given to solve the task. NEVER give an answer without using a tool.
|
||||
- Does not reference a specific data source. The data source is implicit in any queries to the bot,
|
||||
and telling the bot to analyze a specific data source might confuse it given a
|
||||
user query.
|
||||
- Tells the bot to ALWAYS use tools given to solve the task. \
|
||||
NEVER give an answer without using a tool.
|
||||
- Does not reference a specific data source. \
|
||||
The data source is implicit in any queries to the bot, \
|
||||
and telling the bot to analyze a specific data source might confuse it given a \
|
||||
user query.
|
||||
|
||||
"""
|
||||
|
||||
@@ -85,49 +93,196 @@ gen_sys_prompt_messages = [
|
||||
GEN_SYS_PROMPT_TMPL = ChatPromptTemplate(gen_sys_prompt_messages)
|
||||
|
||||
|
||||
class RAGParams(BaseModel):
|
||||
"""RAG parameters.
|
||||
|
||||
Parameters used to configure a RAG pipeline.
|
||||
|
||||
"""
|
||||
|
||||
include_summarization: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"Whether to include summarization in the RAG pipeline. (only for GPT-4)"
|
||||
),
|
||||
)
|
||||
top_k: int = Field(
|
||||
default=2, description="Number of documents to retrieve from vector store."
|
||||
)
|
||||
chunk_size: int = Field(default=1024, description="Chunk size for vector store.")
|
||||
embed_model: str = Field(
|
||||
default="default", description="Embedding model to use (default is OpenAI)"
|
||||
)
|
||||
llm: str = Field(
|
||||
default="gpt-4-1106-preview", description="LLM to use for summarization."
|
||||
)
|
||||
|
||||
|
||||
def load_data(
|
||||
file_names: Optional[List[str]] = None, urls: Optional[List[str]] = None
|
||||
) -> List[Document]:
|
||||
"""Load data."""
|
||||
file_names = file_names or []
|
||||
urls = urls or []
|
||||
if not file_names and not urls:
|
||||
raise ValueError("Must specify either file_names or urls.")
|
||||
elif file_names and urls:
|
||||
raise ValueError("Must specify only one of file_names or urls.")
|
||||
elif file_names:
|
||||
reader = SimpleDirectoryReader(input_files=file_names)
|
||||
docs = reader.load_data()
|
||||
elif urls:
|
||||
from llama_hub.web.simple_web.base import SimpleWebPageReader
|
||||
|
||||
# use simple web page reader from llamahub
|
||||
loader = SimpleWebPageReader()
|
||||
docs = loader.load_data(urls=urls)
|
||||
else:
|
||||
raise ValueError("Must specify either file_names or urls.")
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
def load_agent(
|
||||
tools: List,
|
||||
llm: LLM,
|
||||
tools: List,
|
||||
llm: LLM,
|
||||
system_prompt: str,
|
||||
extra_kwargs: Optional[Dict] = None,
|
||||
**kwargs: Any
|
||||
) -> BaseAgent:
|
||||
**kwargs: Any,
|
||||
) -> BaseChatEngine:
|
||||
"""Load agent."""
|
||||
extra_kwargs = extra_kwargs or {}
|
||||
if isinstance(llm, OpenAI) and is_function_calling_model(llm.model):
|
||||
# get OpenAI Agent
|
||||
agent = OpenAIAgent.from_tools(
|
||||
tools=tools,
|
||||
llm=llm,
|
||||
system_prompt=system_prompt,
|
||||
**kwargs
|
||||
agent: BaseChatEngine = OpenAIAgent.from_tools(
|
||||
tools=tools, llm=llm, system_prompt=system_prompt, **kwargs
|
||||
)
|
||||
else:
|
||||
if "vector_index" not in extra_kwargs:
|
||||
raise ValueError("Must pass in vector index for CondensePlusContextChatEngine.")
|
||||
raise ValueError(
|
||||
"Must pass in vector index for CondensePlusContextChatEngine."
|
||||
)
|
||||
vector_index = cast(VectorStoreIndex, extra_kwargs["vector_index"])
|
||||
rag_params = cast(RAGParams, extra_kwargs["rag_params"])
|
||||
# use condense + context chat engine
|
||||
agent = CondensePlusContextChatEngine.from_defaults(
|
||||
vector_index.as_retriever(similarity_top_k=rag_params.top_k),
|
||||
)
|
||||
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
class RAGParams(BaseModel):
|
||||
"""RAG parameters.
|
||||
def load_meta_agent(
|
||||
tools: List,
|
||||
llm: LLM,
|
||||
system_prompt: str,
|
||||
extra_kwargs: Optional[Dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseAgent:
|
||||
"""Load meta agent.
|
||||
|
||||
TODO: consolidate with load_agent.
|
||||
|
||||
The meta-agent *has* to perform tool-use.
|
||||
|
||||
Parameters used to configure a RAG pipeline.
|
||||
|
||||
"""
|
||||
include_summarization: bool = Field(default=False, description="Whether to include summarization in the RAG pipeline. (only for GPT-4)")
|
||||
top_k: int = Field(default=2, description="Number of documents to retrieve from vector store.")
|
||||
chunk_size: int = Field(default=1024, description="Chunk size for vector store.")
|
||||
embed_model: str = Field(
|
||||
default="default", description="Embedding model to use (default is OpenAI)"
|
||||
extra_kwargs = extra_kwargs or {}
|
||||
if isinstance(llm, OpenAI) and is_function_calling_model(llm.model):
|
||||
# get OpenAI Agent
|
||||
agent: BaseAgent = OpenAIAgent.from_tools(
|
||||
tools=tools, llm=llm, system_prompt=system_prompt, **kwargs
|
||||
)
|
||||
else:
|
||||
agent = ReActAgent.from_tools(
|
||||
tools=tools,
|
||||
llm=llm,
|
||||
react_chat_formatter=ReActChatFormatter(
|
||||
system_header=system_prompt + "\n" + REACT_CHAT_SYSTEM_HEADER,
|
||||
),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def construct_agent(
|
||||
system_prompt: str,
|
||||
rag_params: RAGParams,
|
||||
docs: List[Document],
|
||||
vector_index: Optional[VectorStoreIndex] = None,
|
||||
additional_tools: Optional[List] = None,
|
||||
) -> Tuple[BaseChatEngine, Dict]:
|
||||
"""Construct agent from docs / parameters / indices."""
|
||||
extra_info = {}
|
||||
additional_tools = additional_tools or []
|
||||
|
||||
# first resolve llm and embedding model
|
||||
embed_model = resolve_embed_model(rag_params.embed_model)
|
||||
# llm = resolve_llm(rag_params.llm)
|
||||
# TODO: use OpenAI for now
|
||||
# llm = OpenAI(model=rag_params.llm)
|
||||
llm = _resolve_llm(rag_params.llm)
|
||||
|
||||
# first let's index the data with the right parameters
|
||||
service_context = ServiceContext.from_defaults(
|
||||
chunk_size=rag_params.chunk_size,
|
||||
llm=llm,
|
||||
embed_model=embed_model,
|
||||
)
|
||||
llm: str = Field(default="gpt-4-1106-preview", description="LLM to use for summarization.")
|
||||
|
||||
if vector_index is None:
|
||||
vector_index = VectorStoreIndex.from_documents(
|
||||
docs, service_context=service_context
|
||||
)
|
||||
else:
|
||||
pass
|
||||
|
||||
extra_info["vector_index"] = vector_index
|
||||
|
||||
vector_query_engine = vector_index.as_query_engine(
|
||||
similarity_top_k=rag_params.top_k
|
||||
)
|
||||
all_tools = []
|
||||
vector_tool = QueryEngineTool(
|
||||
query_engine=vector_query_engine,
|
||||
metadata=ToolMetadata(
|
||||
name="vector_tool",
|
||||
description=("Use this tool to answer any user question over any data."),
|
||||
),
|
||||
)
|
||||
all_tools.append(vector_tool)
|
||||
if rag_params.include_summarization:
|
||||
summary_index = SummaryIndex.from_documents(
|
||||
docs, service_context=service_context
|
||||
)
|
||||
summary_query_engine = summary_index.as_query_engine()
|
||||
summary_tool = QueryEngineTool(
|
||||
query_engine=summary_query_engine,
|
||||
metadata=ToolMetadata(
|
||||
name="summary_tool",
|
||||
description=(
|
||||
"Use this tool for any user questions that ask "
|
||||
"for a summarization of content"
|
||||
),
|
||||
),
|
||||
)
|
||||
all_tools.append(summary_tool)
|
||||
|
||||
# then we add tools
|
||||
all_tools.extend(additional_tools)
|
||||
|
||||
# build agent
|
||||
if system_prompt is None:
|
||||
return "System prompt not set yet. Please set system prompt first."
|
||||
|
||||
agent = load_agent(
|
||||
all_tools,
|
||||
llm=llm,
|
||||
system_prompt=system_prompt,
|
||||
verbose=True,
|
||||
extra_kwargs={"vector_index": vector_index, "rag_params": rag_params},
|
||||
)
|
||||
return agent, extra_info
|
||||
|
||||
|
||||
class ParamCache(BaseModel):
|
||||
@@ -135,20 +290,163 @@ class ParamCache(BaseModel):
|
||||
|
||||
Created a wrapper class around a dict in case we wanted to more explicitly
|
||||
type different items in the cache.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
# arbitrary types
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
system_prompt: Optional[str] = Field(default=None, description="System prompt for RAG agent.")
|
||||
file_paths: List[str] = Field(default_factory=list, description="File paths for RAG agent.")
|
||||
docs: List[Document] = Field(default_factory=list, description="Documents for RAG agent.")
|
||||
tools: List = Field(default_factory=list, description="Additional tools for RAG agent (e.g. web)")
|
||||
rag_params: RAGParams = Field(default_factory=RAGParams, description="RAG parameters for RAG agent.")
|
||||
agent: Optional[OpenAIAgent] = Field(default=None, description="RAG agent.")
|
||||
# system prompt
|
||||
system_prompt: Optional[str] = Field(
|
||||
default=None, description="System prompt for RAG agent."
|
||||
)
|
||||
# data
|
||||
file_names: List[str] = Field(
|
||||
default_factory=list, description="File names as data source (if specified)"
|
||||
)
|
||||
urls: List[str] = Field(
|
||||
default_factory=list, description="URLs as data source (if specified)"
|
||||
)
|
||||
docs: List = Field(default_factory=list, description="Documents for RAG agent.")
|
||||
# tools
|
||||
tools: List = Field(
|
||||
default_factory=list, description="Additional tools for RAG agent (e.g. web)"
|
||||
)
|
||||
# RAG params
|
||||
rag_params: RAGParams = Field(
|
||||
default_factory=RAGParams, description="RAG parameters for RAG agent."
|
||||
)
|
||||
|
||||
# agent params
|
||||
vector_index: Optional[VectorStoreIndex] = Field(
|
||||
default=None, description="Vector index for RAG agent."
|
||||
)
|
||||
agent_id: str = Field(
|
||||
default_factory=lambda: f"Agent_{str(uuid.uuid4())}",
|
||||
description="Agent ID for RAG agent.",
|
||||
)
|
||||
agent: Optional[BaseChatEngine] = Field(default=None, description="RAG agent.")
|
||||
|
||||
def save_to_disk(self, save_dir: str) -> None:
|
||||
"""Save cache to disk."""
|
||||
# NOTE: more complex than just calling dict() because we want to
|
||||
# only store serializable fields and be space-efficient
|
||||
|
||||
dict_to_serialize = {
|
||||
"system_prompt": self.system_prompt,
|
||||
"file_names": self.file_names,
|
||||
"urls": self.urls,
|
||||
# TODO: figure out tools
|
||||
# "tools": [],
|
||||
"rag_params": self.rag_params.dict(),
|
||||
"agent_id": self.agent_id,
|
||||
}
|
||||
# store the vector store within the agent
|
||||
if self.vector_index is None:
|
||||
raise ValueError("Must specify vector index in order to save.")
|
||||
self.vector_index.storage_context.persist(Path(save_dir) / "storage")
|
||||
|
||||
# if save_path directories don't exist, create it
|
||||
if not Path(save_dir).exists():
|
||||
Path(save_dir).mkdir(parents=True)
|
||||
with open(Path(save_dir) / "cache.json", "w") as f:
|
||||
json.dump(dict_to_serialize, f)
|
||||
|
||||
@classmethod
|
||||
def load_from_disk(
|
||||
cls,
|
||||
save_dir: str,
|
||||
) -> "ParamCache":
|
||||
"""Load cache from disk."""
|
||||
storage_context = StorageContext.from_defaults(
|
||||
persist_dir=str(Path(save_dir) / "storage")
|
||||
)
|
||||
vector_index = cast(VectorStoreIndex, load_index_from_storage(storage_context))
|
||||
|
||||
with open(Path(save_dir) / "cache.json", "r") as f:
|
||||
cache_dict = json.load(f)
|
||||
|
||||
# replace rag params with RAGParams object
|
||||
cache_dict["rag_params"] = RAGParams(**cache_dict["rag_params"])
|
||||
|
||||
# add in the missing fields
|
||||
# load docs
|
||||
cache_dict["docs"] = load_data(
|
||||
file_names=cache_dict["file_names"], urls=cache_dict["urls"]
|
||||
)
|
||||
# load agent from index
|
||||
agent, _ = construct_agent(
|
||||
cache_dict["system_prompt"],
|
||||
cache_dict["rag_params"],
|
||||
cache_dict["docs"],
|
||||
vector_index=vector_index,
|
||||
# TODO: figure out tools
|
||||
)
|
||||
cache_dict["vector_index"] = vector_index
|
||||
cache_dict["agent"] = agent
|
||||
|
||||
return cls(**cache_dict)
|
||||
|
||||
|
||||
def add_agent_id_to_directory(dir: str, agent_id: str) -> None:
|
||||
"""Save agent id to directory."""
|
||||
full_path = Path(dir) / "agent_ids.json"
|
||||
if not full_path.exists():
|
||||
with open(full_path, "w") as f:
|
||||
json.dump({"agent_ids": [agent_id]}, f)
|
||||
else:
|
||||
with open(full_path, "r") as f:
|
||||
agent_ids = json.load(f)["agent_ids"]
|
||||
if agent_id in agent_ids:
|
||||
raise ValueError(f"Agent id {agent_id} already exists.")
|
||||
agent_ids_set = set(agent_ids)
|
||||
agent_ids_set.add(agent_id)
|
||||
with open(full_path, "w") as f:
|
||||
json.dump({"agent_ids": list(agent_ids_set)}, f)
|
||||
|
||||
|
||||
def load_agent_ids_from_directory(dir: str) -> List[str]:
|
||||
"""Load agent ids file."""
|
||||
full_path = Path(dir) / "agent_ids.json"
|
||||
if not full_path.exists():
|
||||
return []
|
||||
with open(full_path, "r") as f:
|
||||
agent_ids = json.load(f)["agent_ids"]
|
||||
|
||||
return agent_ids
|
||||
|
||||
|
||||
def load_cache_from_directory(
|
||||
dir: str,
|
||||
agent_id: str,
|
||||
) -> ParamCache:
|
||||
"""Load cache from directory."""
|
||||
full_path = Path(dir) / f"{agent_id}"
|
||||
if not full_path.exists():
|
||||
raise ValueError(f"Cache for agent {agent_id} does not exist.")
|
||||
cache = ParamCache.load_from_disk(str(full_path))
|
||||
return cache
|
||||
|
||||
|
||||
def remove_agent_from_directory(
|
||||
dir: str,
|
||||
agent_id: str,
|
||||
) -> None:
|
||||
"""Remove agent from directory."""
|
||||
|
||||
# modify / resave agent_ids
|
||||
agent_ids = load_agent_ids_from_directory(dir)
|
||||
new_agent_ids = [id for id in agent_ids if id != agent_id]
|
||||
full_path = Path(dir) / "agent_ids.json"
|
||||
with open(full_path, "w") as f:
|
||||
json.dump({"agent_ids": new_agent_ids}, f)
|
||||
|
||||
# remove agent cache
|
||||
full_path = Path(dir) / f"{agent_id}"
|
||||
if full_path.exists():
|
||||
# recursive delete
|
||||
shutil.rmtree(full_path)
|
||||
|
||||
|
||||
class RAGAgentBuilder:
|
||||
@@ -161,11 +459,15 @@ class RAGAgentBuilder:
|
||||
- setting parameters (e.g. top-k)
|
||||
|
||||
Must pass in a cache. This cache will be modified as the agent is built.
|
||||
|
||||
|
||||
"""
|
||||
def __init__(self, cache: Optional[ParamCache] = None) -> None:
|
||||
|
||||
def __init__(
|
||||
self, cache: Optional[ParamCache] = None, cache_dir: Optional[str] = None
|
||||
) -> None:
|
||||
"""Init params."""
|
||||
self._cache = cache or ParamCache()
|
||||
self._cache_dir = cache_dir or AGENT_CACHE_DIR
|
||||
|
||||
@property
|
||||
def cache(self) -> ParamCache:
|
||||
@@ -181,11 +483,8 @@ class RAGAgentBuilder:
|
||||
|
||||
return f"System prompt created: {response.message.content}"
|
||||
|
||||
|
||||
def load_data(
|
||||
self,
|
||||
file_names: Optional[List[str]] = None,
|
||||
urls: Optional[List[str]] = None
|
||||
self, file_names: Optional[List[str]] = None, urls: Optional[List[str]] = None
|
||||
) -> str:
|
||||
"""Load data for a given task.
|
||||
|
||||
@@ -196,32 +495,18 @@ class RAGAgentBuilder:
|
||||
Defaults to None.
|
||||
urls (Optional[List[str]]): List of urls to load.
|
||||
Defaults to None.
|
||||
|
||||
|
||||
"""
|
||||
if file_names is None and urls is None:
|
||||
raise ValueError("Must specify either file_names or urls.")
|
||||
elif file_names is not None and urls is not None:
|
||||
raise ValueError("Must specify only one of file_names or urls.")
|
||||
elif file_names is not None:
|
||||
reader = SimpleDirectoryReader(input_files=file_names)
|
||||
docs = reader.load_data()
|
||||
file_paths = file_names
|
||||
elif urls is not None:
|
||||
from llama_hub.web.simple_web.base import SimpleWebPageReader
|
||||
# use simple web page reader from llamahub
|
||||
loader = SimpleWebPageReader()
|
||||
docs = loader.load_data(urls=urls)
|
||||
file_paths = urls
|
||||
else:
|
||||
raise ValueError("Must specify either file_names or urls.")
|
||||
|
||||
file_names = file_names or []
|
||||
urls = urls or []
|
||||
docs = load_data(file_names=file_names, urls=urls)
|
||||
self._cache.docs = docs
|
||||
self._cache.file_paths = file_paths
|
||||
self._cache.file_names = file_names
|
||||
self._cache.urls = urls
|
||||
return "Data loaded successfully."
|
||||
|
||||
|
||||
# NOTE: unused
|
||||
def add_web_tool(self) -> None:
|
||||
def add_web_tool(self) -> str:
|
||||
"""Add a web tool to enable agent to solve a task."""
|
||||
# TODO: make this not hardcoded to a web tool
|
||||
# Set up Metaphor tool
|
||||
@@ -241,21 +526,20 @@ class RAGAgentBuilder:
|
||||
|
||||
Should be called before `set_rag_params` so that the agent is aware of the
|
||||
schema.
|
||||
|
||||
|
||||
"""
|
||||
rag_params = self._cache.rag_params
|
||||
return rag_params.dict()
|
||||
|
||||
|
||||
def set_rag_params(self, **rag_params: Dict):
|
||||
def set_rag_params(self, **rag_params: Dict) -> str:
|
||||
"""Set RAG parameters.
|
||||
|
||||
These parameters will then be used to actually initialize the agent.
|
||||
Should call `get_rag_params` first to get the schema of the input dictionary.
|
||||
|
||||
Args:
|
||||
**rag_params (Dict): dictionary of RAG parameters.
|
||||
|
||||
**rag_params (Dict): dictionary of RAG parameters.
|
||||
|
||||
"""
|
||||
new_dict = self._cache.rag_params.dict()
|
||||
new_dict.update(rag_params)
|
||||
@@ -263,69 +547,85 @@ class RAGAgentBuilder:
|
||||
self._cache.rag_params = rag_params_obj
|
||||
return "RAG parameters set successfully."
|
||||
|
||||
|
||||
def create_agent(self) -> None:
|
||||
def create_agent(self, agent_id: Optional[str] = None) -> str:
|
||||
"""Create an agent.
|
||||
|
||||
There are no parameters for this function because all the
|
||||
functions should have already been called to set up the agent.
|
||||
|
||||
|
||||
"""
|
||||
rag_params = cast(RAGParams, self._cache.rag_params)
|
||||
docs = self._cache.docs
|
||||
|
||||
# first resolve llm and embedding model
|
||||
embed_model = resolve_embed_model(rag_params.embed_model)
|
||||
# llm = resolve_llm(rag_params.llm)
|
||||
# TODO: use OpenAI for now
|
||||
# llm = OpenAI(model=rag_params.llm)
|
||||
llm = _resolve_llm(rag_params.llm)
|
||||
|
||||
# first let's index the data with the right parameters
|
||||
service_context = ServiceContext.from_defaults(
|
||||
chunk_size=rag_params.chunk_size,
|
||||
llm=llm,
|
||||
embed_model=embed_model,
|
||||
)
|
||||
vector_index = VectorStoreIndex.from_documents(docs, service_context=service_context)
|
||||
vector_query_engine = vector_index.as_query_engine(similarity_top_k=rag_params.top_k)
|
||||
all_tools = []
|
||||
vector_tool = QueryEngineTool(
|
||||
query_engine=vector_query_engine,
|
||||
metadata=ToolMetadata(
|
||||
name="vector_tool",
|
||||
description=("Use this tool to answer any user question over any data."),
|
||||
),
|
||||
)
|
||||
all_tools.append(vector_tool)
|
||||
if rag_params.include_summarization:
|
||||
summary_index = SummaryIndex.from_documents(docs, service_context=service_context)
|
||||
summary_query_engine = summary_index.as_query_engine()
|
||||
summary_tool = QueryEngineTool(
|
||||
query_engine=summary_query_engine,
|
||||
metadata=ToolMetadata(
|
||||
name="summary_tool",
|
||||
description=("Use this tool for any user questions that ask for a summarization of content"),
|
||||
),
|
||||
)
|
||||
all_tools.append(summary_tool)
|
||||
|
||||
|
||||
# then we add tools
|
||||
all_tools.extend(self._cache.tools)
|
||||
|
||||
# build agent
|
||||
if self._cache.system_prompt is None:
|
||||
return "System prompt not set yet. Please set system prompt first."
|
||||
raise ValueError("Must set system prompt before creating agent.")
|
||||
|
||||
agent = load_agent(
|
||||
all_tools, llm=llm, system_prompt=self._cache.system_prompt, verbose=True,
|
||||
extra_kwargs={"vector_index": vector_index, "rag_params": rag_params}
|
||||
agent, extra_info = construct_agent(
|
||||
cast(str, self._cache.system_prompt),
|
||||
cast(RAGParams, self._cache.rag_params),
|
||||
self._cache.docs,
|
||||
additional_tools=self._cache.tools,
|
||||
)
|
||||
|
||||
# if agent_id not specified, randomly generate one
|
||||
agent_id = agent_id or self._cache.agent_id or f"Agent_{str(uuid.uuid4())}"
|
||||
self._cache.vector_index = extra_info["vector_index"]
|
||||
self._cache.agent_id = agent_id
|
||||
self._cache.agent = agent
|
||||
|
||||
# save the cache to disk
|
||||
agent_cache_path = f"{self._cache_dir}/{agent_id}"
|
||||
self._cache.save_to_disk(agent_cache_path)
|
||||
# save to agent ids
|
||||
add_agent_id_to_directory(str(self._cache_dir), agent_id)
|
||||
|
||||
return "Agent created successfully."
|
||||
|
||||
def update_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
include_summarization: Optional[bool] = None,
|
||||
top_k: Optional[int] = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
embed_model: Optional[str] = None,
|
||||
llm: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Update agent.
|
||||
|
||||
Delete old agent by ID and create a new one.
|
||||
Optionally update the system prompt and RAG parameters.
|
||||
|
||||
NOTE: Currently is manually called, not meant for agent use.
|
||||
|
||||
"""
|
||||
# remove saved agent from directory, since we'll be re-saving
|
||||
remove_agent_from_directory(str(AGENT_CACHE_DIR), self.cache.agent_id)
|
||||
|
||||
# set agent id
|
||||
self.cache.agent_id = agent_id
|
||||
|
||||
# set system prompt
|
||||
if system_prompt is not None:
|
||||
self.cache.system_prompt = system_prompt
|
||||
# get agent_builder
|
||||
# We call set_rag_params and create_agent, which will
|
||||
# update the cache
|
||||
# TODO: decouple functions from tool functions exposed to the agent
|
||||
|
||||
rag_params_dict: Dict[str, Any] = {}
|
||||
if include_summarization is not None:
|
||||
rag_params_dict["include_summarization"] = include_summarization
|
||||
if top_k is not None:
|
||||
rag_params_dict["top_k"] = top_k
|
||||
if chunk_size is not None:
|
||||
rag_params_dict["chunk_size"] = chunk_size
|
||||
if embed_model is not None:
|
||||
rag_params_dict["embed_model"] = embed_model
|
||||
if llm is not None:
|
||||
rag_params_dict["llm"] = llm
|
||||
|
||||
self.set_rag_params(**rag_params_dict)
|
||||
# this will update the agent in the cache
|
||||
self.create_agent()
|
||||
|
||||
|
||||
####################
|
||||
#### META Agent ####
|
||||
@@ -339,6 +639,7 @@ You should generally use the tools in this rough order to build the agent.
|
||||
2) Load in user-specified data (based on file paths they specify).
|
||||
3) Decide whether or not to add additional tools.
|
||||
4) Set parameters for the RAG pipeline.
|
||||
5) Build the agent
|
||||
|
||||
This will be a back and forth conversation with the user. You should
|
||||
continue asking users if there's anything else they want to do until
|
||||
@@ -355,25 +656,26 @@ have available (e.g. "Do you want to set the number of documents to retrieve?")
|
||||
|
||||
|
||||
# define agent
|
||||
@st.cache_resource
|
||||
def load_meta_agent_and_tools() -> Tuple[OpenAIAgent, RAGAgentBuilder]:
|
||||
# @st.cache_resource
|
||||
def load_meta_agent_and_tools(
|
||||
cache: Optional[ParamCache] = None,
|
||||
) -> Tuple[BaseAgent, RAGAgentBuilder]:
|
||||
|
||||
# think of this as tools for the agent to use
|
||||
agent_builder = RAGAgentBuilder()
|
||||
agent_builder = RAGAgentBuilder(cache)
|
||||
|
||||
fns = [
|
||||
agent_builder.create_system_prompt,
|
||||
agent_builder.load_data,
|
||||
fns: List[Callable] = [
|
||||
agent_builder.create_system_prompt,
|
||||
agent_builder.load_data,
|
||||
# add_web_tool,
|
||||
agent_builder.get_rag_params,
|
||||
agent_builder.set_rag_params,
|
||||
agent_builder.create_agent
|
||||
agent_builder.create_agent,
|
||||
]
|
||||
fn_tools = [FunctionTool.from_defaults(fn=fn) for fn in fns]
|
||||
|
||||
builder_agent = load_agent(
|
||||
builder_agent = load_meta_agent(
|
||||
fn_tools, llm=BUILDER_LLM, system_prompt=RAG_BUILDER_SYS_STR, verbose=True
|
||||
)
|
||||
|
||||
return builder_agent, agent_builder
|
||||
|
||||
+2
-1
@@ -7,6 +7,7 @@ import os
|
||||
|
||||
## OpenAI
|
||||
from llama_index.llms import OpenAI
|
||||
|
||||
# set OpenAI Key - use Streamlit secrets
|
||||
os.environ["OPENAI_API_KEY"] = st.secrets.openai_key
|
||||
# load LLM
|
||||
@@ -16,4 +17,4 @@ BUILDER_LLM = OpenAI(model="gpt-4-1106-preview")
|
||||
# from llama_index.llms import Anthropic
|
||||
# # set Anthropic key
|
||||
# os.environ["ANTHROPIC_API_KEY"] = st.secrets.anthropic_key
|
||||
# BUILDER_LLM = Anthropic()
|
||||
# BUILDER_LLM = Anthropic()
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
from pathlib import Path
|
||||
|
||||
AGENT_CACHE_DIR = Path(__file__).parent / "cache" / "agents"
|
||||
MESSAGES_CACHE_DIR = Path(__file__).parent / "cache" / "messages"
|
||||
+103
-37
@@ -1,13 +1,16 @@
|
||||
"""Streamlit page showing builder config."""
|
||||
import streamlit as st
|
||||
import openai
|
||||
from streamlit_pills import pills
|
||||
from typing import cast
|
||||
from typing import cast, Optional
|
||||
|
||||
from agent_utils import (
|
||||
RAGParams,
|
||||
RAGAgentBuilder,
|
||||
ParamCache,
|
||||
remove_agent_from_directory,
|
||||
)
|
||||
from st_utils import update_selected_agent_with_id
|
||||
from constants import AGENT_CACHE_DIR
|
||||
from st_utils import add_sidebar
|
||||
|
||||
|
||||
####################
|
||||
@@ -15,53 +18,116 @@ from agent_utils import (
|
||||
####################
|
||||
|
||||
|
||||
def update_agent() -> None:
|
||||
"""Update agent."""
|
||||
if (
|
||||
"config_agent_builder" in st.session_state.keys()
|
||||
and st.session_state.config_agent_builder is not None
|
||||
):
|
||||
agent_builder = cast(RAGAgentBuilder, st.session_state.config_agent_builder)
|
||||
### Update the agent
|
||||
agent_builder.update_agent(
|
||||
st.session_state.agent_id_st,
|
||||
system_prompt=st.session_state.sys_prompt_st,
|
||||
include_summarization=st.session_state.include_summarization_st,
|
||||
top_k=st.session_state.top_k_st,
|
||||
chunk_size=st.session_state.chunk_size_st,
|
||||
embed_model=st.session_state.embed_model_st,
|
||||
llm=st.session_state.llm_st,
|
||||
)
|
||||
|
||||
st.set_page_config(page_title="RAG Pipeline Config", page_icon="🦙", layout="centered", initial_sidebar_state="auto", menu_items=None)
|
||||
st.title("RAG Pipeline Config")
|
||||
st.info(
|
||||
"This is generated by the builder in the above section.", icon="ℹ️"
|
||||
# Update Radio Buttons: update selected agent to the new id
|
||||
update_selected_agent_with_id(agent_builder.cache.agent_id)
|
||||
else:
|
||||
raise ValueError("Agent builder is None. Cannot update agent.")
|
||||
|
||||
|
||||
def delete_agent() -> None:
|
||||
"""Delete agent."""
|
||||
if (
|
||||
"config_agent_builder" in st.session_state.keys()
|
||||
and st.session_state.config_agent_builder is not None
|
||||
):
|
||||
agent_builder = cast(RAGAgentBuilder, st.session_state.config_agent_builder)
|
||||
### Delete agent
|
||||
# remove saved agent from directory
|
||||
remove_agent_from_directory(str(AGENT_CACHE_DIR), agent_builder.cache.agent_id)
|
||||
# Update Radio Buttons: update selected agent to the new id
|
||||
update_selected_agent_with_id(None)
|
||||
else:
|
||||
raise ValueError("Agent builder is None. Cannot delete agent.")
|
||||
|
||||
|
||||
st.set_page_config(
|
||||
page_title="RAG Pipeline Config",
|
||||
page_icon="🦙",
|
||||
layout="centered",
|
||||
initial_sidebar_state="auto",
|
||||
menu_items=None,
|
||||
)
|
||||
st.title("RAG Pipeline Config")
|
||||
add_sidebar()
|
||||
|
||||
if "agent_builder" in st.session_state.keys():
|
||||
# first, pick the cache: this is preloaded from an existing agent,
|
||||
# or is part of the current one being created
|
||||
if (
|
||||
"selected_cache" in st.session_state.keys()
|
||||
and st.session_state.selected_cache is not None
|
||||
):
|
||||
cache = cast(ParamCache, st.session_state.selected_cache)
|
||||
agent_builder: Optional[RAGAgentBuilder] = RAGAgentBuilder(cache)
|
||||
elif "agent_builder" in st.session_state.keys():
|
||||
agent_builder = cast(RAGAgentBuilder, st.session_state.agent_builder)
|
||||
else:
|
||||
agent_builder = None
|
||||
|
||||
# set as session state
|
||||
st.session_state.config_agent_builder = agent_builder
|
||||
|
||||
if agent_builder is not None:
|
||||
|
||||
st.info(f"Viewing config for agent: {agent_builder.cache.agent_id}", icon="ℹ️")
|
||||
|
||||
agent_id_st = st.text_input(
|
||||
"Agent ID", value=agent_builder.cache.agent_id, key="agent_id_st"
|
||||
)
|
||||
|
||||
if agent_builder.cache.system_prompt is None:
|
||||
system_prompt = ""
|
||||
else:
|
||||
system_prompt = agent_builder.cache.system_prompt
|
||||
sys_prompt_st = st.text_area("System Prompt", value=system_prompt)
|
||||
sys_prompt_st = st.text_area(
|
||||
"System Prompt", value=system_prompt, key="sys_prompt_st"
|
||||
)
|
||||
|
||||
rag_params = cast(RAGParams, agent_builder.cache.rag_params)
|
||||
file_paths = st.text_input(
|
||||
"File/URL paths (not editable)",
|
||||
value=",".join(agent_builder.cache.file_paths),
|
||||
disabled=True
|
||||
file_names = st.text_input(
|
||||
"File names (not editable)",
|
||||
value=",".join(agent_builder.cache.file_names),
|
||||
disabled=True,
|
||||
)
|
||||
include_summarization_st = st.checkbox("Include Summarization (only works for GPT-4)", value=rag_params.include_summarization)
|
||||
top_k_st = st.number_input("Top K", value=rag_params.top_k)
|
||||
chunk_size_st = st.number_input("Chunk Size", value=rag_params.chunk_size)
|
||||
embed_model_st = st.text_input("Embed Model", value=rag_params.embed_model)
|
||||
llm_st = st.text_input("LLM", value=rag_params.llm)
|
||||
urls = st.text_input(
|
||||
"URLs (not editable)", value=",".join(agent_builder.cache.urls), disabled=True
|
||||
)
|
||||
include_summarization_st = st.checkbox(
|
||||
"Include Summarization (only works for GPT-4)",
|
||||
value=rag_params.include_summarization,
|
||||
key="include_summarization_st",
|
||||
)
|
||||
top_k_st = st.number_input("Top K", value=rag_params.top_k, key="top_k_st")
|
||||
chunk_size_st = st.number_input(
|
||||
"Chunk Size", value=rag_params.chunk_size, key="chunk_size_st"
|
||||
)
|
||||
embed_model_st = st.text_input(
|
||||
"Embed Model", value=rag_params.embed_model, key="embed_model_st"
|
||||
)
|
||||
llm_st = st.text_input("LLM", value=rag_params.llm, key="llm_st")
|
||||
if agent_builder.cache.agent is not None:
|
||||
if st.button("Update Agent"):
|
||||
# update the agent
|
||||
agent_builder.cache.system_prompt = sys_prompt_st
|
||||
# get agent_builder
|
||||
# We call set_rag_params and create_agent, which will
|
||||
# update the cache
|
||||
agent_builder = cast(RAGAgentBuilder, st.session_state.agent_builder)
|
||||
|
||||
# TODO: decouple functions from tool functions exposed to the agent
|
||||
agent_builder.set_rag_params(
|
||||
include_summarization=include_summarization_st,
|
||||
top_k=top_k_st,
|
||||
chunk_size=chunk_size_st,
|
||||
embed_model=embed_model_st,
|
||||
llm=llm_st,
|
||||
)
|
||||
# this will update the agent in the cache
|
||||
agent_builder.create_agent()
|
||||
st.button("Update Agent", on_click=update_agent)
|
||||
st.button(":red[Delete Agent]", on_click=delete_agent)
|
||||
else:
|
||||
# show text saying "agent not created"
|
||||
st.info("Agent not created. Please create an agent in the above section.")
|
||||
|
||||
else:
|
||||
st.info("agent builder not created yet. Please describe your task in the above section.")
|
||||
st.info("No agent builder found. Please create an agent in the above section.")
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
"""Streamlit page showing builder config."""
|
||||
import streamlit as st
|
||||
from typing import cast
|
||||
from agent_utils import (
|
||||
RAGAgentBuilder,
|
||||
)
|
||||
|
||||
from streamlit_pills import pills
|
||||
from typing import cast, Optional
|
||||
from agent_utils import RAGAgentBuilder, ParamCache
|
||||
from st_utils import add_sidebar
|
||||
|
||||
|
||||
####################
|
||||
@@ -13,46 +10,67 @@ from streamlit_pills import pills
|
||||
####################
|
||||
|
||||
|
||||
|
||||
st.set_page_config(page_title="Generated RAG Agent", page_icon="🦙", layout="centered", initial_sidebar_state="auto", menu_items=None)
|
||||
st.title("Generated RAG Agent")
|
||||
st.info(
|
||||
"This is generated by the builder in the above section.", icon="ℹ️"
|
||||
st.set_page_config(
|
||||
page_title="Generated RAG Agent",
|
||||
page_icon="🦙",
|
||||
layout="centered",
|
||||
initial_sidebar_state="auto",
|
||||
menu_items=None,
|
||||
)
|
||||
st.title("Generated RAG Agent")
|
||||
|
||||
if "agent_messages" not in st.session_state.keys(): # Initialize the chat messages history
|
||||
add_sidebar()
|
||||
|
||||
if (
|
||||
"agent_messages" not in st.session_state.keys()
|
||||
): # Initialize the chat messages history
|
||||
st.session_state.agent_messages = [
|
||||
{"role": "assistant", "content": "Ask me a question!"}
|
||||
]
|
||||
|
||||
def add_to_message_history(role, content):
|
||||
|
||||
def add_to_message_history(role: str, content: str) -> None:
|
||||
message = {"role": role, "content": str(content)}
|
||||
st.session_state.agent_messages.append(message) # Add response to message history
|
||||
st.session_state.agent_messages.append(message) # Add response to message history
|
||||
|
||||
|
||||
# first, pick the cache: this is preloaded from an existing agent,
|
||||
# or is part of the current one being created
|
||||
agent = None
|
||||
if "agent_builder" in st.session_state.keys():
|
||||
if (
|
||||
"selected_cache" in st.session_state.keys()
|
||||
and st.session_state.selected_cache is not None
|
||||
):
|
||||
cache: Optional[ParamCache] = cast(ParamCache, st.session_state.selected_cache)
|
||||
elif "agent_builder" in st.session_state.keys():
|
||||
agent_builder = cast(RAGAgentBuilder, st.session_state.agent_builder)
|
||||
if agent_builder.cache.agent is not None:
|
||||
agent = agent_builder.cache.agent
|
||||
for message in st.session_state.agent_messages: # Display the prior chat messages
|
||||
with st.chat_message(message["role"]):
|
||||
st.write(message["content"])
|
||||
cache = agent_builder.cache
|
||||
else:
|
||||
cache = None
|
||||
st.info("Agent not created. Please create an agent in the above section.")
|
||||
|
||||
# don't process selected for now
|
||||
if prompt := st.chat_input("Your question"): # Prompt for user input and save to chat history
|
||||
add_to_message_history("user", prompt)
|
||||
with st.chat_message("user"):
|
||||
st.write(prompt)
|
||||
# if agent is created, then we can chat with it
|
||||
if cache is not None and cache.agent is not None:
|
||||
st.info(f"Viewing config for agent: {cache.agent_id}", icon="ℹ️")
|
||||
agent = cache.agent
|
||||
for message in st.session_state.agent_messages: # Display the prior chat messages
|
||||
with st.chat_message(message["role"]):
|
||||
st.write(message["content"])
|
||||
|
||||
# If last message is not from assistant, generate a new response
|
||||
if st.session_state.agent_messages[-1]["role"] != "assistant":
|
||||
with st.chat_message("assistant"):
|
||||
with st.spinner("Thinking..."):
|
||||
response = agent.chat(prompt)
|
||||
st.write(str(response))
|
||||
add_to_message_history("assistant", response)
|
||||
else:
|
||||
st.info("Agent not created. Please create an agent in the above section.")
|
||||
# don't process selected for now
|
||||
if prompt := st.chat_input(
|
||||
"Your question"
|
||||
): # Prompt for user input and save to chat history
|
||||
add_to_message_history("user", prompt)
|
||||
with st.chat_message("user"):
|
||||
st.write(prompt)
|
||||
|
||||
# If last message is not from assistant, generate a new response
|
||||
if st.session_state.agent_messages[-1]["role"] != "assistant":
|
||||
with st.chat_message("assistant"):
|
||||
with st.spinner("Thinking..."):
|
||||
response = agent.chat(str(prompt))
|
||||
st.write(str(response))
|
||||
add_to_message_history("assistant", str(response))
|
||||
else:
|
||||
st.info("Agent not created. Please create an agent in the above section.")
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
[tool.poetry]
|
||||
name = "rags"
|
||||
version = "0.0.2"
|
||||
description = "Build RAG with natural language."
|
||||
authors = ["Jerry Liu"]
|
||||
# New attributes
|
||||
license = "MIT"
|
||||
readme = "README.md"
|
||||
homepage = "https://docs.llamaindex.ai/en/latest/"
|
||||
repository = "https://github.com/run-llama/rags"
|
||||
keywords = ["llama-index", "rags"]
|
||||
include = [
|
||||
"LICENSE",
|
||||
]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<3.12,!=3.9.7"
|
||||
streamlit = "1.28.0"
|
||||
streamlit-pills = "0.3.0"
|
||||
llama-index = "0.9.7"
|
||||
llama-hub = "0.0.44"
|
||||
# NOTE: this is due to a trivial dependency in the web tool, will refactor
|
||||
langchain = "0.0.305"
|
||||
pypdf = "3.17.1"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
# pytest = "7.2.1"
|
||||
# pytest-dotenv = "0.5.2"
|
||||
# pytest_httpserver = "1.0.8"
|
||||
# pytest-mock = "3.11.1"
|
||||
typing-inspect = "0.8.0"
|
||||
typing_extensions = "^4.5.0"
|
||||
types-requests = "2.28.11.8"
|
||||
black = "22.12.0"
|
||||
isort = "5.11.4"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
ruff = "0.0.285"
|
||||
mypy = "0.991"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry>=0.12", "poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.masonry.api"
|
||||
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = true
|
||||
ignore_missing_imports = true
|
||||
exclude = ["notebooks", "build", "examples"]
|
||||
|
||||
[tool.ruff]
|
||||
# Allow lines to be as long as 80 characters.
|
||||
# TODO: it should be removed, but we need to fix the entire code first.
|
||||
line-length = 88
|
||||
exclude = [
|
||||
".venv",
|
||||
"__pycache__",
|
||||
".ipynb_checkpoints",
|
||||
".mypy_cache",
|
||||
".ruff_cache",
|
||||
"examples",
|
||||
"notebooks",
|
||||
".git"
|
||||
]
|
||||
|
||||
[tool.ruff.per-file-ignores]
|
||||
"base.py" = ["E402", "F811", "E501"]
|
||||
+58
@@ -0,0 +1,58 @@
|
||||
"""Streamlit utils."""
|
||||
from agent_utils import (
|
||||
load_agent_ids_from_directory,
|
||||
load_cache_from_directory,
|
||||
)
|
||||
from constants import (
|
||||
AGENT_CACHE_DIR,
|
||||
)
|
||||
from typing import Optional
|
||||
|
||||
import streamlit as st
|
||||
|
||||
|
||||
def update_selected_agent_with_id(selected_id: Optional[str] = None) -> None:
|
||||
"""Update selected agent with id."""
|
||||
# set session state
|
||||
st.session_state.selected_id = (
|
||||
selected_id if selected_id != "Create a new agent" else None
|
||||
)
|
||||
if st.session_state.selected_id is None:
|
||||
st.session_state.selected_cache = None
|
||||
else:
|
||||
# load agent from directory
|
||||
agent_cache = load_cache_from_directory(
|
||||
str(AGENT_CACHE_DIR), st.session_state.selected_id
|
||||
)
|
||||
st.session_state.selected_cache = agent_cache
|
||||
|
||||
|
||||
## handler for sidebar specifically
|
||||
def update_selected_agent() -> None:
|
||||
"""Update selected agent."""
|
||||
selected_id = st.session_state.agent_selector
|
||||
|
||||
update_selected_agent_with_id(selected_id)
|
||||
|
||||
|
||||
def add_sidebar() -> None:
|
||||
"""Add sidebar."""
|
||||
with st.sidebar:
|
||||
st.session_state.cur_agent_ids = load_agent_ids_from_directory(
|
||||
str(AGENT_CACHE_DIR)
|
||||
)
|
||||
choices = ["Create a new agent"] + st.session_state.cur_agent_ids
|
||||
|
||||
# by default, set index to 0. if value is in selected_id, set index to that
|
||||
index = 0
|
||||
if "selected_id" in st.session_state.keys():
|
||||
if st.session_state.selected_id is not None:
|
||||
index = choices.index(st.session_state.selected_id)
|
||||
# display buttons
|
||||
st.radio(
|
||||
"Agents",
|
||||
choices,
|
||||
index=index,
|
||||
on_change=update_selected_agent,
|
||||
key="agent_selector",
|
||||
)
|
||||
Reference in New Issue
Block a user