mirror of
https://github.com/run-llama/rags.git
synced 2026-07-01 20:54:00 -04:00
upgrade v3 (#49)
support multi-modality allow loading an entire directory (and all the files in the directory) add sources in the response
This commit is contained in:
+7
-36
@@ -1,8 +1,13 @@
|
||||
import streamlit as st
|
||||
from streamlit_pills import pills
|
||||
|
||||
from st_utils import add_sidebar, get_current_state
|
||||
from st_utils import (
|
||||
add_builder_config,
|
||||
add_sidebar,
|
||||
get_current_state,
|
||||
)
|
||||
|
||||
current_state = get_current_state()
|
||||
|
||||
####################
|
||||
#### STREAMLIT #####
|
||||
@@ -28,7 +33,7 @@ if "metaphor_key" in st.secrets:
|
||||
st.info("**NOTE**: The ability to add web search is enabled.")
|
||||
|
||||
|
||||
current_state = get_current_state()
|
||||
add_builder_config()
|
||||
add_sidebar()
|
||||
|
||||
|
||||
@@ -60,40 +65,6 @@ for message in st.session_state.messages: # Display the prior chat messages
|
||||
with st.chat_message(message["role"]):
|
||||
st.write(message["content"])
|
||||
|
||||
|
||||
# def handle_user_input() -> None:
|
||||
# """Handle user input."""
|
||||
# prompt = st.session_state.user_question_st
|
||||
# print(f"USER PROMPT: {prompt}")
|
||||
# add_to_message_history("user", prompt)
|
||||
# with st.chat_message("user"):
|
||||
# st.write(prompt)
|
||||
# # If last message is not from assistant, generate a new response
|
||||
# if st.session_state.messages[-1]["role"] != "assistant":
|
||||
# with st.chat_message("assistant"):
|
||||
# with st.spinner("Thinking..."):
|
||||
# response = current_state.builder_agent.chat(prompt)
|
||||
# st.write(str(response))
|
||||
# add_to_message_history("assistant", str(response))
|
||||
|
||||
# else:
|
||||
# pass
|
||||
|
||||
# # check agent_ids again
|
||||
# # if it doesn't match, add to directory and refresh
|
||||
# agent_ids = current_state.agent_registry.get_agent_ids()
|
||||
# # check diff between agent_ids and cur agent ids
|
||||
# diff_ids = list(set(agent_ids) - set(st.session_state.cur_agent_ids))
|
||||
# if len(diff_ids) > 0:
|
||||
# # # clear streamlit cache, to allow you to generate a new agent
|
||||
# # st.cache_resource.clear()
|
||||
# st.rerun()
|
||||
|
||||
# handle user input
|
||||
# st.chat_input(
|
||||
# "Your question", key="user_question_st", on_submit=handle_user_input
|
||||
# ) # Prompt for user input and save to chat history
|
||||
|
||||
# TODO: this is really hacky, only because st.rerun is jank
|
||||
if prompt := st.chat_input(
|
||||
"Your question",
|
||||
|
||||
@@ -40,6 +40,8 @@ streamlit run 1_🏠_Home.py
|
||||
|
||||
```
|
||||
|
||||
**NOTE**: If you've upgraded the version of RAGs, and you're running into issues on launch, you may need to delete the `cache` folder in your home directory (we may have introduced breaking changes in the stored data structure between versions).
|
||||
|
||||
## Detailed Overview
|
||||
|
||||
The app contains the following sections, corresponding to the steps listed above.
|
||||
|
||||
@@ -3,93 +3,19 @@
|
||||
from llama_index.llms import ChatMessage
|
||||
from llama_index.prompts import ChatPromptTemplate
|
||||
from typing import List, cast, Optional
|
||||
from llama_index.tools import FunctionTool
|
||||
from llama_index.agent.types import BaseAgent
|
||||
from core.builder_config import BUILDER_LLM
|
||||
from typing import Dict, Tuple, Any, Callable, Union
|
||||
import streamlit as st
|
||||
from pathlib import Path
|
||||
import json
|
||||
from typing import Dict, Any
|
||||
import uuid
|
||||
from core.constants import AGENT_CACHE_DIR
|
||||
import shutil
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.param_cache import ParamCache, RAGParams
|
||||
from core.utils import (
|
||||
load_data,
|
||||
get_tool_objects,
|
||||
construct_agent,
|
||||
load_meta_agent,
|
||||
)
|
||||
|
||||
|
||||
class AgentCacheRegistry:
|
||||
"""Registry for agent caches, in disk.
|
||||
|
||||
Can register new agent caches, load agent caches, delete agent caches, etc.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, dir: Union[str, Path]) -> None:
|
||||
"""Init params."""
|
||||
self._dir = dir
|
||||
|
||||
def _add_agent_id_to_directory(self, agent_id: str) -> None:
|
||||
"""Save agent id to directory."""
|
||||
full_path = Path(self._dir) / "agent_ids.json"
|
||||
if not full_path.exists():
|
||||
with open(full_path, "w") as f:
|
||||
json.dump({"agent_ids": [agent_id]}, f)
|
||||
else:
|
||||
with open(full_path, "r") as f:
|
||||
agent_ids = json.load(f)["agent_ids"]
|
||||
if agent_id in agent_ids:
|
||||
raise ValueError(f"Agent id {agent_id} already exists.")
|
||||
agent_ids_set = set(agent_ids)
|
||||
agent_ids_set.add(agent_id)
|
||||
with open(full_path, "w") as f:
|
||||
json.dump({"agent_ids": list(agent_ids_set)}, f)
|
||||
|
||||
def add_new_agent_cache(self, agent_id: str, cache: ParamCache) -> None:
|
||||
"""Register agent."""
|
||||
# save the cache to disk
|
||||
agent_cache_path = f"{self._dir}/{agent_id}"
|
||||
cache.save_to_disk(agent_cache_path)
|
||||
# save to agent ids
|
||||
self._add_agent_id_to_directory(agent_id)
|
||||
|
||||
def get_agent_ids(self) -> List[str]:
|
||||
"""Get agent ids."""
|
||||
full_path = Path(self._dir) / "agent_ids.json"
|
||||
if not full_path.exists():
|
||||
return []
|
||||
with open(full_path, "r") as f:
|
||||
agent_ids = json.load(f)["agent_ids"]
|
||||
|
||||
return agent_ids
|
||||
|
||||
def get_agent_cache(self, agent_id: str) -> ParamCache:
|
||||
"""Get agent cache."""
|
||||
full_path = Path(self._dir) / f"{agent_id}"
|
||||
if not full_path.exists():
|
||||
raise ValueError(f"Cache for agent {agent_id} does not exist.")
|
||||
cache = ParamCache.load_from_disk(str(full_path))
|
||||
return cache
|
||||
|
||||
def delete_agent_cache(self, agent_id: str) -> None:
|
||||
"""Delete agent cache."""
|
||||
# modify / resave agent_ids
|
||||
agent_ids = self.get_agent_ids()
|
||||
new_agent_ids = [id for id in agent_ids if id != agent_id]
|
||||
full_path = Path(self._dir) / "agent_ids.json"
|
||||
with open(full_path, "w") as f:
|
||||
json.dump({"agent_ids": new_agent_ids}, f)
|
||||
|
||||
# remove agent cache
|
||||
full_path = Path(self._dir) / f"{agent_id}"
|
||||
if full_path.exists():
|
||||
# recursive delete
|
||||
shutil.rmtree(full_path)
|
||||
from core.agent_builder.registry import AgentCacheRegistry
|
||||
|
||||
|
||||
# System prompt tool
|
||||
@@ -121,7 +47,21 @@ gen_sys_prompt_messages = [
|
||||
GEN_SYS_PROMPT_TMPL = ChatPromptTemplate(gen_sys_prompt_messages)
|
||||
|
||||
|
||||
class RAGAgentBuilder:
|
||||
class BaseRAGAgentBuilder(ABC):
|
||||
"""Base RAG Agent builder class."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def cache(self) -> ParamCache:
|
||||
"""Cache."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def agent_registry(self) -> AgentCacheRegistry:
|
||||
"""Agent registry."""
|
||||
|
||||
|
||||
class RAGAgentBuilder(BaseRAGAgentBuilder):
|
||||
"""RAG Agent builder.
|
||||
|
||||
Contains a set of functions to construct a RAG agent, including:
|
||||
@@ -165,25 +105,31 @@ class RAGAgentBuilder:
|
||||
return f"System prompt created: {response.message.content}"
|
||||
|
||||
def load_data(
|
||||
self, file_names: Optional[List[str]] = None, urls: Optional[List[str]] = None
|
||||
self,
|
||||
file_names: Optional[List[str]] = None,
|
||||
directory: Optional[str] = None,
|
||||
urls: Optional[List[str]] = None,
|
||||
) -> str:
|
||||
"""Load data for a given task.
|
||||
|
||||
Only ONE of file_names or urls should be specified.
|
||||
Only ONE of file_names or directory or urls should be specified.
|
||||
|
||||
Args:
|
||||
file_names (Optional[List[str]]): List of file names to load.
|
||||
Defaults to None.
|
||||
directory (Optional[str]): Directory to load files from.
|
||||
urls (Optional[List[str]]): List of urls to load.
|
||||
Defaults to None.
|
||||
|
||||
"""
|
||||
file_names = file_names or []
|
||||
urls = urls or []
|
||||
docs = load_data(file_names=file_names, urls=urls)
|
||||
directory = directory or ""
|
||||
docs = load_data(file_names=file_names, directory=directory, urls=urls)
|
||||
self._cache.docs = docs
|
||||
self._cache.file_names = file_names
|
||||
self._cache.urls = urls
|
||||
self._cache.directory = directory
|
||||
return "Data loaded successfully."
|
||||
|
||||
def add_web_tool(self) -> str:
|
||||
@@ -302,77 +248,3 @@ class RAGAgentBuilder:
|
||||
|
||||
# this will update the agent in the cache
|
||||
self.create_agent()
|
||||
|
||||
|
||||
####################
|
||||
#### META Agent ####
|
||||
####################
|
||||
|
||||
RAG_BUILDER_SYS_STR = """\
|
||||
You are helping to construct an agent given a user-specified task.
|
||||
You should generally use the tools in this rough order to build the agent.
|
||||
|
||||
1) Create system prompt tool: to create the system prompt for the agent.
|
||||
2) Load in user-specified data (based on file paths they specify).
|
||||
3) Decide whether or not to add additional tools.
|
||||
4) Set parameters for the RAG pipeline.
|
||||
5) Build the agent
|
||||
|
||||
This will be a back and forth conversation with the user. You should
|
||||
continue asking users if there's anything else they want to do until
|
||||
they say they're done. To help guide them on the process,
|
||||
you can give suggestions on parameters they can set based on the tools they
|
||||
have available (e.g. "Do you want to set the number of documents to retrieve?")
|
||||
|
||||
"""
|
||||
|
||||
|
||||
### DEFINE Agent ####
|
||||
# NOTE: here we define a function that is dependent on the LLM,
|
||||
# please make sure to update the LLM above if you change the function below
|
||||
|
||||
|
||||
def _get_builder_agent_tools(agent_builder: RAGAgentBuilder) -> List[FunctionTool]:
|
||||
"""Get list of builder agent tools to pass to the builder agent."""
|
||||
# see if metaphor api key is set, otherwise don't add web tool
|
||||
# TODO: refactor this later
|
||||
|
||||
if "metaphor_key" in st.secrets:
|
||||
fns: List[Callable] = [
|
||||
agent_builder.create_system_prompt,
|
||||
agent_builder.load_data,
|
||||
agent_builder.add_web_tool,
|
||||
agent_builder.get_rag_params,
|
||||
agent_builder.set_rag_params,
|
||||
agent_builder.create_agent,
|
||||
]
|
||||
else:
|
||||
fns = [
|
||||
agent_builder.create_system_prompt,
|
||||
agent_builder.load_data,
|
||||
agent_builder.get_rag_params,
|
||||
agent_builder.set_rag_params,
|
||||
agent_builder.create_agent,
|
||||
]
|
||||
|
||||
fn_tools: List[FunctionTool] = [FunctionTool.from_defaults(fn=fn) for fn in fns]
|
||||
return fn_tools
|
||||
|
||||
|
||||
# define agent
|
||||
# @st.cache_resource
|
||||
def load_meta_agent_and_tools(
|
||||
cache: Optional[ParamCache] = None,
|
||||
agent_registry: Optional[AgentCacheRegistry] = None,
|
||||
) -> Tuple[BaseAgent, RAGAgentBuilder]:
|
||||
|
||||
# think of this as tools for the agent to use
|
||||
agent_builder = RAGAgentBuilder(cache, agent_registry=agent_registry)
|
||||
|
||||
fn_tools = _get_builder_agent_tools(agent_builder)
|
||||
|
||||
builder_agent = load_meta_agent(
|
||||
fn_tools, llm=BUILDER_LLM, system_prompt=RAG_BUILDER_SYS_STR, verbose=True
|
||||
)
|
||||
|
||||
return builder_agent, agent_builder
|
||||
@@ -0,0 +1,115 @@
|
||||
"""Loader agent."""
|
||||
|
||||
from typing import List, cast, Optional
|
||||
from llama_index.tools import FunctionTool
|
||||
from llama_index.agent.types import BaseAgent
|
||||
from core.builder_config import BUILDER_LLM
|
||||
from typing import Tuple, Callable
|
||||
import streamlit as st
|
||||
|
||||
from core.param_cache import ParamCache
|
||||
from core.utils import (
|
||||
load_meta_agent,
|
||||
)
|
||||
from core.agent_builder.registry import AgentCacheRegistry
|
||||
from core.agent_builder.base import RAGAgentBuilder, BaseRAGAgentBuilder
|
||||
from core.agent_builder.multimodal import MultimodalRAGAgentBuilder
|
||||
|
||||
####################
|
||||
#### META Agent ####
|
||||
####################
|
||||
|
||||
RAG_BUILDER_SYS_STR = """\
|
||||
You are helping to construct an agent given a user-specified task.
|
||||
You should generally use the tools in this rough order to build the agent.
|
||||
|
||||
1) Create system prompt tool: to create the system prompt for the agent.
|
||||
2) Load in user-specified data (based on file paths they specify).
|
||||
3) Decide whether or not to add additional tools.
|
||||
4) Set parameters for the RAG pipeline.
|
||||
5) Build the agent
|
||||
|
||||
This will be a back and forth conversation with the user. You should
|
||||
continue asking users if there's anything else they want to do until
|
||||
they say they're done. To help guide them on the process,
|
||||
you can give suggestions on parameters they can set based on the tools they
|
||||
have available (e.g. "Do you want to set the number of documents to retrieve?")
|
||||
|
||||
"""
|
||||
|
||||
|
||||
### DEFINE Agent ####
|
||||
# NOTE: here we define a function that is dependent on the LLM,
|
||||
# please make sure to update the LLM above if you change the function below
|
||||
|
||||
|
||||
def _get_builder_agent_tools(agent_builder: RAGAgentBuilder) -> List[FunctionTool]:
|
||||
"""Get list of builder agent tools to pass to the builder agent."""
|
||||
# see if metaphor api key is set, otherwise don't add web tool
|
||||
# TODO: refactor this later
|
||||
|
||||
if "metaphor_key" in st.secrets:
|
||||
fns: List[Callable] = [
|
||||
agent_builder.create_system_prompt,
|
||||
agent_builder.load_data,
|
||||
agent_builder.add_web_tool,
|
||||
agent_builder.get_rag_params,
|
||||
agent_builder.set_rag_params,
|
||||
agent_builder.create_agent,
|
||||
]
|
||||
else:
|
||||
fns = [
|
||||
agent_builder.create_system_prompt,
|
||||
agent_builder.load_data,
|
||||
agent_builder.get_rag_params,
|
||||
agent_builder.set_rag_params,
|
||||
agent_builder.create_agent,
|
||||
]
|
||||
|
||||
fn_tools: List[FunctionTool] = [FunctionTool.from_defaults(fn=fn) for fn in fns]
|
||||
return fn_tools
|
||||
|
||||
|
||||
def _get_mm_builder_agent_tools(
|
||||
agent_builder: MultimodalRAGAgentBuilder,
|
||||
) -> List[FunctionTool]:
|
||||
"""Get list of builder agent tools to pass to the builder agent."""
|
||||
fns: List[Callable] = [
|
||||
agent_builder.create_system_prompt,
|
||||
agent_builder.load_data,
|
||||
agent_builder.get_rag_params,
|
||||
agent_builder.set_rag_params,
|
||||
agent_builder.create_agent,
|
||||
]
|
||||
|
||||
fn_tools: List[FunctionTool] = [FunctionTool.from_defaults(fn=fn) for fn in fns]
|
||||
return fn_tools
|
||||
|
||||
|
||||
# define agent
|
||||
def load_meta_agent_and_tools(
|
||||
cache: Optional[ParamCache] = None,
|
||||
agent_registry: Optional[AgentCacheRegistry] = None,
|
||||
is_multimodal: bool = False,
|
||||
) -> Tuple[BaseAgent, BaseRAGAgentBuilder]:
|
||||
"""Load meta agent and tools."""
|
||||
|
||||
if is_multimodal:
|
||||
agent_builder: BaseRAGAgentBuilder = MultimodalRAGAgentBuilder(
|
||||
cache, agent_registry=agent_registry
|
||||
)
|
||||
fn_tools = _get_mm_builder_agent_tools(
|
||||
cast(MultimodalRAGAgentBuilder, agent_builder)
|
||||
)
|
||||
builder_agent = load_meta_agent(
|
||||
fn_tools, llm=BUILDER_LLM, system_prompt=RAG_BUILDER_SYS_STR, verbose=True
|
||||
)
|
||||
else:
|
||||
# think of this as tools for the agent to use
|
||||
agent_builder = RAGAgentBuilder(cache, agent_registry=agent_registry)
|
||||
fn_tools = _get_builder_agent_tools(agent_builder)
|
||||
builder_agent = load_meta_agent(
|
||||
fn_tools, llm=BUILDER_LLM, system_prompt=RAG_BUILDER_SYS_STR, verbose=True
|
||||
)
|
||||
|
||||
return builder_agent, agent_builder
|
||||
@@ -0,0 +1,256 @@
|
||||
"""Multimodal agent builder."""
|
||||
|
||||
from llama_index.llms import ChatMessage
|
||||
from typing import List, cast, Optional
|
||||
from core.builder_config import BUILDER_LLM
|
||||
from typing import Dict, Any
|
||||
import uuid
|
||||
from core.constants import AGENT_CACHE_DIR
|
||||
|
||||
from core.param_cache import ParamCache, RAGParams
|
||||
from core.utils import (
|
||||
load_data,
|
||||
construct_mm_agent,
|
||||
)
|
||||
from core.agent_builder.registry import AgentCacheRegistry
|
||||
from core.agent_builder.base import GEN_SYS_PROMPT_TMPL, BaseRAGAgentBuilder
|
||||
|
||||
from llama_index.chat_engine.types import BaseChatEngine
|
||||
|
||||
from llama_index.callbacks import trace_method
|
||||
from llama_index.query_engine.multi_modal import SimpleMultiModalQueryEngine
|
||||
from llama_index.chat_engine.types import (
|
||||
AGENT_CHAT_RESPONSE_TYPE,
|
||||
StreamingAgentChatResponse,
|
||||
AgentChatResponse,
|
||||
)
|
||||
from llama_index.llms.base import ChatResponse
|
||||
from typing import Generator
|
||||
|
||||
|
||||
class MultimodalChatEngine(BaseChatEngine):
|
||||
"""Multimodal chat engine.
|
||||
|
||||
This chat engine is a light wrapper around a query engine.
|
||||
Offers no real 'chat' functionality, is a beta feature.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, mm_query_engine: SimpleMultiModalQueryEngine) -> None:
|
||||
"""Init params."""
|
||||
self._mm_query_engine = mm_query_engine
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset conversation state."""
|
||||
pass
|
||||
|
||||
@trace_method("chat")
|
||||
def chat(
|
||||
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
||||
) -> AGENT_CHAT_RESPONSE_TYPE:
|
||||
"""Main chat interface."""
|
||||
# just return the top-k results
|
||||
response = self._mm_query_engine.query(message)
|
||||
return AgentChatResponse(response=str(response))
|
||||
|
||||
@trace_method("chat")
|
||||
def stream_chat(
|
||||
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
||||
) -> StreamingAgentChatResponse:
|
||||
"""Stream chat interface."""
|
||||
response = self._mm_query_engine.query(message)
|
||||
|
||||
def _chat_stream(response: str) -> Generator[ChatResponse, None, None]:
|
||||
yield ChatResponse(message=ChatMessage(role="assistant", content=response))
|
||||
|
||||
chat_stream = _chat_stream(str(response))
|
||||
return StreamingAgentChatResponse(chat_stream=chat_stream)
|
||||
|
||||
@trace_method("chat")
|
||||
async def achat(
|
||||
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
||||
) -> AGENT_CHAT_RESPONSE_TYPE:
|
||||
"""Async version of main chat interface."""
|
||||
response = await self._mm_query_engine.aquery(message)
|
||||
return AgentChatResponse(response=str(response))
|
||||
|
||||
@trace_method("chat")
|
||||
async def astream_chat(
|
||||
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
||||
) -> StreamingAgentChatResponse:
|
||||
"""Async version of main chat interface."""
|
||||
return self.stream_chat(message, chat_history)
|
||||
|
||||
|
||||
class MultimodalRAGAgentBuilder(BaseRAGAgentBuilder):
|
||||
"""Multimodal RAG Agent builder.
|
||||
|
||||
Contains a set of functions to construct a RAG agent, including:
|
||||
- setting system prompts
|
||||
- loading data
|
||||
- adding web search
|
||||
- setting parameters (e.g. top-k)
|
||||
|
||||
Must pass in a cache. This cache will be modified as the agent is built.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache: Optional[ParamCache] = None,
|
||||
agent_registry: Optional[AgentCacheRegistry] = None,
|
||||
) -> None:
|
||||
"""Init params."""
|
||||
self._cache = cache or ParamCache()
|
||||
self._agent_registry = agent_registry or AgentCacheRegistry(
|
||||
str(AGENT_CACHE_DIR)
|
||||
)
|
||||
|
||||
@property
|
||||
def cache(self) -> ParamCache:
|
||||
"""Cache."""
|
||||
return self._cache
|
||||
|
||||
@property
|
||||
def agent_registry(self) -> AgentCacheRegistry:
|
||||
"""Agent registry."""
|
||||
return self._agent_registry
|
||||
|
||||
def create_system_prompt(self, task: str) -> str:
|
||||
"""Create system prompt for another agent given an input task."""
|
||||
llm = BUILDER_LLM
|
||||
fmt_messages = GEN_SYS_PROMPT_TMPL.format_messages(task=task)
|
||||
response = llm.chat(fmt_messages)
|
||||
self._cache.system_prompt = response.message.content
|
||||
|
||||
return f"System prompt created: {response.message.content}"
|
||||
|
||||
def load_data(
|
||||
self,
|
||||
file_names: Optional[List[str]] = None,
|
||||
directory: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Load data for a given task.
|
||||
|
||||
Only ONE of file_names or directory should be specified.
|
||||
**NOTE**: urls not supported in multi-modal setting.
|
||||
|
||||
Args:
|
||||
file_names (Optional[List[str]]): List of file names to load.
|
||||
Defaults to None.
|
||||
directory (Optional[str]): Directory to load files from.
|
||||
|
||||
"""
|
||||
file_names = file_names or []
|
||||
directory = directory or ""
|
||||
docs = load_data(file_names=file_names, directory=directory)
|
||||
self._cache.docs = docs
|
||||
self._cache.file_names = file_names
|
||||
self._cache.directory = directory
|
||||
return "Data loaded successfully."
|
||||
|
||||
def get_rag_params(self) -> Dict:
|
||||
"""Get parameters used to configure the RAG pipeline.
|
||||
|
||||
Should be called before `set_rag_params` so that the agent is aware of the
|
||||
schema.
|
||||
|
||||
"""
|
||||
rag_params = self._cache.rag_params
|
||||
return rag_params.dict()
|
||||
|
||||
def set_rag_params(self, **rag_params: Dict) -> str:
|
||||
"""Set RAG parameters.
|
||||
|
||||
These parameters will then be used to actually initialize the agent.
|
||||
Should call `get_rag_params` first to get the schema of the input dictionary.
|
||||
|
||||
Args:
|
||||
**rag_params (Dict): dictionary of RAG parameters.
|
||||
|
||||
"""
|
||||
new_dict = self._cache.rag_params.dict()
|
||||
new_dict.update(rag_params)
|
||||
rag_params_obj = RAGParams(**new_dict)
|
||||
self._cache.rag_params = rag_params_obj
|
||||
return "RAG parameters set successfully."
|
||||
|
||||
def create_agent(self, agent_id: Optional[str] = None) -> str:
|
||||
"""Create an agent.
|
||||
|
||||
There are no parameters for this function because all the
|
||||
functions should have already been called to set up the agent.
|
||||
|
||||
"""
|
||||
if self._cache.system_prompt is None:
|
||||
raise ValueError("Must set system prompt before creating agent.")
|
||||
|
||||
# construct additional tools
|
||||
agent, extra_info = construct_mm_agent(
|
||||
cast(str, self._cache.system_prompt),
|
||||
cast(RAGParams, self._cache.rag_params),
|
||||
self._cache.docs,
|
||||
)
|
||||
|
||||
# if agent_id not specified, randomly generate one
|
||||
agent_id = agent_id or self._cache.agent_id or f"Agent_{str(uuid.uuid4())}"
|
||||
self._cache.builder_type = "multimodal"
|
||||
self._cache.vector_index = extra_info["vector_index"]
|
||||
self._cache.agent_id = agent_id
|
||||
self._cache.agent = agent
|
||||
|
||||
# save the cache to disk
|
||||
self._agent_registry.add_new_agent_cache(agent_id, self._cache)
|
||||
return "Agent created successfully."
|
||||
|
||||
def update_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
include_summarization: Optional[bool] = None,
|
||||
top_k: Optional[int] = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
embed_model: Optional[str] = None,
|
||||
llm: Optional[str] = None,
|
||||
additional_tools: Optional[List] = None,
|
||||
) -> None:
|
||||
"""Update agent.
|
||||
|
||||
Delete old agent by ID and create a new one.
|
||||
Optionally update the system prompt and RAG parameters.
|
||||
|
||||
NOTE: Currently is manually called, not meant for agent use.
|
||||
|
||||
"""
|
||||
self._agent_registry.delete_agent_cache(self.cache.agent_id)
|
||||
|
||||
# set agent id
|
||||
self.cache.agent_id = agent_id
|
||||
|
||||
# set system prompt
|
||||
if system_prompt is not None:
|
||||
self.cache.system_prompt = system_prompt
|
||||
# get agent_builder
|
||||
# We call set_rag_params and create_agent, which will
|
||||
# update the cache
|
||||
# TODO: decouple functions from tool functions exposed to the agent
|
||||
rag_params_dict: Dict[str, Any] = {}
|
||||
if include_summarization is not None:
|
||||
rag_params_dict["include_summarization"] = include_summarization
|
||||
if top_k is not None:
|
||||
rag_params_dict["top_k"] = top_k
|
||||
if chunk_size is not None:
|
||||
rag_params_dict["chunk_size"] = chunk_size
|
||||
if embed_model is not None:
|
||||
rag_params_dict["embed_model"] = embed_model
|
||||
if llm is not None:
|
||||
rag_params_dict["llm"] = llm
|
||||
|
||||
self.set_rag_params(**rag_params_dict)
|
||||
|
||||
# update tools
|
||||
if additional_tools is not None:
|
||||
self.cache.tools = additional_tools
|
||||
|
||||
# this will update the agent in the cache
|
||||
self.create_agent()
|
||||
@@ -0,0 +1,78 @@
|
||||
"""Agent builder registry."""
|
||||
|
||||
from typing import List
|
||||
from typing import Union
|
||||
from pathlib import Path
|
||||
import json
|
||||
import shutil
|
||||
|
||||
from core.param_cache import ParamCache
|
||||
|
||||
|
||||
class AgentCacheRegistry:
|
||||
"""Registry for agent caches, in disk.
|
||||
|
||||
Can register new agent caches, load agent caches, delete agent caches, etc.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, dir: Union[str, Path]) -> None:
|
||||
"""Init params."""
|
||||
self._dir = dir
|
||||
|
||||
def _add_agent_id_to_directory(self, agent_id: str) -> None:
|
||||
"""Save agent id to directory."""
|
||||
full_path = Path(self._dir) / "agent_ids.json"
|
||||
if not full_path.exists():
|
||||
with open(full_path, "w") as f:
|
||||
json.dump({"agent_ids": [agent_id]}, f)
|
||||
else:
|
||||
with open(full_path, "r") as f:
|
||||
agent_ids = json.load(f)["agent_ids"]
|
||||
if agent_id in agent_ids:
|
||||
raise ValueError(f"Agent id {agent_id} already exists.")
|
||||
agent_ids_set = set(agent_ids)
|
||||
agent_ids_set.add(agent_id)
|
||||
with open(full_path, "w") as f:
|
||||
json.dump({"agent_ids": list(agent_ids_set)}, f)
|
||||
|
||||
def add_new_agent_cache(self, agent_id: str, cache: ParamCache) -> None:
|
||||
"""Register agent."""
|
||||
# save the cache to disk
|
||||
agent_cache_path = f"{self._dir}/{agent_id}"
|
||||
cache.save_to_disk(agent_cache_path)
|
||||
# save to agent ids
|
||||
self._add_agent_id_to_directory(agent_id)
|
||||
|
||||
def get_agent_ids(self) -> List[str]:
|
||||
"""Get agent ids."""
|
||||
full_path = Path(self._dir) / "agent_ids.json"
|
||||
if not full_path.exists():
|
||||
return []
|
||||
with open(full_path, "r") as f:
|
||||
agent_ids = json.load(f)["agent_ids"]
|
||||
|
||||
return agent_ids
|
||||
|
||||
def get_agent_cache(self, agent_id: str) -> ParamCache:
|
||||
"""Get agent cache."""
|
||||
full_path = Path(self._dir) / f"{agent_id}"
|
||||
if not full_path.exists():
|
||||
raise ValueError(f"Cache for agent {agent_id} does not exist.")
|
||||
cache = ParamCache.load_from_disk(str(full_path))
|
||||
return cache
|
||||
|
||||
def delete_agent_cache(self, agent_id: str) -> None:
|
||||
"""Delete agent cache."""
|
||||
# modify / resave agent_ids
|
||||
agent_ids = self.get_agent_ids()
|
||||
new_agent_ids = [id for id in agent_ids if id != agent_id]
|
||||
full_path = Path(self._dir) / "agent_ids.json"
|
||||
with open(full_path, "w") as f:
|
||||
json.dump({"agent_ids": new_agent_ids}, f)
|
||||
|
||||
# remove agent cache
|
||||
full_path = Path(self._dir) / f"{agent_id}"
|
||||
if full_path.exists():
|
||||
# recursive delete
|
||||
shutil.rmtree(full_path)
|
||||
+49
-13
@@ -11,7 +11,13 @@ from llama_index.chat_engine.types import BaseChatEngine
|
||||
from pathlib import Path
|
||||
import json
|
||||
import uuid
|
||||
from core.utils import load_data, get_tool_objects, construct_agent, RAGParams
|
||||
from core.utils import (
|
||||
load_data,
|
||||
get_tool_objects,
|
||||
construct_agent,
|
||||
RAGParams,
|
||||
construct_mm_agent,
|
||||
)
|
||||
|
||||
|
||||
class ParamCache(BaseModel):
|
||||
@@ -37,6 +43,10 @@ class ParamCache(BaseModel):
|
||||
urls: List[str] = Field(
|
||||
default_factory=list, description="URLs as data source (if specified)"
|
||||
)
|
||||
directory: Optional[str] = Field(
|
||||
default=None, description="Directory as data source (if specified)"
|
||||
)
|
||||
|
||||
docs: List = Field(default_factory=list, description="Documents for RAG agent.")
|
||||
# tools
|
||||
tools: List = Field(
|
||||
@@ -48,6 +58,9 @@ class ParamCache(BaseModel):
|
||||
)
|
||||
|
||||
# agent params
|
||||
builder_type: str = Field(
|
||||
default="default", description="Builder type (default, multimodal)."
|
||||
)
|
||||
vector_index: Optional[VectorStoreIndex] = Field(
|
||||
default=None, description="Vector index for RAG agent."
|
||||
)
|
||||
@@ -66,9 +79,11 @@ class ParamCache(BaseModel):
|
||||
"system_prompt": self.system_prompt,
|
||||
"file_names": self.file_names,
|
||||
"urls": self.urls,
|
||||
"directory": self.directory,
|
||||
# TODO: figure out tools
|
||||
"tools": self.tools,
|
||||
"rag_params": self.rag_params.dict(),
|
||||
"builder_type": self.builder_type,
|
||||
"agent_id": self.agent_id,
|
||||
}
|
||||
# store the vector store within the agent
|
||||
@@ -88,13 +103,22 @@ class ParamCache(BaseModel):
|
||||
save_dir: str,
|
||||
) -> "ParamCache":
|
||||
"""Load cache from disk."""
|
||||
with open(Path(save_dir) / "cache.json", "r") as f:
|
||||
cache_dict = json.load(f)
|
||||
|
||||
storage_context = StorageContext.from_defaults(
|
||||
persist_dir=str(Path(save_dir) / "storage")
|
||||
)
|
||||
vector_index = cast(VectorStoreIndex, load_index_from_storage(storage_context))
|
||||
if cache_dict["builder_type"] == "multimodal":
|
||||
from llama_index.indices.multi_modal.base import MultiModalVectorStoreIndex
|
||||
|
||||
with open(Path(save_dir) / "cache.json", "r") as f:
|
||||
cache_dict = json.load(f)
|
||||
vector_index: VectorStoreIndex = cast(
|
||||
MultiModalVectorStoreIndex, load_index_from_storage(storage_context)
|
||||
)
|
||||
else:
|
||||
vector_index = cast(
|
||||
VectorStoreIndex, load_index_from_storage(storage_context)
|
||||
)
|
||||
|
||||
# replace rag params with RAGParams object
|
||||
cache_dict["rag_params"] = RAGParams(**cache_dict["rag_params"])
|
||||
@@ -102,18 +126,30 @@ class ParamCache(BaseModel):
|
||||
# add in the missing fields
|
||||
# load docs
|
||||
cache_dict["docs"] = load_data(
|
||||
file_names=cache_dict["file_names"], urls=cache_dict["urls"]
|
||||
file_names=cache_dict["file_names"],
|
||||
urls=cache_dict["urls"],
|
||||
directory=cache_dict["directory"],
|
||||
)
|
||||
# load agent from index
|
||||
additional_tools = get_tool_objects(cache_dict["tools"])
|
||||
agent, _ = construct_agent(
|
||||
cache_dict["system_prompt"],
|
||||
cache_dict["rag_params"],
|
||||
cache_dict["docs"],
|
||||
vector_index=vector_index,
|
||||
additional_tools=additional_tools,
|
||||
# TODO: figure out tools
|
||||
)
|
||||
|
||||
if cache_dict["builder_type"] == "multimodal":
|
||||
vector_index = cast(MultiModalVectorStoreIndex, vector_index)
|
||||
agent, _ = construct_mm_agent(
|
||||
cache_dict["system_prompt"],
|
||||
cache_dict["rag_params"],
|
||||
cache_dict["docs"],
|
||||
mm_vector_index=vector_index,
|
||||
)
|
||||
else:
|
||||
agent, _ = construct_agent(
|
||||
cache_dict["system_prompt"],
|
||||
cache_dict["rag_params"],
|
||||
cache_dict["docs"],
|
||||
vector_index=vector_index,
|
||||
additional_tools=additional_tools,
|
||||
# TODO: figure out tools
|
||||
)
|
||||
cache_dict["vector_index"] = vector_index
|
||||
cache_dict["agent"] = agent
|
||||
|
||||
|
||||
+159
-7
@@ -26,8 +26,25 @@ from core.builder_config import BUILDER_LLM
|
||||
from typing import Dict, Tuple, Any
|
||||
import streamlit as st
|
||||
|
||||
from llama_index.callbacks import CallbackManager
|
||||
from llama_index.callbacks import CallbackManager, trace_method
|
||||
from core.callback_manager import StreamlitFunctionsCallbackHandler
|
||||
from llama_index.schema import ImageNode, NodeWithScore
|
||||
|
||||
### BETA: Multi-modal
|
||||
from llama_index.indices.multi_modal.base import MultiModalVectorStoreIndex
|
||||
from llama_index.multi_modal_llms.openai import OpenAIMultiModal
|
||||
from llama_index.indices.multi_modal.retriever import (
|
||||
MultiModalVectorIndexRetriever,
|
||||
)
|
||||
from llama_index.llms import ChatMessage
|
||||
from llama_index.query_engine.multi_modal import SimpleMultiModalQueryEngine
|
||||
from llama_index.chat_engine.types import (
|
||||
AGENT_CHAT_RESPONSE_TYPE,
|
||||
StreamingAgentChatResponse,
|
||||
AgentChatResponse,
|
||||
)
|
||||
from llama_index.llms.base import ChatResponse
|
||||
from typing import Generator
|
||||
|
||||
|
||||
class RAGParams(BaseModel):
|
||||
@@ -82,18 +99,28 @@ def _resolve_llm(llm_str: str) -> LLM:
|
||||
|
||||
|
||||
def load_data(
|
||||
file_names: Optional[List[str]] = None, urls: Optional[List[str]] = None
|
||||
file_names: Optional[List[str]] = None,
|
||||
directory: Optional[str] = None,
|
||||
urls: Optional[List[str]] = None,
|
||||
) -> List[Document]:
|
||||
"""Load data."""
|
||||
file_names = file_names or []
|
||||
directory = directory or ""
|
||||
urls = urls or []
|
||||
if not file_names and not urls:
|
||||
raise ValueError("Must specify either file_names or urls.")
|
||||
elif file_names and urls:
|
||||
raise ValueError("Must specify only one of file_names or urls.")
|
||||
|
||||
# get number depending on whether specified
|
||||
num_specified = sum(1 for v in [file_names, urls, directory] if v)
|
||||
|
||||
if num_specified == 0:
|
||||
raise ValueError("Must specify either file_names or urls or directory.")
|
||||
elif num_specified > 1:
|
||||
raise ValueError("Must specify only one of file_names or urls or directory.")
|
||||
elif file_names:
|
||||
reader = SimpleDirectoryReader(input_files=file_names)
|
||||
docs = reader.load_data()
|
||||
elif directory:
|
||||
reader = SimpleDirectoryReader(input_dir=directory)
|
||||
docs = reader.load_data()
|
||||
elif urls:
|
||||
from llama_hub.web.simple_web.base import SimpleWebPageReader
|
||||
|
||||
@@ -101,7 +128,7 @@ def load_data(
|
||||
loader = SimpleWebPageReader()
|
||||
docs = loader.load_data(urls=urls)
|
||||
else:
|
||||
raise ValueError("Must specify either file_names or urls.")
|
||||
raise ValueError("Must specify either file_names or urls or directory.")
|
||||
|
||||
return docs
|
||||
|
||||
@@ -326,3 +353,128 @@ def get_tool_objects(tool_names: List[str]) -> List:
|
||||
raise ValueError(f"Tool {tool_name} not recognized.")
|
||||
|
||||
return tool_objs
|
||||
|
||||
|
||||
class MultimodalChatEngine(BaseChatEngine):
|
||||
"""Multimodal chat engine.
|
||||
|
||||
This chat engine is a light wrapper around a query engine.
|
||||
Offers no real 'chat' functionality, is a beta feature.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, mm_query_engine: SimpleMultiModalQueryEngine) -> None:
|
||||
"""Init params."""
|
||||
self._mm_query_engine = mm_query_engine
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset conversation state."""
|
||||
pass
|
||||
|
||||
@property
|
||||
def chat_history(self) -> List[ChatMessage]:
|
||||
return []
|
||||
|
||||
@trace_method("chat")
|
||||
def chat(
|
||||
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
||||
) -> AGENT_CHAT_RESPONSE_TYPE:
|
||||
"""Main chat interface."""
|
||||
# just return the top-k results
|
||||
response = self._mm_query_engine.query(message)
|
||||
return AgentChatResponse(
|
||||
response=str(response), source_nodes=response.source_nodes
|
||||
)
|
||||
|
||||
@trace_method("chat")
|
||||
def stream_chat(
|
||||
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
||||
) -> StreamingAgentChatResponse:
|
||||
"""Stream chat interface."""
|
||||
response = self._mm_query_engine.query(message)
|
||||
|
||||
def _chat_stream(response: str) -> Generator[ChatResponse, None, None]:
|
||||
yield ChatResponse(message=ChatMessage(role="assistant", content=response))
|
||||
|
||||
chat_stream = _chat_stream(str(response))
|
||||
return StreamingAgentChatResponse(
|
||||
chat_stream=chat_stream, source_nodes=response.source_nodes
|
||||
)
|
||||
|
||||
@trace_method("chat")
|
||||
async def achat(
|
||||
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
||||
) -> AGENT_CHAT_RESPONSE_TYPE:
|
||||
"""Async version of main chat interface."""
|
||||
response = await self._mm_query_engine.aquery(message)
|
||||
return AgentChatResponse(
|
||||
response=str(response), source_nodes=response.source_nodes
|
||||
)
|
||||
|
||||
@trace_method("chat")
|
||||
async def astream_chat(
|
||||
self, message: str, chat_history: Optional[List[ChatMessage]] = None
|
||||
) -> StreamingAgentChatResponse:
|
||||
"""Async version of main chat interface."""
|
||||
return self.stream_chat(message, chat_history)
|
||||
|
||||
|
||||
def construct_mm_agent(
|
||||
system_prompt: str,
|
||||
rag_params: RAGParams,
|
||||
docs: List[Document],
|
||||
mm_vector_index: Optional[VectorStoreIndex] = None,
|
||||
additional_tools: Optional[List] = None,
|
||||
) -> Tuple[BaseChatEngine, Dict]:
|
||||
"""Construct agent from docs / parameters / indices.
|
||||
|
||||
NOTE: system prompt isn't used right now
|
||||
|
||||
"""
|
||||
extra_info = {}
|
||||
additional_tools = additional_tools or []
|
||||
|
||||
# first resolve llm and embedding model
|
||||
embed_model = resolve_embed_model(rag_params.embed_model)
|
||||
# TODO: use OpenAI for now
|
||||
os.environ["OPENAI_API_KEY"] = st.secrets.openai_key
|
||||
openai_mm_llm = OpenAIMultiModal(model="gpt-4-vision-preview", max_new_tokens=1500)
|
||||
|
||||
# first let's index the data with the right parameters
|
||||
service_context = ServiceContext.from_defaults(
|
||||
chunk_size=rag_params.chunk_size,
|
||||
embed_model=embed_model,
|
||||
)
|
||||
|
||||
if mm_vector_index is None:
|
||||
mm_vector_index = MultiModalVectorStoreIndex.from_documents(
|
||||
docs, service_context=service_context
|
||||
)
|
||||
else:
|
||||
pass
|
||||
|
||||
mm_retriever = mm_vector_index.as_retriever(similarity_top_k=rag_params.top_k)
|
||||
mm_query_engine = SimpleMultiModalQueryEngine(
|
||||
cast(MultiModalVectorIndexRetriever, mm_retriever),
|
||||
multi_modal_llm=openai_mm_llm,
|
||||
)
|
||||
|
||||
extra_info["vector_index"] = mm_vector_index
|
||||
|
||||
# use condense + context chat engine
|
||||
agent = MultimodalChatEngine(mm_query_engine)
|
||||
|
||||
return agent, extra_info
|
||||
|
||||
|
||||
def get_image_and_text_nodes(
|
||||
nodes: List[NodeWithScore],
|
||||
) -> Tuple[List[NodeWithScore], List[NodeWithScore]]:
|
||||
image_nodes = []
|
||||
text_nodes = []
|
||||
for res_node in nodes:
|
||||
if isinstance(res_node.node, ImageNode):
|
||||
image_nodes.append(res_node)
|
||||
else:
|
||||
text_nodes.append(res_node)
|
||||
return image_nodes, text_nodes
|
||||
|
||||
@@ -4,7 +4,7 @@ import streamlit as st
|
||||
from core.param_cache import (
|
||||
RAGParams,
|
||||
)
|
||||
from core.agent_builder import (
|
||||
from core.agent_builder.loader import (
|
||||
RAGAgentBuilder,
|
||||
AgentCacheRegistry,
|
||||
)
|
||||
@@ -93,14 +93,24 @@ if current_state.agent_builder is not None:
|
||||
)
|
||||
|
||||
rag_params = cast(RAGParams, current_state.cache.rag_params)
|
||||
file_names = st.text_input(
|
||||
"File names (not editable)",
|
||||
value=",".join(current_state.cache.file_names),
|
||||
disabled=True,
|
||||
)
|
||||
urls = st.text_input(
|
||||
"URLs (not editable)", value=",".join(current_state.cache.urls), disabled=True
|
||||
)
|
||||
|
||||
with st.expander("Loaded Data (Expand to view)"):
|
||||
file_names = st.text_input(
|
||||
"File names (not editable)",
|
||||
value=",".join(current_state.cache.file_names),
|
||||
disabled=True,
|
||||
)
|
||||
directory = st.text_input(
|
||||
"Directory (not editable)",
|
||||
value=current_state.cache.directory,
|
||||
disabled=True,
|
||||
)
|
||||
urls = st.text_input(
|
||||
"URLs (not editable)",
|
||||
value=",".join(current_state.cache.urls),
|
||||
disabled=True,
|
||||
)
|
||||
|
||||
include_summarization_st = st.checkbox(
|
||||
"Include Summarization (only works for GPT-4)",
|
||||
value=rag_params.include_summarization,
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
"""Streamlit page showing builder config."""
|
||||
import streamlit as st
|
||||
from st_utils import add_sidebar, get_current_state
|
||||
from core.utils import get_image_and_text_nodes
|
||||
from llama_index.schema import MetadataMode
|
||||
from llama_index.chat_engine.types import AGENT_CHAT_RESPONSE_TYPE
|
||||
from typing import Dict, Optional
|
||||
import pandas as pd
|
||||
|
||||
|
||||
####################
|
||||
@@ -28,8 +33,36 @@ if (
|
||||
]
|
||||
|
||||
|
||||
def add_to_message_history(role: str, content: str) -> None:
|
||||
message = {"role": role, "content": str(content)}
|
||||
def display_sources(response: AGENT_CHAT_RESPONSE_TYPE) -> None:
|
||||
image_nodes, text_nodes = get_image_and_text_nodes(response.source_nodes)
|
||||
if len(image_nodes) > 0 or len(text_nodes) > 0:
|
||||
with st.expander("Sources"):
|
||||
# get image nodes
|
||||
if len(image_nodes) > 0:
|
||||
st.subheader("Images")
|
||||
for image_node in image_nodes:
|
||||
st.image(image_node.metadata["file_path"])
|
||||
|
||||
if len(text_nodes) > 0:
|
||||
st.subheader("Text")
|
||||
sources_df_list = []
|
||||
for text_node in text_nodes:
|
||||
sources_df_list.append(
|
||||
{
|
||||
"ID": text_node.id_,
|
||||
"Text": text_node.node.get_content(
|
||||
metadata_mode=MetadataMode.ALL
|
||||
),
|
||||
}
|
||||
)
|
||||
sources_df = pd.DataFrame(sources_df_list)
|
||||
st.dataframe(sources_df)
|
||||
|
||||
|
||||
def add_to_message_history(
|
||||
role: str, content: str, extra: Optional[Dict] = None
|
||||
) -> None:
|
||||
message = {"role": role, "content": str(content), "extra": extra}
|
||||
st.session_state.agent_messages.append(message) # Add response to message history
|
||||
|
||||
|
||||
@@ -45,6 +78,11 @@ def display_messages() -> None:
|
||||
else:
|
||||
raise ValueError(f"Unknown message type: {msg_type}")
|
||||
|
||||
# display sources
|
||||
if "extra" in message and isinstance(message["extra"], dict):
|
||||
if "response" in message["extra"].keys():
|
||||
display_sources(message["extra"]["response"])
|
||||
|
||||
|
||||
# if agent is created, then we can chat with it
|
||||
if current_state.cache is not None and current_state.cache.agent is not None:
|
||||
@@ -68,6 +106,13 @@ if current_state.cache is not None and current_state.cache.agent is not None:
|
||||
with st.spinner("Thinking..."):
|
||||
response = agent.chat(str(prompt))
|
||||
st.write(str(response))
|
||||
add_to_message_history("assistant", str(response))
|
||||
|
||||
# display sources
|
||||
# Multi-modal: check if image nodes are present
|
||||
display_sources(response)
|
||||
|
||||
add_to_message_history(
|
||||
"assistant", str(response), extra={"response": response}
|
||||
)
|
||||
else:
|
||||
st.info("Agent not created. Please create an agent in the above section.")
|
||||
|
||||
+10
-1
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "rags"
|
||||
version = "0.0.4"
|
||||
version = "0.0.5"
|
||||
description = "Build RAG with natural language."
|
||||
authors = ["Jerry Liu"]
|
||||
# New attributes
|
||||
@@ -22,6 +22,7 @@ llama-hub = "0.0.44"
|
||||
# NOTE: this is due to a trivial dependency in the web tool, will refactor
|
||||
langchain = "0.0.305"
|
||||
pypdf = "3.17.1"
|
||||
clip = { git = "https://github.com/openai/CLIP.git" }
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
# pytest = "7.2.1"
|
||||
@@ -65,3 +66,11 @@ exclude = [
|
||||
|
||||
[tool.ruff.per-file-ignores]
|
||||
"base.py" = ["E402", "F811", "E501"]
|
||||
|
||||
|
||||
[tool.poetry.extras]
|
||||
multimodal = [
|
||||
"torch",
|
||||
"torchvision",
|
||||
"clip",
|
||||
]
|
||||
+49
-4
@@ -1,9 +1,9 @@
|
||||
"""Streamlit utils."""
|
||||
from core.agent_builder import (
|
||||
from core.agent_builder.loader import (
|
||||
load_meta_agent_and_tools,
|
||||
AgentCacheRegistry,
|
||||
RAGAgentBuilder,
|
||||
)
|
||||
from core.agent_builder.base import BaseRAGAgentBuilder
|
||||
from core.param_cache import ParamCache
|
||||
from core.constants import (
|
||||
AGENT_CACHE_DIR,
|
||||
@@ -38,6 +38,47 @@ def update_selected_agent() -> None:
|
||||
update_selected_agent_with_id(selected_id)
|
||||
|
||||
|
||||
def get_cached_is_multimodal() -> bool:
|
||||
"""Get default multimodal st."""
|
||||
if (
|
||||
"selected_cache" not in st.session_state.keys()
|
||||
or st.session_state.selected_cache is None
|
||||
):
|
||||
default_val = False
|
||||
else:
|
||||
selected_cache = cast(ParamCache, st.session_state.selected_cache)
|
||||
default_val = True if selected_cache.builder_type == "multimodal" else False
|
||||
return default_val
|
||||
|
||||
|
||||
def get_is_multimodal() -> bool:
|
||||
"""Get is multimodal."""
|
||||
if "is_multimodal_st" not in st.session_state.keys():
|
||||
st.session_state.is_multimodal_st = False
|
||||
return st.session_state.is_multimodal_st
|
||||
|
||||
|
||||
def add_builder_config() -> None:
|
||||
"""Add builder config."""
|
||||
with st.expander("Builder Config (Advanced)"):
|
||||
# add a few options - openai api key, and
|
||||
if (
|
||||
"selected_cache" not in st.session_state.keys()
|
||||
or st.session_state.selected_cache is None
|
||||
):
|
||||
is_locked = False
|
||||
else:
|
||||
is_locked = True
|
||||
|
||||
st.checkbox(
|
||||
"Enable multimodal search (beta)",
|
||||
key="is_multimodal_st",
|
||||
on_change=update_selected_agent,
|
||||
value=get_cached_is_multimodal(),
|
||||
disabled=is_locked,
|
||||
)
|
||||
|
||||
|
||||
def add_sidebar() -> None:
|
||||
"""Add sidebar."""
|
||||
with st.sidebar:
|
||||
@@ -70,7 +111,7 @@ class CurrentSessionState(BaseModel):
|
||||
agent_registry: AgentCacheRegistry
|
||||
selected_id: Optional[str]
|
||||
selected_cache: Optional[ParamCache]
|
||||
agent_builder: RAGAgentBuilder
|
||||
agent_builder: BaseRAGAgentBuilder
|
||||
cache: ParamCache
|
||||
builder_agent: BaseAgent
|
||||
|
||||
@@ -126,11 +167,15 @@ def get_current_state() -> CurrentSessionState:
|
||||
builder_agent, agent_builder = load_meta_agent_and_tools(
|
||||
cache=st.session_state.selected_cache,
|
||||
agent_registry=st.session_state.agent_registry,
|
||||
# NOTE: we will probably generalize this later into different
|
||||
# builder configs
|
||||
is_multimodal=get_cached_is_multimodal(),
|
||||
)
|
||||
else:
|
||||
# create builder agent / tools from new cache
|
||||
builder_agent, agent_builder = load_meta_agent_and_tools(
|
||||
agent_registry=st.session_state.agent_registry
|
||||
agent_registry=st.session_state.agent_registry,
|
||||
is_multimodal=get_is_multimodal(),
|
||||
)
|
||||
|
||||
st.session_state.builder_agent = builder_agent
|
||||
|
||||
Reference in New Issue
Block a user