Merge pull request #4 from sfc-gh-jcarroll/session-state-index

This commit is contained in:
Lance Martin
2023-07-28 10:07:07 -07:00
committed by GitHub
+13 -11
View File
@@ -3,7 +3,8 @@ from langchain.callbacks.base import BaseCallbackHandler
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.retrievers.web_research import WebResearchRetriever
@st.cache_resource
st.set_page_config(page_title="Interweb Explorer", page_icon="🌐")
def settings():
# Vectorstore
@@ -26,10 +27,10 @@ def settings():
# Initialize
web_retriever = WebResearchRetriever.from_llm(
vectorstore=vectorstore_public,
llm=llm,
search=search,
num_search_results=3
vectorstore=vectorstore_public,
llm=llm,
search=search,
num_search_results=3
)
return web_retriever, llm
@@ -62,10 +63,13 @@ class PrintRetrievalHandler(BaseCallbackHandler):
st.sidebar.image("img/ai.png")
st.header("`Interweb Explorer`")
st.info("`I am an AI that can answer questions by exploring, reading, and summarizing web pages."
"I can be configured to use different moddes: public API or private (no data sharing).`")
"I can be configured to use different modes: public API or private (no data sharing).`")
# Make retriever and llm
web_retriever, llm = settings()
if 'retriever' not in st.session_state:
st.session_state['retriever'], st.session_state['llm'] = settings()
web_retriever = st.session_state.retriever
llm = st.session_state.llm
# User input
question = st.text_input("`Ask a question:`")
@@ -76,9 +80,8 @@ if question:
import logging
logging.basicConfig()
logging.getLogger("langchain.retrievers.web_research").setLevel(logging.INFO)
qa_chain = RetrievalQAWithSourcesChain.from_chain_type(llm,
retriever=web_retriever)
qa_chain = RetrievalQAWithSourcesChain.from_chain_type(llm, retriever=web_retriever)
# Write answer and sources
retrieval_streamer_cb = PrintRetrievalHandler(st.container())
answer = st.empty()
@@ -86,4 +89,3 @@ if question:
result = qa_chain({"question": question},callbacks=[retrieval_streamer_cb, stream_handler])
answer.info('`Answer:`\n\n' + result['answer'])
st.info('`Sources:`\n\n' + result['sources'])