upgrade v3 (#49)

support multi-modality
allow loading an entire directory (and all the files in the directory)
add sources in the response
This commit is contained in:
Jerry Liu
2023-12-04 17:01:24 -08:00
committed by GitHub
parent 76bf409458
commit 4bec270239
13 changed files with 819 additions and 228 deletions
+7 -36
View File
@@ -1,8 +1,13 @@
import streamlit as st
from streamlit_pills import pills
from st_utils import add_sidebar, get_current_state
from st_utils import (
add_builder_config,
add_sidebar,
get_current_state,
)
current_state = get_current_state()
####################
#### STREAMLIT #####
@@ -28,7 +33,7 @@ if "metaphor_key" in st.secrets:
st.info("**NOTE**: The ability to add web search is enabled.")
current_state = get_current_state()
add_builder_config()
add_sidebar()
@@ -60,40 +65,6 @@ for message in st.session_state.messages: # Display the prior chat messages
with st.chat_message(message["role"]):
st.write(message["content"])
# def handle_user_input() -> None:
# """Handle user input."""
# prompt = st.session_state.user_question_st
# print(f"USER PROMPT: {prompt}")
# add_to_message_history("user", prompt)
# with st.chat_message("user"):
# st.write(prompt)
# # If last message is not from assistant, generate a new response
# if st.session_state.messages[-1]["role"] != "assistant":
# with st.chat_message("assistant"):
# with st.spinner("Thinking..."):
# response = current_state.builder_agent.chat(prompt)
# st.write(str(response))
# add_to_message_history("assistant", str(response))
# else:
# pass
# # check agent_ids again
# # if it doesn't match, add to directory and refresh
# agent_ids = current_state.agent_registry.get_agent_ids()
# # check diff between agent_ids and cur agent ids
# diff_ids = list(set(agent_ids) - set(st.session_state.cur_agent_ids))
# if len(diff_ids) > 0:
# # # clear streamlit cache, to allow you to generate a new agent
# # st.cache_resource.clear()
# st.rerun()
# handle user input
# st.chat_input(
# "Your question", key="user_question_st", on_submit=handle_user_input
# ) # Prompt for user input and save to chat history
# TODO: this is really hacky, only because st.rerun is jank
if prompt := st.chat_input(
"Your question",
+2
View File
@@ -40,6 +40,8 @@ streamlit run 1_🏠_Home.py
```
**NOTE**: If you've upgraded the version of RAGs, and you're running into issues on launch, you may need to delete the `cache` folder in your home directory (we may have introduced breaking changes in the stored data structure between versions).
## Detailed Overview
The app contains the following sections, corresponding to the steps listed above.
View File
@@ -3,93 +3,19 @@
from llama_index.llms import ChatMessage
from llama_index.prompts import ChatPromptTemplate
from typing import List, cast, Optional
from llama_index.tools import FunctionTool
from llama_index.agent.types import BaseAgent
from core.builder_config import BUILDER_LLM
from typing import Dict, Tuple, Any, Callable, Union
import streamlit as st
from pathlib import Path
import json
from typing import Dict, Any
import uuid
from core.constants import AGENT_CACHE_DIR
import shutil
from abc import ABC, abstractmethod
from core.param_cache import ParamCache, RAGParams
from core.utils import (
load_data,
get_tool_objects,
construct_agent,
load_meta_agent,
)
class AgentCacheRegistry:
"""Registry for agent caches, in disk.
Can register new agent caches, load agent caches, delete agent caches, etc.
"""
def __init__(self, dir: Union[str, Path]) -> None:
"""Init params."""
self._dir = dir
def _add_agent_id_to_directory(self, agent_id: str) -> None:
"""Save agent id to directory."""
full_path = Path(self._dir) / "agent_ids.json"
if not full_path.exists():
with open(full_path, "w") as f:
json.dump({"agent_ids": [agent_id]}, f)
else:
with open(full_path, "r") as f:
agent_ids = json.load(f)["agent_ids"]
if agent_id in agent_ids:
raise ValueError(f"Agent id {agent_id} already exists.")
agent_ids_set = set(agent_ids)
agent_ids_set.add(agent_id)
with open(full_path, "w") as f:
json.dump({"agent_ids": list(agent_ids_set)}, f)
def add_new_agent_cache(self, agent_id: str, cache: ParamCache) -> None:
"""Register agent."""
# save the cache to disk
agent_cache_path = f"{self._dir}/{agent_id}"
cache.save_to_disk(agent_cache_path)
# save to agent ids
self._add_agent_id_to_directory(agent_id)
def get_agent_ids(self) -> List[str]:
"""Get agent ids."""
full_path = Path(self._dir) / "agent_ids.json"
if not full_path.exists():
return []
with open(full_path, "r") as f:
agent_ids = json.load(f)["agent_ids"]
return agent_ids
def get_agent_cache(self, agent_id: str) -> ParamCache:
"""Get agent cache."""
full_path = Path(self._dir) / f"{agent_id}"
if not full_path.exists():
raise ValueError(f"Cache for agent {agent_id} does not exist.")
cache = ParamCache.load_from_disk(str(full_path))
return cache
def delete_agent_cache(self, agent_id: str) -> None:
"""Delete agent cache."""
# modify / resave agent_ids
agent_ids = self.get_agent_ids()
new_agent_ids = [id for id in agent_ids if id != agent_id]
full_path = Path(self._dir) / "agent_ids.json"
with open(full_path, "w") as f:
json.dump({"agent_ids": new_agent_ids}, f)
# remove agent cache
full_path = Path(self._dir) / f"{agent_id}"
if full_path.exists():
# recursive delete
shutil.rmtree(full_path)
from core.agent_builder.registry import AgentCacheRegistry
# System prompt tool
@@ -121,7 +47,21 @@ gen_sys_prompt_messages = [
GEN_SYS_PROMPT_TMPL = ChatPromptTemplate(gen_sys_prompt_messages)
class RAGAgentBuilder:
class BaseRAGAgentBuilder(ABC):
"""Base RAG Agent builder class."""
@property
@abstractmethod
def cache(self) -> ParamCache:
"""Cache."""
@property
@abstractmethod
def agent_registry(self) -> AgentCacheRegistry:
"""Agent registry."""
class RAGAgentBuilder(BaseRAGAgentBuilder):
"""RAG Agent builder.
Contains a set of functions to construct a RAG agent, including:
@@ -165,25 +105,31 @@ class RAGAgentBuilder:
return f"System prompt created: {response.message.content}"
def load_data(
self, file_names: Optional[List[str]] = None, urls: Optional[List[str]] = None
self,
file_names: Optional[List[str]] = None,
directory: Optional[str] = None,
urls: Optional[List[str]] = None,
) -> str:
"""Load data for a given task.
Only ONE of file_names or urls should be specified.
Only ONE of file_names or directory or urls should be specified.
Args:
file_names (Optional[List[str]]): List of file names to load.
Defaults to None.
directory (Optional[str]): Directory to load files from.
urls (Optional[List[str]]): List of urls to load.
Defaults to None.
"""
file_names = file_names or []
urls = urls or []
docs = load_data(file_names=file_names, urls=urls)
directory = directory or ""
docs = load_data(file_names=file_names, directory=directory, urls=urls)
self._cache.docs = docs
self._cache.file_names = file_names
self._cache.urls = urls
self._cache.directory = directory
return "Data loaded successfully."
def add_web_tool(self) -> str:
@@ -302,77 +248,3 @@ class RAGAgentBuilder:
# this will update the agent in the cache
self.create_agent()
####################
#### META Agent ####
####################
RAG_BUILDER_SYS_STR = """\
You are helping to construct an agent given a user-specified task.
You should generally use the tools in this rough order to build the agent.
1) Create system prompt tool: to create the system prompt for the agent.
2) Load in user-specified data (based on file paths they specify).
3) Decide whether or not to add additional tools.
4) Set parameters for the RAG pipeline.
5) Build the agent
This will be a back and forth conversation with the user. You should
continue asking users if there's anything else they want to do until
they say they're done. To help guide them on the process,
you can give suggestions on parameters they can set based on the tools they
have available (e.g. "Do you want to set the number of documents to retrieve?")
"""
### DEFINE Agent ####
# NOTE: here we define a function that is dependent on the LLM,
# please make sure to update the LLM above if you change the function below
def _get_builder_agent_tools(agent_builder: RAGAgentBuilder) -> List[FunctionTool]:
"""Get list of builder agent tools to pass to the builder agent."""
# see if metaphor api key is set, otherwise don't add web tool
# TODO: refactor this later
if "metaphor_key" in st.secrets:
fns: List[Callable] = [
agent_builder.create_system_prompt,
agent_builder.load_data,
agent_builder.add_web_tool,
agent_builder.get_rag_params,
agent_builder.set_rag_params,
agent_builder.create_agent,
]
else:
fns = [
agent_builder.create_system_prompt,
agent_builder.load_data,
agent_builder.get_rag_params,
agent_builder.set_rag_params,
agent_builder.create_agent,
]
fn_tools: List[FunctionTool] = [FunctionTool.from_defaults(fn=fn) for fn in fns]
return fn_tools
# define agent
# @st.cache_resource
def load_meta_agent_and_tools(
cache: Optional[ParamCache] = None,
agent_registry: Optional[AgentCacheRegistry] = None,
) -> Tuple[BaseAgent, RAGAgentBuilder]:
# think of this as tools for the agent to use
agent_builder = RAGAgentBuilder(cache, agent_registry=agent_registry)
fn_tools = _get_builder_agent_tools(agent_builder)
builder_agent = load_meta_agent(
fn_tools, llm=BUILDER_LLM, system_prompt=RAG_BUILDER_SYS_STR, verbose=True
)
return builder_agent, agent_builder
+115
View File
@@ -0,0 +1,115 @@
"""Loader agent."""
from typing import List, cast, Optional
from llama_index.tools import FunctionTool
from llama_index.agent.types import BaseAgent
from core.builder_config import BUILDER_LLM
from typing import Tuple, Callable
import streamlit as st
from core.param_cache import ParamCache
from core.utils import (
load_meta_agent,
)
from core.agent_builder.registry import AgentCacheRegistry
from core.agent_builder.base import RAGAgentBuilder, BaseRAGAgentBuilder
from core.agent_builder.multimodal import MultimodalRAGAgentBuilder
####################
#### META Agent ####
####################
RAG_BUILDER_SYS_STR = """\
You are helping to construct an agent given a user-specified task.
You should generally use the tools in this rough order to build the agent.
1) Create system prompt tool: to create the system prompt for the agent.
2) Load in user-specified data (based on file paths they specify).
3) Decide whether or not to add additional tools.
4) Set parameters for the RAG pipeline.
5) Build the agent
This will be a back and forth conversation with the user. You should
continue asking users if there's anything else they want to do until
they say they're done. To help guide them on the process,
you can give suggestions on parameters they can set based on the tools they
have available (e.g. "Do you want to set the number of documents to retrieve?")
"""
### DEFINE Agent ####
# NOTE: here we define a function that is dependent on the LLM,
# please make sure to update the LLM above if you change the function below
def _get_builder_agent_tools(agent_builder: RAGAgentBuilder) -> List[FunctionTool]:
"""Get list of builder agent tools to pass to the builder agent."""
# see if metaphor api key is set, otherwise don't add web tool
# TODO: refactor this later
if "metaphor_key" in st.secrets:
fns: List[Callable] = [
agent_builder.create_system_prompt,
agent_builder.load_data,
agent_builder.add_web_tool,
agent_builder.get_rag_params,
agent_builder.set_rag_params,
agent_builder.create_agent,
]
else:
fns = [
agent_builder.create_system_prompt,
agent_builder.load_data,
agent_builder.get_rag_params,
agent_builder.set_rag_params,
agent_builder.create_agent,
]
fn_tools: List[FunctionTool] = [FunctionTool.from_defaults(fn=fn) for fn in fns]
return fn_tools
def _get_mm_builder_agent_tools(
agent_builder: MultimodalRAGAgentBuilder,
) -> List[FunctionTool]:
"""Get list of builder agent tools to pass to the builder agent."""
fns: List[Callable] = [
agent_builder.create_system_prompt,
agent_builder.load_data,
agent_builder.get_rag_params,
agent_builder.set_rag_params,
agent_builder.create_agent,
]
fn_tools: List[FunctionTool] = [FunctionTool.from_defaults(fn=fn) for fn in fns]
return fn_tools
# define agent
def load_meta_agent_and_tools(
cache: Optional[ParamCache] = None,
agent_registry: Optional[AgentCacheRegistry] = None,
is_multimodal: bool = False,
) -> Tuple[BaseAgent, BaseRAGAgentBuilder]:
"""Load meta agent and tools."""
if is_multimodal:
agent_builder: BaseRAGAgentBuilder = MultimodalRAGAgentBuilder(
cache, agent_registry=agent_registry
)
fn_tools = _get_mm_builder_agent_tools(
cast(MultimodalRAGAgentBuilder, agent_builder)
)
builder_agent = load_meta_agent(
fn_tools, llm=BUILDER_LLM, system_prompt=RAG_BUILDER_SYS_STR, verbose=True
)
else:
# think of this as tools for the agent to use
agent_builder = RAGAgentBuilder(cache, agent_registry=agent_registry)
fn_tools = _get_builder_agent_tools(agent_builder)
builder_agent = load_meta_agent(
fn_tools, llm=BUILDER_LLM, system_prompt=RAG_BUILDER_SYS_STR, verbose=True
)
return builder_agent, agent_builder
+256
View File
@@ -0,0 +1,256 @@
"""Multimodal agent builder."""
from llama_index.llms import ChatMessage
from typing import List, cast, Optional
from core.builder_config import BUILDER_LLM
from typing import Dict, Any
import uuid
from core.constants import AGENT_CACHE_DIR
from core.param_cache import ParamCache, RAGParams
from core.utils import (
load_data,
construct_mm_agent,
)
from core.agent_builder.registry import AgentCacheRegistry
from core.agent_builder.base import GEN_SYS_PROMPT_TMPL, BaseRAGAgentBuilder
from llama_index.chat_engine.types import BaseChatEngine
from llama_index.callbacks import trace_method
from llama_index.query_engine.multi_modal import SimpleMultiModalQueryEngine
from llama_index.chat_engine.types import (
AGENT_CHAT_RESPONSE_TYPE,
StreamingAgentChatResponse,
AgentChatResponse,
)
from llama_index.llms.base import ChatResponse
from typing import Generator
class MultimodalChatEngine(BaseChatEngine):
"""Multimodal chat engine.
This chat engine is a light wrapper around a query engine.
Offers no real 'chat' functionality, is a beta feature.
"""
def __init__(self, mm_query_engine: SimpleMultiModalQueryEngine) -> None:
"""Init params."""
self._mm_query_engine = mm_query_engine
def reset(self) -> None:
"""Reset conversation state."""
pass
@trace_method("chat")
def chat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> AGENT_CHAT_RESPONSE_TYPE:
"""Main chat interface."""
# just return the top-k results
response = self._mm_query_engine.query(message)
return AgentChatResponse(response=str(response))
@trace_method("chat")
def stream_chat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> StreamingAgentChatResponse:
"""Stream chat interface."""
response = self._mm_query_engine.query(message)
def _chat_stream(response: str) -> Generator[ChatResponse, None, None]:
yield ChatResponse(message=ChatMessage(role="assistant", content=response))
chat_stream = _chat_stream(str(response))
return StreamingAgentChatResponse(chat_stream=chat_stream)
@trace_method("chat")
async def achat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> AGENT_CHAT_RESPONSE_TYPE:
"""Async version of main chat interface."""
response = await self._mm_query_engine.aquery(message)
return AgentChatResponse(response=str(response))
@trace_method("chat")
async def astream_chat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> StreamingAgentChatResponse:
"""Async version of main chat interface."""
return self.stream_chat(message, chat_history)
class MultimodalRAGAgentBuilder(BaseRAGAgentBuilder):
"""Multimodal RAG Agent builder.
Contains a set of functions to construct a RAG agent, including:
- setting system prompts
- loading data
- adding web search
- setting parameters (e.g. top-k)
Must pass in a cache. This cache will be modified as the agent is built.
"""
def __init__(
self,
cache: Optional[ParamCache] = None,
agent_registry: Optional[AgentCacheRegistry] = None,
) -> None:
"""Init params."""
self._cache = cache or ParamCache()
self._agent_registry = agent_registry or AgentCacheRegistry(
str(AGENT_CACHE_DIR)
)
@property
def cache(self) -> ParamCache:
"""Cache."""
return self._cache
@property
def agent_registry(self) -> AgentCacheRegistry:
"""Agent registry."""
return self._agent_registry
def create_system_prompt(self, task: str) -> str:
"""Create system prompt for another agent given an input task."""
llm = BUILDER_LLM
fmt_messages = GEN_SYS_PROMPT_TMPL.format_messages(task=task)
response = llm.chat(fmt_messages)
self._cache.system_prompt = response.message.content
return f"System prompt created: {response.message.content}"
def load_data(
self,
file_names: Optional[List[str]] = None,
directory: Optional[str] = None,
) -> str:
"""Load data for a given task.
Only ONE of file_names or directory should be specified.
**NOTE**: urls not supported in multi-modal setting.
Args:
file_names (Optional[List[str]]): List of file names to load.
Defaults to None.
directory (Optional[str]): Directory to load files from.
"""
file_names = file_names or []
directory = directory or ""
docs = load_data(file_names=file_names, directory=directory)
self._cache.docs = docs
self._cache.file_names = file_names
self._cache.directory = directory
return "Data loaded successfully."
def get_rag_params(self) -> Dict:
"""Get parameters used to configure the RAG pipeline.
Should be called before `set_rag_params` so that the agent is aware of the
schema.
"""
rag_params = self._cache.rag_params
return rag_params.dict()
def set_rag_params(self, **rag_params: Dict) -> str:
"""Set RAG parameters.
These parameters will then be used to actually initialize the agent.
Should call `get_rag_params` first to get the schema of the input dictionary.
Args:
**rag_params (Dict): dictionary of RAG parameters.
"""
new_dict = self._cache.rag_params.dict()
new_dict.update(rag_params)
rag_params_obj = RAGParams(**new_dict)
self._cache.rag_params = rag_params_obj
return "RAG parameters set successfully."
def create_agent(self, agent_id: Optional[str] = None) -> str:
"""Create an agent.
There are no parameters for this function because all the
functions should have already been called to set up the agent.
"""
if self._cache.system_prompt is None:
raise ValueError("Must set system prompt before creating agent.")
# construct additional tools
agent, extra_info = construct_mm_agent(
cast(str, self._cache.system_prompt),
cast(RAGParams, self._cache.rag_params),
self._cache.docs,
)
# if agent_id not specified, randomly generate one
agent_id = agent_id or self._cache.agent_id or f"Agent_{str(uuid.uuid4())}"
self._cache.builder_type = "multimodal"
self._cache.vector_index = extra_info["vector_index"]
self._cache.agent_id = agent_id
self._cache.agent = agent
# save the cache to disk
self._agent_registry.add_new_agent_cache(agent_id, self._cache)
return "Agent created successfully."
def update_agent(
self,
agent_id: str,
system_prompt: Optional[str] = None,
include_summarization: Optional[bool] = None,
top_k: Optional[int] = None,
chunk_size: Optional[int] = None,
embed_model: Optional[str] = None,
llm: Optional[str] = None,
additional_tools: Optional[List] = None,
) -> None:
"""Update agent.
Delete old agent by ID and create a new one.
Optionally update the system prompt and RAG parameters.
NOTE: Currently is manually called, not meant for agent use.
"""
self._agent_registry.delete_agent_cache(self.cache.agent_id)
# set agent id
self.cache.agent_id = agent_id
# set system prompt
if system_prompt is not None:
self.cache.system_prompt = system_prompt
# get agent_builder
# We call set_rag_params and create_agent, which will
# update the cache
# TODO: decouple functions from tool functions exposed to the agent
rag_params_dict: Dict[str, Any] = {}
if include_summarization is not None:
rag_params_dict["include_summarization"] = include_summarization
if top_k is not None:
rag_params_dict["top_k"] = top_k
if chunk_size is not None:
rag_params_dict["chunk_size"] = chunk_size
if embed_model is not None:
rag_params_dict["embed_model"] = embed_model
if llm is not None:
rag_params_dict["llm"] = llm
self.set_rag_params(**rag_params_dict)
# update tools
if additional_tools is not None:
self.cache.tools = additional_tools
# this will update the agent in the cache
self.create_agent()
+78
View File
@@ -0,0 +1,78 @@
"""Agent builder registry."""
from typing import List
from typing import Union
from pathlib import Path
import json
import shutil
from core.param_cache import ParamCache
class AgentCacheRegistry:
"""Registry for agent caches, in disk.
Can register new agent caches, load agent caches, delete agent caches, etc.
"""
def __init__(self, dir: Union[str, Path]) -> None:
"""Init params."""
self._dir = dir
def _add_agent_id_to_directory(self, agent_id: str) -> None:
"""Save agent id to directory."""
full_path = Path(self._dir) / "agent_ids.json"
if not full_path.exists():
with open(full_path, "w") as f:
json.dump({"agent_ids": [agent_id]}, f)
else:
with open(full_path, "r") as f:
agent_ids = json.load(f)["agent_ids"]
if agent_id in agent_ids:
raise ValueError(f"Agent id {agent_id} already exists.")
agent_ids_set = set(agent_ids)
agent_ids_set.add(agent_id)
with open(full_path, "w") as f:
json.dump({"agent_ids": list(agent_ids_set)}, f)
def add_new_agent_cache(self, agent_id: str, cache: ParamCache) -> None:
"""Register agent."""
# save the cache to disk
agent_cache_path = f"{self._dir}/{agent_id}"
cache.save_to_disk(agent_cache_path)
# save to agent ids
self._add_agent_id_to_directory(agent_id)
def get_agent_ids(self) -> List[str]:
"""Get agent ids."""
full_path = Path(self._dir) / "agent_ids.json"
if not full_path.exists():
return []
with open(full_path, "r") as f:
agent_ids = json.load(f)["agent_ids"]
return agent_ids
def get_agent_cache(self, agent_id: str) -> ParamCache:
"""Get agent cache."""
full_path = Path(self._dir) / f"{agent_id}"
if not full_path.exists():
raise ValueError(f"Cache for agent {agent_id} does not exist.")
cache = ParamCache.load_from_disk(str(full_path))
return cache
def delete_agent_cache(self, agent_id: str) -> None:
"""Delete agent cache."""
# modify / resave agent_ids
agent_ids = self.get_agent_ids()
new_agent_ids = [id for id in agent_ids if id != agent_id]
full_path = Path(self._dir) / "agent_ids.json"
with open(full_path, "w") as f:
json.dump({"agent_ids": new_agent_ids}, f)
# remove agent cache
full_path = Path(self._dir) / f"{agent_id}"
if full_path.exists():
# recursive delete
shutil.rmtree(full_path)
+49 -13
View File
@@ -11,7 +11,13 @@ from llama_index.chat_engine.types import BaseChatEngine
from pathlib import Path
import json
import uuid
from core.utils import load_data, get_tool_objects, construct_agent, RAGParams
from core.utils import (
load_data,
get_tool_objects,
construct_agent,
RAGParams,
construct_mm_agent,
)
class ParamCache(BaseModel):
@@ -37,6 +43,10 @@ class ParamCache(BaseModel):
urls: List[str] = Field(
default_factory=list, description="URLs as data source (if specified)"
)
directory: Optional[str] = Field(
default=None, description="Directory as data source (if specified)"
)
docs: List = Field(default_factory=list, description="Documents for RAG agent.")
# tools
tools: List = Field(
@@ -48,6 +58,9 @@ class ParamCache(BaseModel):
)
# agent params
builder_type: str = Field(
default="default", description="Builder type (default, multimodal)."
)
vector_index: Optional[VectorStoreIndex] = Field(
default=None, description="Vector index for RAG agent."
)
@@ -66,9 +79,11 @@ class ParamCache(BaseModel):
"system_prompt": self.system_prompt,
"file_names": self.file_names,
"urls": self.urls,
"directory": self.directory,
# TODO: figure out tools
"tools": self.tools,
"rag_params": self.rag_params.dict(),
"builder_type": self.builder_type,
"agent_id": self.agent_id,
}
# store the vector store within the agent
@@ -88,13 +103,22 @@ class ParamCache(BaseModel):
save_dir: str,
) -> "ParamCache":
"""Load cache from disk."""
with open(Path(save_dir) / "cache.json", "r") as f:
cache_dict = json.load(f)
storage_context = StorageContext.from_defaults(
persist_dir=str(Path(save_dir) / "storage")
)
vector_index = cast(VectorStoreIndex, load_index_from_storage(storage_context))
if cache_dict["builder_type"] == "multimodal":
from llama_index.indices.multi_modal.base import MultiModalVectorStoreIndex
with open(Path(save_dir) / "cache.json", "r") as f:
cache_dict = json.load(f)
vector_index: VectorStoreIndex = cast(
MultiModalVectorStoreIndex, load_index_from_storage(storage_context)
)
else:
vector_index = cast(
VectorStoreIndex, load_index_from_storage(storage_context)
)
# replace rag params with RAGParams object
cache_dict["rag_params"] = RAGParams(**cache_dict["rag_params"])
@@ -102,18 +126,30 @@ class ParamCache(BaseModel):
# add in the missing fields
# load docs
cache_dict["docs"] = load_data(
file_names=cache_dict["file_names"], urls=cache_dict["urls"]
file_names=cache_dict["file_names"],
urls=cache_dict["urls"],
directory=cache_dict["directory"],
)
# load agent from index
additional_tools = get_tool_objects(cache_dict["tools"])
agent, _ = construct_agent(
cache_dict["system_prompt"],
cache_dict["rag_params"],
cache_dict["docs"],
vector_index=vector_index,
additional_tools=additional_tools,
# TODO: figure out tools
)
if cache_dict["builder_type"] == "multimodal":
vector_index = cast(MultiModalVectorStoreIndex, vector_index)
agent, _ = construct_mm_agent(
cache_dict["system_prompt"],
cache_dict["rag_params"],
cache_dict["docs"],
mm_vector_index=vector_index,
)
else:
agent, _ = construct_agent(
cache_dict["system_prompt"],
cache_dict["rag_params"],
cache_dict["docs"],
vector_index=vector_index,
additional_tools=additional_tools,
# TODO: figure out tools
)
cache_dict["vector_index"] = vector_index
cache_dict["agent"] = agent
+159 -7
View File
@@ -26,8 +26,25 @@ from core.builder_config import BUILDER_LLM
from typing import Dict, Tuple, Any
import streamlit as st
from llama_index.callbacks import CallbackManager
from llama_index.callbacks import CallbackManager, trace_method
from core.callback_manager import StreamlitFunctionsCallbackHandler
from llama_index.schema import ImageNode, NodeWithScore
### BETA: Multi-modal
from llama_index.indices.multi_modal.base import MultiModalVectorStoreIndex
from llama_index.multi_modal_llms.openai import OpenAIMultiModal
from llama_index.indices.multi_modal.retriever import (
MultiModalVectorIndexRetriever,
)
from llama_index.llms import ChatMessage
from llama_index.query_engine.multi_modal import SimpleMultiModalQueryEngine
from llama_index.chat_engine.types import (
AGENT_CHAT_RESPONSE_TYPE,
StreamingAgentChatResponse,
AgentChatResponse,
)
from llama_index.llms.base import ChatResponse
from typing import Generator
class RAGParams(BaseModel):
@@ -82,18 +99,28 @@ def _resolve_llm(llm_str: str) -> LLM:
def load_data(
file_names: Optional[List[str]] = None, urls: Optional[List[str]] = None
file_names: Optional[List[str]] = None,
directory: Optional[str] = None,
urls: Optional[List[str]] = None,
) -> List[Document]:
"""Load data."""
file_names = file_names or []
directory = directory or ""
urls = urls or []
if not file_names and not urls:
raise ValueError("Must specify either file_names or urls.")
elif file_names and urls:
raise ValueError("Must specify only one of file_names or urls.")
# get number depending on whether specified
num_specified = sum(1 for v in [file_names, urls, directory] if v)
if num_specified == 0:
raise ValueError("Must specify either file_names or urls or directory.")
elif num_specified > 1:
raise ValueError("Must specify only one of file_names or urls or directory.")
elif file_names:
reader = SimpleDirectoryReader(input_files=file_names)
docs = reader.load_data()
elif directory:
reader = SimpleDirectoryReader(input_dir=directory)
docs = reader.load_data()
elif urls:
from llama_hub.web.simple_web.base import SimpleWebPageReader
@@ -101,7 +128,7 @@ def load_data(
loader = SimpleWebPageReader()
docs = loader.load_data(urls=urls)
else:
raise ValueError("Must specify either file_names or urls.")
raise ValueError("Must specify either file_names or urls or directory.")
return docs
@@ -326,3 +353,128 @@ def get_tool_objects(tool_names: List[str]) -> List:
raise ValueError(f"Tool {tool_name} not recognized.")
return tool_objs
class MultimodalChatEngine(BaseChatEngine):
"""Multimodal chat engine.
This chat engine is a light wrapper around a query engine.
Offers no real 'chat' functionality, is a beta feature.
"""
def __init__(self, mm_query_engine: SimpleMultiModalQueryEngine) -> None:
"""Init params."""
self._mm_query_engine = mm_query_engine
def reset(self) -> None:
"""Reset conversation state."""
pass
@property
def chat_history(self) -> List[ChatMessage]:
return []
@trace_method("chat")
def chat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> AGENT_CHAT_RESPONSE_TYPE:
"""Main chat interface."""
# just return the top-k results
response = self._mm_query_engine.query(message)
return AgentChatResponse(
response=str(response), source_nodes=response.source_nodes
)
@trace_method("chat")
def stream_chat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> StreamingAgentChatResponse:
"""Stream chat interface."""
response = self._mm_query_engine.query(message)
def _chat_stream(response: str) -> Generator[ChatResponse, None, None]:
yield ChatResponse(message=ChatMessage(role="assistant", content=response))
chat_stream = _chat_stream(str(response))
return StreamingAgentChatResponse(
chat_stream=chat_stream, source_nodes=response.source_nodes
)
@trace_method("chat")
async def achat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> AGENT_CHAT_RESPONSE_TYPE:
"""Async version of main chat interface."""
response = await self._mm_query_engine.aquery(message)
return AgentChatResponse(
response=str(response), source_nodes=response.source_nodes
)
@trace_method("chat")
async def astream_chat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> StreamingAgentChatResponse:
"""Async version of main chat interface."""
return self.stream_chat(message, chat_history)
def construct_mm_agent(
system_prompt: str,
rag_params: RAGParams,
docs: List[Document],
mm_vector_index: Optional[VectorStoreIndex] = None,
additional_tools: Optional[List] = None,
) -> Tuple[BaseChatEngine, Dict]:
"""Construct agent from docs / parameters / indices.
NOTE: system prompt isn't used right now
"""
extra_info = {}
additional_tools = additional_tools or []
# first resolve llm and embedding model
embed_model = resolve_embed_model(rag_params.embed_model)
# TODO: use OpenAI for now
os.environ["OPENAI_API_KEY"] = st.secrets.openai_key
openai_mm_llm = OpenAIMultiModal(model="gpt-4-vision-preview", max_new_tokens=1500)
# first let's index the data with the right parameters
service_context = ServiceContext.from_defaults(
chunk_size=rag_params.chunk_size,
embed_model=embed_model,
)
if mm_vector_index is None:
mm_vector_index = MultiModalVectorStoreIndex.from_documents(
docs, service_context=service_context
)
else:
pass
mm_retriever = mm_vector_index.as_retriever(similarity_top_k=rag_params.top_k)
mm_query_engine = SimpleMultiModalQueryEngine(
cast(MultiModalVectorIndexRetriever, mm_retriever),
multi_modal_llm=openai_mm_llm,
)
extra_info["vector_index"] = mm_vector_index
# use condense + context chat engine
agent = MultimodalChatEngine(mm_query_engine)
return agent, extra_info
def get_image_and_text_nodes(
nodes: List[NodeWithScore],
) -> Tuple[List[NodeWithScore], List[NodeWithScore]]:
image_nodes = []
text_nodes = []
for res_node in nodes:
if isinstance(res_node.node, ImageNode):
image_nodes.append(res_node)
else:
text_nodes.append(res_node)
return image_nodes, text_nodes
+19 -9
View File
@@ -4,7 +4,7 @@ import streamlit as st
from core.param_cache import (
RAGParams,
)
from core.agent_builder import (
from core.agent_builder.loader import (
RAGAgentBuilder,
AgentCacheRegistry,
)
@@ -93,14 +93,24 @@ if current_state.agent_builder is not None:
)
rag_params = cast(RAGParams, current_state.cache.rag_params)
file_names = st.text_input(
"File names (not editable)",
value=",".join(current_state.cache.file_names),
disabled=True,
)
urls = st.text_input(
"URLs (not editable)", value=",".join(current_state.cache.urls), disabled=True
)
with st.expander("Loaded Data (Expand to view)"):
file_names = st.text_input(
"File names (not editable)",
value=",".join(current_state.cache.file_names),
disabled=True,
)
directory = st.text_input(
"Directory (not editable)",
value=current_state.cache.directory,
disabled=True,
)
urls = st.text_input(
"URLs (not editable)",
value=",".join(current_state.cache.urls),
disabled=True,
)
include_summarization_st = st.checkbox(
"Include Summarization (only works for GPT-4)",
value=rag_params.include_summarization,
+48 -3
View File
@@ -1,6 +1,11 @@
"""Streamlit page showing builder config."""
import streamlit as st
from st_utils import add_sidebar, get_current_state
from core.utils import get_image_and_text_nodes
from llama_index.schema import MetadataMode
from llama_index.chat_engine.types import AGENT_CHAT_RESPONSE_TYPE
from typing import Dict, Optional
import pandas as pd
####################
@@ -28,8 +33,36 @@ if (
]
def add_to_message_history(role: str, content: str) -> None:
message = {"role": role, "content": str(content)}
def display_sources(response: AGENT_CHAT_RESPONSE_TYPE) -> None:
image_nodes, text_nodes = get_image_and_text_nodes(response.source_nodes)
if len(image_nodes) > 0 or len(text_nodes) > 0:
with st.expander("Sources"):
# get image nodes
if len(image_nodes) > 0:
st.subheader("Images")
for image_node in image_nodes:
st.image(image_node.metadata["file_path"])
if len(text_nodes) > 0:
st.subheader("Text")
sources_df_list = []
for text_node in text_nodes:
sources_df_list.append(
{
"ID": text_node.id_,
"Text": text_node.node.get_content(
metadata_mode=MetadataMode.ALL
),
}
)
sources_df = pd.DataFrame(sources_df_list)
st.dataframe(sources_df)
def add_to_message_history(
role: str, content: str, extra: Optional[Dict] = None
) -> None:
message = {"role": role, "content": str(content), "extra": extra}
st.session_state.agent_messages.append(message) # Add response to message history
@@ -45,6 +78,11 @@ def display_messages() -> None:
else:
raise ValueError(f"Unknown message type: {msg_type}")
# display sources
if "extra" in message and isinstance(message["extra"], dict):
if "response" in message["extra"].keys():
display_sources(message["extra"]["response"])
# if agent is created, then we can chat with it
if current_state.cache is not None and current_state.cache.agent is not None:
@@ -68,6 +106,13 @@ if current_state.cache is not None and current_state.cache.agent is not None:
with st.spinner("Thinking..."):
response = agent.chat(str(prompt))
st.write(str(response))
add_to_message_history("assistant", str(response))
# display sources
# Multi-modal: check if image nodes are present
display_sources(response)
add_to_message_history(
"assistant", str(response), extra={"response": response}
)
else:
st.info("Agent not created. Please create an agent in the above section.")
+10 -1
View File
@@ -1,6 +1,6 @@
[tool.poetry]
name = "rags"
version = "0.0.4"
version = "0.0.5"
description = "Build RAG with natural language."
authors = ["Jerry Liu"]
# New attributes
@@ -22,6 +22,7 @@ llama-hub = "0.0.44"
# NOTE: this is due to a trivial dependency in the web tool, will refactor
langchain = "0.0.305"
pypdf = "3.17.1"
clip = { git = "https://github.com/openai/CLIP.git" }
[tool.poetry.dev-dependencies]
# pytest = "7.2.1"
@@ -65,3 +66,11 @@ exclude = [
[tool.ruff.per-file-ignores]
"base.py" = ["E402", "F811", "E501"]
[tool.poetry.extras]
multimodal = [
"torch",
"torchvision",
"clip",
]
+49 -4
View File
@@ -1,9 +1,9 @@
"""Streamlit utils."""
from core.agent_builder import (
from core.agent_builder.loader import (
load_meta_agent_and_tools,
AgentCacheRegistry,
RAGAgentBuilder,
)
from core.agent_builder.base import BaseRAGAgentBuilder
from core.param_cache import ParamCache
from core.constants import (
AGENT_CACHE_DIR,
@@ -38,6 +38,47 @@ def update_selected_agent() -> None:
update_selected_agent_with_id(selected_id)
def get_cached_is_multimodal() -> bool:
"""Get default multimodal st."""
if (
"selected_cache" not in st.session_state.keys()
or st.session_state.selected_cache is None
):
default_val = False
else:
selected_cache = cast(ParamCache, st.session_state.selected_cache)
default_val = True if selected_cache.builder_type == "multimodal" else False
return default_val
def get_is_multimodal() -> bool:
"""Get is multimodal."""
if "is_multimodal_st" not in st.session_state.keys():
st.session_state.is_multimodal_st = False
return st.session_state.is_multimodal_st
def add_builder_config() -> None:
"""Add builder config."""
with st.expander("Builder Config (Advanced)"):
# add a few options - openai api key, and
if (
"selected_cache" not in st.session_state.keys()
or st.session_state.selected_cache is None
):
is_locked = False
else:
is_locked = True
st.checkbox(
"Enable multimodal search (beta)",
key="is_multimodal_st",
on_change=update_selected_agent,
value=get_cached_is_multimodal(),
disabled=is_locked,
)
def add_sidebar() -> None:
"""Add sidebar."""
with st.sidebar:
@@ -70,7 +111,7 @@ class CurrentSessionState(BaseModel):
agent_registry: AgentCacheRegistry
selected_id: Optional[str]
selected_cache: Optional[ParamCache]
agent_builder: RAGAgentBuilder
agent_builder: BaseRAGAgentBuilder
cache: ParamCache
builder_agent: BaseAgent
@@ -126,11 +167,15 @@ def get_current_state() -> CurrentSessionState:
builder_agent, agent_builder = load_meta_agent_and_tools(
cache=st.session_state.selected_cache,
agent_registry=st.session_state.agent_registry,
# NOTE: we will probably generalize this later into different
# builder configs
is_multimodal=get_cached_is_multimodal(),
)
else:
# create builder agent / tools from new cache
builder_agent, agent_builder = load_meta_agent_and_tools(
agent_registry=st.session_state.agent_registry
agent_registry=st.session_state.agent_registry,
is_multimodal=get_is_multimodal(),
)
st.session_state.builder_agent = builder_agent