mirror of
https://github.com/langchain-ai/streamlit-agent.git
synced 2026-07-01 09:25:05 -04:00
mrkl_demo source
This commit is contained in:
Binary file not shown.
@@ -0,0 +1,159 @@
|
||||
"""Callback Handler captures all callbacks in a session for future offline playback."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
import time
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
|
||||
|
||||
# This is intentionally not an enum so that we avoid serializing a
|
||||
# custom class with pickle.
|
||||
class CallbackType:
|
||||
ON_LLM_START = "on_llm_start"
|
||||
ON_LLM_NEW_TOKEN = "on_llm_new_token"
|
||||
ON_LLM_END = "on_llm_end"
|
||||
ON_LLM_ERROR = "on_llm_error"
|
||||
ON_TOOL_START = "on_tool_start"
|
||||
ON_TOOL_END = "on_tool_end"
|
||||
ON_TOOL_ERROR = "on_tool_error"
|
||||
ON_TEXT = "on_text"
|
||||
ON_CHAIN_START = "on_chain_start"
|
||||
ON_CHAIN_END = "on_chain_end"
|
||||
ON_CHAIN_ERROR = "on_chain_error"
|
||||
ON_AGENT_ACTION = "on_agent_action"
|
||||
ON_AGENT_FINISH = "on_agent_finish"
|
||||
|
||||
|
||||
# We use TypedDict, rather than NamedTuple, so that we avoid serializing a
|
||||
# custom class with pickle. All of this class's members should be basic Python types.
|
||||
class CallbackRecord(TypedDict):
|
||||
callback_type: str
|
||||
args: tuple[Any, ...]
|
||||
kwargs: dict[str, Any]
|
||||
time_delta: float # Number of seconds between this record and the previous one
|
||||
|
||||
|
||||
def load_records_from_file(path: str) -> list[CallbackRecord]:
|
||||
"""Load the list of CallbackRecords from a pickle file at the given path."""
|
||||
with open(path, "rb") as file:
|
||||
records = pickle.load(file)
|
||||
|
||||
if not isinstance(records, list):
|
||||
raise RuntimeError(f"Bad CallbackRecord data in {path}")
|
||||
return records
|
||||
|
||||
|
||||
def playback_callbacks(
|
||||
handlers: list[BaseCallbackHandler],
|
||||
records_or_filename: list[CallbackRecord] | str,
|
||||
max_pause_time: float,
|
||||
) -> str:
|
||||
if isinstance(records_or_filename, list):
|
||||
records = records_or_filename
|
||||
else:
|
||||
records = load_records_from_file(records_or_filename)
|
||||
|
||||
for record in records:
|
||||
pause_time = min(record["time_delta"], max_pause_time)
|
||||
if pause_time > 0:
|
||||
time.sleep(pause_time)
|
||||
|
||||
for handler in handlers:
|
||||
if record["callback_type"] == CallbackType.ON_LLM_START:
|
||||
handler.on_llm_start(*record["args"], **record["kwargs"])
|
||||
elif record["callback_type"] == CallbackType.ON_LLM_NEW_TOKEN:
|
||||
handler.on_llm_new_token(*record["args"], **record["kwargs"])
|
||||
elif record["callback_type"] == CallbackType.ON_LLM_END:
|
||||
handler.on_llm_end(*record["args"], **record["kwargs"])
|
||||
elif record["callback_type"] == CallbackType.ON_LLM_ERROR:
|
||||
handler.on_llm_error(*record["args"], **record["kwargs"])
|
||||
elif record["callback_type"] == CallbackType.ON_TOOL_START:
|
||||
handler.on_tool_start(*record["args"], **record["kwargs"])
|
||||
elif record["callback_type"] == CallbackType.ON_TOOL_END:
|
||||
handler.on_tool_end(*record["args"], **record["kwargs"])
|
||||
elif record["callback_type"] == CallbackType.ON_TOOL_ERROR:
|
||||
handler.on_tool_error(*record["args"], **record["kwargs"])
|
||||
elif record["callback_type"] == CallbackType.ON_TEXT:
|
||||
handler.on_text(*record["args"], **record["kwargs"])
|
||||
elif record["callback_type"] == CallbackType.ON_CHAIN_START:
|
||||
handler.on_chain_start(*record["args"], **record["kwargs"])
|
||||
elif record["callback_type"] == CallbackType.ON_CHAIN_END:
|
||||
handler.on_chain_end(*record["args"], **record["kwargs"])
|
||||
elif record["callback_type"] == CallbackType.ON_CHAIN_ERROR:
|
||||
handler.on_chain_error(*record["args"], **record["kwargs"])
|
||||
elif record["callback_type"] == CallbackType.ON_AGENT_ACTION:
|
||||
handler.on_agent_action(*record["args"], **record["kwargs"])
|
||||
elif record["callback_type"] == CallbackType.ON_AGENT_FINISH:
|
||||
handler.on_agent_finish(*record["args"], **record["kwargs"])
|
||||
|
||||
# Return the agent's result
|
||||
for record in records:
|
||||
if record["callback_type"] == CallbackType.ON_AGENT_FINISH:
|
||||
return record["args"][0][0]["output"]
|
||||
|
||||
return "[Missing Agent Result]"
|
||||
|
||||
|
||||
class CapturingCallbackHandler(BaseCallbackHandler):
|
||||
def __init__(self) -> None:
|
||||
self._records: list[CallbackRecord] = []
|
||||
self._last_time: float | None = None
|
||||
|
||||
def dump_records_to_file(self, path: str) -> None:
|
||||
"""Write the list of CallbackRecords to a pickle file at the given path."""
|
||||
with open(path, "wb") as file:
|
||||
pickle.dump(self._records, file)
|
||||
|
||||
def _append_record(
|
||||
self, type: str, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> None:
|
||||
time_now = time.time()
|
||||
time_delta = time_now - self._last_time if self._last_time is not None else 0
|
||||
self._last_time = time_now
|
||||
self._records.append(
|
||||
CallbackRecord(
|
||||
callback_type=type, args=args, kwargs=kwargs, time_delta=time_delta
|
||||
)
|
||||
)
|
||||
|
||||
def on_llm_start(self, *args: Any, **kwargs: Any) -> None:
|
||||
self._append_record(CallbackType.ON_LLM_START, args, kwargs)
|
||||
|
||||
def on_llm_new_token(self, *args: Any, **kwargs: Any) -> None:
|
||||
self._append_record(CallbackType.ON_LLM_NEW_TOKEN, args, kwargs)
|
||||
|
||||
def on_llm_end(self, *args: Any, **kwargs: Any) -> None:
|
||||
self._append_record(CallbackType.ON_LLM_END, args, kwargs)
|
||||
|
||||
def on_llm_error(self, *args: Any, **kwargs: Any) -> None:
|
||||
self._append_record(CallbackType.ON_LLM_ERROR, args, kwargs)
|
||||
|
||||
def on_tool_start(self, *args: Any, **kwargs: Any) -> None:
|
||||
self._append_record(CallbackType.ON_TOOL_START, args, kwargs)
|
||||
|
||||
def on_tool_end(self, *args: Any, **kwargs: Any) -> None:
|
||||
self._append_record(CallbackType.ON_TOOL_END, args, kwargs)
|
||||
|
||||
def on_tool_error(self, *args: Any, **kwargs: Any) -> None:
|
||||
self._append_record(CallbackType.ON_TOOL_ERROR, args, kwargs)
|
||||
|
||||
def on_text(self, *args: Any, **kwargs: Any) -> None:
|
||||
self._append_record(CallbackType.ON_TEXT, args, kwargs)
|
||||
|
||||
def on_chain_start(self, *args: Any, **kwargs: Any) -> None:
|
||||
self._append_record(CallbackType.ON_CHAIN_START, args, kwargs)
|
||||
|
||||
def on_chain_end(self, *args: Any, **kwargs: Any) -> None:
|
||||
self._append_record(CallbackType.ON_CHAIN_END, args, kwargs)
|
||||
|
||||
def on_chain_error(self, *args: Any, **kwargs: Any) -> None:
|
||||
self._append_record(CallbackType.ON_CHAIN_ERROR, args, kwargs)
|
||||
|
||||
def on_agent_action(self, *args: Any, **kwargs: Any) -> Any:
|
||||
self._append_record(CallbackType.ON_AGENT_ACTION, args, kwargs)
|
||||
|
||||
def on_agent_finish(self, *args: Any, **kwargs: Any) -> None:
|
||||
self._append_record(CallbackType.ON_AGENT_FINISH, args, kwargs)
|
||||
@@ -0,0 +1,31 @@
|
||||
import streamlit as st
|
||||
|
||||
# A hack to "clear" the previous result when submitting a new prompt. This avoids
|
||||
# the "previous run's text is grayed-out but visible during rerun" Streamlit behavior.
|
||||
class DirtyState:
|
||||
NOT_DIRTY = "NOT_DIRTY"
|
||||
DIRTY = "DIRTY"
|
||||
UNHANDLED_SUBMIT = "UNHANDLED_SUBMIT"
|
||||
|
||||
|
||||
def get_dirty_state() -> str:
|
||||
return st.session_state.get("dirty_state", DirtyState.NOT_DIRTY)
|
||||
|
||||
|
||||
def set_dirty_state(state: str) -> None:
|
||||
st.session_state["dirty_state"] = state
|
||||
|
||||
|
||||
def with_clear_container(submit_clicked):
|
||||
if get_dirty_state() == DirtyState.DIRTY:
|
||||
if submit_clicked:
|
||||
set_dirty_state(DirtyState.UNHANDLED_SUBMIT)
|
||||
st.experimental_rerun()
|
||||
else:
|
||||
set_dirty_state(DirtyState.NOT_DIRTY)
|
||||
|
||||
if submit_clicked or get_dirty_state() == DirtyState.UNHANDLED_SUBMIT:
|
||||
set_dirty_state(DirtyState.DIRTY)
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -0,0 +1,162 @@
|
||||
from pathlib import Path
|
||||
|
||||
import streamlit as st
|
||||
|
||||
st.set_page_config(page_title="MRKL", page_icon="🦜", layout="wide")
|
||||
|
||||
"# 🦜🔗 MRKL"
|
||||
|
||||
"""
|
||||
This Streamlit app showcases a LangChain agent that replicates the
|
||||
[MRKL chain](https://arxiv.org/abs/2205.00445).
|
||||
|
||||
This uses the [example Chinook database](https://github.com/lerocha/chinook-database).
|
||||
To set it up follow the instructions [here](https://database.guide/2-sample-databases-sqlite/),
|
||||
placing the .db file in the same directory as this app.
|
||||
|
||||
"""
|
||||
|
||||
# Setup credentials in Streamlit
|
||||
user_openai_api_key = st.sidebar.text_input(
|
||||
"OpenAI API Key", type="password", help="Set this to run your own custom questions."
|
||||
)
|
||||
user_serpapi_api_key = st.sidebar.text_input(
|
||||
"SerpAPI API Key",
|
||||
type="password",
|
||||
help="Set this to run your own custom questions. Get yours at https://serpapi.com/manage-api-key.",
|
||||
)
|
||||
|
||||
if user_openai_api_key and user_serpapi_api_key:
|
||||
openai_api_key = user_openai_api_key
|
||||
serpapi_api_key = user_serpapi_api_key
|
||||
enable_custom = True
|
||||
else:
|
||||
openai_api_key = "not_supplied"
|
||||
serpapi_api_key = "not_supplied"
|
||||
enable_custom = False
|
||||
|
||||
with st.expander("👉 View the source code"), st.echo():
|
||||
# LangChain imports
|
||||
from langchain import (
|
||||
LLMMathChain,
|
||||
OpenAI,
|
||||
SerpAPIWrapper,
|
||||
SQLDatabase,
|
||||
SQLDatabaseChain,
|
||||
)
|
||||
from langchain.agents import AgentType
|
||||
from langchain.agents import initialize_agent, Tool
|
||||
from langchain.callbacks import StreamlitCallbackHandler
|
||||
|
||||
from callbacks.capturing_callback_handler import playback_callbacks
|
||||
|
||||
# Tools setup
|
||||
DB_PATH = (Path(__file__).parent / "Chinook.db").absolute()
|
||||
|
||||
llm = OpenAI(temperature=0, openai_api_key=openai_api_key, streaming=True)
|
||||
search = SerpAPIWrapper(serpapi_api_key=serpapi_api_key)
|
||||
llm_math_chain = LLMMathChain(llm=llm, verbose=True)
|
||||
db = SQLDatabase.from_uri(f"sqlite:///{DB_PATH}")
|
||||
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)
|
||||
tools = [
|
||||
Tool(
|
||||
name="Search",
|
||||
func=search.run,
|
||||
description="useful for when you need to answer questions about current events. You should ask targeted questions",
|
||||
),
|
||||
Tool(
|
||||
name="Calculator",
|
||||
func=llm_math_chain.run,
|
||||
description="useful for when you need to answer questions about math",
|
||||
),
|
||||
Tool(
|
||||
name="FooBar DB",
|
||||
func=db_chain.run,
|
||||
description="useful for when you need to answer questions about FooBar. Input should be in the form of a question containing full context",
|
||||
),
|
||||
]
|
||||
|
||||
# Initialize agent
|
||||
mrkl = initialize_agent(
|
||||
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||
)
|
||||
# To run the agent, use `mrkl.run(mrkl_input)`
|
||||
|
||||
# More Streamlit here!
|
||||
|
||||
expand_new_thoughts = st.sidebar.checkbox(
|
||||
"Expand New Thoughts",
|
||||
value=True,
|
||||
help="True if LLM thoughts should be expanded by default",
|
||||
)
|
||||
|
||||
collapse_completed_thoughts = st.sidebar.checkbox(
|
||||
"Collapse Completed Thoughts",
|
||||
value=True,
|
||||
help="True if LLM thoughts should be collapsed when they complete",
|
||||
)
|
||||
|
||||
max_thought_containers = st.sidebar.number_input(
|
||||
"Max Thought Containers",
|
||||
value=4,
|
||||
min_value=1,
|
||||
help="Max number of completed thoughts to show. When exceeded, older thoughts will be moved into a 'History' expander.",
|
||||
)
|
||||
|
||||
SAVED_SESSIONS = {
|
||||
"Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?": "leo.pickle",
|
||||
"What is the full name of the artist who recently released an album called "
|
||||
"'The Storm Before the Calm' and are they in the FooBar database? If so, what albums of theirs "
|
||||
"are in the FooBar database?": "alanis.pickle",
|
||||
}
|
||||
|
||||
key = "input"
|
||||
shadow_key = "_input"
|
||||
|
||||
if key in st.session_state and shadow_key not in st.session_state:
|
||||
st.session_state[shadow_key] = st.session_state[key]
|
||||
|
||||
with st.form(key="form"):
|
||||
if not enable_custom:
|
||||
"Ask one of the sample questions, or enter your API Keys in the sidebar to ask your own custom questions."
|
||||
prefilled = st.selectbox("Sample questions", sorted(SAVED_SESSIONS.keys())) or ""
|
||||
mrkl_input = ""
|
||||
|
||||
if enable_custom:
|
||||
mrkl_input = st.text_input("Or, ask your own question", key=shadow_key)
|
||||
st.session_state[key] = mrkl_input
|
||||
if not mrkl_input:
|
||||
mrkl_input = prefilled
|
||||
submit_clicked = st.form_submit_button("Submit Question")
|
||||
|
||||
question_container = st.empty()
|
||||
results_container = st.empty()
|
||||
|
||||
# A hack to "clear" the previous result when submitting a new prompt.
|
||||
from clear_results import with_clear_container
|
||||
|
||||
if with_clear_container(submit_clicked):
|
||||
# Create our StreamlitCallbackHandler
|
||||
res = results_container.container()
|
||||
streamlit_handler = StreamlitCallbackHandler(
|
||||
parent_container=res,
|
||||
max_thought_containers=int(max_thought_containers),
|
||||
expand_new_thoughts=expand_new_thoughts,
|
||||
collapse_completed_thoughts=collapse_completed_thoughts,
|
||||
)
|
||||
|
||||
question_container.write(f"**Question:** {mrkl_input}")
|
||||
|
||||
# If we've saved this question, play it back instead of actually running LangChain
|
||||
# (so that we don't exhaust our API calls unnecessarily)
|
||||
if mrkl_input in SAVED_SESSIONS:
|
||||
session_name = SAVED_SESSIONS[mrkl_input]
|
||||
session_path = Path(__file__).parent / "runs" / session_name
|
||||
print(f"Playing saved session: {session_path}")
|
||||
answer = playback_callbacks(
|
||||
[streamlit_handler], str(session_path), max_pause_time=3
|
||||
)
|
||||
res.write(f"**Answer:** {answer}")
|
||||
else:
|
||||
answer = mrkl.run(mrkl_input, callbacks=[streamlit_handler])
|
||||
res.write(f"**Answer:** {answer}")
|
||||
Binary file not shown.
Binary file not shown.
Reference in New Issue
Block a user