expose raw search results in the output

This commit is contained in:
vbarda
2025-01-24 17:53:16 -05:00
parent 758c4a8f5a
commit c94df7a225
4 changed files with 49 additions and 14 deletions
+3
View File
@@ -12,6 +12,9 @@ class Configuration:
max_search_queries: int = 3 # Max search queries per company
max_search_results: int = 3 # Max search results per query
max_reflection_steps: int = 0 # Max reflection steps
include_search_results: bool = (
False # Whether to include search results in the output
)
@classmethod
def from_runnable_config(
+11 -4
View File
@@ -11,7 +11,7 @@ from pydantic import BaseModel, Field
from agent.configuration import Configuration
from agent.state import InputState, OutputState, OverallState
from agent.utils import deduplicate_and_format_sources, format_all_notes
from agent.utils import deduplicate_sources, format_sources, format_all_notes
from agent.prompts import (
EXTRACTION_PROMPT,
REFLECTION_PROMPT,
@@ -120,8 +120,9 @@ async def research_company(
search_docs = await asyncio.gather(*search_tasks)
# Deduplicate and format sources
source_str = deduplicate_and_format_sources(
search_docs, max_tokens_per_source=1000, include_raw_content=True
deduplicated_search_docs = deduplicate_sources(search_docs)
source_str = format_sources(
deduplicated_search_docs, max_tokens_per_source=1000, include_raw_content=True
)
# Generate structured notes relevant to the extraction schema
@@ -132,7 +133,13 @@ async def research_company(
user_notes=state.user_notes,
)
result = await claude_3_5_sonnet.ainvoke(p)
return {"completed_notes": [str(result.content)]}
state_update = {
"completed_notes": [str(result.content)],
}
if configurable.include_search_results:
state_update["search_results"] = deduplicated_search_docs
return state_update
def gather_notes_extract_schema(state: OverallState) -> dict[str, Any]:
+6
View File
@@ -68,6 +68,9 @@ class OverallState:
search_queries: list[str] = field(default=None)
"List of generated search queries to find relevant information"
search_results: list[dict] = field(default=None)
"List of search results"
completed_notes: Annotated[list, operator.add] = field(default_factory=list)
"Notes from completed research related to the schema"
@@ -99,3 +102,6 @@ class OutputState:
based on the user's query and the graph's execution.
This is the primary output of the enrichment process.
"""
search_results: list[dict] = field(default=None)
"List of search results"
+29 -10
View File
@@ -1,10 +1,6 @@
def deduplicate_and_format_sources(
search_response, max_tokens_per_source, include_raw_content=True
):
def deduplicate_sources(search_response: dict | list[dict]) -> list[dict]:
"""
Takes either a single search response or list of responses from Tavily API and formats them.
Limits the raw_content to approximately max_tokens_per_source.
include_raw_content specifies whether to include the raw_content from Tavily in the formatted string.
Takes either a single search response or list of responses from Tavily API and de-duplicates them based on the URL.
Args:
search_response: Either:
@@ -30,14 +26,37 @@ def deduplicate_and_format_sources(
)
# Deduplicate by URL
unique_sources = {}
unique_urls = set()
unique_sources_list = []
for source in sources_list:
if source["url"] not in unique_sources:
unique_sources[source["url"]] = source
if source["url"] not in unique_urls:
unique_urls.add(source["url"])
unique_sources_list.append(source)
return unique_sources_list
def format_sources(
sources_list: list[dict],
include_raw_content: bool = True,
max_tokens_per_source: int = 1000,
) -> str:
"""
Takes a list of unique results from Tavily API and formats them.
Limits the raw_content to approximately max_tokens_per_source.
include_raw_content specifies whether to include the raw_content from Tavily in the formatted string.
Args:
sources_list: list of unique results from Tavily API
max_tokens_per_source: int, maximum number of tokens per each search result to include in the formatted string
include_raw_content: bool, whether to include the raw_content from Tavily in the formatted string
Returns:
str: Formatted string with deduplicated sources
"""
# Format output
formatted_text = "Sources:\n\n"
for i, source in enumerate(unique_sources.values(), 1):
for source in sources_list:
formatted_text += f"Source {source['title']}:\n===\n"
formatted_text += f"URL: {source['url']}\n===\n"
formatted_text += (