Files
streamlit-agent/streamlit_agent/basic_memory.py
T
2024-01-15 22:51:33 -08:00

67 lines
2.5 KiB
Python

from langchain.chains import LLMChain
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate
from langchain_community.chat_message_histories import StreamlitChatMessageHistory
from langchain_openai import OpenAI
import streamlit as st
st.set_page_config(page_title="StreamlitChatMessageHistory", page_icon="📖")
st.title("📖 StreamlitChatMessageHistory")
"""
A basic example of using StreamlitChatMessageHistory to help LLMChain remember messages in a conversation.
The messages are stored in Session State across re-runs automatically. You can view the contents of Session State
in the expander below. View the
[source code for this app](https://github.com/langchain-ai/streamlit-agent/blob/main/streamlit_agent/basic_memory.py).
"""
# Set up memory
msgs = StreamlitChatMessageHistory(key="langchain_messages")
memory = ConversationBufferMemory(chat_memory=msgs)
if len(msgs.messages) == 0:
msgs.add_ai_message("How can I help you?")
view_messages = st.expander("View the message contents in session state")
# Get an OpenAI API Key before continuing
if "openai_api_key" in st.secrets:
openai_api_key = st.secrets.openai_api_key
else:
openai_api_key = st.sidebar.text_input("OpenAI API Key", type="password")
if not openai_api_key:
st.info("Enter an OpenAI API Key to continue")
st.stop()
# Set up the LLMChain, passing in memory
template = """You are an AI chatbot having a conversation with a human.
{history}
Human: {human_input}
AI: """
prompt = PromptTemplate(input_variables=["history", "human_input"], template=template)
llm_chain = LLMChain(llm=OpenAI(openai_api_key=openai_api_key), prompt=prompt, memory=memory)
# Render current messages from StreamlitChatMessageHistory
for msg in msgs.messages:
st.chat_message(msg.type).write(msg.content)
# If user inputs a new prompt, generate and draw a new response
if prompt := st.chat_input():
st.chat_message("human").write(prompt)
# Note: new messages are saved to history automatically by Langchain during run
response = llm_chain.invoke(prompt)
st.chat_message("ai").write(response["text"])
# Draw the messages at the end, so newly generated ones show up immediately
with view_messages:
"""
Memory initialized with:
```python
msgs = StreamlitChatMessageHistory(key="langchain_messages")
memory = ConversationBufferMemory(chat_memory=msgs)
```
Contents of `st.session_state.langchain_messages`:
"""
view_messages.json(st.session_state.langchain_messages)