Files
llama-api/poe-api/poe_api/base_handler.py
T
Simon Suo d0ae5b75ba wip
2023-04-17 13:10:26 -07:00

119 lines
4.0 KiB
Python

import json
import logging
from typing import AsyncIterable, Dict, Optional, Union
from fastapi.responses import JSONResponse
from sse_starlette.sse import ServerSentEvent
from poe_api.types import (AddDocumentsRequest, ContentType, QueryRequest,
ReportErrorRequest, ReportFeedbackRequest,
SettingsRequest, SettingsResponse)
logger = logging.getLogger("uvicorn.default")
class PoeHandler:
# Override these for your bot
async def get_response(self, query: QueryRequest) -> AsyncIterable[ServerSentEvent]:
"""Override this to return a response to user queries."""
yield self.text_event("hello")
async def get_settings(self, setting: SettingsRequest) -> SettingsResponse:
"""Override this to return non-standard settings."""
return SettingsResponse()
async def on_feedback(self, feedback_request: ReportFeedbackRequest) -> None:
"""Override this to record feedback from the user."""
pass
async def on_error(self, error_request: ReportErrorRequest) -> None:
"""Override this to record errors from the Poe server."""
logger.error(f"Error from Poe server: {error_request}")
async def add_documents(self, add_documents_request: AddDocumentsRequest) -> None:
"""Override this to record errors from the Poe server."""
pass
def shutdown(self) -> None:
pass
# Helpers for generating responses
@staticmethod
def text_event(text: str) -> ServerSentEvent:
return ServerSentEvent(data=json.dumps({"text": text}), event="text")
@staticmethod
def replace_response_event(text: str) -> ServerSentEvent:
return ServerSentEvent(
data=json.dumps({"text": text}), event="replace_response"
)
@staticmethod
def done_event() -> ServerSentEvent:
return ServerSentEvent(data="{}", event="done")
@staticmethod
def suggested_reply_event(text: str) -> ServerSentEvent:
return ServerSentEvent(data=json.dumps({"text": text}), event="suggested_reply")
@staticmethod
def meta_event(
*,
content_type: ContentType = "text/markdown",
refetch_settings: bool = False,
linkify: bool = True,
suggested_replies: bool = True,
) -> ServerSentEvent:
return ServerSentEvent(
data=json.dumps(
{
"content_type": content_type,
"refetch_settings": refetch_settings,
"linkify": linkify,
"suggested_replies": suggested_replies,
}
),
event="meta",
)
@staticmethod
def error_event(
text: Optional[str] = None, *, allow_retry: bool = True
) -> ServerSentEvent:
data: Dict[str, Union[bool, str]] = {"allow_retry": allow_retry}
if text is not None:
data["text"] = text
return ServerSentEvent(data=json.dumps(data), event="error")
# Internal handlers
async def handle_report_feedback(
self, feedback_request: ReportFeedbackRequest
) -> JSONResponse:
await self.on_feedback(feedback_request)
return JSONResponse({})
async def handle_report_error(
self, error_request: ReportErrorRequest
) -> JSONResponse:
await self.on_error(error_request)
return JSONResponse({})
async def handle_settings(self, settings_request: SettingsRequest) -> JSONResponse:
settings = await self.get_settings(settings_request)
return JSONResponse(settings.dict())
async def handle_query(self, query: QueryRequest) -> AsyncIterable[ServerSentEvent]:
async for event in self.get_response(query):
yield event
yield self.done_event()
async def handle_add_documents(self, request: AddDocumentsRequest) -> JSONResponse:
await self.add_documents(request)
return JSONResponse({})
def handle_shutdown(self) -> None:
self.shutdown()