mirror of
https://github.com/run-llama/create-llama.git
synced 2026-07-02 19:14:28 -04:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 73dfac54c4 | |||
| 589e1e8cd9 |
@@ -1,13 +1,14 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Any, Optional, Dict, Tuple
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from llama_index.core.chat_engine.types import (
|
||||
BaseChatEngine,
|
||||
StreamingAgentChatResponse,
|
||||
)
|
||||
from llama_index.core.schema import NodeWithScore
|
||||
from llama_index.core.llms import ChatMessage, MessageRole
|
||||
from app.engine import get_chat_engine
|
||||
from app.api.routers.vercel_response import VercelStreamResponse
|
||||
|
||||
chat_router = r = APIRouter()
|
||||
|
||||
@@ -20,18 +21,6 @@ class _Message(BaseModel):
|
||||
class _ChatData(BaseModel):
|
||||
messages: List[_Message]
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What standards for letters exist?",
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class _SourceNodes(BaseModel):
|
||||
id: str
|
||||
@@ -91,13 +80,25 @@ async def chat(
|
||||
|
||||
response = await chat_engine.astream_chat(last_message_content, messages)
|
||||
|
||||
async def event_generator():
|
||||
async def event_generator(request: Request, response: StreamingAgentChatResponse):
|
||||
# Yield the text response
|
||||
async for token in response.async_response_gen():
|
||||
# If client closes connection, stop sending events
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
yield token
|
||||
yield VercelStreamResponse.convert_text(token)
|
||||
|
||||
return StreamingResponse(event_generator(), media_type="text/plain")
|
||||
# Yield the source nodes
|
||||
yield VercelStreamResponse.convert_data(
|
||||
{
|
||||
"nodes": [
|
||||
_SourceNodes.from_source_node(node).dict()
|
||||
for node in response.source_nodes
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
return VercelStreamResponse(content=event_generator(request, response))
|
||||
|
||||
|
||||
# non-streaming endpoint - delete if not needed
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
|
||||
class VercelStreamResponse(StreamingResponse):
|
||||
"""
|
||||
Class to convert the response from the chat engine to the streaming format expected by Vercel
|
||||
"""
|
||||
|
||||
TEXT_PREFIX = "0:"
|
||||
DATA_PREFIX = "2:"
|
||||
VERCEL_HEADERS = {
|
||||
"X-Experimental-Stream-Data": "true",
|
||||
"Content-Type": "text/plain; charset=utf-8",
|
||||
"Access-Control-Expose-Headers": "X-Experimental-Stream-Data",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def convert_text(cls, token: str):
|
||||
# Escape newlines to avoid breaking the stream
|
||||
token = token.replace("\n", "\\n")
|
||||
return f'{cls.TEXT_PREFIX}"{token}"\n'
|
||||
|
||||
@classmethod
|
||||
def convert_data(cls, data: dict):
|
||||
data_str = json.dumps(data)
|
||||
return f"{cls.DATA_PREFIX}[{data_str}]\n"
|
||||
|
||||
def __init__(self, content: Any, **kwargs):
|
||||
super().__init__(
|
||||
content=content,
|
||||
headers=self.VERCEL_HEADERS,
|
||||
**kwargs,
|
||||
)
|
||||
Reference in New Issue
Block a user