From 634e1cecf23d7ca3a4c5e708944673e057765b2a Mon Sep 17 00:00:00 2001 From: Joshua Carroll Date: Thu, 8 Feb 2024 10:56:44 -0800 Subject: [PATCH] Update SQL agents to use read-only connections (#51) --- streamlit_agent/chat_with_sql_db.py | 18 ++++++++++++++++-- streamlit_agent/mrkl_demo.py | 9 ++++++++- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/streamlit_agent/chat_with_sql_db.py b/streamlit_agent/chat_with_sql_db.py index 8cf930c..c5fc3d1 100644 --- a/streamlit_agent/chat_with_sql_db.py +++ b/streamlit_agent/chat_with_sql_db.py @@ -6,20 +6,28 @@ from langchain.sql_database import SQLDatabase from langchain.agents.agent_types import AgentType from langchain.callbacks import StreamlitCallbackHandler from langchain.agents.agent_toolkits import SQLDatabaseToolkit +from sqlalchemy import create_engine +import sqlite3 st.set_page_config(page_title="LangChain: Chat with SQL DB", page_icon="🦜") st.title("🦜 LangChain: Chat with SQL DB") +INJECTION_WARNING = """ + SQL agent can be vulnerable to prompt injection. Use a DB role with limited permissions. + Read more [here](https://python.langchain.com/docs/security). + """ +LOCALDB = "USE_LOCALDB" + # User inputs radio_opt = ["Use sample database - Chinook.db", "Connect to your SQL database"] selected_opt = st.sidebar.radio(label="Choose suitable option", options=radio_opt) if radio_opt.index(selected_opt) == 1: + st.sidebar.warning(INJECTION_WARNING, icon="⚠️") db_uri = st.sidebar.text_input( label="Database URI", placeholder="mysql://user:pass@hostname:port/db" ) else: - db_filepath = (Path(__file__).parent / "Chinook.db").absolute() - db_uri = f"sqlite:////{db_filepath}" + db_uri = LOCALDB openai_api_key = st.sidebar.text_input( label="OpenAI API Key", @@ -41,6 +49,12 @@ llm = OpenAI(openai_api_key=openai_api_key, temperature=0, streaming=True) @st.cache_resource(ttl="2h") def configure_db(db_uri): + if db_uri == LOCALDB: + # Make the DB connection read-only to reduce risk of injection attacks + # See: https://python.langchain.com/docs/security + db_filepath = (Path(__file__).parent / "Chinook.db").absolute() + creator = lambda: sqlite3.connect(f"file:{db_filepath}?mode=ro", uri=True) + return SQLDatabase(create_engine("sqlite:///", creator=creator)) return SQLDatabase.from_uri(database_uri=db_uri) diff --git a/streamlit_agent/mrkl_demo.py b/streamlit_agent/mrkl_demo.py index 3aa93e9..86fece9 100644 --- a/streamlit_agent/mrkl_demo.py +++ b/streamlit_agent/mrkl_demo.py @@ -10,6 +10,8 @@ from langchain_community.utilities import DuckDuckGoSearchAPIWrapper, SQLDatabas from langchain_core.runnables import RunnableConfig from langchain_experimental.sql import SQLDatabaseChain from langchain_openai import OpenAI +from sqlalchemy import create_engine +import sqlite3 from streamlit_agent.callbacks.capturing_callback_handler import playback_callbacks from streamlit_agent.clear_results import with_clear_container @@ -45,7 +47,12 @@ else: llm = OpenAI(temperature=0, openai_api_key=openai_api_key, streaming=True) search = DuckDuckGoSearchAPIWrapper() llm_math_chain = LLMMathChain.from_llm(llm) -db = SQLDatabase.from_uri(f"sqlite:///{DB_PATH}") + +# Make the DB connection read-only to reduce risk of injection attacks +# See: https://python.langchain.com/docs/security +creator = lambda: sqlite3.connect(f"file:{DB_PATH}?mode=ro", uri=True) +db = SQLDatabase(create_engine("sqlite:///", creator=creator)) + db_chain = SQLDatabaseChain.from_llm(llm, db) tools = [ Tool(