mirror of
https://github.com/langchain-ai/langchain-benchmarks.git
synced 2026-07-01 22:34:02 -04:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 538533c2c5 | |||
| 6845999d35 | |||
| 3ff94b90f9 | |||
| 9d5e6695bc | |||
| 6b5a955078 | |||
| 32d2a50ea9 |
+1
-1
@@ -159,4 +159,4 @@ cython_debug/
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
.DS_Store
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
chromadb/
|
||||
File diff suppressed because it is too large
Load Diff
+2
-2
@@ -1,3 +1,3 @@
|
||||
from openai_functions_agent.agent import agent_executor
|
||||
from openai_functions_agent.agent import agent_executor, create_executor
|
||||
|
||||
__all__ = ["agent_executor"]
|
||||
__all__ = ["agent_executor", "create_executor"]
|
||||
|
||||
+29
-26
@@ -1,4 +1,4 @@
|
||||
from typing import List, Tuple
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain.agents.format_scratchpad import format_to_openai_functions
|
||||
@@ -25,7 +25,6 @@ def search(query, callbacks=None):
|
||||
|
||||
tools = [search]
|
||||
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo-16k", temperature=0)
|
||||
assistant_system_message = """You are a helpful assistant tasked with answering technical questions about LangChain. \
|
||||
Use tools (only if necessary) to best answer the users questions. Do not make up information if you cannot find the answer using your tools."""
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
@@ -37,8 +36,6 @@ prompt = ChatPromptTemplate.from_messages(
|
||||
]
|
||||
)
|
||||
|
||||
llm_with_tools = llm.bind(functions=[format_tool_to_openai_function(t) for t in tools])
|
||||
|
||||
|
||||
def _format_chat_history(chat_history: List[Tuple[str, str]]):
|
||||
buffer = []
|
||||
@@ -48,30 +45,11 @@ def _format_chat_history(chat_history: List[Tuple[str, str]]):
|
||||
return buffer
|
||||
|
||||
|
||||
agent = (
|
||||
{
|
||||
"input": lambda x: x["input"],
|
||||
"chat_history": lambda x: _format_chat_history(x["chat_history"]),
|
||||
"agent_scratchpad": lambda x: format_to_openai_functions(
|
||||
x["intermediate_steps"]
|
||||
),
|
||||
}
|
||||
| prompt
|
||||
| llm_with_tools
|
||||
| OpenAIFunctionsAgentOutputParser()
|
||||
)
|
||||
|
||||
|
||||
class AgentInput(BaseModel):
|
||||
input: str
|
||||
chat_history: List[Tuple[str, str]] = Field(..., extra={"widget": {"type": "chat"}})
|
||||
|
||||
|
||||
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=False).with_types(
|
||||
input_type=AgentInput
|
||||
)
|
||||
|
||||
|
||||
class ChainInput(BaseModel):
|
||||
question: str
|
||||
|
||||
@@ -80,6 +58,31 @@ def mapper(input: dict):
|
||||
return {"input": input["question"], "chat_history": []}
|
||||
|
||||
|
||||
agent_executor = (mapper | agent_executor | (lambda x: x["output"])).with_types(
|
||||
input_type=ChainInput
|
||||
)
|
||||
def create_executor(model_config: Optional[dict] = None):
|
||||
model = (model_config or {}).get("model", "gpt-3.5-turbo-16k")
|
||||
llm = ChatOpenAI(model=model, temperature=0)
|
||||
llm_with_tools = llm.bind(
|
||||
functions=[format_tool_to_openai_function(t) for t in tools]
|
||||
)
|
||||
|
||||
agent = (
|
||||
{
|
||||
"input": lambda x: x["input"],
|
||||
"chat_history": lambda x: _format_chat_history(x["chat_history"]),
|
||||
"agent_scratchpad": lambda x: format_to_openai_functions(
|
||||
x["intermediate_steps"]
|
||||
),
|
||||
}
|
||||
| prompt
|
||||
| llm_with_tools
|
||||
| OpenAIFunctionsAgentOutputParser()
|
||||
)
|
||||
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=False).with_types(
|
||||
input_type=AgentInput
|
||||
)
|
||||
return (mapper | agent_executor | (lambda x: x["output"])).with_types(
|
||||
input_type=ChainInput
|
||||
)
|
||||
|
||||
|
||||
agent_executor = create_executor()
|
||||
|
||||
@@ -1,19 +1,30 @@
|
||||
"""Copy the public dataset to your own langsmith tenant."""
|
||||
from typing import Optional
|
||||
from langsmith import Client
|
||||
from tqdm import tqdm
|
||||
|
||||
DATASET_NAME = "LangChain Docs Q&A"
|
||||
PUBLIC_DATASET_TOKEN = "452ccafc-18e1-4314-885b-edd735f17b9d"
|
||||
client = Client()
|
||||
|
||||
|
||||
def create_langchain_docs_dataset(
|
||||
dataset_name: str = DATASET_NAME, public_dataset_token: str = PUBLIC_DATASET_TOKEN
|
||||
dataset_name: str = DATASET_NAME,
|
||||
public_dataset_token: str = PUBLIC_DATASET_TOKEN,
|
||||
client: Optional[Client] = None,
|
||||
):
|
||||
shared_client = Client(
|
||||
api_url="https://api.smith.langchain.com", api_key="placeholder"
|
||||
)
|
||||
examples = list(shared_client.list_shared_examples(public_dataset_token))
|
||||
client = client or Client()
|
||||
if client.has_dataset(dataset_name=dataset_name):
|
||||
return
|
||||
ds = client.create_dataset(dataset_name=dataset_name)
|
||||
examples = tqdm(list(client.list_shared_examples(public_dataset_token)))
|
||||
loaded_examples = list(client.list_examples(dataset_name=dataset_name))
|
||||
if len(loaded_examples) == len(examples):
|
||||
return
|
||||
else:
|
||||
ds = client.read_dataset(dataset_name=dataset_name)
|
||||
else:
|
||||
ds = client.create_dataset(dataset_name=dataset_name)
|
||||
client.create_examples(
|
||||
inputs=[e.inputs for e in examples],
|
||||
outputs=[e.outputs for e in examples],
|
||||
@@ -23,4 +34,24 @@ def create_langchain_docs_dataset(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
create_langchain_docs_dataset()
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--target-api-key", type=str, required=False)
|
||||
parser.add_argument("--target-endpoint", type=str, required=False)
|
||||
parser.add_argument("--dataset-name", type=str, default=DATASET_NAME)
|
||||
parser.add_argument(
|
||||
"--public-dataset-token", type=str, default=PUBLIC_DATASET_TOKEN
|
||||
)
|
||||
args = parser.parse_args()
|
||||
client = None
|
||||
if args.target_api_key or args.target_endpoint:
|
||||
client = Client(
|
||||
api_key=args.target_api_key,
|
||||
api_url=args.target_endpoint,
|
||||
)
|
||||
create_langchain_docs_dataset(
|
||||
dataset_name=args.dataset_name,
|
||||
public_dataset_token=args.public_dataset_token,
|
||||
client=client,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import argparse
|
||||
import importlib.util
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from functools import partial
|
||||
@@ -13,7 +12,7 @@ from langchain.schema.runnable import Runnable
|
||||
from langchain.smith import RunEvalConfig, run_on_dataset
|
||||
from langsmith import Client
|
||||
from oai_assistant.chain import agent_executor as openai_assistant_chain
|
||||
from openai_functions_agent import agent_executor as openai_functions_agent_chain
|
||||
from openai_functions_agent import create_executor
|
||||
|
||||
ls_client = Client()
|
||||
|
||||
@@ -33,7 +32,7 @@ def _get_chain_factory(arch: str) -> Callable:
|
||||
_map = {
|
||||
"chat": create_chain,
|
||||
"anthropic-iterative-search": lambda _: anthropic_agent_chain,
|
||||
"openai-functions-agent": lambda _: openai_functions_agent_chain,
|
||||
"openai-functions-agent": create_executor,
|
||||
"openai-assistant": lambda _: openai_assistant_chain,
|
||||
}
|
||||
if arch in _map:
|
||||
@@ -93,8 +92,7 @@ def main(
|
||||
run_on_dataset(
|
||||
client=ls_client,
|
||||
dataset_name=dataset_name,
|
||||
llm_or_chain_factory=partial(
|
||||
create_runnable,
|
||||
llm_or_chain_factory=lambda: create_runnable(
|
||||
arch=arch,
|
||||
model_config=model_config,
|
||||
retry_config=retry_config,
|
||||
|
||||
@@ -6,9 +6,14 @@ from run_evals import main
|
||||
|
||||
experiments = [
|
||||
{
|
||||
# "server_url": "http://localhost:1983/openai-functions-agent",
|
||||
"arch": "openai-functions-agent",
|
||||
"project_name": "openai-functions-agent",
|
||||
"model_config": {"model": "gpt-3.5-turbo-16k"},
|
||||
},
|
||||
{
|
||||
"arch": "openai-functions-agent",
|
||||
"project_name": "oaifunc-agent-gpt-4-1106",
|
||||
"model_config": {"model": "gpt-4-1106-preview"},
|
||||
},
|
||||
{
|
||||
# "server_url": "http://localhost:1983/anthropic_chat",
|
||||
@@ -41,9 +46,17 @@ experiments = [
|
||||
"arch": "chat",
|
||||
"model_config": {
|
||||
"chat_cls": "ChatFireworks",
|
||||
"model": "accounts/fireworks/models/llama-v2-34b-code-instruct-w8a16",
|
||||
"model": "accounts/fireworks/models/llama-v2-34b-code-instruct",
|
||||
},
|
||||
"project_name": "llama-v2-34b-code-instruct-w8a16",
|
||||
"project_name": "llama-v2-34b-code-instruct",
|
||||
},
|
||||
{
|
||||
"arch": "chat",
|
||||
"model_config": {
|
||||
"chat_cls": "ChatFireworks",
|
||||
"model": "accounts/fireworks/models/llama-v2-70b-chat",
|
||||
},
|
||||
"project_name": "llama-70b-chat",
|
||||
},
|
||||
{
|
||||
"arch": "chat",
|
||||
@@ -120,6 +133,7 @@ if __name__ == "__main__":
|
||||
]
|
||||
|
||||
for experiment in selected_experiments:
|
||||
print("Running experiment:", experiment)
|
||||
main(
|
||||
**experiment,
|
||||
dataset_name=args.dataset_name,
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
.sql
|
||||
@@ -0,0 +1,6 @@
|
||||
"""RAG environments."""
|
||||
from langchain_benchmarks.rag.evaluators import RAG_EVALUATION
|
||||
from langchain_benchmarks.rag.registration import registry
|
||||
|
||||
# Please keep this list sorted!
|
||||
__all__ = ["registry", "RAG_EVALUATION"]
|
||||
@@ -0,0 +1,6 @@
|
||||
# LangChain Docs Environment
|
||||
|
||||
This code contains utilities to scrape the LangChain docs (already run) and index them
|
||||
using common techniques. The docs were scraped using the code in `_ingest_docs.py` and
|
||||
uploaded to gcs. To better compare retrieval techniques, we hold these constant and pull
|
||||
from that cache whenever generating different indices.
|
||||
@@ -0,0 +1 @@
|
||||
DATASET_ID = "452ccafc-18e1-4314-885b-edd735f17b9d" # ID of public LangChain Docs dataset
|
||||
@@ -0,0 +1,299 @@
|
||||
"""Load html from files, clean up, split, ingest."""
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Generator, Iterable, Optional
|
||||
|
||||
from bs4 import BeautifulSoup, Doctype, NavigableString, SoupStrainer, Tag
|
||||
from langchain.document_loaders import RecursiveUrlLoader, SitemapLoader
|
||||
from langchain.embeddings import OpenAIEmbeddings, VoyageEmbeddings
|
||||
from langchain.indexes import SQLRecordManager, index
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain.utils.html import PREFIXES_TO_IGNORE_REGEX, SUFFIXES_TO_IGNORE_REGEX
|
||||
from langchain.vectorstores.chroma import Chroma
|
||||
import pandas as pd
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
directory = os.path.dirname(os.path.realpath(__file__))
|
||||
db_directory = os.path.join(directory, "langchain_docs_retriever", "db")
|
||||
docs_cache_directory = os.path.join(directory, "langchain_docs_retriever", "db_docs")
|
||||
docs_cache_file = os.path.join(docs_cache_directory, "docs.parquet")
|
||||
|
||||
|
||||
def langchain_docs_extractor(soup: BeautifulSoup) -> str:
|
||||
# Remove all the tags that are not meaningful for the extraction.
|
||||
SCAPE_TAGS = ["nav", "footer", "aside", "script", "style"]
|
||||
[tag.decompose() for tag in soup.find_all(SCAPE_TAGS)]
|
||||
|
||||
def get_text(tag: Tag) -> Generator[str, None, None]:
|
||||
for child in tag.children:
|
||||
if isinstance(child, Doctype):
|
||||
continue
|
||||
|
||||
if isinstance(child, NavigableString):
|
||||
yield child
|
||||
elif isinstance(child, Tag):
|
||||
if child.name in ["h1", "h2", "h3", "h4", "h5", "h6"]:
|
||||
yield f"{'#' * int(child.name[1:])} {child.get_text()}\n\n"
|
||||
elif child.name == "a":
|
||||
yield f"[{child.get_text(strip=False)}]({child.get('href')})"
|
||||
elif child.name == "img":
|
||||
yield f"})"
|
||||
elif child.name in ["strong", "b"]:
|
||||
yield f"**{child.get_text(strip=False)}**"
|
||||
elif child.name in ["em", "i"]:
|
||||
yield f"_{child.get_text(strip=False)}_"
|
||||
elif child.name == "br":
|
||||
yield "\n"
|
||||
elif child.name == "code":
|
||||
parent = child.find_parent()
|
||||
if parent is not None and parent.name == "pre":
|
||||
classes = parent.attrs.get("class", "")
|
||||
|
||||
language = next(
|
||||
filter(lambda x: re.match(r"language-\w+", x), classes),
|
||||
None,
|
||||
)
|
||||
if language is None:
|
||||
language = ""
|
||||
else:
|
||||
language = language.split("-")[1]
|
||||
|
||||
lines: list[str] = []
|
||||
for span in child.find_all("span", class_="token-line"):
|
||||
line_content = "".join(
|
||||
token.get_text() for token in span.find_all("span")
|
||||
)
|
||||
lines.append(line_content)
|
||||
|
||||
code_content = "\n".join(lines)
|
||||
yield f"```{language}\n{code_content}\n```\n\n"
|
||||
else:
|
||||
yield f"`{child.get_text(strip=False)}`"
|
||||
|
||||
elif child.name == "p":
|
||||
yield from get_text(child)
|
||||
yield "\n\n"
|
||||
elif child.name == "ul":
|
||||
for li in child.find_all("li", recursive=False):
|
||||
yield "- "
|
||||
yield from get_text(li)
|
||||
yield "\n\n"
|
||||
elif child.name == "ol":
|
||||
for i, li in enumerate(child.find_all("li", recursive=False)):
|
||||
yield f"{i + 1}. "
|
||||
yield from get_text(li)
|
||||
yield "\n\n"
|
||||
elif child.name == "div" and "tabs-container" in child.attrs.get(
|
||||
"class", [""]
|
||||
):
|
||||
tabs = child.find_all("li", {"role": "tab"})
|
||||
tab_panels = child.find_all("div", {"role": "tabpanel"})
|
||||
for tab, tab_panel in zip(tabs, tab_panels):
|
||||
tab_name = tab.get_text(strip=True)
|
||||
yield f"{tab_name}\n"
|
||||
yield from get_text(tab_panel)
|
||||
elif child.name == "table":
|
||||
thead = child.find("thead")
|
||||
header_exists = isinstance(thead, Tag)
|
||||
if header_exists:
|
||||
headers = thead.find_all("th")
|
||||
if headers:
|
||||
yield "| "
|
||||
yield " | ".join(header.get_text() for header in headers)
|
||||
yield " |\n"
|
||||
yield "| "
|
||||
yield " | ".join("----" for _ in headers)
|
||||
yield " |\n"
|
||||
|
||||
tbody = child.find("tbody")
|
||||
tbody_exists = isinstance(tbody, Tag)
|
||||
if tbody_exists:
|
||||
for row in tbody.find_all("tr"):
|
||||
yield "| "
|
||||
yield " | ".join(
|
||||
cell.get_text(strip=True) for cell in row.find_all("td")
|
||||
)
|
||||
yield " |\n"
|
||||
|
||||
yield "\n\n"
|
||||
elif child.name in ["button"]:
|
||||
continue
|
||||
else:
|
||||
yield from get_text(child)
|
||||
|
||||
joined = "".join(get_text(soup))
|
||||
return re.sub(r"\n\n+", "\n\n", joined).strip()
|
||||
|
||||
|
||||
RECORD_MANAGER_DB_URL = (
|
||||
os.environ.get("RECORD_MANAGER_DB_URL") or "sqlite:///lcdocs_oai_record_manager.sql"
|
||||
)
|
||||
|
||||
|
||||
def metadata_extractor(meta: dict, soup: BeautifulSoup) -> dict:
|
||||
title = soup.find("title")
|
||||
description = soup.find("meta", attrs={"name": "description"})
|
||||
html = soup.find("html")
|
||||
return {
|
||||
"source": meta["loc"] or "",
|
||||
"title": (title.get_text() if title else "") or "",
|
||||
"description": description.get("content") or "" if description else "",
|
||||
"language": html.get("lang") or "" if html else "",
|
||||
**{k: v or "" for k, v in meta.items()},
|
||||
}
|
||||
|
||||
|
||||
def load_langchain_docs():
|
||||
return SitemapLoader(
|
||||
"https://python.langchain.com/sitemap.xml",
|
||||
filter_urls=["https://python.langchain.com/"],
|
||||
parsing_function=langchain_docs_extractor,
|
||||
default_parser="lxml",
|
||||
bs_kwargs={
|
||||
"parse_only": SoupStrainer(
|
||||
name=("article", "title", "html", "lang", "content")
|
||||
),
|
||||
},
|
||||
meta_function=metadata_extractor,
|
||||
).load()
|
||||
|
||||
|
||||
def simple_extractor(html: str) -> str:
|
||||
soup = BeautifulSoup(html, "lxml")
|
||||
return re.sub(r"\n\n+", "\n\n", soup.text).strip()
|
||||
|
||||
|
||||
def load_api_docs():
|
||||
return RecursiveUrlLoader(
|
||||
url="https://api.python.langchain.com/en/latest/",
|
||||
max_depth=8,
|
||||
extractor=simple_extractor,
|
||||
prevent_outside=True,
|
||||
use_async=True,
|
||||
timeout=600,
|
||||
# Drop trailing / to avoid duplicate pages.
|
||||
link_regex=(
|
||||
f"href=[\"']{PREFIXES_TO_IGNORE_REGEX}((?:{SUFFIXES_TO_IGNORE_REGEX}.)*?)"
|
||||
r"(?:[\#'\"]|\/[\#'\"])"
|
||||
),
|
||||
check_response_status=True,
|
||||
exclude_dirs=(
|
||||
"https://api.python.langchain.com/en/latest/_sources",
|
||||
"https://api.python.langchain.com/en/latest/_modules",
|
||||
),
|
||||
).load()
|
||||
|
||||
|
||||
def get_embeddings_model() -> Embeddings:
|
||||
if os.environ.get("VOYAGE_AI_URL") and os.environ.get("VOYAGE_AI_MODEL"):
|
||||
return VoyageEmbeddings()
|
||||
return OpenAIEmbeddings(chunk_size=200)
|
||||
|
||||
|
||||
CHROMA_COLLECTION_NAME = "langchain-docs"
|
||||
|
||||
|
||||
def get_docs() -> Iterable[Document]:
|
||||
# TODO: Make this function actually a generator
|
||||
# Import before loading because it's a bummer to fail after scraping.
|
||||
# we should have an incremental scrape cache.
|
||||
import pyarrow as pa # type: ignore
|
||||
|
||||
docs_from_documentation = load_langchain_docs()
|
||||
logger.info(f"Loaded {len(docs_from_documentation)} docs from documentation")
|
||||
docs_from_api = load_api_docs()
|
||||
logger.info(f"Loaded {len(docs_from_api)} docs from API")
|
||||
|
||||
# We try to return 'source' and 'title' metadata when querying vector store and
|
||||
# Chroma will error at query time if one of the attributes is missing from a
|
||||
# retrieved document.
|
||||
for doc in docs_from_documentation + docs_from_api:
|
||||
if "source" not in doc.metadata:
|
||||
doc.metadata["source"] = ""
|
||||
if "title" not in doc.metadata:
|
||||
doc.metadata["title"] = ""
|
||||
for k, v in doc.metadata.items():
|
||||
if v is None:
|
||||
doc.metadata[k] = ""
|
||||
if not doc.page_content.strip():
|
||||
continue
|
||||
yield doc
|
||||
|
||||
|
||||
def load_docs_from_parquet(filename: Optional[str] = None) -> Iterable[Document]:
|
||||
df = pd.read_parquet(filename or docs_cache_file)
|
||||
docs_transformed = [Document(**row) for row in df.to_dict(orient="records")]
|
||||
for doc in docs_transformed:
|
||||
for k, v in doc.metadata.items():
|
||||
if v is None:
|
||||
doc.metadata[k] = ""
|
||||
if not doc.page_content.strip():
|
||||
continue
|
||||
yield doc
|
||||
|
||||
|
||||
# default ingest function
|
||||
def ingest_docs(overwrite: bool = False):
|
||||
if os.path.exists(docs_cache_file) and not overwrite:
|
||||
logger.info(f"Loading docs from {docs_cache_file}")
|
||||
documents = load_docs_from_parquet(docs_cache_file)
|
||||
|
||||
else:
|
||||
documents = get_docs()
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=4000, chunk_overlap=200)
|
||||
docs_transformed = text_splitter.split_documents(documents)
|
||||
embedding = get_embeddings_model()
|
||||
vectorstore = Chroma(
|
||||
collection_name=CHROMA_COLLECTION_NAME,
|
||||
embedding_function=embedding,
|
||||
persist_directory=db_directory,
|
||||
)
|
||||
|
||||
record_manager = SQLRecordManager(
|
||||
f"chroma/{CHROMA_COLLECTION_NAME}", db_url=RECORD_MANAGER_DB_URL
|
||||
)
|
||||
record_manager.create_schema()
|
||||
|
||||
indexing_stats = index(
|
||||
docs_transformed,
|
||||
record_manager,
|
||||
vectorstore,
|
||||
cleanup="full",
|
||||
source_id_key="source",
|
||||
)
|
||||
|
||||
logger.info("Indexing stats: ", indexing_stats)
|
||||
|
||||
|
||||
def download_docs(overwrite: bool = False):
|
||||
if os.path.exists(docs_cache_file) and not overwrite:
|
||||
logger.info(f"Loading docs from {docs_cache_file}")
|
||||
return
|
||||
if not os.path.exists(docs_cache_directory):
|
||||
os.makedirs(docs_cache_directory)
|
||||
docs = get_docs()
|
||||
# Write as parquet file
|
||||
df = pd.DataFrame.from_records([doc.dict() for doc in docs])
|
||||
df.to_parquet(docs_cache_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--action",
|
||||
choices=["ingest", "download"],
|
||||
default="download",
|
||||
)
|
||||
parser.add_argument("--overwrite", action="store_true")
|
||||
args = parser.parse_args()
|
||||
if args.action == "download":
|
||||
download_docs(args.overwrite)
|
||||
elif args.action == "ingest":
|
||||
ingest_docs(args.overwrite)
|
||||
else:
|
||||
raise ValueError(f"Unknown action {args.action}")
|
||||
@@ -0,0 +1,5 @@
|
||||
from langchain_benchmarks.rag.environments.langchain_docs.architectures.chain_registry import (
|
||||
ARCH_FACTORIES,
|
||||
)
|
||||
|
||||
__all__ = ["ARCH_FACTORIES"]
|
||||
@@ -0,0 +1,29 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
from langchain.schema.runnable import Runnable
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain_benchmarks.rag.environments.langchain_docs.architectures.crqa import (
|
||||
create_response_chain,
|
||||
get_default_response_synthesizer,
|
||||
)
|
||||
|
||||
|
||||
def default_response_chain(
|
||||
retriever: BaseRetriever,
|
||||
response_synthesizer: Optional[Runnable] = None,
|
||||
llm: Optional[BaseLanguageModel] = None,
|
||||
) -> None:
|
||||
"""Get the chain responsible for generating a response based on the retrieved documents."""
|
||||
response_synthesizer = response_synthesizer or get_default_response_synthesizer(
|
||||
llm=llm or ChatOpenAI(model="gpt-3.5-turbo-16k", model_kwargs={"seed": 42})
|
||||
)
|
||||
return create_response_chain(
|
||||
response_synthesizer=response_synthesizer, retriever=retriever
|
||||
)
|
||||
|
||||
|
||||
ARCH_FACTORIES = {
|
||||
"conversational-retrieval-qa": default_response_chain,
|
||||
}
|
||||
@@ -0,0 +1,116 @@
|
||||
"""Chat langchain 'engine'."""
|
||||
# TODO: some simplified architectures that are
|
||||
# environment-agnostic
|
||||
from typing import Callable, Dict, List, Optional, Sequence
|
||||
|
||||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain.schema import Document
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.messages import AIMessage, HumanMessage
|
||||
from langchain.schema.output_parser import StrOutputParser
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
from langchain.schema.runnable import (
|
||||
Runnable,
|
||||
RunnableLambda,
|
||||
)
|
||||
from langchain.schema.runnable.passthrough import RunnableAssign
|
||||
from pydantic import BaseModel
|
||||
from operator import itemgetter
|
||||
|
||||
RESPONSE_TEMPLATE = """\
|
||||
You are an expert programmer and problem-solver, tasked with answering any question \
|
||||
about Langchain.
|
||||
|
||||
Generate a comprehensive and informative answer of 80 words or less for the \
|
||||
given question based solely on the provided search results (URL and content). You must \
|
||||
only use information from the provided search results. Use an unbiased and \
|
||||
journalistic tone. Combine search results together into a coherent answer. Do not \
|
||||
repeat text. Cite search results using [${{number}}] notation. Only cite the most \
|
||||
relevant results that answer the question accurately. Place these citations at the end \
|
||||
of the sentence or paragraph that reference them - do not put them all at the end. If \
|
||||
different results refer to different entities within the same name, write separate \
|
||||
answers for each entity.
|
||||
|
||||
You should use bullet points in your answer for readability. Put citations where they apply
|
||||
rather than putting them all at the end.
|
||||
|
||||
If there is nothing in the context relevant to the question at hand, just say "Hmm, \
|
||||
I'm not sure." Don't try to make up an answer.
|
||||
|
||||
Anything between the following `context` html blocks is retrieved from a knowledge \
|
||||
bank, not part of the conversation with the user.
|
||||
|
||||
<context>
|
||||
{context}
|
||||
<context/>
|
||||
|
||||
REMEMBER: If there is no relevant information within the context, just say "Hmm, I'm \
|
||||
not sure." Don't try to make up an answer. Anything between the preceding 'context' \
|
||||
html blocks is retrieved from a knowledge bank, not part of the conversation with the \
|
||||
user.\
|
||||
"""
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
question: str
|
||||
chat_history: Optional[List[Dict[str, str]]]
|
||||
|
||||
|
||||
def _format_docs(docs: Sequence[Document]) -> str:
|
||||
formatted_docs = []
|
||||
for i, doc in enumerate(docs):
|
||||
doc_string = f"<doc id='{i}'>{doc.page_content}</doc>"
|
||||
formatted_docs.append(doc_string)
|
||||
return "\n".join(formatted_docs)
|
||||
|
||||
|
||||
def serialize_history(request: ChatRequest):
|
||||
chat_history = request.get("chat_history") or []
|
||||
converted_chat_history = []
|
||||
for message in chat_history:
|
||||
if message.get("human") is not None:
|
||||
converted_chat_history.append(HumanMessage(content=message["human"]))
|
||||
if message.get("ai") is not None:
|
||||
converted_chat_history.append(AIMessage(content=message["ai"]))
|
||||
return converted_chat_history
|
||||
|
||||
|
||||
def get_default_response_synthesizer(llm: BaseLanguageModel) -> Runnable:
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", RESPONSE_TEMPLATE),
|
||||
MessagesPlaceholder(variable_name="chat_history"),
|
||||
("human", "{question}"),
|
||||
]
|
||||
)
|
||||
|
||||
return (prompt | llm | StrOutputParser()).with_config(
|
||||
run_name="GenerateResponse",
|
||||
)
|
||||
|
||||
|
||||
def create_response_chain(
|
||||
response_synthesizer: Runnable,
|
||||
retriever: BaseRetriever,
|
||||
format_docs: Optional[Callable[[Sequence[Document]], str]] = None,
|
||||
format_chat_history: Optional[Callable[[ChatRequest], str]] = None,
|
||||
) -> Runnable:
|
||||
format_docs = format_docs or _format_docs
|
||||
format_chat_history = format_chat_history or serialize_history
|
||||
return (
|
||||
RunnableAssign(
|
||||
{
|
||||
"chat_history": RunnableLambda(format_chat_history).with_config(
|
||||
run_name="SerializeHistory"
|
||||
)
|
||||
}
|
||||
)
|
||||
| RunnableAssign(
|
||||
{
|
||||
"context": (
|
||||
itemgetter("question") | retriever | format_docs
|
||||
).with_config(run_name="FormatDocs")
|
||||
}
|
||||
)
|
||||
| response_synthesizer
|
||||
)
|
||||
+4
@@ -0,0 +1,4 @@
|
||||
db/
|
||||
docs_db/
|
||||
.sql
|
||||
.bin
|
||||
+18
@@ -0,0 +1,18 @@
|
||||
from langchain_benchmarks.rag.environments.langchain_docs.langchain_docs_retriever.retriever import (
|
||||
get_parent_document_retriever,
|
||||
get_vectorstore_retriever,
|
||||
get_hyde_retriever,
|
||||
create_index,
|
||||
)
|
||||
from langchain_benchmarks.rag.environments.langchain_docs.langchain_docs_retriever.retriever_registry import (
|
||||
RETRIEVER_FACTORIES,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"create_index",
|
||||
"get_hyde_retriever",
|
||||
"get_parent_document_retriever",
|
||||
"get_retriever",
|
||||
"get_vectorstore_retriever",
|
||||
"RETRIEVER_FACTORIES",
|
||||
]
|
||||
BIN
Binary file not shown.
+47
@@ -0,0 +1,47 @@
|
||||
import os
|
||||
from typing import Iterable
|
||||
import zipfile
|
||||
|
||||
import requests
|
||||
|
||||
chroma_remote_url = "https://storage.googleapis.com/benchmarks-artifacts/langchain-docs-benchmarking/chroma_db.zip"
|
||||
raw_docs_file = "https://storage.googleapis.com/benchmarks-artifacts/langchain-docs-benchmarking/docs.parquet"
|
||||
directory = os.path.dirname(os.path.realpath(__file__))
|
||||
db_directory = os.path.join(directory, "db")
|
||||
DOCS_FILE = os.path.join(directory, "db_docs/docs.parquet")
|
||||
|
||||
|
||||
def is_folder_populated(folder):
|
||||
if os.path.exists(folder):
|
||||
return any(os.scandir(folder))
|
||||
return False
|
||||
|
||||
|
||||
def download_chroma_folder_from_gcs():
|
||||
r = requests.get(chroma_remote_url, allow_redirects=True)
|
||||
open("chroma_db.zip", "wb").write(r.content)
|
||||
|
||||
with zipfile.ZipFile("chroma_db.zip", "r") as zip_ref:
|
||||
zip_ref.extractall(directory)
|
||||
|
||||
os.remove("chroma_db.zip")
|
||||
|
||||
|
||||
def fetch_langchain_docs_db():
|
||||
if not is_folder_populated(db_directory):
|
||||
print(f"Folder {db_directory} is not populated. Downloading from GCS...")
|
||||
download_chroma_folder_from_gcs()
|
||||
|
||||
|
||||
def fetch_remote_parquet_file():
|
||||
if not os.path.exists(DOCS_FILE):
|
||||
print(f"File {DOCS_FILE} does not exist. Downloading from GCS...")
|
||||
r = requests.get(raw_docs_file, allow_redirects=True)
|
||||
if not os.path.exists(os.path.dirname(DOCS_FILE)):
|
||||
os.makedirs(os.path.dirname(DOCS_FILE))
|
||||
open(DOCS_FILE, "wb").write(r.content)
|
||||
print(f"File {DOCS_FILE} downloaded.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fetch_remote_parquet_file()
|
||||
+205
@@ -0,0 +1,205 @@
|
||||
import logging
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Callable, Iterable, Optional
|
||||
|
||||
import pandas as pd
|
||||
from langchain.indexes import SQLRecordManager, index
|
||||
from langchain.retrievers.multi_vector import MultiVectorRetriever
|
||||
from langchain.retrievers.parent_document_retriever import ParentDocumentRetriever
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
from langchain.schema.runnable import Runnable
|
||||
from langchain.schema.storage import BaseStore
|
||||
from langchain.schema.vectorstore import VectorStore
|
||||
from langchain.storage import InMemoryStore
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
from langchain_benchmarks.rag.utils.indexing import (
|
||||
transform_docs_hyde,
|
||||
transform_docs_parent_child,
|
||||
)
|
||||
|
||||
from .download_db import DOCS_FILE, fetch_remote_parquet_file
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_DIRECTORY = os.path.dirname(os.path.abspath(__file__))
|
||||
COLLECTION_NAME = "langchain-docs"
|
||||
RECORD_MANAGER_DB_URL = (
|
||||
os.environ.get("RECORD_MANAGER_DB_URL")
|
||||
or f"sqlite:///{_DIRECTORY}_record_manager.sql"
|
||||
)
|
||||
|
||||
_DEFAULT_SEARCH_KWARGS = {"k": 6}
|
||||
|
||||
|
||||
def load_docs_from_parquet(filename: Optional[str] = None) -> Iterable[Document]:
|
||||
df = pd.read_parquet(filename)
|
||||
docs_transformed = [Document(**row) for row in df.to_dict(orient="records")]
|
||||
for doc in docs_transformed:
|
||||
for k, v in doc.metadata.items():
|
||||
if v is None:
|
||||
doc.metadata[k] = ""
|
||||
if not doc.page_content.strip():
|
||||
continue
|
||||
yield doc
|
||||
|
||||
|
||||
def create_index(
|
||||
embedding: Embeddings,
|
||||
vectorstore: VectorStore,
|
||||
*,
|
||||
transform_docs: Optional[Callable] = None,
|
||||
transformation_name: Optional[str] = None,
|
||||
):
|
||||
fetch_remote_parquet_file()
|
||||
docs = load_docs_from_parquet(DOCS_FILE)
|
||||
if transform_docs:
|
||||
if not transformation_name:
|
||||
raise ValueError(
|
||||
"If you provide a transform function, you must also provide a "
|
||||
"transformation name to use for the record manager."
|
||||
)
|
||||
transformed_docs = transform_docs(docs)
|
||||
else:
|
||||
transformed_docs = docs
|
||||
transformation_name = transformation_name or "raw"
|
||||
vectorstore_name = vectorstore.__class__.__name__
|
||||
embedding_name = embedding.__class__.__name__
|
||||
record_manager = SQLRecordManager(
|
||||
f"{vectorstore_name}/{COLLECTION_NAME}_{vectorstore_name}_{embedding_name}_{transformation_name}",
|
||||
db_url=RECORD_MANAGER_DB_URL,
|
||||
)
|
||||
record_manager.create_schema()
|
||||
|
||||
return index(
|
||||
transformed_docs,
|
||||
record_manager,
|
||||
vectorstore,
|
||||
cleanup="full",
|
||||
source_id_key="source",
|
||||
)
|
||||
|
||||
|
||||
def get_vectorstore_retriever(
|
||||
embedding: Embeddings,
|
||||
vectorstore: VectorStore,
|
||||
*,
|
||||
transform_docs: Optional[Callable] = None,
|
||||
transformation_name: Optional[str] = None,
|
||||
search_kwargs: Optional[dict] = None,
|
||||
) -> BaseRetriever:
|
||||
"""Index the documents (with caching) and return a vector store retriever."""
|
||||
index_stats = create_index(
|
||||
embedding,
|
||||
vectorstore,
|
||||
transform_docs=transform_docs,
|
||||
transformation_name=transformation_name,
|
||||
)
|
||||
logger.info(f"Index stats: {index_stats}")
|
||||
kwargs = search_kwargs or _DEFAULT_SEARCH_KWARGS
|
||||
kwargs.setdefault("metadata", {}).setdefault(
|
||||
"benchmark_environment", "langchain-docs"
|
||||
)
|
||||
kwargs.setdefault("tags", []).append("langchain-benchmarks")
|
||||
return vectorstore.as_retriever(**kwargs)
|
||||
|
||||
|
||||
def get_parent_document_retriever(
|
||||
embedding: Embeddings,
|
||||
vectorstore: VectorStore,
|
||||
*,
|
||||
child_splitter: Optional[TextSplitter] = None,
|
||||
transformation_name: Optional[str] = None,
|
||||
id_key: str = "source",
|
||||
docstore: Optional[BaseStore] = None,
|
||||
parent_splitter: Optional[TextSplitter] = None,
|
||||
search_kwargs: Optional[dict] = None,
|
||||
):
|
||||
"""Index the documents (with caching) and return a parent document retriever."""
|
||||
docstore = docstore or InMemoryStore()
|
||||
if child_splitter is None:
|
||||
child_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=4000, chunk_overlap=200
|
||||
)
|
||||
transformation_name = "parent-document-recursive-cs4k_ol200"
|
||||
logger.info(f"Using default child splitter:\n{child_splitter}")
|
||||
else:
|
||||
if transformation_name is None:
|
||||
raise ValueError(
|
||||
"If you provide a custom child splitter, you must also provide a "
|
||||
"transformation name to use for the record manager."
|
||||
)
|
||||
transformation = partial(
|
||||
transform_docs_parent_child,
|
||||
child_splitter=child_splitter,
|
||||
docstore=docstore,
|
||||
parent_splitter=parent_splitter,
|
||||
id_key=id_key,
|
||||
)
|
||||
index_stats = create_index(
|
||||
embedding,
|
||||
vectorstore,
|
||||
transform_docs=transformation,
|
||||
transformation_name=transformation_name,
|
||||
)
|
||||
logger.info(f"Index stats: {index_stats}")
|
||||
|
||||
return ParentDocumentRetriever(
|
||||
tags=["langchain-benchmarks"],
|
||||
metadata={"benchmark_environment": "langchain-docs"},
|
||||
vectorstore=vectorstore,
|
||||
docstore=docstore,
|
||||
child_splitter=child_splitter,
|
||||
parent_splitter=parent_splitter,
|
||||
search_kwargs=search_kwargs or _DEFAULT_SEARCH_KWARGS,
|
||||
id_key=id_key,
|
||||
)
|
||||
|
||||
|
||||
def get_hyde_retriever(
|
||||
embedding: Embeddings,
|
||||
vectorstore: VectorStore,
|
||||
*,
|
||||
docstore: Optional[BaseStore] = None,
|
||||
query_generator: Optional[Runnable] = None,
|
||||
id_key: str = "source",
|
||||
search_kwargs: Optional[dict] = None,
|
||||
transformation_name: Optional[str] = None,
|
||||
):
|
||||
"""Index the documents (with caching) and return a parent document retriever."""
|
||||
docstore = docstore or InMemoryStore()
|
||||
if query_generator is not None and transformation_name is None:
|
||||
raise ValueError(
|
||||
"If you provide a custom query generator, you must also provide a "
|
||||
"transformation name to use for the record manager."
|
||||
)
|
||||
transformation_name = transformation_name or "HyDE"
|
||||
transformation = partial(
|
||||
transform_docs_hyde,
|
||||
docstore=docstore,
|
||||
query_generator=query_generator,
|
||||
id_key=id_key,
|
||||
)
|
||||
index_stats = create_index(
|
||||
embedding,
|
||||
vectorstore,
|
||||
transform_docs=transformation,
|
||||
transformation_name=transformation_name,
|
||||
)
|
||||
logger.info(f"Index stats: {index_stats}")
|
||||
metadata = {
|
||||
"benchmark_environment": "langchain-docs",
|
||||
"retriever_stragegy": "HyDE",
|
||||
"embedding": embedding.__class__.__name__,
|
||||
"vectorstore": vectorstore.__class__.__name__,
|
||||
}
|
||||
|
||||
return MultiVectorRetriever(
|
||||
vectorstore=vectorstore,
|
||||
docstore=docstore,
|
||||
id_key=id_key,
|
||||
search_kwargs=search_kwargs or _DEFAULT_SEARCH_KWARGS,
|
||||
metadata=metadata,
|
||||
tags=["langchain-benchmarks"],
|
||||
)
|
||||
+66
@@ -0,0 +1,66 @@
|
||||
from typing import Callable, Optional
|
||||
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
from langchain.vectorstores.chroma import Chroma
|
||||
from .retriever import (
|
||||
get_vectorstore_retriever,
|
||||
get_parent_document_retriever,
|
||||
get_hyde_retriever,
|
||||
)
|
||||
|
||||
|
||||
def _chroma_retriever_factory(
|
||||
embedding: Embeddings,
|
||||
search_kwargs: Optional[dict] = None,
|
||||
transform_docs: Optional[Callable] = None,
|
||||
transformation_name: Optional[str] = None,
|
||||
) -> BaseRetriever:
|
||||
embedding_name = embedding.__class__.__name__
|
||||
vectorstore = Chroma(
|
||||
collection_name=f"langchain-benchmarks-classic-{embedding_name}",
|
||||
embedding_function=embedding,
|
||||
persist_directory="./chromadb",
|
||||
)
|
||||
return get_vectorstore_retriever(
|
||||
embedding,
|
||||
vectorstore,
|
||||
transform_docs=transform_docs,
|
||||
transformation_name=transformation_name,
|
||||
search_kwargs=search_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _chroma_parent_document_retriever_factory(
|
||||
embedding: Embeddings,
|
||||
search_kwargs: Optional[dict] = None,
|
||||
) -> BaseRetriever:
|
||||
embedding_name = embedding.__class__.__name__
|
||||
vectorstore = Chroma(
|
||||
collection_name=f"langchain-benchmarks-parent-doc-{embedding_name}",
|
||||
embedding_function=embedding,
|
||||
persist_directory="./chromadb",
|
||||
)
|
||||
return get_parent_document_retriever(
|
||||
embedding, vectorstore, search_kwargs=search_kwargs
|
||||
)
|
||||
|
||||
|
||||
def _chroma_hyde_retriever_factory(
|
||||
embedding: Embeddings,
|
||||
search_kwargs: Optional[dict] = None,
|
||||
) -> BaseRetriever:
|
||||
embedding_name = embedding.__class__.__name__
|
||||
vectorstore = Chroma(
|
||||
collection_name=f"langchain-benchmarks-hyde-{embedding_name}",
|
||||
embedding_function=embedding,
|
||||
persist_directory="./chromadb",
|
||||
)
|
||||
return get_hyde_retriever(embedding, vectorstore, search_kwargs=search_kwargs)
|
||||
|
||||
|
||||
RETRIEVER_FACTORIES = {
|
||||
"basic": _chroma_retriever_factory,
|
||||
"parent-doc": _chroma_parent_document_retriever_factory,
|
||||
"hyde": _chroma_hyde_retriever_factory,
|
||||
}
|
||||
@@ -0,0 +1,94 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.evaluation import load_evaluator
|
||||
from langchain.llms.base import BaseLanguageModel
|
||||
from langchain.smith import RunEvalConfig
|
||||
from langsmith.evaluation.evaluator import EvaluationResult, RunEvaluator
|
||||
from langsmith.schemas import Example, Run
|
||||
|
||||
|
||||
# TODO: Split this into an assertion-by-assertion evaluator
|
||||
# TODO: Combine with a document relevance evaluator (to report retriever performance)
|
||||
class FaithfulnessEvaluator(RunEvaluator):
|
||||
def __init__(self, llm: Optional[BaseLanguageModel] = None):
|
||||
self.evaluator = load_evaluator(
|
||||
"labeled_score_string",
|
||||
criteria={
|
||||
"faithfulness": """
|
||||
Score 1: The answer directly contradicts the information provided in the reference docs.
|
||||
Score 3: The answer contains a mix of correct information from the reference docs and incorrect or unverifiable information not found in the docs.
|
||||
Score 5: The answer is mostly aligned with the reference docs but includes extra information that, while not contradictory, is not verified by the docs.
|
||||
Score 7: The answer aligns well with the reference docs but includes minor, commonly accepted facts not found in the docs.
|
||||
Score 10: The answer perfectly aligns with and is fully entailed by the reference docs, with no extra information."""
|
||||
},
|
||||
llm=llm,
|
||||
normalize_by=10,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_retrieved_docs(run: Run) -> str:
|
||||
# This assumes there is only one retriever in your chain.
|
||||
# To select more precisely, name your retrieval chain
|
||||
# using with_config(name="my_unique_name") and look up
|
||||
# by run.name
|
||||
runs = [run]
|
||||
while runs:
|
||||
run = runs.pop()
|
||||
if run.run_type == "retriever":
|
||||
return str(run.outputs["documents"])
|
||||
if run.child_runs:
|
||||
runs.extend(run.child_runs[::-1])
|
||||
return ""
|
||||
|
||||
def evaluate_run(
|
||||
self, run: Run, example: Optional[Example] = None
|
||||
) -> EvaluationResult:
|
||||
try:
|
||||
docs_string = self._get_retrieved_docs(run)
|
||||
docs_string = f"Reference docs:\n<DOCS>\n{docs_string}\n</DOCS>\n\n"
|
||||
input_query = run.inputs["question"]
|
||||
if run.outputs is not None and len(run.outputs) == 1:
|
||||
prediction = next(iter(run.outputs.values()))
|
||||
else:
|
||||
prediction = run.outputs["output"]
|
||||
result = self.evaluator.evaluate_strings(
|
||||
input=input_query,
|
||||
prediction=prediction,
|
||||
reference=docs_string,
|
||||
)
|
||||
return EvaluationResult(
|
||||
**{"key": "faithfulness", "comment": result.get("reasoning"), **result}
|
||||
)
|
||||
except Exception as e:
|
||||
return EvaluationResult(key="faithfulness", score=None, comment=repr(e))
|
||||
|
||||
|
||||
_ACCURACY_CRITERION = {
|
||||
"accuracy": """
|
||||
Score 1: The answer is incorrect and unrelated to the question or reference document.
|
||||
Score 3: The answer shows slight relevance to the question or reference document but is largely incorrect.
|
||||
Score 5: The answer is partially correct but has significant errors or omissions.
|
||||
Score 7: The answer is mostly correct with minor errors or omissions, and aligns with the reference document.
|
||||
Score 10: The answer is correct, complete, and perfectly aligns with the reference document.
|
||||
|
||||
If the reference answer contains multiple alternatives, the predicted answer must only match one of the alternatives to be considered correct.
|
||||
If the predicted answer contains additional helpful and accurate information that is not present in the reference answer, it should still be considered correct.
|
||||
""" # noqa
|
||||
}
|
||||
|
||||
eval_llm = ChatOpenAI(model="gpt-4", temperature=0.0, model_kwargs={"seed": 42})
|
||||
# Use a longer-context LLM to check documents
|
||||
faithfulness_eval_llm = ChatOpenAI(
|
||||
model="gpt-4-1106-preview", temperature=0.0, model_kwargs={"seed": 42}
|
||||
)
|
||||
|
||||
RAG_EVALUATION = RunEvalConfig(
|
||||
evaluators=[
|
||||
RunEvalConfig.LabeledScoreString(
|
||||
criteria=_ACCURACY_CRITERION, llm=eval_llm, normalize_by=10.0
|
||||
),
|
||||
RunEvalConfig.EmbeddingDistance(),
|
||||
],
|
||||
custom_evaluators=[FaithfulnessEvaluator(llm=faithfulness_eval_llm)],
|
||||
)
|
||||
Binary file not shown.
@@ -0,0 +1,55 @@
|
||||
"""Registry of RAG environments for ease of access."""
|
||||
import dataclasses
|
||||
from typing import Callable, Dict, List
|
||||
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
from langchain_benchmarks.rag.environments import langchain_docs
|
||||
from langchain_benchmarks.rag.environments.langchain_docs import (
|
||||
architectures,
|
||||
langchain_docs_retriever,
|
||||
)
|
||||
from langchain_benchmarks.utils._registration import Environment, Registry
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class RetrievalEnvironment(Environment):
|
||||
retriever_factories: Dict[str, Callable[[Embeddings], BaseRetriever]] # noqa: F821
|
||||
"""Factories that index the docs using the specified strategy."""
|
||||
architecture_factories: Dict[
|
||||
str, Callable[[Embeddings], BaseRetriever]
|
||||
] # noqa: F821
|
||||
"""Factories methods that help build some off-the-shelf architectures。"""
|
||||
|
||||
@property
|
||||
def _table(self) -> List[List[str]]:
|
||||
table = super()._table
|
||||
return table + [
|
||||
["Retriever Factories", ", ".join(self.retriever_factories.keys())],
|
||||
["Architecture Factories", ", ".join(self.architecture_factories.keys())],
|
||||
]
|
||||
|
||||
# Using lower case naming to make a bit prettier API when used in a notebook
|
||||
registry = Registry(
|
||||
environments=[
|
||||
RetrievalEnvironment(
|
||||
id=0,
|
||||
name="LangChain Docs Q&A",
|
||||
dataset_id=langchain_docs.DATASET_ID,
|
||||
retriever_factories=langchain_docs_retriever.RETRIEVER_FACTORIES,
|
||||
architecture_factories=architectures.ARCH_FACTORIES,
|
||||
description=(
|
||||
"""\
|
||||
Questions and answers based on a snapshot of the LangChain python docs.
|
||||
|
||||
The environment provides the documents and the retriever information.
|
||||
|
||||
Each example is composed of a question and reference answer.
|
||||
|
||||
Success is measured based on the accuracy of the answer relative to the reference answer.
|
||||
We also measure the faithfulness of the model's response relative to the retrieved documents (if any).
|
||||
""" # noqa: E501
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
@@ -0,0 +1,109 @@
|
||||
from typing import Iterable, List, Optional
|
||||
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.output_parsers.openai_functions import JsonKeyOutputFunctionsParser
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.pydantic_v1 import BaseModel
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.runnable import Runnable, RunnableConfig
|
||||
from langchain.schema.storage import BaseStore
|
||||
from langchain.text_splitter import TextSplitter
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def transform_docs_parent_child(
|
||||
documents: Iterable[Document],
|
||||
child_splitter: TextSplitter,
|
||||
docstore: BaseStore,
|
||||
id_key: str,
|
||||
*,
|
||||
parent_splitter: Optional[TextSplitter] = None,
|
||||
) -> Iterable[Document]:
|
||||
"""Transforms documents into child <-> parent documents."""
|
||||
if parent_splitter is not None:
|
||||
documents = parent_splitter.split_documents(documents)
|
||||
doc_ids = []
|
||||
for doc in documents:
|
||||
yield doc
|
||||
_id = doc.metadata[id_key]
|
||||
doc_ids.append((_id, doc))
|
||||
sub_docs = child_splitter.split_documents([doc])
|
||||
for _doc in sub_docs:
|
||||
_doc.metadata[id_key] = _id
|
||||
yield _doc
|
||||
docstore.mset(doc_ids)
|
||||
|
||||
|
||||
def _default_hyde_embedder():
|
||||
class hypotheticalQuestions(BaseModel):
|
||||
"""Write user queries that could be answered by the document."""
|
||||
|
||||
questions: List[str]
|
||||
|
||||
return (
|
||||
(
|
||||
{"doc": lambda x: x.page_content}
|
||||
# Only asking for 3 hypothetical questions, but this could be adjusted
|
||||
| ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"You are an AI creating an inverted index for document retrieval. "
|
||||
"Analyze the content of the following document and generate relevant user queries "
|
||||
"that would likely retrieve this document. Document content:\n\n```document.txt\n{doc}\n```",
|
||||
),
|
||||
(
|
||||
"user",
|
||||
"Based on the document's content, what specific technical queries or questions"
|
||||
" are users likely to search that this document can answer?",
|
||||
),
|
||||
]
|
||||
)
|
||||
| ChatOpenAI(max_retries=0, model="gpt-4-1106-preview").bind_functions(
|
||||
functions=[hypotheticalQuestions],
|
||||
function_call="hypotheticalQuestions",
|
||||
)
|
||||
| JsonKeyOutputFunctionsParser(key_name="questions")
|
||||
)
|
||||
.with_retry(stop_after_attempt=3)
|
||||
.with_config(
|
||||
run_name="HyDE",
|
||||
metadata={"benchmark_environment": "langchain-docs"},
|
||||
tags=["langchain-benchmarks"],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def transform_docs_hyde(
|
||||
documents: Iterable[Document],
|
||||
docstore: BaseStore,
|
||||
id_key: str,
|
||||
*,
|
||||
query_generator: Optional[Runnable] = None,
|
||||
runnable_config: Optional[RunnableConfig] = None,
|
||||
) -> Iterable[Document]:
|
||||
"""Generates hypothetical document embeddings."""
|
||||
if query_generator is None:
|
||||
query_generator = _default_hyde_embedder()
|
||||
logger.info(f"Using default query generator\n{query_generator}")
|
||||
generator = query_generator or _default_hyde_embedder()
|
||||
docs = list(documents)
|
||||
questions = generator.batch(
|
||||
docs, runnable_config or {"max_concurrency": 5}, return_exceptions=True
|
||||
)
|
||||
doc_ids = []
|
||||
for doc, expansions in zip(documents, questions):
|
||||
yield doc
|
||||
if isinstance(expansions, BaseException):
|
||||
logger.error(
|
||||
f"Error generating questions for document {doc.metadata[id_key]}"
|
||||
)
|
||||
continue
|
||||
expansion_docs = [
|
||||
Document(page_content=s, metadata=doc.metadata) for s in expansions
|
||||
]
|
||||
yield from expansion_docs
|
||||
doc_ids.append((doc.metadata[id_key], doc))
|
||||
docstore.mset(doc_ids)
|
||||
@@ -1,95 +1,22 @@
|
||||
"""Registry of environments for ease of access."""
|
||||
"""Registry of tool use environments for ease of access."""
|
||||
import dataclasses
|
||||
from typing import Callable, List, Sequence, Union
|
||||
from langchain_benchmarks.tool_usage.environments import alpha
|
||||
from langchain_benchmarks.utils._registration import Environment, Registry
|
||||
from typing import Callable, List
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
from tabulate import tabulate
|
||||
|
||||
from langchain_benchmarks.tool_usage.environments import alpha
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Environment:
|
||||
id: int
|
||||
"""The ID of the environment."""
|
||||
name: str
|
||||
"""The name of the environment."""
|
||||
|
||||
dataset_id: str
|
||||
"""The ID of the langsmith public dataset.
|
||||
|
||||
This dataset contains expected inputs/outputs for the environment, and
|
||||
can be used to evaluate the performance of a model/agent etc.
|
||||
"""
|
||||
|
||||
class ToolEnvironment(Environment):
|
||||
tools_factory: Callable[[], List[BaseTool]]
|
||||
"""Factory that returns a list of tools that can be used in the environment."""
|
||||
|
||||
description: str
|
||||
"""Description of the environment."""
|
||||
|
||||
def _repr_html_(self) -> str:
|
||||
"""Return a HTML representation of the environment."""
|
||||
table = [
|
||||
["ID", self.id],
|
||||
["Name", self.name],
|
||||
["Dataset ID", self.dataset_id],
|
||||
["Description", self.description[:100] + "..."],
|
||||
]
|
||||
return tabulate(
|
||||
table,
|
||||
tablefmt="html",
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Registry:
|
||||
environments: Sequence[Environment]
|
||||
|
||||
def get_environment(self, name_or_id: Union[int, str]) -> Environment:
|
||||
"""Get the environment with the given name."""
|
||||
for env in self.environments:
|
||||
if env.name == name_or_id or env.id == name_or_id:
|
||||
return env
|
||||
raise ValueError(
|
||||
f"Unknown environment {name_or_id}. Use list_environments() to see "
|
||||
f"available environments."
|
||||
)
|
||||
|
||||
def _repr_html_(self) -> str:
|
||||
"""Return a HTML representation of the registry."""
|
||||
headers = [
|
||||
"ID",
|
||||
"Name",
|
||||
"Dataset ID",
|
||||
"Description",
|
||||
]
|
||||
table = [
|
||||
[
|
||||
env.id,
|
||||
env.name,
|
||||
env.dataset_id,
|
||||
env.description,
|
||||
]
|
||||
for env in self.environments
|
||||
]
|
||||
return tabulate(table, headers=headers, tablefmt="html")
|
||||
|
||||
def __getitem__(self, key: Union[int, str]) -> Environment:
|
||||
"""Get an environment from the registry."""
|
||||
if isinstance(key, slice):
|
||||
raise NotImplementedError("Slicing is not supported.")
|
||||
elif isinstance(key, (int, str)):
|
||||
# If key is an integer, return the corresponding environment
|
||||
return self.get_environment(key)
|
||||
else:
|
||||
raise TypeError("Key must be an integer or a slice.")
|
||||
|
||||
|
||||
# Using lower case naming to make a bit prettier API when used in a notebook
|
||||
registry = Registry(
|
||||
environments=[
|
||||
Environment(
|
||||
ToolEnvironment(
|
||||
id=0,
|
||||
name="Tool Usage - Alpha",
|
||||
dataset_id=alpha.DATASET_ID,
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
import dataclasses
|
||||
from typing import List, Sequence, Union
|
||||
|
||||
from tabulate import tabulate
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Environment:
|
||||
id: int
|
||||
"""The ID of the environment."""
|
||||
name: str
|
||||
"""The name of the environment."""
|
||||
|
||||
dataset_id: str
|
||||
"""The ID of the langsmith public dataset.
|
||||
|
||||
This dataset contains expected inputs/outputs for the environment, and
|
||||
can be used to evaluate the performance of a model/agent etc.
|
||||
"""
|
||||
|
||||
description: str
|
||||
"""Description of the environment."""
|
||||
|
||||
@property
|
||||
def _table(self) -> List[List[str]]:
|
||||
return [
|
||||
["ID", self.id],
|
||||
["Name", self.name],
|
||||
["Dataset ID", self.dataset_id],
|
||||
["Description", self.description[:100] + "..."],
|
||||
]
|
||||
|
||||
def _repr_html_(self) -> str:
|
||||
"""Return a HTML representation of the environment."""
|
||||
|
||||
return tabulate(
|
||||
self._table,
|
||||
tablefmt="html",
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Registry:
|
||||
environments: Sequence[Environment]
|
||||
|
||||
def get_environment(self, name_or_id: Union[int, str]) -> Environment:
|
||||
"""Get the environment with the given name."""
|
||||
for env in self.environments:
|
||||
if env.name == name_or_id or env.id == name_or_id:
|
||||
return env
|
||||
raise ValueError(
|
||||
f"Unknown environment {name_or_id}. Use list_environments() to see "
|
||||
f"available environments."
|
||||
)
|
||||
|
||||
def _repr_html_(self) -> str:
|
||||
"""Return a HTML representation of the registry."""
|
||||
headers = [
|
||||
"ID",
|
||||
"Name",
|
||||
"Dataset ID",
|
||||
"Description",
|
||||
]
|
||||
table = [
|
||||
[
|
||||
env.id,
|
||||
env.name,
|
||||
env.dataset_id,
|
||||
env.description,
|
||||
]
|
||||
for env in self.environments
|
||||
]
|
||||
return tabulate(table, headers=headers, tablefmt="html")
|
||||
|
||||
def __getitem__(self, key: Union[int, str]) -> Environment:
|
||||
"""Get an environment from the registry."""
|
||||
if isinstance(key, slice):
|
||||
raise NotImplementedError("Slicing is not supported.")
|
||||
elif isinstance(key, (int, str)):
|
||||
# If key is an integer, return the corresponding environment
|
||||
return self.get_environment(key)
|
||||
else:
|
||||
raise TypeError("Key must be an integer or a slice.")
|
||||
Reference in New Issue
Block a user