Compare commits

...

6 Commits

Author SHA1 Message Date
William Fu-Hinthorn 538533c2c5 hm 2023-11-20 15:06:04 -08:00
William Fu-Hinthorn 6845999d35 update 2023-11-20 08:18:23 -08:00
William Fu-Hinthorn 3ff94b90f9 add assistant example 2023-11-20 07:31:58 -08:00
William Fu-Hinthorn 9d5e6695bc nb 2023-11-20 05:55:19 -08:00
William Fu-Hinthorn 6b5a955078 Add some swappability and notebook 2023-11-18 15:33:26 -08:00
William Fu-Hinthorn 32d2a50ea9 Easier to clone 2023-11-17 15:49:11 -08:00
30 changed files with 2929 additions and 122 deletions
+1 -1
View File
@@ -159,4 +159,4 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.DS_Store
+1
View File
@@ -0,0 +1 @@
chromadb/
File diff suppressed because it is too large Load Diff
@@ -1,3 +1,3 @@
from openai_functions_agent.agent import agent_executor
from openai_functions_agent.agent import agent_executor, create_executor
__all__ = ["agent_executor"]
__all__ = ["agent_executor", "create_executor"]
@@ -1,4 +1,4 @@
from typing import List, Tuple
from typing import List, Tuple, Optional
from langchain.agents import AgentExecutor
from langchain.agents.format_scratchpad import format_to_openai_functions
@@ -25,7 +25,6 @@ def search(query, callbacks=None):
tools = [search]
llm = ChatOpenAI(model="gpt-3.5-turbo-16k", temperature=0)
assistant_system_message = """You are a helpful assistant tasked with answering technical questions about LangChain. \
Use tools (only if necessary) to best answer the users questions. Do not make up information if you cannot find the answer using your tools."""
prompt = ChatPromptTemplate.from_messages(
@@ -37,8 +36,6 @@ prompt = ChatPromptTemplate.from_messages(
]
)
llm_with_tools = llm.bind(functions=[format_tool_to_openai_function(t) for t in tools])
def _format_chat_history(chat_history: List[Tuple[str, str]]):
buffer = []
@@ -48,30 +45,11 @@ def _format_chat_history(chat_history: List[Tuple[str, str]]):
return buffer
agent = (
{
"input": lambda x: x["input"],
"chat_history": lambda x: _format_chat_history(x["chat_history"]),
"agent_scratchpad": lambda x: format_to_openai_functions(
x["intermediate_steps"]
),
}
| prompt
| llm_with_tools
| OpenAIFunctionsAgentOutputParser()
)
class AgentInput(BaseModel):
input: str
chat_history: List[Tuple[str, str]] = Field(..., extra={"widget": {"type": "chat"}})
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=False).with_types(
input_type=AgentInput
)
class ChainInput(BaseModel):
question: str
@@ -80,6 +58,31 @@ def mapper(input: dict):
return {"input": input["question"], "chat_history": []}
agent_executor = (mapper | agent_executor | (lambda x: x["output"])).with_types(
input_type=ChainInput
)
def create_executor(model_config: Optional[dict] = None):
model = (model_config or {}).get("model", "gpt-3.5-turbo-16k")
llm = ChatOpenAI(model=model, temperature=0)
llm_with_tools = llm.bind(
functions=[format_tool_to_openai_function(t) for t in tools]
)
agent = (
{
"input": lambda x: x["input"],
"chat_history": lambda x: _format_chat_history(x["chat_history"]),
"agent_scratchpad": lambda x: format_to_openai_functions(
x["intermediate_steps"]
),
}
| prompt
| llm_with_tools
| OpenAIFunctionsAgentOutputParser()
)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=False).with_types(
input_type=AgentInput
)
return (mapper | agent_executor | (lambda x: x["output"])).with_types(
input_type=ChainInput
)
agent_executor = create_executor()
+37 -6
View File
@@ -1,19 +1,30 @@
"""Copy the public dataset to your own langsmith tenant."""
from typing import Optional
from langsmith import Client
from tqdm import tqdm
DATASET_NAME = "LangChain Docs Q&A"
PUBLIC_DATASET_TOKEN = "452ccafc-18e1-4314-885b-edd735f17b9d"
client = Client()
def create_langchain_docs_dataset(
dataset_name: str = DATASET_NAME, public_dataset_token: str = PUBLIC_DATASET_TOKEN
dataset_name: str = DATASET_NAME,
public_dataset_token: str = PUBLIC_DATASET_TOKEN,
client: Optional[Client] = None,
):
shared_client = Client(
api_url="https://api.smith.langchain.com", api_key="placeholder"
)
examples = list(shared_client.list_shared_examples(public_dataset_token))
client = client or Client()
if client.has_dataset(dataset_name=dataset_name):
return
ds = client.create_dataset(dataset_name=dataset_name)
examples = tqdm(list(client.list_shared_examples(public_dataset_token)))
loaded_examples = list(client.list_examples(dataset_name=dataset_name))
if len(loaded_examples) == len(examples):
return
else:
ds = client.read_dataset(dataset_name=dataset_name)
else:
ds = client.create_dataset(dataset_name=dataset_name)
client.create_examples(
inputs=[e.inputs for e in examples],
outputs=[e.outputs for e in examples],
@@ -23,4 +34,24 @@ def create_langchain_docs_dataset(
if __name__ == "__main__":
create_langchain_docs_dataset()
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--target-api-key", type=str, required=False)
parser.add_argument("--target-endpoint", type=str, required=False)
parser.add_argument("--dataset-name", type=str, default=DATASET_NAME)
parser.add_argument(
"--public-dataset-token", type=str, default=PUBLIC_DATASET_TOKEN
)
args = parser.parse_args()
client = None
if args.target_api_key or args.target_endpoint:
client = Client(
api_key=args.target_api_key,
api_url=args.target_endpoint,
)
create_langchain_docs_dataset(
dataset_name=args.dataset_name,
public_dataset_token=args.public_dataset_token,
client=client,
)
+3 -5
View File
@@ -1,6 +1,5 @@
import argparse
import importlib.util
import os
import sys
import uuid
from functools import partial
@@ -13,7 +12,7 @@ from langchain.schema.runnable import Runnable
from langchain.smith import RunEvalConfig, run_on_dataset
from langsmith import Client
from oai_assistant.chain import agent_executor as openai_assistant_chain
from openai_functions_agent import agent_executor as openai_functions_agent_chain
from openai_functions_agent import create_executor
ls_client = Client()
@@ -33,7 +32,7 @@ def _get_chain_factory(arch: str) -> Callable:
_map = {
"chat": create_chain,
"anthropic-iterative-search": lambda _: anthropic_agent_chain,
"openai-functions-agent": lambda _: openai_functions_agent_chain,
"openai-functions-agent": create_executor,
"openai-assistant": lambda _: openai_assistant_chain,
}
if arch in _map:
@@ -93,8 +92,7 @@ def main(
run_on_dataset(
client=ls_client,
dataset_name=dataset_name,
llm_or_chain_factory=partial(
create_runnable,
llm_or_chain_factory=lambda: create_runnable(
arch=arch,
model_config=model_config,
retry_config=retry_config,
+17 -3
View File
@@ -6,9 +6,14 @@ from run_evals import main
experiments = [
{
# "server_url": "http://localhost:1983/openai-functions-agent",
"arch": "openai-functions-agent",
"project_name": "openai-functions-agent",
"model_config": {"model": "gpt-3.5-turbo-16k"},
},
{
"arch": "openai-functions-agent",
"project_name": "oaifunc-agent-gpt-4-1106",
"model_config": {"model": "gpt-4-1106-preview"},
},
{
# "server_url": "http://localhost:1983/anthropic_chat",
@@ -41,9 +46,17 @@ experiments = [
"arch": "chat",
"model_config": {
"chat_cls": "ChatFireworks",
"model": "accounts/fireworks/models/llama-v2-34b-code-instruct-w8a16",
"model": "accounts/fireworks/models/llama-v2-34b-code-instruct",
},
"project_name": "llama-v2-34b-code-instruct-w8a16",
"project_name": "llama-v2-34b-code-instruct",
},
{
"arch": "chat",
"model_config": {
"chat_cls": "ChatFireworks",
"model": "accounts/fireworks/models/llama-v2-70b-chat",
},
"project_name": "llama-70b-chat",
},
{
"arch": "chat",
@@ -120,6 +133,7 @@ if __name__ == "__main__":
]
for experiment in selected_experiments:
print("Running experiment:", experiment)
main(
**experiment,
dataset_name=args.dataset_name,
+1
View File
@@ -0,0 +1 @@
.sql
+6
View File
@@ -0,0 +1,6 @@
"""RAG environments."""
from langchain_benchmarks.rag.evaluators import RAG_EVALUATION
from langchain_benchmarks.rag.registration import registry
# Please keep this list sorted!
__all__ = ["registry", "RAG_EVALUATION"]
@@ -0,0 +1,6 @@
# LangChain Docs Environment
This code contains utilities to scrape the LangChain docs (already run) and index them
using common techniques. The docs were scraped using the code in `_ingest_docs.py` and
uploaded to gcs. To better compare retrieval techniques, we hold these constant and pull
from that cache whenever generating different indices.
@@ -0,0 +1 @@
DATASET_ID = "452ccafc-18e1-4314-885b-edd735f17b9d" # ID of public LangChain Docs dataset
@@ -0,0 +1,299 @@
"""Load html from files, clean up, split, ingest."""
import logging
import os
import re
from typing import Generator, Iterable, Optional
from bs4 import BeautifulSoup, Doctype, NavigableString, SoupStrainer, Tag
from langchain.document_loaders import RecursiveUrlLoader, SitemapLoader
from langchain.embeddings import OpenAIEmbeddings, VoyageEmbeddings
from langchain.indexes import SQLRecordManager, index
from langchain.schema.document import Document
from langchain.schema.embeddings import Embeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.utils.html import PREFIXES_TO_IGNORE_REGEX, SUFFIXES_TO_IGNORE_REGEX
from langchain.vectorstores.chroma import Chroma
import pandas as pd
logger = logging.getLogger(__name__)
directory = os.path.dirname(os.path.realpath(__file__))
db_directory = os.path.join(directory, "langchain_docs_retriever", "db")
docs_cache_directory = os.path.join(directory, "langchain_docs_retriever", "db_docs")
docs_cache_file = os.path.join(docs_cache_directory, "docs.parquet")
def langchain_docs_extractor(soup: BeautifulSoup) -> str:
# Remove all the tags that are not meaningful for the extraction.
SCAPE_TAGS = ["nav", "footer", "aside", "script", "style"]
[tag.decompose() for tag in soup.find_all(SCAPE_TAGS)]
def get_text(tag: Tag) -> Generator[str, None, None]:
for child in tag.children:
if isinstance(child, Doctype):
continue
if isinstance(child, NavigableString):
yield child
elif isinstance(child, Tag):
if child.name in ["h1", "h2", "h3", "h4", "h5", "h6"]:
yield f"{'#' * int(child.name[1:])} {child.get_text()}\n\n"
elif child.name == "a":
yield f"[{child.get_text(strip=False)}]({child.get('href')})"
elif child.name == "img":
yield f"![{child.get('alt', '')}]({child.get('src')})"
elif child.name in ["strong", "b"]:
yield f"**{child.get_text(strip=False)}**"
elif child.name in ["em", "i"]:
yield f"_{child.get_text(strip=False)}_"
elif child.name == "br":
yield "\n"
elif child.name == "code":
parent = child.find_parent()
if parent is not None and parent.name == "pre":
classes = parent.attrs.get("class", "")
language = next(
filter(lambda x: re.match(r"language-\w+", x), classes),
None,
)
if language is None:
language = ""
else:
language = language.split("-")[1]
lines: list[str] = []
for span in child.find_all("span", class_="token-line"):
line_content = "".join(
token.get_text() for token in span.find_all("span")
)
lines.append(line_content)
code_content = "\n".join(lines)
yield f"```{language}\n{code_content}\n```\n\n"
else:
yield f"`{child.get_text(strip=False)}`"
elif child.name == "p":
yield from get_text(child)
yield "\n\n"
elif child.name == "ul":
for li in child.find_all("li", recursive=False):
yield "- "
yield from get_text(li)
yield "\n\n"
elif child.name == "ol":
for i, li in enumerate(child.find_all("li", recursive=False)):
yield f"{i + 1}. "
yield from get_text(li)
yield "\n\n"
elif child.name == "div" and "tabs-container" in child.attrs.get(
"class", [""]
):
tabs = child.find_all("li", {"role": "tab"})
tab_panels = child.find_all("div", {"role": "tabpanel"})
for tab, tab_panel in zip(tabs, tab_panels):
tab_name = tab.get_text(strip=True)
yield f"{tab_name}\n"
yield from get_text(tab_panel)
elif child.name == "table":
thead = child.find("thead")
header_exists = isinstance(thead, Tag)
if header_exists:
headers = thead.find_all("th")
if headers:
yield "| "
yield " | ".join(header.get_text() for header in headers)
yield " |\n"
yield "| "
yield " | ".join("----" for _ in headers)
yield " |\n"
tbody = child.find("tbody")
tbody_exists = isinstance(tbody, Tag)
if tbody_exists:
for row in tbody.find_all("tr"):
yield "| "
yield " | ".join(
cell.get_text(strip=True) for cell in row.find_all("td")
)
yield " |\n"
yield "\n\n"
elif child.name in ["button"]:
continue
else:
yield from get_text(child)
joined = "".join(get_text(soup))
return re.sub(r"\n\n+", "\n\n", joined).strip()
RECORD_MANAGER_DB_URL = (
os.environ.get("RECORD_MANAGER_DB_URL") or "sqlite:///lcdocs_oai_record_manager.sql"
)
def metadata_extractor(meta: dict, soup: BeautifulSoup) -> dict:
title = soup.find("title")
description = soup.find("meta", attrs={"name": "description"})
html = soup.find("html")
return {
"source": meta["loc"] or "",
"title": (title.get_text() if title else "") or "",
"description": description.get("content") or "" if description else "",
"language": html.get("lang") or "" if html else "",
**{k: v or "" for k, v in meta.items()},
}
def load_langchain_docs():
return SitemapLoader(
"https://python.langchain.com/sitemap.xml",
filter_urls=["https://python.langchain.com/"],
parsing_function=langchain_docs_extractor,
default_parser="lxml",
bs_kwargs={
"parse_only": SoupStrainer(
name=("article", "title", "html", "lang", "content")
),
},
meta_function=metadata_extractor,
).load()
def simple_extractor(html: str) -> str:
soup = BeautifulSoup(html, "lxml")
return re.sub(r"\n\n+", "\n\n", soup.text).strip()
def load_api_docs():
return RecursiveUrlLoader(
url="https://api.python.langchain.com/en/latest/",
max_depth=8,
extractor=simple_extractor,
prevent_outside=True,
use_async=True,
timeout=600,
# Drop trailing / to avoid duplicate pages.
link_regex=(
f"href=[\"']{PREFIXES_TO_IGNORE_REGEX}((?:{SUFFIXES_TO_IGNORE_REGEX}.)*?)"
r"(?:[\#'\"]|\/[\#'\"])"
),
check_response_status=True,
exclude_dirs=(
"https://api.python.langchain.com/en/latest/_sources",
"https://api.python.langchain.com/en/latest/_modules",
),
).load()
def get_embeddings_model() -> Embeddings:
if os.environ.get("VOYAGE_AI_URL") and os.environ.get("VOYAGE_AI_MODEL"):
return VoyageEmbeddings()
return OpenAIEmbeddings(chunk_size=200)
CHROMA_COLLECTION_NAME = "langchain-docs"
def get_docs() -> Iterable[Document]:
# TODO: Make this function actually a generator
# Import before loading because it's a bummer to fail after scraping.
# we should have an incremental scrape cache.
import pyarrow as pa # type: ignore
docs_from_documentation = load_langchain_docs()
logger.info(f"Loaded {len(docs_from_documentation)} docs from documentation")
docs_from_api = load_api_docs()
logger.info(f"Loaded {len(docs_from_api)} docs from API")
# We try to return 'source' and 'title' metadata when querying vector store and
# Chroma will error at query time if one of the attributes is missing from a
# retrieved document.
for doc in docs_from_documentation + docs_from_api:
if "source" not in doc.metadata:
doc.metadata["source"] = ""
if "title" not in doc.metadata:
doc.metadata["title"] = ""
for k, v in doc.metadata.items():
if v is None:
doc.metadata[k] = ""
if not doc.page_content.strip():
continue
yield doc
def load_docs_from_parquet(filename: Optional[str] = None) -> Iterable[Document]:
df = pd.read_parquet(filename or docs_cache_file)
docs_transformed = [Document(**row) for row in df.to_dict(orient="records")]
for doc in docs_transformed:
for k, v in doc.metadata.items():
if v is None:
doc.metadata[k] = ""
if not doc.page_content.strip():
continue
yield doc
# default ingest function
def ingest_docs(overwrite: bool = False):
if os.path.exists(docs_cache_file) and not overwrite:
logger.info(f"Loading docs from {docs_cache_file}")
documents = load_docs_from_parquet(docs_cache_file)
else:
documents = get_docs()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=4000, chunk_overlap=200)
docs_transformed = text_splitter.split_documents(documents)
embedding = get_embeddings_model()
vectorstore = Chroma(
collection_name=CHROMA_COLLECTION_NAME,
embedding_function=embedding,
persist_directory=db_directory,
)
record_manager = SQLRecordManager(
f"chroma/{CHROMA_COLLECTION_NAME}", db_url=RECORD_MANAGER_DB_URL
)
record_manager.create_schema()
indexing_stats = index(
docs_transformed,
record_manager,
vectorstore,
cleanup="full",
source_id_key="source",
)
logger.info("Indexing stats: ", indexing_stats)
def download_docs(overwrite: bool = False):
if os.path.exists(docs_cache_file) and not overwrite:
logger.info(f"Loading docs from {docs_cache_file}")
return
if not os.path.exists(docs_cache_directory):
os.makedirs(docs_cache_directory)
docs = get_docs()
# Write as parquet file
df = pd.DataFrame.from_records([doc.dict() for doc in docs])
df.to_parquet(docs_cache_file)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--action",
choices=["ingest", "download"],
default="download",
)
parser.add_argument("--overwrite", action="store_true")
args = parser.parse_args()
if args.action == "download":
download_docs(args.overwrite)
elif args.action == "ingest":
ingest_docs(args.overwrite)
else:
raise ValueError(f"Unknown action {args.action}")
@@ -0,0 +1,5 @@
from langchain_benchmarks.rag.environments.langchain_docs.architectures.chain_registry import (
ARCH_FACTORIES,
)
__all__ = ["ARCH_FACTORIES"]
@@ -0,0 +1,29 @@
from typing import Optional
from langchain.chat_models import ChatOpenAI
from langchain.schema.retriever import BaseRetriever
from langchain.schema.runnable import Runnable
from langchain.base_language import BaseLanguageModel
from langchain_benchmarks.rag.environments.langchain_docs.architectures.crqa import (
create_response_chain,
get_default_response_synthesizer,
)
def default_response_chain(
retriever: BaseRetriever,
response_synthesizer: Optional[Runnable] = None,
llm: Optional[BaseLanguageModel] = None,
) -> None:
"""Get the chain responsible for generating a response based on the retrieved documents."""
response_synthesizer = response_synthesizer or get_default_response_synthesizer(
llm=llm or ChatOpenAI(model="gpt-3.5-turbo-16k", model_kwargs={"seed": 42})
)
return create_response_chain(
response_synthesizer=response_synthesizer, retriever=retriever
)
ARCH_FACTORIES = {
"conversational-retrieval-qa": default_response_chain,
}
@@ -0,0 +1,116 @@
"""Chat langchain 'engine'."""
# TODO: some simplified architectures that are
# environment-agnostic
from typing import Callable, Dict, List, Optional, Sequence
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.schema import Document
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import AIMessage, HumanMessage
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.retriever import BaseRetriever
from langchain.schema.runnable import (
Runnable,
RunnableLambda,
)
from langchain.schema.runnable.passthrough import RunnableAssign
from pydantic import BaseModel
from operator import itemgetter
RESPONSE_TEMPLATE = """\
You are an expert programmer and problem-solver, tasked with answering any question \
about Langchain.
Generate a comprehensive and informative answer of 80 words or less for the \
given question based solely on the provided search results (URL and content). You must \
only use information from the provided search results. Use an unbiased and \
journalistic tone. Combine search results together into a coherent answer. Do not \
repeat text. Cite search results using [${{number}}] notation. Only cite the most \
relevant results that answer the question accurately. Place these citations at the end \
of the sentence or paragraph that reference them - do not put them all at the end. If \
different results refer to different entities within the same name, write separate \
answers for each entity.
You should use bullet points in your answer for readability. Put citations where they apply
rather than putting them all at the end.
If there is nothing in the context relevant to the question at hand, just say "Hmm, \
I'm not sure." Don't try to make up an answer.
Anything between the following `context` html blocks is retrieved from a knowledge \
bank, not part of the conversation with the user.
<context>
{context}
<context/>
REMEMBER: If there is no relevant information within the context, just say "Hmm, I'm \
not sure." Don't try to make up an answer. Anything between the preceding 'context' \
html blocks is retrieved from a knowledge bank, not part of the conversation with the \
user.\
"""
class ChatRequest(BaseModel):
question: str
chat_history: Optional[List[Dict[str, str]]]
def _format_docs(docs: Sequence[Document]) -> str:
formatted_docs = []
for i, doc in enumerate(docs):
doc_string = f"<doc id='{i}'>{doc.page_content}</doc>"
formatted_docs.append(doc_string)
return "\n".join(formatted_docs)
def serialize_history(request: ChatRequest):
chat_history = request.get("chat_history") or []
converted_chat_history = []
for message in chat_history:
if message.get("human") is not None:
converted_chat_history.append(HumanMessage(content=message["human"]))
if message.get("ai") is not None:
converted_chat_history.append(AIMessage(content=message["ai"]))
return converted_chat_history
def get_default_response_synthesizer(llm: BaseLanguageModel) -> Runnable:
prompt = ChatPromptTemplate.from_messages(
[
("system", RESPONSE_TEMPLATE),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{question}"),
]
)
return (prompt | llm | StrOutputParser()).with_config(
run_name="GenerateResponse",
)
def create_response_chain(
response_synthesizer: Runnable,
retriever: BaseRetriever,
format_docs: Optional[Callable[[Sequence[Document]], str]] = None,
format_chat_history: Optional[Callable[[ChatRequest], str]] = None,
) -> Runnable:
format_docs = format_docs or _format_docs
format_chat_history = format_chat_history or serialize_history
return (
RunnableAssign(
{
"chat_history": RunnableLambda(format_chat_history).with_config(
run_name="SerializeHistory"
)
}
)
| RunnableAssign(
{
"context": (
itemgetter("question") | retriever | format_docs
).with_config(run_name="FormatDocs")
}
)
| response_synthesizer
)
@@ -0,0 +1,4 @@
db/
docs_db/
.sql
.bin
@@ -0,0 +1,18 @@
from langchain_benchmarks.rag.environments.langchain_docs.langchain_docs_retriever.retriever import (
get_parent_document_retriever,
get_vectorstore_retriever,
get_hyde_retriever,
create_index,
)
from langchain_benchmarks.rag.environments.langchain_docs.langchain_docs_retriever.retriever_registry import (
RETRIEVER_FACTORIES,
)
__all__ = [
"create_index",
"get_hyde_retriever",
"get_parent_document_retriever",
"get_retriever",
"get_vectorstore_retriever",
"RETRIEVER_FACTORIES",
]
@@ -0,0 +1,47 @@
import os
from typing import Iterable
import zipfile
import requests
chroma_remote_url = "https://storage.googleapis.com/benchmarks-artifacts/langchain-docs-benchmarking/chroma_db.zip"
raw_docs_file = "https://storage.googleapis.com/benchmarks-artifacts/langchain-docs-benchmarking/docs.parquet"
directory = os.path.dirname(os.path.realpath(__file__))
db_directory = os.path.join(directory, "db")
DOCS_FILE = os.path.join(directory, "db_docs/docs.parquet")
def is_folder_populated(folder):
if os.path.exists(folder):
return any(os.scandir(folder))
return False
def download_chroma_folder_from_gcs():
r = requests.get(chroma_remote_url, allow_redirects=True)
open("chroma_db.zip", "wb").write(r.content)
with zipfile.ZipFile("chroma_db.zip", "r") as zip_ref:
zip_ref.extractall(directory)
os.remove("chroma_db.zip")
def fetch_langchain_docs_db():
if not is_folder_populated(db_directory):
print(f"Folder {db_directory} is not populated. Downloading from GCS...")
download_chroma_folder_from_gcs()
def fetch_remote_parquet_file():
if not os.path.exists(DOCS_FILE):
print(f"File {DOCS_FILE} does not exist. Downloading from GCS...")
r = requests.get(raw_docs_file, allow_redirects=True)
if not os.path.exists(os.path.dirname(DOCS_FILE)):
os.makedirs(os.path.dirname(DOCS_FILE))
open(DOCS_FILE, "wb").write(r.content)
print(f"File {DOCS_FILE} downloaded.")
if __name__ == "__main__":
fetch_remote_parquet_file()
@@ -0,0 +1,205 @@
import logging
import os
from functools import partial
from typing import Callable, Iterable, Optional
import pandas as pd
from langchain.indexes import SQLRecordManager, index
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.retrievers.parent_document_retriever import ParentDocumentRetriever
from langchain.schema.document import Document
from langchain.schema.embeddings import Embeddings
from langchain.schema.retriever import BaseRetriever
from langchain.schema.runnable import Runnable
from langchain.schema.storage import BaseStore
from langchain.schema.vectorstore import VectorStore
from langchain.storage import InMemoryStore
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
from langchain_benchmarks.rag.utils.indexing import (
transform_docs_hyde,
transform_docs_parent_child,
)
from .download_db import DOCS_FILE, fetch_remote_parquet_file
logger = logging.getLogger(__name__)
_DIRECTORY = os.path.dirname(os.path.abspath(__file__))
COLLECTION_NAME = "langchain-docs"
RECORD_MANAGER_DB_URL = (
os.environ.get("RECORD_MANAGER_DB_URL")
or f"sqlite:///{_DIRECTORY}_record_manager.sql"
)
_DEFAULT_SEARCH_KWARGS = {"k": 6}
def load_docs_from_parquet(filename: Optional[str] = None) -> Iterable[Document]:
df = pd.read_parquet(filename)
docs_transformed = [Document(**row) for row in df.to_dict(orient="records")]
for doc in docs_transformed:
for k, v in doc.metadata.items():
if v is None:
doc.metadata[k] = ""
if not doc.page_content.strip():
continue
yield doc
def create_index(
embedding: Embeddings,
vectorstore: VectorStore,
*,
transform_docs: Optional[Callable] = None,
transformation_name: Optional[str] = None,
):
fetch_remote_parquet_file()
docs = load_docs_from_parquet(DOCS_FILE)
if transform_docs:
if not transformation_name:
raise ValueError(
"If you provide a transform function, you must also provide a "
"transformation name to use for the record manager."
)
transformed_docs = transform_docs(docs)
else:
transformed_docs = docs
transformation_name = transformation_name or "raw"
vectorstore_name = vectorstore.__class__.__name__
embedding_name = embedding.__class__.__name__
record_manager = SQLRecordManager(
f"{vectorstore_name}/{COLLECTION_NAME}_{vectorstore_name}_{embedding_name}_{transformation_name}",
db_url=RECORD_MANAGER_DB_URL,
)
record_manager.create_schema()
return index(
transformed_docs,
record_manager,
vectorstore,
cleanup="full",
source_id_key="source",
)
def get_vectorstore_retriever(
embedding: Embeddings,
vectorstore: VectorStore,
*,
transform_docs: Optional[Callable] = None,
transformation_name: Optional[str] = None,
search_kwargs: Optional[dict] = None,
) -> BaseRetriever:
"""Index the documents (with caching) and return a vector store retriever."""
index_stats = create_index(
embedding,
vectorstore,
transform_docs=transform_docs,
transformation_name=transformation_name,
)
logger.info(f"Index stats: {index_stats}")
kwargs = search_kwargs or _DEFAULT_SEARCH_KWARGS
kwargs.setdefault("metadata", {}).setdefault(
"benchmark_environment", "langchain-docs"
)
kwargs.setdefault("tags", []).append("langchain-benchmarks")
return vectorstore.as_retriever(**kwargs)
def get_parent_document_retriever(
embedding: Embeddings,
vectorstore: VectorStore,
*,
child_splitter: Optional[TextSplitter] = None,
transformation_name: Optional[str] = None,
id_key: str = "source",
docstore: Optional[BaseStore] = None,
parent_splitter: Optional[TextSplitter] = None,
search_kwargs: Optional[dict] = None,
):
"""Index the documents (with caching) and return a parent document retriever."""
docstore = docstore or InMemoryStore()
if child_splitter is None:
child_splitter = RecursiveCharacterTextSplitter(
chunk_size=4000, chunk_overlap=200
)
transformation_name = "parent-document-recursive-cs4k_ol200"
logger.info(f"Using default child splitter:\n{child_splitter}")
else:
if transformation_name is None:
raise ValueError(
"If you provide a custom child splitter, you must also provide a "
"transformation name to use for the record manager."
)
transformation = partial(
transform_docs_parent_child,
child_splitter=child_splitter,
docstore=docstore,
parent_splitter=parent_splitter,
id_key=id_key,
)
index_stats = create_index(
embedding,
vectorstore,
transform_docs=transformation,
transformation_name=transformation_name,
)
logger.info(f"Index stats: {index_stats}")
return ParentDocumentRetriever(
tags=["langchain-benchmarks"],
metadata={"benchmark_environment": "langchain-docs"},
vectorstore=vectorstore,
docstore=docstore,
child_splitter=child_splitter,
parent_splitter=parent_splitter,
search_kwargs=search_kwargs or _DEFAULT_SEARCH_KWARGS,
id_key=id_key,
)
def get_hyde_retriever(
embedding: Embeddings,
vectorstore: VectorStore,
*,
docstore: Optional[BaseStore] = None,
query_generator: Optional[Runnable] = None,
id_key: str = "source",
search_kwargs: Optional[dict] = None,
transformation_name: Optional[str] = None,
):
"""Index the documents (with caching) and return a parent document retriever."""
docstore = docstore or InMemoryStore()
if query_generator is not None and transformation_name is None:
raise ValueError(
"If you provide a custom query generator, you must also provide a "
"transformation name to use for the record manager."
)
transformation_name = transformation_name or "HyDE"
transformation = partial(
transform_docs_hyde,
docstore=docstore,
query_generator=query_generator,
id_key=id_key,
)
index_stats = create_index(
embedding,
vectorstore,
transform_docs=transformation,
transformation_name=transformation_name,
)
logger.info(f"Index stats: {index_stats}")
metadata = {
"benchmark_environment": "langchain-docs",
"retriever_stragegy": "HyDE",
"embedding": embedding.__class__.__name__,
"vectorstore": vectorstore.__class__.__name__,
}
return MultiVectorRetriever(
vectorstore=vectorstore,
docstore=docstore,
id_key=id_key,
search_kwargs=search_kwargs or _DEFAULT_SEARCH_KWARGS,
metadata=metadata,
tags=["langchain-benchmarks"],
)
@@ -0,0 +1,66 @@
from typing import Callable, Optional
from langchain.schema.embeddings import Embeddings
from langchain.schema.retriever import BaseRetriever
from langchain.vectorstores.chroma import Chroma
from .retriever import (
get_vectorstore_retriever,
get_parent_document_retriever,
get_hyde_retriever,
)
def _chroma_retriever_factory(
embedding: Embeddings,
search_kwargs: Optional[dict] = None,
transform_docs: Optional[Callable] = None,
transformation_name: Optional[str] = None,
) -> BaseRetriever:
embedding_name = embedding.__class__.__name__
vectorstore = Chroma(
collection_name=f"langchain-benchmarks-classic-{embedding_name}",
embedding_function=embedding,
persist_directory="./chromadb",
)
return get_vectorstore_retriever(
embedding,
vectorstore,
transform_docs=transform_docs,
transformation_name=transformation_name,
search_kwargs=search_kwargs,
)
def _chroma_parent_document_retriever_factory(
embedding: Embeddings,
search_kwargs: Optional[dict] = None,
) -> BaseRetriever:
embedding_name = embedding.__class__.__name__
vectorstore = Chroma(
collection_name=f"langchain-benchmarks-parent-doc-{embedding_name}",
embedding_function=embedding,
persist_directory="./chromadb",
)
return get_parent_document_retriever(
embedding, vectorstore, search_kwargs=search_kwargs
)
def _chroma_hyde_retriever_factory(
embedding: Embeddings,
search_kwargs: Optional[dict] = None,
) -> BaseRetriever:
embedding_name = embedding.__class__.__name__
vectorstore = Chroma(
collection_name=f"langchain-benchmarks-hyde-{embedding_name}",
embedding_function=embedding,
persist_directory="./chromadb",
)
return get_hyde_retriever(embedding, vectorstore, search_kwargs=search_kwargs)
RETRIEVER_FACTORIES = {
"basic": _chroma_retriever_factory,
"parent-doc": _chroma_parent_document_retriever_factory,
"hyde": _chroma_hyde_retriever_factory,
}
+94
View File
@@ -0,0 +1,94 @@
from typing import Optional
from langchain.chat_models import ChatOpenAI
from langchain.evaluation import load_evaluator
from langchain.llms.base import BaseLanguageModel
from langchain.smith import RunEvalConfig
from langsmith.evaluation.evaluator import EvaluationResult, RunEvaluator
from langsmith.schemas import Example, Run
# TODO: Split this into an assertion-by-assertion evaluator
# TODO: Combine with a document relevance evaluator (to report retriever performance)
class FaithfulnessEvaluator(RunEvaluator):
def __init__(self, llm: Optional[BaseLanguageModel] = None):
self.evaluator = load_evaluator(
"labeled_score_string",
criteria={
"faithfulness": """
Score 1: The answer directly contradicts the information provided in the reference docs.
Score 3: The answer contains a mix of correct information from the reference docs and incorrect or unverifiable information not found in the docs.
Score 5: The answer is mostly aligned with the reference docs but includes extra information that, while not contradictory, is not verified by the docs.
Score 7: The answer aligns well with the reference docs but includes minor, commonly accepted facts not found in the docs.
Score 10: The answer perfectly aligns with and is fully entailed by the reference docs, with no extra information."""
},
llm=llm,
normalize_by=10,
)
@staticmethod
def _get_retrieved_docs(run: Run) -> str:
# This assumes there is only one retriever in your chain.
# To select more precisely, name your retrieval chain
# using with_config(name="my_unique_name") and look up
# by run.name
runs = [run]
while runs:
run = runs.pop()
if run.run_type == "retriever":
return str(run.outputs["documents"])
if run.child_runs:
runs.extend(run.child_runs[::-1])
return ""
def evaluate_run(
self, run: Run, example: Optional[Example] = None
) -> EvaluationResult:
try:
docs_string = self._get_retrieved_docs(run)
docs_string = f"Reference docs:\n<DOCS>\n{docs_string}\n</DOCS>\n\n"
input_query = run.inputs["question"]
if run.outputs is not None and len(run.outputs) == 1:
prediction = next(iter(run.outputs.values()))
else:
prediction = run.outputs["output"]
result = self.evaluator.evaluate_strings(
input=input_query,
prediction=prediction,
reference=docs_string,
)
return EvaluationResult(
**{"key": "faithfulness", "comment": result.get("reasoning"), **result}
)
except Exception as e:
return EvaluationResult(key="faithfulness", score=None, comment=repr(e))
_ACCURACY_CRITERION = {
"accuracy": """
Score 1: The answer is incorrect and unrelated to the question or reference document.
Score 3: The answer shows slight relevance to the question or reference document but is largely incorrect.
Score 5: The answer is partially correct but has significant errors or omissions.
Score 7: The answer is mostly correct with minor errors or omissions, and aligns with the reference document.
Score 10: The answer is correct, complete, and perfectly aligns with the reference document.
If the reference answer contains multiple alternatives, the predicted answer must only match one of the alternatives to be considered correct.
If the predicted answer contains additional helpful and accurate information that is not present in the reference answer, it should still be considered correct.
""" # noqa
}
eval_llm = ChatOpenAI(model="gpt-4", temperature=0.0, model_kwargs={"seed": 42})
# Use a longer-context LLM to check documents
faithfulness_eval_llm = ChatOpenAI(
model="gpt-4-1106-preview", temperature=0.0, model_kwargs={"seed": 42}
)
RAG_EVALUATION = RunEvalConfig(
evaluators=[
RunEvalConfig.LabeledScoreString(
criteria=_ACCURACY_CRITERION, llm=eval_llm, normalize_by=10.0
),
RunEvalConfig.EmbeddingDistance(),
],
custom_evaluators=[FaithfulnessEvaluator(llm=faithfulness_eval_llm)],
)
+55
View File
@@ -0,0 +1,55 @@
"""Registry of RAG environments for ease of access."""
import dataclasses
from typing import Callable, Dict, List
from langchain.schema.embeddings import Embeddings
from langchain.schema.retriever import BaseRetriever
from langchain_benchmarks.rag.environments import langchain_docs
from langchain_benchmarks.rag.environments.langchain_docs import (
architectures,
langchain_docs_retriever,
)
from langchain_benchmarks.utils._registration import Environment, Registry
@dataclasses.dataclass(frozen=True)
class RetrievalEnvironment(Environment):
retriever_factories: Dict[str, Callable[[Embeddings], BaseRetriever]] # noqa: F821
"""Factories that index the docs using the specified strategy."""
architecture_factories: Dict[
str, Callable[[Embeddings], BaseRetriever]
] # noqa: F821
"""Factories methods that help build some off-the-shelf architectures。"""
@property
def _table(self) -> List[List[str]]:
table = super()._table
return table + [
["Retriever Factories", ", ".join(self.retriever_factories.keys())],
["Architecture Factories", ", ".join(self.architecture_factories.keys())],
]
# Using lower case naming to make a bit prettier API when used in a notebook
registry = Registry(
environments=[
RetrievalEnvironment(
id=0,
name="LangChain Docs Q&A",
dataset_id=langchain_docs.DATASET_ID,
retriever_factories=langchain_docs_retriever.RETRIEVER_FACTORIES,
architecture_factories=architectures.ARCH_FACTORIES,
description=(
"""\
Questions and answers based on a snapshot of the LangChain python docs.
The environment provides the documents and the retriever information.
Each example is composed of a question and reference answer.
Success is measured based on the accuracy of the answer relative to the reference answer.
We also measure the faithfulness of the model's response relative to the retrieved documents (if any).
""" # noqa: E501
),
)
]
)
+109
View File
@@ -0,0 +1,109 @@
from typing import Iterable, List, Optional
from langchain.chat_models import ChatOpenAI
from langchain.output_parsers.openai_functions import JsonKeyOutputFunctionsParser
from langchain.prompts import ChatPromptTemplate
from langchain.pydantic_v1 import BaseModel
from langchain.schema.document import Document
from langchain.schema.runnable import Runnable, RunnableConfig
from langchain.schema.storage import BaseStore
from langchain.text_splitter import TextSplitter
import logging
logger = logging.getLogger(__name__)
def transform_docs_parent_child(
documents: Iterable[Document],
child_splitter: TextSplitter,
docstore: BaseStore,
id_key: str,
*,
parent_splitter: Optional[TextSplitter] = None,
) -> Iterable[Document]:
"""Transforms documents into child <-> parent documents."""
if parent_splitter is not None:
documents = parent_splitter.split_documents(documents)
doc_ids = []
for doc in documents:
yield doc
_id = doc.metadata[id_key]
doc_ids.append((_id, doc))
sub_docs = child_splitter.split_documents([doc])
for _doc in sub_docs:
_doc.metadata[id_key] = _id
yield _doc
docstore.mset(doc_ids)
def _default_hyde_embedder():
class hypotheticalQuestions(BaseModel):
"""Write user queries that could be answered by the document."""
questions: List[str]
return (
(
{"doc": lambda x: x.page_content}
# Only asking for 3 hypothetical questions, but this could be adjusted
| ChatPromptTemplate.from_messages(
[
(
"system",
"You are an AI creating an inverted index for document retrieval. "
"Analyze the content of the following document and generate relevant user queries "
"that would likely retrieve this document. Document content:\n\n```document.txt\n{doc}\n```",
),
(
"user",
"Based on the document's content, what specific technical queries or questions"
" are users likely to search that this document can answer?",
),
]
)
| ChatOpenAI(max_retries=0, model="gpt-4-1106-preview").bind_functions(
functions=[hypotheticalQuestions],
function_call="hypotheticalQuestions",
)
| JsonKeyOutputFunctionsParser(key_name="questions")
)
.with_retry(stop_after_attempt=3)
.with_config(
run_name="HyDE",
metadata={"benchmark_environment": "langchain-docs"},
tags=["langchain-benchmarks"],
)
)
def transform_docs_hyde(
documents: Iterable[Document],
docstore: BaseStore,
id_key: str,
*,
query_generator: Optional[Runnable] = None,
runnable_config: Optional[RunnableConfig] = None,
) -> Iterable[Document]:
"""Generates hypothetical document embeddings."""
if query_generator is None:
query_generator = _default_hyde_embedder()
logger.info(f"Using default query generator\n{query_generator}")
generator = query_generator or _default_hyde_embedder()
docs = list(documents)
questions = generator.batch(
docs, runnable_config or {"max_concurrency": 5}, return_exceptions=True
)
doc_ids = []
for doc, expansions in zip(documents, questions):
yield doc
if isinstance(expansions, BaseException):
logger.error(
f"Error generating questions for document {doc.metadata[id_key]}"
)
continue
expansion_docs = [
Document(page_content=s, metadata=doc.metadata) for s in expansions
]
yield from expansion_docs
doc_ids.append((doc.metadata[id_key], doc))
docstore.mset(doc_ids)
@@ -1,95 +1,22 @@
"""Registry of environments for ease of access."""
"""Registry of tool use environments for ease of access."""
import dataclasses
from typing import Callable, List, Sequence, Union
from langchain_benchmarks.tool_usage.environments import alpha
from langchain_benchmarks.utils._registration import Environment, Registry
from typing import Callable, List
from langchain.tools import BaseTool
from tabulate import tabulate
from langchain_benchmarks.tool_usage.environments import alpha
@dataclasses.dataclass(frozen=True)
class Environment:
id: int
"""The ID of the environment."""
name: str
"""The name of the environment."""
dataset_id: str
"""The ID of the langsmith public dataset.
This dataset contains expected inputs/outputs for the environment, and
can be used to evaluate the performance of a model/agent etc.
"""
class ToolEnvironment(Environment):
tools_factory: Callable[[], List[BaseTool]]
"""Factory that returns a list of tools that can be used in the environment."""
description: str
"""Description of the environment."""
def _repr_html_(self) -> str:
"""Return a HTML representation of the environment."""
table = [
["ID", self.id],
["Name", self.name],
["Dataset ID", self.dataset_id],
["Description", self.description[:100] + "..."],
]
return tabulate(
table,
tablefmt="html",
)
@dataclasses.dataclass(frozen=True)
class Registry:
environments: Sequence[Environment]
def get_environment(self, name_or_id: Union[int, str]) -> Environment:
"""Get the environment with the given name."""
for env in self.environments:
if env.name == name_or_id or env.id == name_or_id:
return env
raise ValueError(
f"Unknown environment {name_or_id}. Use list_environments() to see "
f"available environments."
)
def _repr_html_(self) -> str:
"""Return a HTML representation of the registry."""
headers = [
"ID",
"Name",
"Dataset ID",
"Description",
]
table = [
[
env.id,
env.name,
env.dataset_id,
env.description,
]
for env in self.environments
]
return tabulate(table, headers=headers, tablefmt="html")
def __getitem__(self, key: Union[int, str]) -> Environment:
"""Get an environment from the registry."""
if isinstance(key, slice):
raise NotImplementedError("Slicing is not supported.")
elif isinstance(key, (int, str)):
# If key is an integer, return the corresponding environment
return self.get_environment(key)
else:
raise TypeError("Key must be an integer or a slice.")
# Using lower case naming to make a bit prettier API when used in a notebook
registry = Registry(
environments=[
Environment(
ToolEnvironment(
id=0,
name="Tool Usage - Alpha",
dataset_id=alpha.DATASET_ID,
@@ -0,0 +1,83 @@
import dataclasses
from typing import List, Sequence, Union
from tabulate import tabulate
@dataclasses.dataclass(frozen=True)
class Environment:
id: int
"""The ID of the environment."""
name: str
"""The name of the environment."""
dataset_id: str
"""The ID of the langsmith public dataset.
This dataset contains expected inputs/outputs for the environment, and
can be used to evaluate the performance of a model/agent etc.
"""
description: str
"""Description of the environment."""
@property
def _table(self) -> List[List[str]]:
return [
["ID", self.id],
["Name", self.name],
["Dataset ID", self.dataset_id],
["Description", self.description[:100] + "..."],
]
def _repr_html_(self) -> str:
"""Return a HTML representation of the environment."""
return tabulate(
self._table,
tablefmt="html",
)
@dataclasses.dataclass(frozen=True)
class Registry:
environments: Sequence[Environment]
def get_environment(self, name_or_id: Union[int, str]) -> Environment:
"""Get the environment with the given name."""
for env in self.environments:
if env.name == name_or_id or env.id == name_or_id:
return env
raise ValueError(
f"Unknown environment {name_or_id}. Use list_environments() to see "
f"available environments."
)
def _repr_html_(self) -> str:
"""Return a HTML representation of the registry."""
headers = [
"ID",
"Name",
"Dataset ID",
"Description",
]
table = [
[
env.id,
env.name,
env.dataset_id,
env.description,
]
for env in self.environments
]
return tabulate(table, headers=headers, tablefmt="html")
def __getitem__(self, key: Union[int, str]) -> Environment:
"""Get an environment from the registry."""
if isinstance(key, slice):
raise NotImplementedError("Slicing is not supported.")
elif isinstance(key, (int, str)):
# If key is an integer, return the corresponding environment
return self.get_environment(key)
else:
raise TypeError("Key must be an integer or a slice.")