diff --git a/src/{{ project_name_snake }}/clients.py b/src/{{ project_name_snake }}/clients.py index ee67170..dfcd00d 100644 --- a/src/{{ project_name_snake }}/clients.py +++ b/src/{{ project_name_snake }}/clients.py @@ -56,6 +56,15 @@ def get_llama_parse_client() -> LlamaParse: @functools.lru_cache(maxsize=None) def get_index(index_name: str) -> LlamaCloudIndex: + """ + Retrieve an existing LlamaCloudIndex or create a new one if it doesn't exist. + + Args: + index_name: The name of the index to retrieve or create. + + Returns: + LlamaCloudIndex: The retrieved or newly created index instance. + """ return LlamaCloudIndex.create_index( name=index_name, project_id=LLAMA_CLOUD_PROJECT_ID, diff --git a/src/{{ project_name_snake }}/qa_workflows.py b/src/{{ project_name_snake }}/qa_workflows.py index a9a2bc4..db29af1 100644 --- a/src/{{ project_name_snake }}/qa_workflows.py +++ b/src/{{ project_name_snake }}/qa_workflows.py @@ -7,10 +7,6 @@ from typing import Any, Literal import httpx from llama_index.core import Settings -from llama_index.core.chat_engine.types import ( - BaseChatEngine, - ChatMode, -) from llama_index.core.llms import ChatMessage import asyncio from llama_index.embeddings.openai import OpenAIEmbedding @@ -25,6 +21,11 @@ from workflows.events import ( HumanResponseEvent, ) from workflows.retry_policy import ConstantDelayRetryPolicy +from llama_index.core.agent.workflow import FunctionAgent +from llama_index.core.tools import QueryEngineTool, ToolMetadata +from llama_index.core.memory import ChatMemoryBuffer +from llama_index.core.query_engine import RetrieverQueryEngine +from llama_index.core.agent.workflow import AgentStream from .clients import ( get_index, @@ -204,12 +205,31 @@ class ConversationMessage(BaseModel): timestamp: str = Field(default_factory=lambda: datetime.now().isoformat()) -def get_chat_engine(index_name: str) -> BaseChatEngine: +def get_agent(index_name: str) -> FunctionAgent: index = get_index(index_name) - return index.as_chat_engine( - chat_mode=ChatMode.CONTEXT, + chunk_retriever = index.as_retriever( + dense_similarity_top_k=5, + sparse_similarity_top_k=5, + enable_reranking=True, + rerank_top_n=3, + ) + query_engine = RetrieverQueryEngine.from_args( + chunk_retriever, llm=Settings.llm, response_mode="compact" + ) + query_tool = QueryEngineTool( + query_engine=query_engine, + metadata=ToolMetadata( + name="document_search", + description=( + "Provides information about relevant documents." + "Use a detailed plain text question as input to the tool." + ), + ), + ) + return FunctionAgent( + tools=[query_tool], llm=Settings.llm, - context_prompt=( + system_prompt=( "You are a helpful assistant that answers questions based on the provided documents. " "Always cite specific information from the documents when answering. " "If you cannot find the answer in the documents, say so clearly." @@ -285,34 +305,42 @@ class ChatWorkflow(Workflow): } ) - chat_engine = get_chat_engine(index_name) - - stream_response = await chat_engine.astream_chat( - user_input, chat_history=initial_state.chat_messages() + logger.info(f"Initializing agent for index: {index_name}") + agent = get_agent(index_name) + memory = ChatMemoryBuffer.from_defaults( + token_limit=3000, + chat_history=initial_state.chat_messages(), ) + + # Run the agent + handler = agent.run(user_input, memory=memory) full_text = "" # Emit streaming deltas to the event stream - async for token in stream_response.async_response_gen(): - full_text += token - ctx.write_event_to_stream(ChatDeltaEvent(delta=token)) - await asyncio.sleep( - 0 - ) # Temp workaround. Some sort of bug in the server drops events without flushing the event loop + async for event in handler.stream_events(): + if isinstance(event, AgentStream): + full_text += event.delta + ctx.write_event_to_stream(ChatDeltaEvent(delta=event.delta)) + await asyncio.sleep(0) # Extract source nodes for citations - sources = [] - if stream_response.source_nodes: - for node in stream_response.source_nodes: - sources.append( - SourceMessage( - text=node.text[:197] + "..." - if len(node.text) >= 200 - else node.text, - score=float(node.score) if node.score else 0.0, - metadata=node.metadata, - ) - ) + response = await handler + all_source_nodes = [ + node + for c in response.tool_calls + for node in c.tool_output.raw_output.source_nodes + ] + + sources = [ + SourceMessage( + text=node.text[:197] + "..." + if len(node.text) >= 200 + else node.text, + score=float(node.score) if node.score else 0.0, + metadata=node.metadata, + ) + for node in all_source_nodes + ] # After streaming completes, emit a summary response event to stream for frontend/main printing assistant_response = ConversationMessage(