get index docstring update (#25)

This commit is contained in:
Patricia
2025-10-08 10:20:02 -07:00
committed by GitHub
parent 0cc4f4d86c
commit 2e7a5f4ce9
2 changed files with 67 additions and 30 deletions
+9
View File
@@ -56,6 +56,15 @@ def get_llama_parse_client() -> LlamaParse:
@functools.lru_cache(maxsize=None)
def get_index(index_name: str) -> LlamaCloudIndex:
"""
Retrieve an existing LlamaCloudIndex or create a new one if it doesn't exist.
Args:
index_name: The name of the index to retrieve or create.
Returns:
LlamaCloudIndex: The retrieved or newly created index instance.
"""
return LlamaCloudIndex.create_index(
name=index_name,
project_id=LLAMA_CLOUD_PROJECT_ID,
+58 -30
View File
@@ -7,10 +7,6 @@ from typing import Any, Literal
import httpx
from llama_index.core import Settings
from llama_index.core.chat_engine.types import (
BaseChatEngine,
ChatMode,
)
from llama_index.core.llms import ChatMessage
import asyncio
from llama_index.embeddings.openai import OpenAIEmbedding
@@ -25,6 +21,11 @@ from workflows.events import (
HumanResponseEvent,
)
from workflows.retry_policy import ConstantDelayRetryPolicy
from llama_index.core.agent.workflow import FunctionAgent
from llama_index.core.tools import QueryEngineTool, ToolMetadata
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.agent.workflow import AgentStream
from .clients import (
get_index,
@@ -204,12 +205,31 @@ class ConversationMessage(BaseModel):
timestamp: str = Field(default_factory=lambda: datetime.now().isoformat())
def get_chat_engine(index_name: str) -> BaseChatEngine:
def get_agent(index_name: str) -> FunctionAgent:
index = get_index(index_name)
return index.as_chat_engine(
chat_mode=ChatMode.CONTEXT,
chunk_retriever = index.as_retriever(
dense_similarity_top_k=5,
sparse_similarity_top_k=5,
enable_reranking=True,
rerank_top_n=3,
)
query_engine = RetrieverQueryEngine.from_args(
chunk_retriever, llm=Settings.llm, response_mode="compact"
)
query_tool = QueryEngineTool(
query_engine=query_engine,
metadata=ToolMetadata(
name="document_search",
description=(
"Provides information about relevant documents."
"Use a detailed plain text question as input to the tool."
),
),
)
return FunctionAgent(
tools=[query_tool],
llm=Settings.llm,
context_prompt=(
system_prompt=(
"You are a helpful assistant that answers questions based on the provided documents. "
"Always cite specific information from the documents when answering. "
"If you cannot find the answer in the documents, say so clearly."
@@ -285,34 +305,42 @@ class ChatWorkflow(Workflow):
}
)
chat_engine = get_chat_engine(index_name)
stream_response = await chat_engine.astream_chat(
user_input, chat_history=initial_state.chat_messages()
logger.info(f"Initializing agent for index: {index_name}")
agent = get_agent(index_name)
memory = ChatMemoryBuffer.from_defaults(
token_limit=3000,
chat_history=initial_state.chat_messages(),
)
# Run the agent
handler = agent.run(user_input, memory=memory)
full_text = ""
# Emit streaming deltas to the event stream
async for token in stream_response.async_response_gen():
full_text += token
ctx.write_event_to_stream(ChatDeltaEvent(delta=token))
await asyncio.sleep(
0
) # Temp workaround. Some sort of bug in the server drops events without flushing the event loop
async for event in handler.stream_events():
if isinstance(event, AgentStream):
full_text += event.delta
ctx.write_event_to_stream(ChatDeltaEvent(delta=event.delta))
await asyncio.sleep(0)
# Extract source nodes for citations
sources = []
if stream_response.source_nodes:
for node in stream_response.source_nodes:
sources.append(
SourceMessage(
text=node.text[:197] + "..."
if len(node.text) >= 200
else node.text,
score=float(node.score) if node.score else 0.0,
metadata=node.metadata,
)
)
response = await handler
all_source_nodes = [
node
for c in response.tool_calls
for node in c.tool_output.raw_output.source_nodes
]
sources = [
SourceMessage(
text=node.text[:197] + "..."
if len(node.text) >= 200
else node.text,
score=float(node.score) if node.score else 0.0,
metadata=node.metadata,
)
for node in all_source_nodes
]
# After streaming completes, emit a summary response event to stream for frontend/main printing
assistant_response = ConversationMessage(