From 0f42868f36d4e4fc561b0a26fc8de4b168c68208 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Tue, 21 May 2024 14:02:44 -0700 Subject: [PATCH] init --- .gitignore | 1 + config.py | 2 + dev.sh | 2 + main.py | 123 +++++++++++++++++++++++++++++++++++++++++++++++++++++ start.sh | 5 +++ 5 files changed, 133 insertions(+) create mode 100644 .gitignore create mode 100644 config.py create mode 100755 dev.sh create mode 100644 main.py create mode 100755 start.sh diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ed8ebf5 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +__pycache__ \ No newline at end of file diff --git a/config.py b/config.py new file mode 100644 index 0000000..11fa8ad --- /dev/null +++ b/config.py @@ -0,0 +1,2 @@ +MODEL_ID = "rag-api" +MODEL_NAME = "RAG Model" diff --git a/dev.sh b/dev.sh new file mode 100755 index 0000000..715aeca --- /dev/null +++ b/dev.sh @@ -0,0 +1,2 @@ +PORT="${PORT:-9099}" +uvicorn main:app --port $PORT --host 0.0.0.0 --forwarded-allow-ips '*' --reload \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..d7f718a --- /dev/null +++ b/main.py @@ -0,0 +1,123 @@ +from fastapi import FastAPI, Request, Depends, status +from fastapi.middleware.cors import CORSMiddleware + +from starlette.responses import StreamingResponse, Response +from pydantic import BaseModel, ConfigDict +from typing import List + + +import time +import json +import uuid + +from config import MODEL_ID, MODEL_NAME + + +app = FastAPI(docs_url="/docs", redoc_url=None) + + +origins = ["*"] + + +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +@app.middleware("http") +async def check_url(request: Request, call_next): + start_time = int(time.time()) + response = await call_next(request) + process_time = int(time.time()) - start_time + response.headers["X-Process-Time"] = str(process_time) + + return response + + +@app.get("/") +async def get_status(): + return {"status": True} + + +@app.get("/models") +@app.get("/v1/models") +async def get_models(): + """ + Returns the model that is available inside Dialog in the OpenAI format. + """ + return { + "data": [ + { + "id": MODEL_ID, + "name": MODEL_NAME, + "object": "model", + "created": int(time.time()), + "owned_by": "openai", + } + ] + } + + +class OpenAIChatMessage(BaseModel): + role: str + content: str + + model_config = ConfigDict(extra="allow") + + +class OpenAIChatCompletionForm(BaseModel): + model: str + messages: List[OpenAIChatMessage] + + model_config = ConfigDict(extra="allow") + + +def stream_message_template(message: str): + return { + "id": f"rag-{str(uuid.uuid4())}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": MODEL_ID, + "choices": [ + { + "index": 0, + "delta": {"content": message}, + "logprobs": None, + "finish_reason": None, + } + ], + } + + +def get_response(): + return "rag response" + + +@app.post("/chat/completions") +@app.post("/v1/chat/completions") +async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm): + + res = get_response() + + finish_message = { + "id": f"rag-{str(uuid.uuid4())}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": MODEL_ID, + "choices": [ + {"index": 0, "delta": {}, "logprobs": None, "finish_reason": "stop"} + ], + } + + def stream_content(): + message = stream_message_template(res) + + yield f"data: {json.dumps(message)}\n\n" + yield f"data: {json.dumps(finish_message)}\n\n" + yield f"data: [DONE]" + + return StreamingResponse(stream_content(), media_type="text/event-stream") diff --git a/start.sh b/start.sh new file mode 100755 index 0000000..cbbdcd7 --- /dev/null +++ b/start.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash +PORT="${PORT:-9099}" +HOST="${HOST:-0.0.0.0}" + +uvicorn main:app --host "$HOST" --port "$PORT" --forwarded-allow-ips '*'