diff --git a/streamlit_agent/clear_results.py b/streamlit_agent/clear_results.py new file mode 100644 index 0000000..a8c38e4 --- /dev/null +++ b/streamlit_agent/clear_results.py @@ -0,0 +1,32 @@ +import streamlit as st + + +# A hack to "clear" the previous result when submitting a new prompt. This avoids +# the "previous run's text is grayed-out but visible during rerun" Streamlit behavior. +class DirtyState: + NOT_DIRTY = "NOT_DIRTY" + DIRTY = "DIRTY" + UNHANDLED_SUBMIT = "UNHANDLED_SUBMIT" + + +def get_dirty_state() -> str: + return st.session_state.get("dirty_state", DirtyState.NOT_DIRTY) + + +def set_dirty_state(state: str) -> None: + st.session_state["dirty_state"] = state + + +def with_clear_container(submit_clicked: bool) -> bool: + if get_dirty_state() == DirtyState.DIRTY: + if submit_clicked: + set_dirty_state(DirtyState.UNHANDLED_SUBMIT) + st.experimental_rerun() + else: + set_dirty_state(DirtyState.NOT_DIRTY) + + if submit_clicked or get_dirty_state() == DirtyState.UNHANDLED_SUBMIT: + set_dirty_state(DirtyState.DIRTY) + return True + + return False diff --git a/streamlit_agent/mrkl_demo.py b/streamlit_agent/mrkl_demo.py index 525347d..8a0ddfa 100644 --- a/streamlit_agent/mrkl_demo.py +++ b/streamlit_agent/mrkl_demo.py @@ -14,6 +14,7 @@ from langchain.callbacks import StreamlitCallbackHandler from langchain.utilities import DuckDuckGoSearchAPIWrapper from streamlit_agent.callbacks.capturing_callback_handler import playback_callbacks +from streamlit_agent.clear_results import with_clear_container DB_PATH = (Path(__file__).parent / "Chinook.db").absolute() @@ -24,9 +25,7 @@ SAVED_SESSIONS = { "are in the FooBar database?": "alanis.pickle", } -st.set_page_config( - page_title="MRKL", page_icon="🦜", layout="wide", initial_sidebar_state="collapsed" -) +st.set_page_config(page_title="MRKL", page_icon="🦜", layout="wide", initial_sidebar_state="collapsed") "# 🦜🔗 MRKL" @@ -102,10 +101,7 @@ if st.session_state["dirty_state"] == "dirty": st.empty() st.experimental_rerun() -if ( - not st.session_state["latest_user_input_executed"] - and st.session_state["dirty_state"] == "initial" -): +if not st.session_state["latest_user_input_executed"] and st.session_state["dirty_state"] == "initial": if st.session_state["latest_user_input"]: st.chat_message("user").write(st.session_state["latest_user_input"]) @@ -122,9 +118,7 @@ if ( [st_callback], str(session_path), max_pause_time=3 ) else: - answer = mrkl.run( - st.session_state["latest_user_input"], callbacks=[st_callback] - ) + answer = mrkl.run(st.session_state["latest_user_input"], callbacks=[st_callback]) result_container.write(answer) st.session_state["dirty_state"] = "dirty"