mirror of
https://github.com/langchain-ai/langgraph-tutorials.git
synced 2026-07-01 14:40:46 -04:00
x
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
from langgraph_tutorials.customer_support.db import DB
|
||||
|
||||
__all__ = ["DB"]
|
||||
|
||||
@@ -1,91 +1,131 @@
|
||||
"""Database utilities for the customer support application.
|
||||
|
||||
This module handles database initialization, updates, and maintenance for the
|
||||
customer support system. It downloads a SQLite database if needed and provides
|
||||
functionality to update flight dates to current time.
|
||||
Handles downloading, initializing, and managing the SQLite database,
|
||||
with a 'dirty' state that updates flight dates to current time.
|
||||
|
||||
This implementation is used for the tutorial purposes only and is not
|
||||
meant for production use.
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import sqlite3
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator, Optional
|
||||
|
||||
import pandas as pd
|
||||
import requests
|
||||
|
||||
# Database configuration
|
||||
db_url = "https://storage.googleapis.com/benchmarks-artifacts/travel-db/travel2.sqlite"
|
||||
local_file = "travel2.sqlite"
|
||||
# The backup lets us restart for each tutorial section
|
||||
backup_file = "travel2.backup.sqlite"
|
||||
overwrite = False
|
||||
|
||||
# Download database if needed
|
||||
if overwrite or not os.path.exists(local_file):
|
||||
response = requests.get(db_url)
|
||||
response.raise_for_status() # Ensure the request was successful
|
||||
with open(local_file, "wb") as f:
|
||||
f.write(response.content)
|
||||
# Backup - we will use this to "reset" our DB in each section
|
||||
shutil.copy(local_file, backup_file)
|
||||
class DatabaseManager:
|
||||
"""Manages the customer support database with original and dirty states."""
|
||||
|
||||
|
||||
def update_dates(file):
|
||||
"""Updates flight and booking dates to current time.
|
||||
|
||||
Creates a fresh copy of the database from backup and updates all datetime
|
||||
fields to be relative to the current date, maintaining the same relative
|
||||
time differences from the original data.
|
||||
|
||||
Args:
|
||||
file (str): Path to the database file to update
|
||||
|
||||
Returns:
|
||||
str: Path to the updated database file
|
||||
|
||||
Note:
|
||||
This function modifies the following tables:
|
||||
- flights: Updates scheduled/actual departure/arrival times
|
||||
- bookings: Updates booking dates
|
||||
"""
|
||||
shutil.copy(backup_file, file)
|
||||
conn = sqlite3.connect(file)
|
||||
tables = pd.read_sql(
|
||||
"SELECT name FROM sqlite_master WHERE type='table';", conn
|
||||
).name.tolist()
|
||||
tdf = {}
|
||||
for t in tables:
|
||||
tdf[t] = pd.read_sql(f"SELECT * from {t}", conn)
|
||||
|
||||
example_time = pd.to_datetime(
|
||||
tdf["flights"]["actual_departure"].replace("\\N", pd.NaT)
|
||||
).max()
|
||||
current_time = pd.to_datetime("now").tz_localize(example_time.tz)
|
||||
time_diff = current_time - example_time
|
||||
|
||||
tdf["bookings"]["book_date"] = (
|
||||
pd.to_datetime(tdf["bookings"]["book_date"].replace("\\N", pd.NaT), utc=True)
|
||||
+ time_diff
|
||||
DEFAULT_DB_URL = (
|
||||
"https://storage.googleapis.com/benchmarks-artifacts/travel-db/travel2.sqlite"
|
||||
)
|
||||
ORIGINAL_FILE = "travel_original.sqlite"
|
||||
DIRTY_FILE = "travel_dirty.sqlite"
|
||||
|
||||
datetime_columns = [
|
||||
"scheduled_departure",
|
||||
"scheduled_arrival",
|
||||
"actual_departure",
|
||||
"actual_arrival",
|
||||
]
|
||||
for column in datetime_columns:
|
||||
tdf["flights"][column] = (
|
||||
pd.to_datetime(tdf["flights"][column].replace("\\N", pd.NaT)) + time_diff
|
||||
)
|
||||
def __init__(self, db_url: str = DEFAULT_DB_URL) -> None:
|
||||
self.db_url: str = db_url
|
||||
self.original_file: str = self.ORIGINAL_FILE
|
||||
self.dirty_file: str = self.DIRTY_FILE
|
||||
|
||||
for table_name, df in tdf.items():
|
||||
df.to_sql(table_name, conn, if_exists="replace", index=False)
|
||||
del df
|
||||
del tdf
|
||||
conn.commit()
|
||||
conn.close()
|
||||
def initialize(self, *, force_download: bool = False) -> str:
|
||||
"""Ensure 'original' database is present and create 'dirty' copy with updated dates.
|
||||
|
||||
return file
|
||||
Args:
|
||||
force_download: Force re-download of the original database.
|
||||
|
||||
Returns:
|
||||
Path to the dirty database.
|
||||
"""
|
||||
if force_download or not os.path.exists(self.original_file):
|
||||
self._download_original()
|
||||
self._reset_dirty_with_updated_dates()
|
||||
return self.dirty_file
|
||||
|
||||
def _download_original(self) -> None:
|
||||
"""Download the original database from the remote URL."""
|
||||
print("Downloading original database...")
|
||||
response: requests.Response = requests.get(self.db_url)
|
||||
response.raise_for_status()
|
||||
with open(self.original_file, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
def _reset_dirty_with_updated_dates(self) -> None:
|
||||
"""Create a fresh dirty database by copying and updating the original."""
|
||||
shutil.copy(self.original_file, self.dirty_file)
|
||||
self._update_dates(self.dirty_file)
|
||||
|
||||
def _update_dates(self, db_path: str) -> None:
|
||||
"""Update flight and booking dates in the database to match current time."""
|
||||
conn: sqlite3.Connection = sqlite3.connect(db_path)
|
||||
tables: list[str] = pd.read_sql(
|
||||
"SELECT name FROM sqlite_master WHERE type='table';", conn
|
||||
).name.tolist()
|
||||
dataframes: dict[str, pd.DataFrame] = {
|
||||
t: pd.read_sql(f"SELECT * FROM {t}", conn) for t in tables
|
||||
}
|
||||
|
||||
flights: Optional[pd.DataFrame] = dataframes.get("flights")
|
||||
if flights is not None:
|
||||
example_time = pd.to_datetime(
|
||||
flights["actual_departure"].replace("\\N", pd.NaT)
|
||||
).max()
|
||||
current_time = pd.Timestamp.now(tz=example_time.tz)
|
||||
time_diff = current_time - example_time
|
||||
|
||||
for col in [
|
||||
"scheduled_departure",
|
||||
"scheduled_arrival",
|
||||
"actual_departure",
|
||||
"actual_arrival",
|
||||
]:
|
||||
flights[col] = (
|
||||
pd.to_datetime(flights[col].replace("\\N", pd.NaT)) + time_diff
|
||||
)
|
||||
|
||||
bookings: Optional[pd.DataFrame] = dataframes.get("bookings")
|
||||
if bookings is not None:
|
||||
bookings["book_date"] = (
|
||||
pd.to_datetime(bookings["book_date"].replace("\\N", pd.NaT), utc=True)
|
||||
+ time_diff
|
||||
)
|
||||
|
||||
for table_name, df in dataframes.items():
|
||||
df.to_sql(table_name, conn, if_exists="replace", index=False)
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
@contextmanager
|
||||
def get_cursor(self) -> Generator[sqlite3.Cursor, None, None]:
|
||||
"""Context manager to provide a SQLite cursor for the dirty database.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the dirty database file is missing.
|
||||
|
||||
Yields:
|
||||
sqlite3.Cursor: A database cursor for performing operations.
|
||||
"""
|
||||
# File check is done here to provide a helpful failure message to
|
||||
# potential users. This is not meant to be production code, but
|
||||
# rather a tutorial example.
|
||||
if not os.path.exists(self.dirty_file):
|
||||
raise FileNotFoundError(
|
||||
f"The dirty database file '{self.dirty_file}' does not exist.\n"
|
||||
f"Please run 'manager.initialize()' first to set up the database."
|
||||
)
|
||||
conn: sqlite3.Connection = sqlite3.connect(self.dirty_file)
|
||||
cursor: sqlite3.Cursor = conn.cursor()
|
||||
try:
|
||||
yield cursor
|
||||
conn.commit()
|
||||
finally:
|
||||
cursor.close()
|
||||
conn.close()
|
||||
|
||||
|
||||
db = update_dates(local_file)
|
||||
# Global manager instance
|
||||
DB: DatabaseManager = DatabaseManager()
|
||||
|
||||
+1
-1
@@ -3,7 +3,7 @@ name = "langgraph-tutorials"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"langchain-core>=0.3.56",
|
||||
"langgraph>=0.4.1",
|
||||
|
||||
Reference in New Issue
Block a user