diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..827c8a7 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +workflows.db +.venv +.env +package-lock.json +node_modules diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..f805e56 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,38 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "human-in-the-loop" +version = "0.1.0" +description = "A workflow showcasing how to use human in the loop" +requires-python = ">=3.10" +readme = "README.md" +dependencies = [ + "llama-index-workflows>=2.5.0,<3.0.0", + "llama-index-llms-openai" +] + +[dependency-groups] +dev = [ + "hatch>=1.14.2", + "pytest>=8.4.2", + "ruff>=0.13.2", + "ty>=0.0.1a21", +] + +[tool.hatch.envs.default.scripts] +format = "ruff format ." +format-check = "ruff format --check ." +lint = "ruff check --fix ." +lint-check = ["ruff check ."] +typecheck = "ty check src" +test = "pytest" +all-check = ["format-check", "lint-check", "test"] +all-fix = ["format", "lint", "test"] + +[tool.llamadeploy] +env_files = [".env"] + +[tool.llamadeploy.workflows] +default = "human_in_the_loop.workflow:workflow" diff --git a/src/human_in_the_loop/__init__.py b/src/human_in_the_loop/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/human_in_the_loop/workflow.py b/src/human_in_the_loop/workflow.py new file mode 100644 index 0000000..0d5c0c5 --- /dev/null +++ b/src/human_in_the_loop/workflow.py @@ -0,0 +1,171 @@ +import random + +from pydantic import BaseModel, Field +from workflows import Workflow, step, Context +from workflows.resource import Resource +from typing import Annotated +from workflows.events import ( + StartEvent, + StopEvent, + InputRequiredEvent, + HumanResponseEvent, +) + +from llama_index.llms.openai import OpenAIResponses +from llama_index.core.llms import ChatMessage +from llama_index.core.llms.structured_llm import StructuredLLM + + +# replace with an actual flight searcher +class FlightsAPI: + def __init__(self) -> None: + self.allowed_departure = [ + "San Francisco", + "San Jose", + "Los Angeles", + "New York", + ] + self.allowed_arrival = ["Paris", "London", "Berlin", "Rome"] + self.allowed_hours = ["7.00 AM", "12.00 AM", "5.00 PM", "10.00 PM"] + + def search_flights( + self, departure: str, arrival: str, date: str + ) -> str | list[str]: + if arrival not in self.allowed_arrival: + return "Sorry, we do not have planes that go to " + arrival + if departure not in self.allowed_departure: + return "Sorry, we do not have planes departing from " + departure + allowed_hours = self.allowed_hours[self.allowed_departure.index(departure) :] + flights = [] + for hour in allowed_hours: + flights.append( + f"Flight from {departure} to {arrival} at {hour} on {date} for {random.randint(200, 400)}$" + ) + return flights + + def book_flight(self, flight: str) -> str: + n = random.randint(0, 1) + if n == 0: + return f"Successfully booked: {flight}" + return "Sorry, something went wrong while booking your flight" + + +class FlightSearchEvent(InputRequiredEvent): + candidate_flights: list[str] + + +class FlightChoiceEvent(HumanResponseEvent): + chosen_flight: str + continue_booking: bool + + +async def get_flights_api(*args, **kwargs) -> FlightsAPI: + return FlightsAPI() + + +class FlightSearchDetails(BaseModel): + departure_location: str = Field(description="Departure location") + arrival_location: str = Field(description="Arrival location") + date: str = Field(description="Flight date") + + +async def get_llm(*args, **kwargs) -> StructuredLLM: + return OpenAIResponses("gpt-4.1").as_structured_llm(FlightSearchDetails) + + +class FlightSearchWorkflow(Workflow): + @step + async def search_for_flight( + self, + ev: StartEvent, + ctx: Context, + llm: Annotated[StructuredLLM, Resource(get_llm)], + flight_api: Annotated[FlightsAPI, Resource(get_flights_api)], + ) -> StopEvent | FlightSearchEvent: + response = await llm.achat( + [ + ChatMessage( + content=f"Extract flight details from this request: {ev.message}" + ) + ] + ) + if response.message.content: + flight_details = FlightSearchDetails.model_validate_json( + response.message.content + ) + else: + return StopEvent(result="Unable to get details for your flight") + flights = flight_api.search_flights( + departure=flight_details.departure_location, + arrival=flight_details.arrival_location, + date=flight_details.date, + ) + if isinstance(flights, str): + return StopEvent(result=flights) + else: + return FlightSearchEvent(candidate_flights=flights) + + @step + async def chosen_flight( + self, + ev: FlightChoiceEvent, + flight_api: Annotated[FlightsAPI, Resource(get_flights_api)], + ctx: Context, + ) -> StopEvent: + if ev.continue_booking: + booking = flight_api.book_flight(ev.chosen_flight) + return StopEvent(result=booking) + else: + return StopEvent(result="No permission to book, exiting...") + + +async def main(message: str) -> None: + w = FlightSearchWorkflow(timeout=100, verbose=False) + handler = w.run(message=message) + async for ev in handler.stream_events(): + if isinstance(ev, FlightSearchEvent): + print("Flights:\n" + "\n- ".join(ev.candidate_flights) + "\n\n") + are_ok = input("Are the flights ok for you? [yes/no] ") + if are_ok.lower().strip() != "yes": + handler.ctx.send_event( + FlightChoiceEvent(chosen_flight="", continue_booking=False) + ) # type: ignore + break + res = input("Choose a flight: ") + while res not in ev.candidate_flights: + res = input( + "Sorry, that flight is not available, can you choose one flight from the above, please? Your choice: " + ) + appr = input(f"Do you wish to continue with booking for {res}? [yes/no] ") + if appr.lower().strip() == "yes": + handler.ctx.send_event( + FlightChoiceEvent(chosen_flight=res, continue_booking=True) + ) # type: ignore + else: + handler.ctx.send_event( + FlightChoiceEvent(chosen_flight=res, continue_booking=False) + ) # type: ignore + result = await handler + print(str(result)) + + +workflow = FlightSearchWorkflow(timeout=None) + + +if __name__ == "__main__": + import asyncio + import os + from argparse import ArgumentParser + + parser = ArgumentParser() + parser.add_argument( + "-m", "--message", required=True, help="Flight you would like to take" + ) + args = parser.parse_args() + + if not os.getenv("OPENAI_API_KEY", None): + raise ValueError( + "You need to set OPENAI_API_KEY in your environment before using this workflow" + ) + + asyncio.run(main(message=args.message)) diff --git a/tests/test_placeholder.py b/tests/test_placeholder.py new file mode 100644 index 0000000..3384ea0 --- /dev/null +++ b/tests/test_placeholder.py @@ -0,0 +1,12 @@ +"""Placeholder test file. + +Replace this with actual tests for your project. +""" + + +def test_placeholder() -> None: + """Placeholder test that always passes. + + Remove this test once you add real tests to your project. + """ + assert True