Compare commits

...

2 Commits

Author SHA1 Message Date
leehuwuj 73dfac54c4 fix vercel stream breaking 2024-04-10 16:56:24 +07:00
leehuwuj 589e1e8cd9 add vercel streaming format 2024-04-10 15:24:50 +07:00
2 changed files with 52 additions and 16 deletions
@@ -1,13 +1,14 @@
from pydantic import BaseModel
from typing import List, Any, Optional, Dict, Tuple
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import StreamingResponse
from llama_index.core.chat_engine.types import (
BaseChatEngine,
StreamingAgentChatResponse,
)
from llama_index.core.schema import NodeWithScore
from llama_index.core.llms import ChatMessage, MessageRole
from app.engine import get_chat_engine
from app.api.routers.vercel_response import VercelStreamResponse
chat_router = r = APIRouter()
@@ -20,18 +21,6 @@ class _Message(BaseModel):
class _ChatData(BaseModel):
messages: List[_Message]
class Config:
json_schema_extra = {
"example": {
"messages": [
{
"role": "user",
"content": "What standards for letters exist?",
}
]
}
}
class _SourceNodes(BaseModel):
id: str
@@ -91,13 +80,25 @@ async def chat(
response = await chat_engine.astream_chat(last_message_content, messages)
async def event_generator():
async def event_generator(request: Request, response: StreamingAgentChatResponse):
# Yield the text response
async for token in response.async_response_gen():
# If client closes connection, stop sending events
if await request.is_disconnected():
break
yield token
yield VercelStreamResponse.convert_text(token)
return StreamingResponse(event_generator(), media_type="text/plain")
# Yield the source nodes
yield VercelStreamResponse.convert_data(
{
"nodes": [
_SourceNodes.from_source_node(node).dict()
for node in response.source_nodes
]
}
)
return VercelStreamResponse(content=event_generator(request, response))
# non-streaming endpoint - delete if not needed
@@ -0,0 +1,35 @@
import json
from typing import Any
from fastapi.responses import StreamingResponse
class VercelStreamResponse(StreamingResponse):
"""
Class to convert the response from the chat engine to the streaming format expected by Vercel
"""
TEXT_PREFIX = "0:"
DATA_PREFIX = "2:"
VERCEL_HEADERS = {
"X-Experimental-Stream-Data": "true",
"Content-Type": "text/plain; charset=utf-8",
"Access-Control-Expose-Headers": "X-Experimental-Stream-Data",
}
@classmethod
def convert_text(cls, token: str):
# Escape newlines to avoid breaking the stream
token = token.replace("\n", "\\n")
return f'{cls.TEXT_PREFIX}"{token}"\n'
@classmethod
def convert_data(cls, data: dict):
data_str = json.dumps(data)
return f"{cls.DATA_PREFIX}[{data_str}]\n"
def __init__(self, content: Any, **kwargs):
super().__init__(
content=content,
headers=self.VERCEL_HEADERS,
**kwargs,
)