This commit is contained in:
Eugene Yurtsev
2025-05-01 22:27:06 -04:00
parent 955ce81a96
commit c90f6f137f
3 changed files with 116 additions and 73 deletions
@@ -0,0 +1,3 @@
from langgraph_tutorials.customer_support.db import DB
__all__ = ["DB"]
+112 -72
View File
@@ -1,91 +1,131 @@
"""Database utilities for the customer support application.
This module handles database initialization, updates, and maintenance for the
customer support system. It downloads a SQLite database if needed and provides
functionality to update flight dates to current time.
Handles downloading, initializing, and managing the SQLite database,
with a 'dirty' state that updates flight dates to current time.
This implementation is used for the tutorial purposes only and is not
meant for production use.
"""
import os
import shutil
import sqlite3
from contextlib import contextmanager
from typing import Generator, Optional
import pandas as pd
import requests
# Database configuration
db_url = "https://storage.googleapis.com/benchmarks-artifacts/travel-db/travel2.sqlite"
local_file = "travel2.sqlite"
# The backup lets us restart for each tutorial section
backup_file = "travel2.backup.sqlite"
overwrite = False
# Download database if needed
if overwrite or not os.path.exists(local_file):
response = requests.get(db_url)
response.raise_for_status() # Ensure the request was successful
with open(local_file, "wb") as f:
f.write(response.content)
# Backup - we will use this to "reset" our DB in each section
shutil.copy(local_file, backup_file)
class DatabaseManager:
"""Manages the customer support database with original and dirty states."""
def update_dates(file):
"""Updates flight and booking dates to current time.
Creates a fresh copy of the database from backup and updates all datetime
fields to be relative to the current date, maintaining the same relative
time differences from the original data.
Args:
file (str): Path to the database file to update
Returns:
str: Path to the updated database file
Note:
This function modifies the following tables:
- flights: Updates scheduled/actual departure/arrival times
- bookings: Updates booking dates
"""
shutil.copy(backup_file, file)
conn = sqlite3.connect(file)
tables = pd.read_sql(
"SELECT name FROM sqlite_master WHERE type='table';", conn
).name.tolist()
tdf = {}
for t in tables:
tdf[t] = pd.read_sql(f"SELECT * from {t}", conn)
example_time = pd.to_datetime(
tdf["flights"]["actual_departure"].replace("\\N", pd.NaT)
).max()
current_time = pd.to_datetime("now").tz_localize(example_time.tz)
time_diff = current_time - example_time
tdf["bookings"]["book_date"] = (
pd.to_datetime(tdf["bookings"]["book_date"].replace("\\N", pd.NaT), utc=True)
+ time_diff
DEFAULT_DB_URL = (
"https://storage.googleapis.com/benchmarks-artifacts/travel-db/travel2.sqlite"
)
ORIGINAL_FILE = "travel_original.sqlite"
DIRTY_FILE = "travel_dirty.sqlite"
datetime_columns = [
"scheduled_departure",
"scheduled_arrival",
"actual_departure",
"actual_arrival",
]
for column in datetime_columns:
tdf["flights"][column] = (
pd.to_datetime(tdf["flights"][column].replace("\\N", pd.NaT)) + time_diff
)
def __init__(self, db_url: str = DEFAULT_DB_URL) -> None:
self.db_url: str = db_url
self.original_file: str = self.ORIGINAL_FILE
self.dirty_file: str = self.DIRTY_FILE
for table_name, df in tdf.items():
df.to_sql(table_name, conn, if_exists="replace", index=False)
del df
del tdf
conn.commit()
conn.close()
def initialize(self, *, force_download: bool = False) -> str:
"""Ensure 'original' database is present and create 'dirty' copy with updated dates.
return file
Args:
force_download: Force re-download of the original database.
Returns:
Path to the dirty database.
"""
if force_download or not os.path.exists(self.original_file):
self._download_original()
self._reset_dirty_with_updated_dates()
return self.dirty_file
def _download_original(self) -> None:
"""Download the original database from the remote URL."""
print("Downloading original database...")
response: requests.Response = requests.get(self.db_url)
response.raise_for_status()
with open(self.original_file, "wb") as f:
f.write(response.content)
def _reset_dirty_with_updated_dates(self) -> None:
"""Create a fresh dirty database by copying and updating the original."""
shutil.copy(self.original_file, self.dirty_file)
self._update_dates(self.dirty_file)
def _update_dates(self, db_path: str) -> None:
"""Update flight and booking dates in the database to match current time."""
conn: sqlite3.Connection = sqlite3.connect(db_path)
tables: list[str] = pd.read_sql(
"SELECT name FROM sqlite_master WHERE type='table';", conn
).name.tolist()
dataframes: dict[str, pd.DataFrame] = {
t: pd.read_sql(f"SELECT * FROM {t}", conn) for t in tables
}
flights: Optional[pd.DataFrame] = dataframes.get("flights")
if flights is not None:
example_time = pd.to_datetime(
flights["actual_departure"].replace("\\N", pd.NaT)
).max()
current_time = pd.Timestamp.now(tz=example_time.tz)
time_diff = current_time - example_time
for col in [
"scheduled_departure",
"scheduled_arrival",
"actual_departure",
"actual_arrival",
]:
flights[col] = (
pd.to_datetime(flights[col].replace("\\N", pd.NaT)) + time_diff
)
bookings: Optional[pd.DataFrame] = dataframes.get("bookings")
if bookings is not None:
bookings["book_date"] = (
pd.to_datetime(bookings["book_date"].replace("\\N", pd.NaT), utc=True)
+ time_diff
)
for table_name, df in dataframes.items():
df.to_sql(table_name, conn, if_exists="replace", index=False)
conn.commit()
conn.close()
@contextmanager
def get_cursor(self) -> Generator[sqlite3.Cursor, None, None]:
"""Context manager to provide a SQLite cursor for the dirty database.
Raises:
FileNotFoundError: If the dirty database file is missing.
Yields:
sqlite3.Cursor: A database cursor for performing operations.
"""
# File check is done here to provide a helpful failure message to
# potential users. This is not meant to be production code, but
# rather a tutorial example.
if not os.path.exists(self.dirty_file):
raise FileNotFoundError(
f"The dirty database file '{self.dirty_file}' does not exist.\n"
f"Please run 'manager.initialize()' first to set up the database."
)
conn: sqlite3.Connection = sqlite3.connect(self.dirty_file)
cursor: sqlite3.Cursor = conn.cursor()
try:
yield cursor
conn.commit()
finally:
cursor.close()
conn.close()
db = update_dates(local_file)
# Global manager instance
DB: DatabaseManager = DatabaseManager()
+1 -1
View File
@@ -3,7 +3,7 @@ name = "langgraph-tutorials"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
requires-python = ">=3.10"
dependencies = [
"langchain-core>=0.3.56",
"langgraph>=0.4.1",