mirror of
https://github.com/run-llama/create-llama.git
synced 2026-07-02 19:14:28 -04:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1664be0e4e | |||
| 5bb3db8b8b | |||
| df77a1e6d8 | |||
| 8f893e2f62 |
@@ -0,0 +1,5 @@
|
||||
---
|
||||
"create-llama": patch
|
||||
---
|
||||
|
||||
Add nodes to the response and support Vercel streaming format
|
||||
+1
-1
@@ -505,7 +505,7 @@ export const askQuestions = async (
|
||||
|
||||
if (program.framework === "nextjs" || program.frontend) {
|
||||
if (!program.ui) {
|
||||
program.ui = getPrefOrDefault("ui");
|
||||
program.ui = defaults.ui;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
from typing import List
|
||||
from pydantic import BaseModel
|
||||
from fastapi.responses import StreamingResponse
|
||||
from typing import List, Any, Optional, Dict, Tuple
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from llama_index.core.chat_engine.types import BaseChatEngine
|
||||
from llama_index.core.chat_engine.types import (
|
||||
BaseChatEngine,
|
||||
StreamingAgentChatResponse,
|
||||
)
|
||||
from llama_index.core.schema import NodeWithScore
|
||||
from llama_index.core.llms import ChatMessage, MessageRole
|
||||
from app.engine import get_chat_engine
|
||||
from typing import List, Tuple
|
||||
from app.api.routers.vercel_response import VercelStreamResponse
|
||||
|
||||
chat_router = r = APIRouter()
|
||||
|
||||
@@ -19,8 +22,27 @@ class _ChatData(BaseModel):
|
||||
messages: List[_Message]
|
||||
|
||||
|
||||
class _SourceNodes(BaseModel):
|
||||
id: str
|
||||
metadata: Dict[str, Any]
|
||||
score: Optional[float]
|
||||
|
||||
@classmethod
|
||||
def from_source_node(cls, source_node: NodeWithScore):
|
||||
return cls(
|
||||
id=source_node.node.node_id,
|
||||
metadata=source_node.node.metadata,
|
||||
score=source_node.score,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_source_nodes(cls, source_nodes: List[NodeWithScore]):
|
||||
return [cls.from_source_node(node) for node in source_nodes]
|
||||
|
||||
|
||||
class _Result(BaseModel):
|
||||
result: _Message
|
||||
nodes: List[_SourceNodes]
|
||||
|
||||
|
||||
async def parse_chat_data(data: _ChatData) -> Tuple[str, List[ChatMessage]]:
|
||||
@@ -58,13 +80,25 @@ async def chat(
|
||||
|
||||
response = await chat_engine.astream_chat(last_message_content, messages)
|
||||
|
||||
async def event_generator():
|
||||
async def event_generator(request: Request, response: StreamingAgentChatResponse):
|
||||
# Yield the text response
|
||||
async for token in response.async_response_gen():
|
||||
# If client closes connection, stop sending events
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
yield token
|
||||
yield VercelStreamResponse.convert_text(token)
|
||||
|
||||
return StreamingResponse(event_generator(), media_type="text/plain")
|
||||
# Yield the source nodes
|
||||
yield VercelStreamResponse.convert_data(
|
||||
{
|
||||
"nodes": [
|
||||
_SourceNodes.from_source_node(node).dict()
|
||||
for node in response.source_nodes
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
return VercelStreamResponse(content=event_generator(request, response))
|
||||
|
||||
|
||||
# non-streaming endpoint - delete if not needed
|
||||
@@ -77,5 +111,6 @@ async def chat_request(
|
||||
|
||||
response = await chat_engine.achat(last_message_content, messages)
|
||||
return _Result(
|
||||
result=_Message(role=MessageRole.ASSISTANT, content=response.response)
|
||||
result=_Message(role=MessageRole.ASSISTANT, content=response.response),
|
||||
nodes=_SourceNodes.from_source_nodes(response.source_nodes),
|
||||
)
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
|
||||
class VercelStreamResponse(StreamingResponse):
|
||||
"""
|
||||
Class to convert the response from the chat engine to the streaming format expected by Vercel
|
||||
"""
|
||||
|
||||
TEXT_PREFIX = "0:"
|
||||
DATA_PREFIX = "2:"
|
||||
VERCEL_HEADERS = {
|
||||
"X-Experimental-Stream-Data": "true",
|
||||
"Content-Type": "text/plain; charset=utf-8",
|
||||
"Access-Control-Expose-Headers": "X-Experimental-Stream-Data",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def convert_text(cls, token: str):
|
||||
return f'{cls.TEXT_PREFIX}"{token}"\n'
|
||||
|
||||
@classmethod
|
||||
def convert_data(cls, data: dict):
|
||||
data_str = json.dumps(data)
|
||||
return f"{cls.DATA_PREFIX}[{data_str}]\n"
|
||||
|
||||
def __init__(self, content: Any, **kwargs):
|
||||
super().__init__(
|
||||
content=content,
|
||||
headers=self.VERCEL_HEADERS,
|
||||
**kwargs,
|
||||
)
|
||||
Reference in New Issue
Block a user