mirror of
https://github.com/run-llama/template-workflow-document-qa.git
synced 2026-06-30 21:47:58 -04:00
get index docstring update (#25)
This commit is contained in:
@@ -56,6 +56,15 @@ def get_llama_parse_client() -> LlamaParse:
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_index(index_name: str) -> LlamaCloudIndex:
|
||||
"""
|
||||
Retrieve an existing LlamaCloudIndex or create a new one if it doesn't exist.
|
||||
|
||||
Args:
|
||||
index_name: The name of the index to retrieve or create.
|
||||
|
||||
Returns:
|
||||
LlamaCloudIndex: The retrieved or newly created index instance.
|
||||
"""
|
||||
return LlamaCloudIndex.create_index(
|
||||
name=index_name,
|
||||
project_id=LLAMA_CLOUD_PROJECT_ID,
|
||||
|
||||
@@ -7,10 +7,6 @@ from typing import Any, Literal
|
||||
|
||||
import httpx
|
||||
from llama_index.core import Settings
|
||||
from llama_index.core.chat_engine.types import (
|
||||
BaseChatEngine,
|
||||
ChatMode,
|
||||
)
|
||||
from llama_index.core.llms import ChatMessage
|
||||
import asyncio
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
@@ -25,6 +21,11 @@ from workflows.events import (
|
||||
HumanResponseEvent,
|
||||
)
|
||||
from workflows.retry_policy import ConstantDelayRetryPolicy
|
||||
from llama_index.core.agent.workflow import FunctionAgent
|
||||
from llama_index.core.tools import QueryEngineTool, ToolMetadata
|
||||
from llama_index.core.memory import ChatMemoryBuffer
|
||||
from llama_index.core.query_engine import RetrieverQueryEngine
|
||||
from llama_index.core.agent.workflow import AgentStream
|
||||
|
||||
from .clients import (
|
||||
get_index,
|
||||
@@ -204,12 +205,31 @@ class ConversationMessage(BaseModel):
|
||||
timestamp: str = Field(default_factory=lambda: datetime.now().isoformat())
|
||||
|
||||
|
||||
def get_chat_engine(index_name: str) -> BaseChatEngine:
|
||||
def get_agent(index_name: str) -> FunctionAgent:
|
||||
index = get_index(index_name)
|
||||
return index.as_chat_engine(
|
||||
chat_mode=ChatMode.CONTEXT,
|
||||
chunk_retriever = index.as_retriever(
|
||||
dense_similarity_top_k=5,
|
||||
sparse_similarity_top_k=5,
|
||||
enable_reranking=True,
|
||||
rerank_top_n=3,
|
||||
)
|
||||
query_engine = RetrieverQueryEngine.from_args(
|
||||
chunk_retriever, llm=Settings.llm, response_mode="compact"
|
||||
)
|
||||
query_tool = QueryEngineTool(
|
||||
query_engine=query_engine,
|
||||
metadata=ToolMetadata(
|
||||
name="document_search",
|
||||
description=(
|
||||
"Provides information about relevant documents."
|
||||
"Use a detailed plain text question as input to the tool."
|
||||
),
|
||||
),
|
||||
)
|
||||
return FunctionAgent(
|
||||
tools=[query_tool],
|
||||
llm=Settings.llm,
|
||||
context_prompt=(
|
||||
system_prompt=(
|
||||
"You are a helpful assistant that answers questions based on the provided documents. "
|
||||
"Always cite specific information from the documents when answering. "
|
||||
"If you cannot find the answer in the documents, say so clearly."
|
||||
@@ -285,34 +305,42 @@ class ChatWorkflow(Workflow):
|
||||
}
|
||||
)
|
||||
|
||||
chat_engine = get_chat_engine(index_name)
|
||||
|
||||
stream_response = await chat_engine.astream_chat(
|
||||
user_input, chat_history=initial_state.chat_messages()
|
||||
logger.info(f"Initializing agent for index: {index_name}")
|
||||
agent = get_agent(index_name)
|
||||
memory = ChatMemoryBuffer.from_defaults(
|
||||
token_limit=3000,
|
||||
chat_history=initial_state.chat_messages(),
|
||||
)
|
||||
|
||||
# Run the agent
|
||||
handler = agent.run(user_input, memory=memory)
|
||||
full_text = ""
|
||||
|
||||
# Emit streaming deltas to the event stream
|
||||
async for token in stream_response.async_response_gen():
|
||||
full_text += token
|
||||
ctx.write_event_to_stream(ChatDeltaEvent(delta=token))
|
||||
await asyncio.sleep(
|
||||
0
|
||||
) # Temp workaround. Some sort of bug in the server drops events without flushing the event loop
|
||||
async for event in handler.stream_events():
|
||||
if isinstance(event, AgentStream):
|
||||
full_text += event.delta
|
||||
ctx.write_event_to_stream(ChatDeltaEvent(delta=event.delta))
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# Extract source nodes for citations
|
||||
sources = []
|
||||
if stream_response.source_nodes:
|
||||
for node in stream_response.source_nodes:
|
||||
sources.append(
|
||||
SourceMessage(
|
||||
text=node.text[:197] + "..."
|
||||
if len(node.text) >= 200
|
||||
else node.text,
|
||||
score=float(node.score) if node.score else 0.0,
|
||||
metadata=node.metadata,
|
||||
)
|
||||
)
|
||||
response = await handler
|
||||
all_source_nodes = [
|
||||
node
|
||||
for c in response.tool_calls
|
||||
for node in c.tool_output.raw_output.source_nodes
|
||||
]
|
||||
|
||||
sources = [
|
||||
SourceMessage(
|
||||
text=node.text[:197] + "..."
|
||||
if len(node.text) >= 200
|
||||
else node.text,
|
||||
score=float(node.score) if node.score else 0.0,
|
||||
metadata=node.metadata,
|
||||
)
|
||||
for node in all_source_nodes
|
||||
]
|
||||
|
||||
# After streaming completes, emit a summary response event to stream for frontend/main printing
|
||||
assistant_response = ConversationMessage(
|
||||
|
||||
Reference in New Issue
Block a user