This commit is contained in:
Eugene Yurtsev
2025-05-19 16:12:48 -04:00
parent 01526ef665
commit 1a9ea5ff9b
4 changed files with 16 additions and 29 deletions
+6 -1
View File
@@ -1,8 +1,11 @@
from datetime import date, datetime
from langchain_core.tools import tool
from langgraph_tutorials.customer_support.db import DB
@tool
def search_car_rentals(
location: str | None = None,
name: str | None = None,
@@ -43,6 +46,7 @@ def search_car_rentals(
]
@tool
def book_car_rental(rental_id: int) -> str:
"""Book a car rental by its ID.
@@ -59,6 +63,7 @@ def book_car_rental(rental_id: int) -> str:
return f"No car rental found with ID {rental_id}."
@tool
def update_car_rental(
rental_id: int,
start_date: datetime | date | None = None,
@@ -93,7 +98,7 @@ def update_car_rental(
return f"Car rental {rental_id} successfully updated."
return f"No car rental found with ID {rental_id}."
@tool
def cancel_car_rental(rental_id: int) -> str:
"""Cancel a car rental by its ID.
@@ -1,8 +1,11 @@
"""Information about excursions."""
from langchain_core.tools import tool
from langgraph_tutorials.customer_support.db import DB
@tool
def search_trip_recommendations(
location: str | None = None,
name: str | None = None,
@@ -43,6 +46,7 @@ def search_trip_recommendations(
]
@tool
def book_excursion(recommendation_id: int) -> str:
"""Book an excursion by its recommendation ID.
@@ -62,6 +66,7 @@ def book_excursion(recommendation_id: int) -> str:
return f"No trip recommendation found with ID {recommendation_id}."
@tool
def update_excursion(recommendation_id: int, details: str) -> str:
"""Update a trip recommendation's details by its ID.
@@ -82,6 +87,7 @@ def update_excursion(recommendation_id: int, details: str) -> str:
return f"No trip recommendation found with ID {recommendation_id}."
@tool
def cancel_excursion(recommendation_id: int) -> str:
"""Cancel a trip recommendation by its ID.
@@ -36,6 +36,7 @@ def fetch_user_flight_information(config: RunnableConfig) -> list[dict]:
return [dict(zip(column_names, row, strict=False)) for row in rows]
@tool
def search_flights(
departure_airport: str | None = None,
arrival_airport: str | None = None,
@@ -73,6 +74,7 @@ def search_flights(
return [dict(zip(column_names, row, strict=False)) for row in rows]
@tool
def update_ticket_to_new_flight(
ticket_no: str, new_flight_id: int, *, config: RunnableConfig
) -> str:
@@ -131,6 +133,7 @@ def update_ticket_to_new_flight(
return "Ticket successfully updated to new flight."
@tool
def cancel_ticket(ticket_no: str, *, config: RunnableConfig) -> str:
"""Cancel the user's ticket and remove it from the database."""
configuration = config.get("configurable", {})
@@ -158,3 +161,4 @@ def cancel_ticket(ticket_no: str, *, config: RunnableConfig) -> str:
cursor.execute("DELETE FROM ticket_flights WHERE ticket_no = ?", (ticket_no,))
return "Ticket successfully cancelled."
@@ -10,7 +10,6 @@ import numpy as np
import requests
from langchain.embeddings import init_embeddings
from langchain_core.embeddings import Embeddings
from langchain_core.tools import BaseTool, tool
class PolicyRetriever:
@@ -89,30 +88,3 @@ class PolicyRetriever:
return [
{**self._docs[idx], "similarity": scores[idx]} for idx in top_k_idx_sorted
]
def as_tool(self, k: int = 2) -> BaseTool:
"""Return a LangChain tool for looking up policy information.
Args:
k (int, optional): Number of top documents to return in the tool.
Defaults to 2.
Returns:
Callable: A LangChain tool function.
"""
retriever = self # capture self in closure
@tool
def lookup_policy(query: str) -> str:
"""Consult company policies to check whether certain options are permitted.
Args:
query (str): The user query about policy information.
Returns:
str: Concatenated content of the most relevant policy documents.
"""
top_docs = retriever.query(query, k=k)
return "\n\n".join(doc["page_content"] for doc in top_docs)
return lookup_policy