mrkl_demo source

This commit is contained in:
Tim Conkling
2023-06-26 12:31:39 -04:00
parent be2ffeb262
commit 262c9bd1f1
7 changed files with 352 additions and 0 deletions
Binary file not shown.
@@ -0,0 +1,159 @@
"""Callback Handler captures all callbacks in a session for future offline playback."""
from __future__ import annotations
import pickle
import time
from typing import Any, TypedDict
from langchain.callbacks.base import BaseCallbackHandler
# This is intentionally not an enum so that we avoid serializing a
# custom class with pickle.
class CallbackType:
ON_LLM_START = "on_llm_start"
ON_LLM_NEW_TOKEN = "on_llm_new_token"
ON_LLM_END = "on_llm_end"
ON_LLM_ERROR = "on_llm_error"
ON_TOOL_START = "on_tool_start"
ON_TOOL_END = "on_tool_end"
ON_TOOL_ERROR = "on_tool_error"
ON_TEXT = "on_text"
ON_CHAIN_START = "on_chain_start"
ON_CHAIN_END = "on_chain_end"
ON_CHAIN_ERROR = "on_chain_error"
ON_AGENT_ACTION = "on_agent_action"
ON_AGENT_FINISH = "on_agent_finish"
# We use TypedDict, rather than NamedTuple, so that we avoid serializing a
# custom class with pickle. All of this class's members should be basic Python types.
class CallbackRecord(TypedDict):
callback_type: str
args: tuple[Any, ...]
kwargs: dict[str, Any]
time_delta: float # Number of seconds between this record and the previous one
def load_records_from_file(path: str) -> list[CallbackRecord]:
"""Load the list of CallbackRecords from a pickle file at the given path."""
with open(path, "rb") as file:
records = pickle.load(file)
if not isinstance(records, list):
raise RuntimeError(f"Bad CallbackRecord data in {path}")
return records
def playback_callbacks(
handlers: list[BaseCallbackHandler],
records_or_filename: list[CallbackRecord] | str,
max_pause_time: float,
) -> str:
if isinstance(records_or_filename, list):
records = records_or_filename
else:
records = load_records_from_file(records_or_filename)
for record in records:
pause_time = min(record["time_delta"], max_pause_time)
if pause_time > 0:
time.sleep(pause_time)
for handler in handlers:
if record["callback_type"] == CallbackType.ON_LLM_START:
handler.on_llm_start(*record["args"], **record["kwargs"])
elif record["callback_type"] == CallbackType.ON_LLM_NEW_TOKEN:
handler.on_llm_new_token(*record["args"], **record["kwargs"])
elif record["callback_type"] == CallbackType.ON_LLM_END:
handler.on_llm_end(*record["args"], **record["kwargs"])
elif record["callback_type"] == CallbackType.ON_LLM_ERROR:
handler.on_llm_error(*record["args"], **record["kwargs"])
elif record["callback_type"] == CallbackType.ON_TOOL_START:
handler.on_tool_start(*record["args"], **record["kwargs"])
elif record["callback_type"] == CallbackType.ON_TOOL_END:
handler.on_tool_end(*record["args"], **record["kwargs"])
elif record["callback_type"] == CallbackType.ON_TOOL_ERROR:
handler.on_tool_error(*record["args"], **record["kwargs"])
elif record["callback_type"] == CallbackType.ON_TEXT:
handler.on_text(*record["args"], **record["kwargs"])
elif record["callback_type"] == CallbackType.ON_CHAIN_START:
handler.on_chain_start(*record["args"], **record["kwargs"])
elif record["callback_type"] == CallbackType.ON_CHAIN_END:
handler.on_chain_end(*record["args"], **record["kwargs"])
elif record["callback_type"] == CallbackType.ON_CHAIN_ERROR:
handler.on_chain_error(*record["args"], **record["kwargs"])
elif record["callback_type"] == CallbackType.ON_AGENT_ACTION:
handler.on_agent_action(*record["args"], **record["kwargs"])
elif record["callback_type"] == CallbackType.ON_AGENT_FINISH:
handler.on_agent_finish(*record["args"], **record["kwargs"])
# Return the agent's result
for record in records:
if record["callback_type"] == CallbackType.ON_AGENT_FINISH:
return record["args"][0][0]["output"]
return "[Missing Agent Result]"
class CapturingCallbackHandler(BaseCallbackHandler):
def __init__(self) -> None:
self._records: list[CallbackRecord] = []
self._last_time: float | None = None
def dump_records_to_file(self, path: str) -> None:
"""Write the list of CallbackRecords to a pickle file at the given path."""
with open(path, "wb") as file:
pickle.dump(self._records, file)
def _append_record(
self, type: str, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> None:
time_now = time.time()
time_delta = time_now - self._last_time if self._last_time is not None else 0
self._last_time = time_now
self._records.append(
CallbackRecord(
callback_type=type, args=args, kwargs=kwargs, time_delta=time_delta
)
)
def on_llm_start(self, *args: Any, **kwargs: Any) -> None:
self._append_record(CallbackType.ON_LLM_START, args, kwargs)
def on_llm_new_token(self, *args: Any, **kwargs: Any) -> None:
self._append_record(CallbackType.ON_LLM_NEW_TOKEN, args, kwargs)
def on_llm_end(self, *args: Any, **kwargs: Any) -> None:
self._append_record(CallbackType.ON_LLM_END, args, kwargs)
def on_llm_error(self, *args: Any, **kwargs: Any) -> None:
self._append_record(CallbackType.ON_LLM_ERROR, args, kwargs)
def on_tool_start(self, *args: Any, **kwargs: Any) -> None:
self._append_record(CallbackType.ON_TOOL_START, args, kwargs)
def on_tool_end(self, *args: Any, **kwargs: Any) -> None:
self._append_record(CallbackType.ON_TOOL_END, args, kwargs)
def on_tool_error(self, *args: Any, **kwargs: Any) -> None:
self._append_record(CallbackType.ON_TOOL_ERROR, args, kwargs)
def on_text(self, *args: Any, **kwargs: Any) -> None:
self._append_record(CallbackType.ON_TEXT, args, kwargs)
def on_chain_start(self, *args: Any, **kwargs: Any) -> None:
self._append_record(CallbackType.ON_CHAIN_START, args, kwargs)
def on_chain_end(self, *args: Any, **kwargs: Any) -> None:
self._append_record(CallbackType.ON_CHAIN_END, args, kwargs)
def on_chain_error(self, *args: Any, **kwargs: Any) -> None:
self._append_record(CallbackType.ON_CHAIN_ERROR, args, kwargs)
def on_agent_action(self, *args: Any, **kwargs: Any) -> Any:
self._append_record(CallbackType.ON_AGENT_ACTION, args, kwargs)
def on_agent_finish(self, *args: Any, **kwargs: Any) -> None:
self._append_record(CallbackType.ON_AGENT_FINISH, args, kwargs)
+31
View File
@@ -0,0 +1,31 @@
import streamlit as st
# A hack to "clear" the previous result when submitting a new prompt. This avoids
# the "previous run's text is grayed-out but visible during rerun" Streamlit behavior.
class DirtyState:
NOT_DIRTY = "NOT_DIRTY"
DIRTY = "DIRTY"
UNHANDLED_SUBMIT = "UNHANDLED_SUBMIT"
def get_dirty_state() -> str:
return st.session_state.get("dirty_state", DirtyState.NOT_DIRTY)
def set_dirty_state(state: str) -> None:
st.session_state["dirty_state"] = state
def with_clear_container(submit_clicked):
if get_dirty_state() == DirtyState.DIRTY:
if submit_clicked:
set_dirty_state(DirtyState.UNHANDLED_SUBMIT)
st.experimental_rerun()
else:
set_dirty_state(DirtyState.NOT_DIRTY)
if submit_clicked or get_dirty_state() == DirtyState.UNHANDLED_SUBMIT:
set_dirty_state(DirtyState.DIRTY)
return True
return False
+162
View File
@@ -0,0 +1,162 @@
from pathlib import Path
import streamlit as st
st.set_page_config(page_title="MRKL", page_icon="🦜", layout="wide")
"# 🦜🔗 MRKL"
"""
This Streamlit app showcases a LangChain agent that replicates the
[MRKL chain](https://arxiv.org/abs/2205.00445).
This uses the [example Chinook database](https://github.com/lerocha/chinook-database).
To set it up follow the instructions [here](https://database.guide/2-sample-databases-sqlite/),
placing the .db file in the same directory as this app.
"""
# Setup credentials in Streamlit
user_openai_api_key = st.sidebar.text_input(
"OpenAI API Key", type="password", help="Set this to run your own custom questions."
)
user_serpapi_api_key = st.sidebar.text_input(
"SerpAPI API Key",
type="password",
help="Set this to run your own custom questions. Get yours at https://serpapi.com/manage-api-key.",
)
if user_openai_api_key and user_serpapi_api_key:
openai_api_key = user_openai_api_key
serpapi_api_key = user_serpapi_api_key
enable_custom = True
else:
openai_api_key = "not_supplied"
serpapi_api_key = "not_supplied"
enable_custom = False
with st.expander("👉 View the source code"), st.echo():
# LangChain imports
from langchain import (
LLMMathChain,
OpenAI,
SerpAPIWrapper,
SQLDatabase,
SQLDatabaseChain,
)
from langchain.agents import AgentType
from langchain.agents import initialize_agent, Tool
from langchain.callbacks import StreamlitCallbackHandler
from callbacks.capturing_callback_handler import playback_callbacks
# Tools setup
DB_PATH = (Path(__file__).parent / "Chinook.db").absolute()
llm = OpenAI(temperature=0, openai_api_key=openai_api_key, streaming=True)
search = SerpAPIWrapper(serpapi_api_key=serpapi_api_key)
llm_math_chain = LLMMathChain(llm=llm, verbose=True)
db = SQLDatabase.from_uri(f"sqlite:///{DB_PATH}")
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)
tools = [
Tool(
name="Search",
func=search.run,
description="useful for when you need to answer questions about current events. You should ask targeted questions",
),
Tool(
name="Calculator",
func=llm_math_chain.run,
description="useful for when you need to answer questions about math",
),
Tool(
name="FooBar DB",
func=db_chain.run,
description="useful for when you need to answer questions about FooBar. Input should be in the form of a question containing full context",
),
]
# Initialize agent
mrkl = initialize_agent(
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
)
# To run the agent, use `mrkl.run(mrkl_input)`
# More Streamlit here!
expand_new_thoughts = st.sidebar.checkbox(
"Expand New Thoughts",
value=True,
help="True if LLM thoughts should be expanded by default",
)
collapse_completed_thoughts = st.sidebar.checkbox(
"Collapse Completed Thoughts",
value=True,
help="True if LLM thoughts should be collapsed when they complete",
)
max_thought_containers = st.sidebar.number_input(
"Max Thought Containers",
value=4,
min_value=1,
help="Max number of completed thoughts to show. When exceeded, older thoughts will be moved into a 'History' expander.",
)
SAVED_SESSIONS = {
"Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?": "leo.pickle",
"What is the full name of the artist who recently released an album called "
"'The Storm Before the Calm' and are they in the FooBar database? If so, what albums of theirs "
"are in the FooBar database?": "alanis.pickle",
}
key = "input"
shadow_key = "_input"
if key in st.session_state and shadow_key not in st.session_state:
st.session_state[shadow_key] = st.session_state[key]
with st.form(key="form"):
if not enable_custom:
"Ask one of the sample questions, or enter your API Keys in the sidebar to ask your own custom questions."
prefilled = st.selectbox("Sample questions", sorted(SAVED_SESSIONS.keys())) or ""
mrkl_input = ""
if enable_custom:
mrkl_input = st.text_input("Or, ask your own question", key=shadow_key)
st.session_state[key] = mrkl_input
if not mrkl_input:
mrkl_input = prefilled
submit_clicked = st.form_submit_button("Submit Question")
question_container = st.empty()
results_container = st.empty()
# A hack to "clear" the previous result when submitting a new prompt.
from clear_results import with_clear_container
if with_clear_container(submit_clicked):
# Create our StreamlitCallbackHandler
res = results_container.container()
streamlit_handler = StreamlitCallbackHandler(
parent_container=res,
max_thought_containers=int(max_thought_containers),
expand_new_thoughts=expand_new_thoughts,
collapse_completed_thoughts=collapse_completed_thoughts,
)
question_container.write(f"**Question:** {mrkl_input}")
# If we've saved this question, play it back instead of actually running LangChain
# (so that we don't exhaust our API calls unnecessarily)
if mrkl_input in SAVED_SESSIONS:
session_name = SAVED_SESSIONS[mrkl_input]
session_path = Path(__file__).parent / "runs" / session_name
print(f"Playing saved session: {session_path}")
answer = playback_callbacks(
[streamlit_handler], str(session_path), max_pause_time=3
)
res.write(f"**Answer:** {answer}")
else:
answer = mrkl.run(mrkl_input, callbacks=[streamlit_handler])
res.write(f"**Answer:** {answer}")
Binary file not shown.
Binary file not shown.