mirror of
https://github.com/langchain-ai/langchain-benchmarks.git
synced 2026-07-01 22:34:02 -04:00
Compare commits
25 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| eb2d9e2b63 | |||
| 09d214522f | |||
| 8798735ea4 | |||
| 7ed859c068 | |||
| 417e6faccf | |||
| aeae13ba63 | |||
| 825d8ec9bb | |||
| 44a5c3530a | |||
| 14de11a420 | |||
| b15620ee9c | |||
| 13e7f2df0a | |||
| 888fce5060 | |||
| 148a3e4f89 | |||
| 0e10f3227f | |||
| b0667043ea | |||
| bd5eac5abd | |||
| dbb85200ac | |||
| c1023a14b8 | |||
| 8899acc989 | |||
| c0e7f51626 | |||
| 9f827eaca5 | |||
| d9fc08b05c | |||
| 8a5ba6d575 | |||
| 8204930f2b | |||
| 013fe6a153 |
@@ -1,6 +1,4 @@
|
||||
🚧 Under Active Development 🚧
|
||||
|
||||
# 🦜💪 LangChain Benchmarks
|
||||
# 🦜💯 LangChain Benchmarks
|
||||
|
||||
[](https://github.com/langchain-ai/langchain-benchmarks/releases)
|
||||
[](https://github.com/langchain-ai/langchain-benchmarks/actions/workflows/ci.yml)
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from chat_langchain.chain import chain
|
||||
from fastapi import FastAPI
|
||||
from openai_functions_agent import agent_executor as openai_functions_agent_chain
|
||||
|
||||
from langserve import add_routes
|
||||
from openai_functions_agent import agent_executor as openai_functions_agent_chain
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@@ -1 +1,3 @@
|
||||
chromadb/
|
||||
index.md
|
||||
Untitled.ipynb
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
+610
@@ -0,0 +1,610 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9fa3470d-9448-4792-9f65-6978fc58cf84",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Multi-modal eval: Baseline\n",
|
||||
"\n",
|
||||
"`Multi-modal slide decks` is a public dataset that contains a dataset of question-answer pairs from slide decks with visual content.\n",
|
||||
"\n",
|
||||
"The question-answer pairs are derived from the visual content in the decks, testing the ability of RAG to perform visual reasoning.\n",
|
||||
"\n",
|
||||
"As a baseline, we evaluate this dataset using text-based RAG pipeline, below.\n",
|
||||
"\n",
|
||||
"This will not reason about visual content and will simply load the text from the slides. \n",
|
||||
"\n",
|
||||
"## Pre-requisites"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "47220461-d4e9-4f1d-9c57-672ca947ca0d",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# %pip install -U langchain langsmith langchain_benchmarks\n",
|
||||
"# %pip install --quiet chromadb openai pypdf pandas"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "196de967-6de6-40da-aa75-e836923ab5e3",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import getpass\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ[\"LANGCHAIN_ENDPOINT\"] = \"https://api.smith.langchain.com\"\n",
|
||||
"env_vars = [\"LANGCHAIN_API_KEY\", \"OPENAI_API_KEY\"]\n",
|
||||
"for var in env_vars:\n",
|
||||
" if var not in os.environ:\n",
|
||||
" os.environ[var] = getpass.getpass(prompt=f\"Enter your {var}: \")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "10da8e11-6288-4131-bd60-d5aa86928acc",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Dataset\n",
|
||||
"\n",
|
||||
"We can browse the available LangChain benchmark datasets for retrieval."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "2ff97905-14a6-413c-99be-58b7a9c8d4c1",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<table>\n",
|
||||
"<thead>\n",
|
||||
"<tr><th>Name </th><th>Type </th><th>Dataset ID </th><th>Description </th></tr>\n",
|
||||
"</thead>\n",
|
||||
"<tbody>\n",
|
||||
"<tr><td>LangChain Docs Q&A </td><td>RetrievalTask</td><td><a href=\"https://smith.langchain.com/public/452ccafc-18e1-4314-885b-edd735f17b9d/d\" target=\"_blank\" rel=\"noopener\">452ccafc-18e1-4314-885b-edd735f17b9d</a></td><td>Questions and answers based on a snapshot of the LangChain python docs.\n",
|
||||
"\n",
|
||||
"The environment provides the documents and the retriever information.\n",
|
||||
"\n",
|
||||
"Each example is composed of a question and reference answer.\n",
|
||||
"\n",
|
||||
"Success is measured based on the accuracy of the answer relative to the reference answer.\n",
|
||||
"We also measure the faithfulness of the model's response relative to the retrieved documents (if any). </td></tr>\n",
|
||||
"<tr><td>Semi-structured Reports</td><td>RetrievalTask</td><td><a href=\"https://smith.langchain.com/public/c47d9617-ab99-4d6e-a6e6-92b8daf85a7d/d\" target=\"_blank\" rel=\"noopener\">c47d9617-ab99-4d6e-a6e6-92b8daf85a7d</a></td><td>Questions and answers based on PDFs containing tables and charts.\n",
|
||||
"\n",
|
||||
"The task provides the raw documents as well as factory methods to easily index them\n",
|
||||
"and create a retriever.\n",
|
||||
"\n",
|
||||
"Each example is composed of a question and reference answer.\n",
|
||||
"\n",
|
||||
"Success is measured based on the accuracy of the answer relative to the reference answer.\n",
|
||||
"We also measure the faithfulness of the model's response relative to the retrieved documents (if any). </td></tr>\n",
|
||||
"<tr><td>Multi-modal slide decks</td><td>RetrievalTask</td><td><a href=\"https://smith.langchain.com/public/40afc8e7-9d7e-44ed-8971-2cae1eb59731/d\" target=\"_blank\" rel=\"noopener\">40afc8e7-9d7e-44ed-8971-2cae1eb59731</a></td><td>This public dataset is a work-in-progress and will be extended over time.\n",
|
||||
" \n",
|
||||
"Questions and answers based on slide decks containing visual tables and charts.\n",
|
||||
"\n",
|
||||
"Each example is composed of a question and reference answer.\n",
|
||||
"\n",
|
||||
"Success is measured based on the accuracy of the answer relative to the reference answer. </td></tr>\n",
|
||||
"</tbody>\n",
|
||||
"</table>"
|
||||
],
|
||||
"text/plain": [
|
||||
"Registry(tasks=[RetrievalTask(name='LangChain Docs Q&A', dataset_id='https://smith.langchain.com/public/452ccafc-18e1-4314-885b-edd735f17b9d/d', description=\"Questions and answers based on a snapshot of the LangChain python docs.\\n\\nThe environment provides the documents and the retriever information.\\n\\nEach example is composed of a question and reference answer.\\n\\nSuccess is measured based on the accuracy of the answer relative to the reference answer.\\nWe also measure the faithfulness of the model's response relative to the retrieved documents (if any).\\n\", get_docs=<function load_cached_docs at 0x104485800>, retriever_factories={'basic': <function _chroma_retriever_factory at 0x1360289a0>, 'parent-doc': <function _chroma_parent_document_retriever_factory at 0x136028a40>, 'hyde': <function _chroma_hyde_retriever_factory at 0x136028ae0>}, architecture_factories={'conversational-retrieval-qa': <function default_response_chain at 0x126ba2660>}), RetrievalTask(name='Semi-structured Reports', dataset_id='https://smith.langchain.com/public/c47d9617-ab99-4d6e-a6e6-92b8daf85a7d/d', description=\"Questions and answers based on PDFs containing tables and charts.\\n\\nThe task provides the raw documents as well as factory methods to easily index them\\nand create a retriever.\\n\\nEach example is composed of a question and reference answer.\\n\\nSuccess is measured based on the accuracy of the answer relative to the reference answer.\\nWe also measure the faithfulness of the model's response relative to the retrieved documents (if any).\\n\", get_docs=<function load_docs at 0x136029620>, retriever_factories={'basic': <function _chroma_retriever_factory at 0x1360296c0>, 'parent-doc': <function _chroma_parent_document_retriever_factory at 0x136029760>, 'hyde': <function _chroma_hyde_retriever_factory at 0x136029800>}, architecture_factories={}), RetrievalTask(name='Multi-modal slide decks', dataset_id='https://smith.langchain.com/public/40afc8e7-9d7e-44ed-8971-2cae1eb59731/d', description='This public dataset is a work-in-progress and will be extended over time.\\n \\nQuestions and answers based on slide decks containing visual tables and charts.\\n\\nEach example is composed of a question and reference answer.\\n\\nSuccess is measured based on the accuracy of the answer relative to the reference answer.\\n', get_docs={}, retriever_factories={}, architecture_factories={})])"
|
||||
]
|
||||
},
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain_benchmarks import clone_public_dataset, registry\n",
|
||||
"\n",
|
||||
"registry = registry.filter(Type=\"RetrievalTask\")\n",
|
||||
"registry"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2fb7dc3d-28f1-4c28-b0d0-3784d04b81ce",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"`Multi-modal slide decks` is the relevant dataset for our task."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "219a4141-4a5f-48e4-ae05-5a824e2193fd",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<table>\n",
|
||||
"<tbody>\n",
|
||||
"<tr><td>Name </td><td>Multi-modal slide decks </td></tr>\n",
|
||||
"<tr><td>Type </td><td>RetrievalTask </td></tr>\n",
|
||||
"<tr><td>Dataset ID </td><td><a href=\"https://smith.langchain.com/public/40afc8e7-9d7e-44ed-8971-2cae1eb59731/d\" target=\"_blank\" rel=\"noopener\">40afc8e7-9d7e-44ed-8971-2cae1eb59731</a></td></tr>\n",
|
||||
"<tr><td>Description </td><td>This public dataset is a work-in-progress and will be extended over time.\n",
|
||||
" \n",
|
||||
"Questions and answers based on slide decks containing visual tables and charts.\n",
|
||||
"\n",
|
||||
"Each example is composed of a question and reference answer.\n",
|
||||
"\n",
|
||||
"Success is measured based on the accuracy of the answer relative to the reference answer. </td></tr>\n",
|
||||
"<tr><td>Retriever Factories </td><td> </td></tr>\n",
|
||||
"<tr><td>Architecture Factories</td><td> </td></tr>\n",
|
||||
"<tr><td>get_docs </td><td>{} </td></tr>\n",
|
||||
"</tbody>\n",
|
||||
"</table>"
|
||||
],
|
||||
"text/plain": [
|
||||
"RetrievalTask(name='Multi-modal slide decks', dataset_id='https://smith.langchain.com/public/40afc8e7-9d7e-44ed-8971-2cae1eb59731/d', description='This public dataset is a work-in-progress and will be extended over time.\\n \\nQuestions and answers based on slide decks containing visual tables and charts.\\n\\nEach example is composed of a question and reference answer.\\n\\nSuccess is measured based on the accuracy of the answer relative to the reference answer.\\n', get_docs={}, retriever_factories={}, architecture_factories={})"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"task = registry[\"Multi-modal slide decks\"]\n",
|
||||
"task"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2d6569b5-e79a-41b7-9745-c2f8a1dd704e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Clone the dataset so that it's available in our LangSmith datasets."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "d2caa086-9549-4c74-bba9-ba80d5a7b218",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Dataset Multi-modal slide decks already exists. Skipping.\n",
|
||||
"You can access the dataset at https://smith.langchain.com/o/ebbaf2eb-769b-4505-aca2-d11de10372a4/datasets/08a29acb-5ad6-42ce-a482-574c9e2e5306.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"clone_public_dataset(task.dataset_id, dataset_name=task.name)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bf350917-a1e5-46f4-81cd-c1678ab9220f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Fetch the associated PDFs from remote cache for the dataset so that we can perform ingestion."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "99ce6afb-2317-4bc1-9faf-4f828095ad91",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_benchmarks.rag.tasks.multi_modal_slide_decks import get_file_names\n",
|
||||
"\n",
|
||||
"file_names = list(get_file_names()) # PosixPath"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "848a4cdb-6c08-4c01-81ce-16ab83a7fdff",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Load\n",
|
||||
"\n",
|
||||
"Load and split the files for indexing."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "6ce85810-98a7-406e-b44e-ce860ac35986",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"There are 98 text elements in DDOG_Q3_earnings_deck.pdf\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.document_loaders import PyPDFLoader\n",
|
||||
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def load_and_split(file):\n",
|
||||
" \"\"\"\n",
|
||||
" Load and split PDF files\n",
|
||||
" :param file: PosixPath path for pdf\n",
|
||||
" :return: A list of text chunks\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" loader = PyPDFLoader(str(file))\n",
|
||||
" pdf_pages = loader.load()\n",
|
||||
"\n",
|
||||
" text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(\n",
|
||||
" chunk_size=100, chunk_overlap=50\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Get chunks\n",
|
||||
" docs = text_splitter.split_documents(pdf_pages)\n",
|
||||
" texts = [d.page_content for d in docs]\n",
|
||||
" print(f\"There are {len(texts)} text elements in {file.name}\")\n",
|
||||
" return texts\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"texts = []\n",
|
||||
"for fi in file_names:\n",
|
||||
" texts.extend(load_and_split(fi))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "eb01925d-b7d1-47a1-bd90-805178d3c4a9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Index\n",
|
||||
"\n",
|
||||
"Embed (OpenAIEmbeddings) and store splits in a vectorstore (Chroma)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "ceb31f71-45fb-4b12-bc1c-31981de334bb",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.embeddings import OpenAIEmbeddings\n",
|
||||
"from langchain.vectorstores import Chroma\n",
|
||||
"\n",
|
||||
"vectorstore_baseline = Chroma.from_texts(\n",
|
||||
" texts=texts, collection_name=\"baseline-multi-modal\", embedding=OpenAIEmbeddings()\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"retriever_baseline = vectorstore_baseline.as_retriever()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e6dcbb01-f480-456d-b972-c732eb26c393",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## RAG\n",
|
||||
"\n",
|
||||
"Create a pipeline for retrieval of relevant chunks based on semantic similarity to the input question.\n",
|
||||
"\n",
|
||||
"Pass the images to GPT-4 for answer synthesis."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "ea233664-e527-42f1-a820-0c2271e16c20",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chat_models import ChatOpenAI\n",
|
||||
"from langchain.prompts import ChatPromptTemplate\n",
|
||||
"from langchain.schema.output_parser import StrOutputParser\n",
|
||||
"from langchain.schema.runnable import RunnablePassthrough\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def rag_chain(retriever):\n",
|
||||
" \"\"\"\n",
|
||||
" RAG pipeline for the indexed presentations\n",
|
||||
" :param retriever: PosixPath path for pdf\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" # Prompt template\n",
|
||||
" template = \"\"\"Answer the question based only on the following context, which can include text and tables:\n",
|
||||
" {context}\n",
|
||||
" Question: {question}\n",
|
||||
" \"\"\"\n",
|
||||
" prompt = ChatPromptTemplate.from_template(template)\n",
|
||||
"\n",
|
||||
" # LLM\n",
|
||||
" model = ChatOpenAI(temperature=0, model=\"gpt-4\")\n",
|
||||
"\n",
|
||||
" # RAG pipeline\n",
|
||||
" chain = (\n",
|
||||
" {\n",
|
||||
" \"context\": retriever | (lambda x: \"\\n\\n\".join([i.page_content for i in x])),\n",
|
||||
" \"question\": RunnablePassthrough(),\n",
|
||||
" }\n",
|
||||
" | prompt\n",
|
||||
" | model\n",
|
||||
" | StrOutputParser()\n",
|
||||
" )\n",
|
||||
" return chain\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Create RAG chain\n",
|
||||
"chain = rag_chain(retriever_baseline)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "95df1446-143d-4f4c-a15b-2a379266d8bf",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Eval\n",
|
||||
"\n",
|
||||
"Run evaluation on our dataset:\n",
|
||||
"\n",
|
||||
"* `task.name` is the dataset of QA pairs that we cloned\n",
|
||||
"* `eval_config` specifies the [LangSmith evaluator](https://docs.smith.langchain.com/evaluation/evaluator-implementations#correctness-qa-evaluation) for our dataset, which will use GPT-4 as a grader\n",
|
||||
"* The grader will evaluate the chain-generated answer to each question relative to ground truth"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "479ce09d-642e-4b3b-9e4e-e9c2b7f0e9ca",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"View the evaluation results for project '866f-baseline' at:\n",
|
||||
"https://smith.langchain.com/o/ebbaf2eb-769b-4505-aca2-d11de10372a4/datasets/08a29acb-5ad6-42ce-a482-574c9e2e5306/compare?selectedSessions=30199d47-50d7-4c5c-a55a-e74157e05951\n",
|
||||
"\n",
|
||||
"View all tests for Dataset Multi-modal slide decks at:\n",
|
||||
"https://smith.langchain.com/o/ebbaf2eb-769b-4505-aca2-d11de10372a4/datasets/08a29acb-5ad6-42ce-a482-574c9e2e5306\n",
|
||||
"[------------------------------------------------->] 10/10"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<h3>Experiment Results:</h3>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>output</th>\n",
|
||||
" <th>feedback.COT Contextual Accuracy</th>\n",
|
||||
" <th>error</th>\n",
|
||||
" <th>execution_time</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>count</th>\n",
|
||||
" <td>10</td>\n",
|
||||
" <td>10.000000</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>10.000000</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>unique</th>\n",
|
||||
" <td>10</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>top</th>\n",
|
||||
" <td>Datadog has 20 total customers.</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>freq</th>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>mean</th>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>0.200000</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>4.674478</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>std</th>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>0.421637</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>0.864273</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>min</th>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>0.000000</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>3.307960</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>25%</th>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>0.000000</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>4.113816</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>50%</th>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>0.000000</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>4.700962</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>75%</th>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>0.000000</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>5.018359</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>max</th>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>1.000000</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>6.188082</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" output feedback.COT Contextual Accuracy \\\n",
|
||||
"count 10 10.000000 \n",
|
||||
"unique 10 NaN \n",
|
||||
"top Datadog has 20 total customers. NaN \n",
|
||||
"freq 1 NaN \n",
|
||||
"mean NaN 0.200000 \n",
|
||||
"std NaN 0.421637 \n",
|
||||
"min NaN 0.000000 \n",
|
||||
"25% NaN 0.000000 \n",
|
||||
"50% NaN 0.000000 \n",
|
||||
"75% NaN 0.000000 \n",
|
||||
"max NaN 1.000000 \n",
|
||||
"\n",
|
||||
" error execution_time \n",
|
||||
"count 0 10.000000 \n",
|
||||
"unique 0 NaN \n",
|
||||
"top NaN NaN \n",
|
||||
"freq NaN NaN \n",
|
||||
"mean NaN 4.674478 \n",
|
||||
"std NaN 0.864273 \n",
|
||||
"min NaN 3.307960 \n",
|
||||
"25% NaN 4.113816 \n",
|
||||
"50% NaN 4.700962 \n",
|
||||
"75% NaN 5.018359 \n",
|
||||
"max NaN 6.188082 "
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import uuid\n",
|
||||
"\n",
|
||||
"from langchain.smith import RunEvalConfig\n",
|
||||
"from langsmith.client import Client\n",
|
||||
"\n",
|
||||
"# Evaluator configuration\n",
|
||||
"client = Client()\n",
|
||||
"eval_config = RunEvalConfig(\n",
|
||||
" evaluators=[\"cot_qa\"],\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Experiments\n",
|
||||
"chain_map = {\n",
|
||||
" \"baseline\": chain,\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# Run evaluation\n",
|
||||
"run_id = uuid.uuid4().hex[:4]\n",
|
||||
"test_runs = {}\n",
|
||||
"for project_name, chain in chain_map.items():\n",
|
||||
" test_runs[project_name] = client.run_on_dataset(\n",
|
||||
" dataset_name=task.name,\n",
|
||||
" llm_or_chain_factory=lambda: (lambda x: x[\"Question\"]) | chain,\n",
|
||||
" evaluation=eval_config,\n",
|
||||
" verbose=True,\n",
|
||||
" project_name=f\"{run_id}-{project_name}\",\n",
|
||||
" project_metadata={\"chain\": project_name},\n",
|
||||
" )"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
+319
@@ -0,0 +1,319 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b6856d11-40d5-48e5-9eb3-423f479933a1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Semi-structured eval: Chunk size tuning\n",
|
||||
"\n",
|
||||
"`Semi-structured Reports` is a public dataset that contains question-answer pairs from documents with text and tables.\n",
|
||||
"\n",
|
||||
"The question-answer pairs are derived from the tables as well as some of the paragraphs in the docs.\n",
|
||||
"\n",
|
||||
"We evaluation performance of various chunk sizes with RAG. \n",
|
||||
"\n",
|
||||
"## Pre-requisites"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c387b660-967d-4d2f-8c38-af125f7b7a8b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# %pip install -U langchain langsmith langchain_benchmarks\n",
|
||||
"# %pip install --quiet chromadb openai pypdf tiktoken fireworks-ai"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e9e332b1-7da4-47fc-8d9a-4d65fbfc6953",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import getpass\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ[\"LANGCHAIN_ENDPOINT\"] = \"https://api.smith.langchain.com\"\n",
|
||||
"env_vars = [\"LANGCHAIN_API_KEY\", \"OPENAI_API_KEY\", \"FIREWORKS_API_KEY\"]\n",
|
||||
"for var in env_vars:\n",
|
||||
" if var not in os.environ:\n",
|
||||
" os.environ[var] = getpass.getpass(prompt=f\"Enter your {var}: \")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b1a19f23-468c-4aeb-a0e9-0765a85f3f0b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Dataset\n",
|
||||
"\n",
|
||||
"Fetch the associated PDFs from remote cache for the dataset so that we can perform ingestion."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "a94d9aa5-acd8-4032-ad8f-f995dec4d13c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"from langchain_benchmarks import clone_public_dataset, registry\n",
|
||||
"from langchain_benchmarks.rag.tasks.semi_structured_reports import get_file_names\n",
|
||||
"\n",
|
||||
"# Task\n",
|
||||
"task = registry[\"Semi-structured Reports\"]\n",
|
||||
"\n",
|
||||
"# Files used\n",
|
||||
"paths = list(get_file_names())\n",
|
||||
"files = [str(p) for p in paths]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "12b52285-358c-4752-ad6b-25ffb629e309",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Clone the dataset so that it's available in our LangSmith datasets."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "1ecca7af-c3e7-42d1-97dd-c7d9777207cb",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Dataset Semi-structured Reports already exists. Skipping.\n",
|
||||
"You can access the dataset at https://smith.langchain.com/o/1fa8b1f4-fcb9-4072-9aa9-983e35ad61b8/datasets/6549a3a5-1cb9-463f-951d-0166cb9cf45c.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"clone_public_dataset(task.dataset_id, dataset_name=task.name)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "64f37705-0190-4b7a-9d88-63bfd904fbd9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Load and index\n",
|
||||
"\n",
|
||||
"We load each file, split it, embed with `OpenAIEmbeddings`, and create an index with `Chroma` vectorstore."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7eb9e333-77e6-48f9-b221-9bded023b978",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.callbacks.manager import CallbackManager\n",
|
||||
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
|
||||
"from langchain.chat_models import ChatFireworks, ChatOpenAI\n",
|
||||
"from langchain.document_loaders import PyPDFLoader\n",
|
||||
"from langchain.embeddings import OpenAIEmbeddings\n",
|
||||
"from langchain.prompts import ChatPromptTemplate\n",
|
||||
"from langchain.schema.output_parser import StrOutputParser\n",
|
||||
"from langchain.schema.runnable import RunnablePassthrough\n",
|
||||
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
|
||||
"from langchain.vectorstores import Chroma\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def load_and_split(file, token_count, split_document=True):\n",
|
||||
" \"\"\"\n",
|
||||
" Load and optionally split PDF files.\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" file (str): File path.\n",
|
||||
" token_count (int): Token count for splitting.\n",
|
||||
" split_document (bool): Flag for splitting or returning pages.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" loader = PyPDFLoader(file)\n",
|
||||
" pdf_pages = loader.load()\n",
|
||||
"\n",
|
||||
" if split_document:\n",
|
||||
" text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(\n",
|
||||
" chunk_size=token_count, chunk_overlap=50\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" docs = text_splitter.split_documents(pdf_pages)\n",
|
||||
" texts = [d.page_content for d in docs]\n",
|
||||
" else:\n",
|
||||
" texts = [d.page_content for d in pdf_pages]\n",
|
||||
"\n",
|
||||
" print(f\"There are {len(texts)} text elements\")\n",
|
||||
" return texts\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def load_files(files, token_count, split_document):\n",
|
||||
" \"\"\"\n",
|
||||
" Load files.\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" files (list): List of file names.\n",
|
||||
" dir (str): Directory path.\n",
|
||||
" token_count (int): Token count for splitting.\n",
|
||||
" split_document (bool): Flag for splitting documents.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" texts = []\n",
|
||||
" for fi in files:\n",
|
||||
" texts.extend(load_and_split(fi, token_count, split_document))\n",
|
||||
" return texts\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def make_retriever(texts, expt):\n",
|
||||
" \"\"\"\n",
|
||||
" Make vector store.\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" texts (list): List of texts.\n",
|
||||
" expt (str): Experiment name.\n",
|
||||
" \"\"\"\n",
|
||||
" vectorstore = Chroma.from_texts(\n",
|
||||
" texts=texts, collection_name=expt, embedding=OpenAIEmbeddings()\n",
|
||||
" )\n",
|
||||
" retriever = vectorstore.as_retriever()\n",
|
||||
" return retriever\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def rag_chain(retriever, llm):\n",
|
||||
" \"\"\"\n",
|
||||
" RAG chain.\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" retriever: The retriever to use.\n",
|
||||
" llm: The llm to use.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" # Prompt template\n",
|
||||
" template = \"\"\"Answer the question based only on the following context, which can include text and tables:\n",
|
||||
" {context}\n",
|
||||
" Question: {question}\n",
|
||||
" \"\"\"\n",
|
||||
" prompt = ChatPromptTemplate.from_template(template)\n",
|
||||
"\n",
|
||||
" # LLM\n",
|
||||
" if llm == \"mixtral\":\n",
|
||||
" model = ChatFireworks(\n",
|
||||
" model=\"accounts/fireworks/models/mixtral-8x7b-instruct\", temperature=0\n",
|
||||
" )\n",
|
||||
" else:\n",
|
||||
" model = ChatOpenAI(temperature=0, model=\"gpt-4\")\n",
|
||||
"\n",
|
||||
" # RAG pipeline\n",
|
||||
" chain = (\n",
|
||||
" {\n",
|
||||
" \"context\": retriever | (lambda x: \"\\n\\n\".join([i.page_content for i in x])),\n",
|
||||
" \"question\": RunnablePassthrough(),\n",
|
||||
" }\n",
|
||||
" | prompt\n",
|
||||
" | model\n",
|
||||
" | StrOutputParser()\n",
|
||||
" )\n",
|
||||
" return chain\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Experiment configurations\n",
|
||||
"experiments = [\n",
|
||||
" (None, False, \"page_split-oai\", \"oai\"),\n",
|
||||
" (50, True, \"50_tok_split-oai\", \"oai\"),\n",
|
||||
" (100, True, \"100_tok_split-oai\", \"oai\"),\n",
|
||||
" (250, True, \"250_tok_split-oai\", \"oai\"),\n",
|
||||
" (250, True, \"250_tok_split-mixtral\", \"mixtral\"),\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"# Run\n",
|
||||
"stor_chain = {}\n",
|
||||
"for token_count, split_document, expt, llm in experiments:\n",
|
||||
" texts = load_files(files, token_count, split_document)\n",
|
||||
" retriever = make_retriever(texts, expt)\n",
|
||||
" stor_chain[expt] = rag_chain(retriever, llm)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "29515a91-3cb1-41bd-a2d4-6cf6ce7806c2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Eval\n",
|
||||
"\n",
|
||||
"Run eval onm our dataset, `Semi-structured Reports`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "edd2e7f9-b3f6-4885-bf05-96f1c1758b20",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import uuid\n",
|
||||
"\n",
|
||||
"from langchain.smith import RunEvalConfig\n",
|
||||
"from langsmith.client import Client\n",
|
||||
"\n",
|
||||
"# Config\n",
|
||||
"client = Client()\n",
|
||||
"eval_config = RunEvalConfig(\n",
|
||||
" evaluators=[\"cot_qa\"],\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Experiments\n",
|
||||
"chain_map = {\n",
|
||||
" \"page_split\": stor_chain[\"page_split-oai\"],\n",
|
||||
" \"baseline-50-tok\": stor_chain[\"50_tok_split-oai\"],\n",
|
||||
" \"baseline-100-tok\": stor_chain[\"100_tok_split-oai\"],\n",
|
||||
" \"baseline-250-tok\": stor_chain[\"250_tok_split-oai\"],\n",
|
||||
" \"baseline-250-tok-mixtral\": stor_chain[\"250_tok_split-mixtral\"],\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# Run evaluation\n",
|
||||
"run_id = uuid.uuid4().hex[:4]\n",
|
||||
"test_runs = {}\n",
|
||||
"for project_name, chain in chain_map.items():\n",
|
||||
" test_runs[project_name] = client.run_on_dataset(\n",
|
||||
" dataset_name=task.name,\n",
|
||||
" llm_or_chain_factory=lambda: (lambda x: x[\"question\"]) | chain,\n",
|
||||
" evaluation=eval_config,\n",
|
||||
" verbose=True,\n",
|
||||
" project_name=f\"{run_id}-{project_name}\",\n",
|
||||
" project_metadata={\"chain\": project_name},\n",
|
||||
" )"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.16"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
+365
File diff suppressed because one or more lines are too long
+434
@@ -0,0 +1,434 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7cd0617a-4d00-4c4c-a5df-abc3430e7897",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Semi-structured eval: Multi vector\n",
|
||||
"\n",
|
||||
"`Semi-structured Reports` is a public dataset that contains question-answer pairs from documents with text and tables.\n",
|
||||
"\n",
|
||||
"The question-answer pairs are derived from the tables as well as some of the paragraphs in the docs.\n",
|
||||
"\n",
|
||||
"We evaluation performance using multi-vector retriever for RAG. \n",
|
||||
"\n",
|
||||
"## Pre-requisites"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4edd540d-705f-4042-9ed0-aee42d29f37d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# %pip install -U langchain langsmith langchain_benchmarks\n",
|
||||
"# %pip install --quiet chromadb openai pypdf tiktoken"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "29031433-53db-43bb-ab1a-8ac1721661e8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import getpass\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ[\"LANGCHAIN_ENDPOINT\"] = \"https://api.smith.langchain.com\"\n",
|
||||
"env_vars = [\"LANGCHAIN_API_KEY\", \"OPENAI_API_KEY\"]\n",
|
||||
"for var in env_vars:\n",
|
||||
" if var not in os.environ:\n",
|
||||
" os.environ[var] = getpass.getpass(prompt=f\"Enter your {var}: \")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b560e044-f5ac-418b-b3d6-164b423ab23b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Dataset\n",
|
||||
"\n",
|
||||
"Fetch the associated PDFs from remote cache for the dataset so that we can perform ingestion."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "76f8b0e3-693a-4eed-98e7-c0fa9ba02ff9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"from langchain_benchmarks import clone_public_dataset, registry\n",
|
||||
"from langchain_benchmarks.rag.tasks.semi_structured_reports import get_file_names\n",
|
||||
"\n",
|
||||
"# Task\n",
|
||||
"task = registry[\"Semi-structured Reports\"]\n",
|
||||
"\n",
|
||||
"# Files used\n",
|
||||
"paths = list(get_file_names())\n",
|
||||
"files = [str(p) for p in paths]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "720016d6-9206-4560-9b12-5881dbcabeb3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Clone the dataset so that it's available in our LangSmith datasets."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1e2309e4-0b35-477b-80a6-d4cb06ca4310",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"clone_public_dataset(task.dataset_id, dataset_name=task.name)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fb1db618-05c4-4253-a54b-1c554dd0dc78",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Load and index\n",
|
||||
"\n",
|
||||
"We build a retriever that focuses on tables. \n",
|
||||
"\n",
|
||||
"To do this, we use an LLM to scan each page and summarize any tables within the page. \n",
|
||||
"\n",
|
||||
"We then index those summaries for retrieval and store the raw page text containing the table with [multi-vector retriever](https://blog.langchain.dev/semi-structured-multi-modal-rag/). \n",
|
||||
"\n",
|
||||
"Finally, we use [ensemble retriever](https://python.langchain.com/docs/modules/data_connection/retrievers/ensemble) to mix retrieved table chunks with the raw text chunks: \n",
|
||||
"\n",
|
||||
"* Combines the rankings from different retrievers into a single, unified ranking.\n",
|
||||
"* Each retriever provides a list of documents (or search results) ranked based on their relevance to the query.\n",
|
||||
"* The weights represent the relative importance or trust you place in each retriever's results.\n",
|
||||
"* The weights are used to scale the contribution of each retriever to the final combined ranking.\n",
|
||||
"* The RRF method uses the rank of each item in the lists provided by the retrievers.\n",
|
||||
"* The basic idea is to give higher scores to items that are ranked higher (i.e., have a lower rank number) in the lists."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3d14be7d-30c8-4084-afad-3e82c3fbf9e0",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import uuid\n",
|
||||
"\n",
|
||||
"from langchain.chat_models import ChatOpenAI\n",
|
||||
"from langchain.document_loaders import PyPDFLoader\n",
|
||||
"from langchain.embeddings import OpenAIEmbeddings\n",
|
||||
"from langchain.prompts import ChatPromptTemplate\n",
|
||||
"from langchain.retrievers import EnsembleRetriever\n",
|
||||
"from langchain.retrievers.multi_vector import MultiVectorRetriever\n",
|
||||
"from langchain.schema.document import Document\n",
|
||||
"from langchain.schema.output_parser import StrOutputParser\n",
|
||||
"from langchain.schema.runnable import RunnableLambda, RunnablePassthrough\n",
|
||||
"from langchain.storage import InMemoryStore\n",
|
||||
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
|
||||
"from langchain.vectorstores import Chroma\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def prepare_documents(docs):\n",
|
||||
" \"\"\"\n",
|
||||
" Prepare documents for prompt. Concatenates Document objects (after extracting their page_content)\n",
|
||||
" and strings into a single string, separated by two newlines.\n",
|
||||
"\n",
|
||||
" :param docs: A list of str or Document objects.\n",
|
||||
" :return: A single string containing all documents.\n",
|
||||
" \"\"\"\n",
|
||||
" # Process each document and append it to the list\n",
|
||||
" processed_docs = [\n",
|
||||
" doc.page_content if isinstance(doc, Document) else doc for doc in docs\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" # Join all processed documents into a single string\n",
|
||||
" return \"\\n\\n\".join(processed_docs)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def create_multi_vector_retriever(vectorstore, text_summaries, texts):\n",
|
||||
" \"\"\"\n",
|
||||
" Create retriever that indexes summaries, but returns raw images or texts\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" # Initialize the storage layer\n",
|
||||
" store = InMemoryStore()\n",
|
||||
" id_key = \"doc_id\"\n",
|
||||
"\n",
|
||||
" # Create the multi-vector retriever\n",
|
||||
" retriever = MultiVectorRetriever(\n",
|
||||
" vectorstore=vectorstore,\n",
|
||||
" docstore=store,\n",
|
||||
" id_key=id_key,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Helper function to add documents to the vectorstore and docstore\n",
|
||||
" def add_documents(retriever, doc_summaries, doc_contents):\n",
|
||||
" doc_ids = [str(uuid.uuid4()) for _ in doc_contents]\n",
|
||||
" summary_docs = [\n",
|
||||
" Document(page_content=s, metadata={id_key: doc_ids[i]})\n",
|
||||
" for i, s in enumerate(doc_summaries)\n",
|
||||
" ]\n",
|
||||
" retriever.vectorstore.add_documents(summary_docs)\n",
|
||||
" retriever.docstore.mset(list(zip(doc_ids, doc_contents)))\n",
|
||||
"\n",
|
||||
" # Add texts, tables, and images\n",
|
||||
" add_documents(retriever, text_summaries, texts)\n",
|
||||
" return retriever\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def generate_doc_summary(file):\n",
|
||||
" \"\"\"\n",
|
||||
" Create a doc summary\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" # Prompt\n",
|
||||
" prompt_text = \"\"\"You are an assistant tasked extracting two attributes \\\n",
|
||||
" from financial documents. (1) Tell me the company that the document is \\\n",
|
||||
" focused on. (2) Look at any tables in the document and tell me the units \\ \n",
|
||||
" of the table. Many table will have '(In thousands)' or '(in millions)' prior \\\n",
|
||||
" to the table text. Provide these two for the document: \\n\\n {document} \"\"\"\n",
|
||||
" prompt = ChatPromptTemplate.from_template(prompt_text)\n",
|
||||
"\n",
|
||||
" # Text summary chain\n",
|
||||
" model = ChatOpenAI(temperature=0, model=\"gpt-4-1106-preview\")\n",
|
||||
" summarize_chain = {\"document\": lambda x: x} | prompt | model | StrOutputParser()\n",
|
||||
"\n",
|
||||
" # Load doc\n",
|
||||
" loader = PyPDFLoader(file)\n",
|
||||
" pdf_pages = loader.load()\n",
|
||||
" texts = [t.page_content for t in pdf_pages]\n",
|
||||
" text_string = \" \".join(texts)\n",
|
||||
" summary = summarize_chain.invoke({\"document\": text_string})\n",
|
||||
" return summary\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def generate_table_summaries(texts):\n",
|
||||
" \"\"\"\n",
|
||||
" Summarize text elements\n",
|
||||
" texts: List of str\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" # Prompt\n",
|
||||
" prompt_text = \"\"\"You are an assistant tasked with summarizing tables within a provided text chunk. \\\n",
|
||||
" If the text chunk contains tables, then give a brief summary of the table and list the row and column \\\n",
|
||||
" names to identify what is captured in the table. Do not sumnmarize quantitative results in the table. \\ \n",
|
||||
" If there is no table present, then just return \"No table\". \\n\\n Text: {element} \"\"\"\n",
|
||||
" prompt = ChatPromptTemplate.from_template(prompt_text)\n",
|
||||
"\n",
|
||||
" # Text summary chain\n",
|
||||
" model = ChatOpenAI(temperature=0, model=\"gpt-4\")\n",
|
||||
" summarize_chain = {\"element\": lambda x: x} | prompt | model | StrOutputParser()\n",
|
||||
"\n",
|
||||
" # Initialize empty summaries\n",
|
||||
" text_summaries = []\n",
|
||||
" text_summaries = summarize_chain.batch(texts, {\"max_concurrency\": 5})\n",
|
||||
"\n",
|
||||
" return text_summaries\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def load_and_split(file, token_count, split_document=True):\n",
|
||||
" \"\"\"\n",
|
||||
" Load and optionally split PDF files.\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" file (str): File path.\n",
|
||||
" token_count (int): Token count for splitting.\n",
|
||||
" split_document (bool): Flag for splitting or returning pages.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" loader = PyPDFLoader(file)\n",
|
||||
" pdf_pages = loader.load()\n",
|
||||
"\n",
|
||||
" if split_document:\n",
|
||||
" text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(\n",
|
||||
" chunk_size=token_count, chunk_overlap=50\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" docs = text_splitter.split_documents(pdf_pages)\n",
|
||||
" texts = [d.page_content for d in docs]\n",
|
||||
" else:\n",
|
||||
" texts = [d.page_content for d in pdf_pages]\n",
|
||||
"\n",
|
||||
" print(f\"There are {len(texts)} text elements\")\n",
|
||||
" return texts\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def load_files(files, token_count, split_document):\n",
|
||||
" \"\"\"\n",
|
||||
" Load files.\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" files (list): List of file names.\n",
|
||||
" dir (str): Directory path.\n",
|
||||
" token_count (int): Token count for splitting.\n",
|
||||
" split_document (bool): Flag for splitting documents.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" texts = []\n",
|
||||
" for fi in files:\n",
|
||||
" doc_summary = generate_doc_summary(fi)\n",
|
||||
" texts.extend(load_and_split(fi, token_count, split_document))\n",
|
||||
" return texts, doc_summary\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def rag_chain(retriever):\n",
|
||||
" \"\"\"\n",
|
||||
" RAG chain.\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" retriever: The retriever to use.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" # Prompt template\n",
|
||||
" template = \"\"\"Answer the question based only on the following context, which can include text and tables:\n",
|
||||
" {context}\n",
|
||||
" Question: {question}\n",
|
||||
" \"\"\"\n",
|
||||
" prompt = ChatPromptTemplate.from_template(template)\n",
|
||||
"\n",
|
||||
" # LLM\n",
|
||||
" model = ChatOpenAI(temperature=0, model=\"gpt-4\")\n",
|
||||
"\n",
|
||||
" # RAG pipeline\n",
|
||||
" chain = (\n",
|
||||
" {\n",
|
||||
" \"context\": retriever | RunnableLambda(prepare_documents),\n",
|
||||
" \"question\": RunnablePassthrough(),\n",
|
||||
" }\n",
|
||||
" | prompt\n",
|
||||
" | model\n",
|
||||
" | StrOutputParser()\n",
|
||||
" )\n",
|
||||
" return chain\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Experiment configurations\n",
|
||||
"experiments = [\n",
|
||||
" (None, False, \"page_split_multivector\"),\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"# Run\n",
|
||||
"stor_chain = {}\n",
|
||||
"for token_count, split_document, expt in experiments:\n",
|
||||
" # Get texts and doc summary\n",
|
||||
" doc_texts, doc_summary = load_files(files, token_count, split_document)\n",
|
||||
"\n",
|
||||
" # Get table summaries\n",
|
||||
" doc_table_summaries = generate_table_summaries(doc_texts)\n",
|
||||
"\n",
|
||||
" # Add doc summary to table summary to preserve context\n",
|
||||
" doc_text_summaries = [\n",
|
||||
" \"Here is a summary of the doc: \\n\\n\"\n",
|
||||
" + doc_summary\n",
|
||||
" + \"\\n\\n Here is a summary of a table within this doc: \\n\\n\"\n",
|
||||
" + t\n",
|
||||
" for t in doc_table_summaries\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" # The vectorstore to use to index the summaries\n",
|
||||
" vectorstore = Chroma(collection_name=expt, embedding_function=OpenAIEmbeddings())\n",
|
||||
"\n",
|
||||
" # Create our table retriever\n",
|
||||
" table_retriever = create_multi_vector_retriever(\n",
|
||||
" vectorstore, doc_table_summaries, doc_texts\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Create our docs retriever\n",
|
||||
" vectorstore_docs = Chroma.from_texts(\n",
|
||||
" texts=doc_texts, collection_name=expt + \"docs\", embedding=OpenAIEmbeddings()\n",
|
||||
" )\n",
|
||||
" docs_retriever = vectorstore_docs.as_retriever()\n",
|
||||
"\n",
|
||||
" # Initialize ensemble retriever\n",
|
||||
" ensemble_retriever = EnsembleRetriever(\n",
|
||||
" retrievers=[table_retriever, docs_retriever], weights=[0.75, 0.25]\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Chain\n",
|
||||
" stor_chain[expt] = rag_chain(ensemble_retriever)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "77aeb2e2-156d-4a39-be93-4f401f1df455",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Eval\n",
|
||||
"\n",
|
||||
"Run eval onm our dataset, `Semi-structured Reports`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "55fd91b5-6b8e-4bb5-b97a-42ccc5dd53dd",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import uuid\n",
|
||||
"\n",
|
||||
"from langchain.smith import RunEvalConfig\n",
|
||||
"from langsmith.client import Client\n",
|
||||
"\n",
|
||||
"# Config\n",
|
||||
"client = Client()\n",
|
||||
"eval_config = RunEvalConfig(\n",
|
||||
" evaluators=[\"cot_qa\"],\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Experiments\n",
|
||||
"chain_map = {\n",
|
||||
" \"page_split_multivector_emsemble\": stor_chain[\"page_split_multivector\"],\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# Run evaluation\n",
|
||||
"run_id = uuid.uuid4().hex[:4]\n",
|
||||
"test_runs = {}\n",
|
||||
"for project_name, chain in chain_map.items():\n",
|
||||
" test_runs[project_name] = client.run_on_dataset(\n",
|
||||
" dataset_name=task.name,\n",
|
||||
" llm_or_chain_factory=lambda: (lambda x: x[\"Question\"]) | chain,\n",
|
||||
" evaluation=eval_config,\n",
|
||||
" verbose=True,\n",
|
||||
" project_name=f\"{run_id}-{project_name}\",\n",
|
||||
" project_metadata={\"chain\": project_name},\n",
|
||||
" )"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.16"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -0,0 +1,215 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1ba9f105-c48f-4d8c-8253-355ef13156b0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Benchmark All\n",
|
||||
"\n",
|
||||
"Here, we'll run benchmarking against all tool usage task.\n",
|
||||
"\n",
|
||||
"Expand the models list to benchmark against different models."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "13a7483b-d08f-49fa-83da-619863171e5b",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import datetime\n",
|
||||
"import uuid\n",
|
||||
"\n",
|
||||
"from langchain.globals import set_verbose\n",
|
||||
"from langsmith.client import Client\n",
|
||||
"\n",
|
||||
"from langchain_benchmarks import (\n",
|
||||
" __version__,\n",
|
||||
" clone_public_dataset,\n",
|
||||
" model_registry,\n",
|
||||
" registry,\n",
|
||||
")\n",
|
||||
"from langchain_benchmarks.rate_limiting import RateLimiter\n",
|
||||
"from langchain_benchmarks.tool_usage.agents import (\n",
|
||||
" CustomAgentFactory,\n",
|
||||
" OpenAIAgentFactory,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "50bbe23b-a3b1-4607-929d-ea6e88b7085e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Prior to starting the tests, you may want to verify\n",
|
||||
"that the task that you're working with and the models are propelry defined."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "adfbcaa9-349c-4223-89be-4abff9cf76ff",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'input': \"Repeat the given string using the provided tools. Do not write anything else or provide any explanations. For example, if the string is 'abc', you must print the letters 'a', 'b', and 'c' one at a time and in that order. \\nWrite down your answer, but do not explain it. Input: `abc`\",\n",
|
||||
" 'output': ' Thank you for the input and for confirming the output of each letter I printed. I simply followed the instructions to repeat the given string \"abc\" by printing one letter at a time using the provided \"type_letter\" tool without any additional explanations. Please let me know if you need me to repeat this process with a different input string.',\n",
|
||||
" 'intermediate_steps': [(AgentActionMessageLog(tool='type_letter', tool_input={'letter': 'a'}, log=\"\\nInvoking type_letter: {'letter': 'a'}\\n\\t\", message_log=[AIMessage(content='<tool>{\\n \"tool_name\": \"type_letter\",\\n \"arguments\": {\\n \"letter\": \"a\"\\n }\\n}</tool>\\n')]),\n",
|
||||
" 'OK'),\n",
|
||||
" (AgentActionMessageLog(tool='type_letter', tool_input={'letter': 'b'}, log=\"\\nInvoking type_letter: {'letter': 'b'}\\n\\t\", message_log=[AIMessage(content='<tool>{\\n \"tool_name\": \"type_letter\",\\n \"arguments\": {\\n \"letter\": \"b\"\\n }\\n}</tool>\\n')]),\n",
|
||||
" 'OK'),\n",
|
||||
" (AgentActionMessageLog(tool='type_letter', tool_input={'letter': 'c'}, log=\"\\nInvoking type_letter: {'letter': 'c'}\\n\\t\", message_log=[AIMessage(content='<tool>{\\n \"tool_name\": \"type_letter\",\\n \"arguments\": {\\n \"letter\": \"c\"\\n }\\n}</tool>\\n')]),\n",
|
||||
" 'OK')],\n",
|
||||
" 'state': 'abc'}"
|
||||
]
|
||||
},
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"task = registry[\"Tool Usage - Typewriter (1 tool)\"]\n",
|
||||
"agent_factory = CustomAgentFactory(task, \"claude-2.1\")\n",
|
||||
"\n",
|
||||
"agent_factory().invoke({\"question\": \"abc\"})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "65b32e7d-3986-4461-8a3b-8e9b6d4008cb",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Define the test cases"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "26d390b6-9ade-424c-aabb-d450f52ed121",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tests = [\n",
|
||||
" # 2-tuple of (architecture, model name)\n",
|
||||
" (\"xml\", \"mixtral-8x7b-instruct-fw\"),\n",
|
||||
" (\"xml\", \"claude-2.1\"),\n",
|
||||
" (\"xml\", \"claude-2\"),\n",
|
||||
" (\"xml\", \"yi-34b-200k-fw\"),\n",
|
||||
" (\"xml\", \"llama-v2-70b-chat-fw\"),\n",
|
||||
" (\"xml\", \"llama-v2-13b-chat-fw\"),\n",
|
||||
" (\"openai_functions\", \"gpt-3.5-turbo-1106\"),\n",
|
||||
" (\"openai_functions\", \"gpt-3.5-turbo-0613\"),\n",
|
||||
" (\"openai_functions\", \"gpt-4-1106-preview\")(\"openai_functions\", \"gpt-4-0613\"),\n",
|
||||
"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b55b7c24-8b4d-4bd7-8b00-365fbe61897f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Run"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "a415dd82-2e70-4173-a3f3-8e1aac60db9e",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"experiment_uuid = uuid.uuid4().hex[:4]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e6fbc3ef-7a3f-430f-8b79-45af5861b3ee",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"client = Client() # Launch langsmith client for cloning datasets\n",
|
||||
"today = datetime.date.today().isoformat()\n",
|
||||
"rate_limiter = RateLimiter(requests_per_second=1)\n",
|
||||
"\n",
|
||||
"for task in registry:\n",
|
||||
" dataset_name = task.name + f\"_benchmarking_{today}\"\n",
|
||||
" clone_public_dataset(task.dataset_id, dataset_name=dataset_name)\n",
|
||||
"\n",
|
||||
" if task.type != \"ToolUsageTask\":\n",
|
||||
" continue\n",
|
||||
"\n",
|
||||
" for arch, model in tests:\n",
|
||||
" print()\n",
|
||||
" print(f\"Benchmarking {task.name} with model: {model} and arch: {arch}\")\n",
|
||||
" eval_config = task.get_eval_config()\n",
|
||||
"\n",
|
||||
" if arch == \"openai_functions\":\n",
|
||||
" agent_factory = OpenAIAgentFactory(\n",
|
||||
" task, model=model, rate_limiter=rate_limiter\n",
|
||||
" )\n",
|
||||
" elif arch == \"xml\":\n",
|
||||
" agent_factory = CustomAgentFactory(\n",
|
||||
" task, model=model, rate_limiter=rate_limiter\n",
|
||||
" )\n",
|
||||
" else:\n",
|
||||
" raise ValueError()\n",
|
||||
"\n",
|
||||
" client.run_on_dataset(\n",
|
||||
" dataset_name=dataset_name,\n",
|
||||
" llm_or_chain_factory=agent_factory,\n",
|
||||
" evaluation=eval_config,\n",
|
||||
" verbose=False,\n",
|
||||
" project_name=f\"{model}{experiment_uuid}\",\n",
|
||||
" tags=[model],\n",
|
||||
" concurrency_level=5,\n",
|
||||
" project_metadata={\n",
|
||||
" \"model\": model,\n",
|
||||
" \"id\": experiment_uuid,\n",
|
||||
" \"task\": task.name,\n",
|
||||
" \"date\": today,\n",
|
||||
" \"langchain_benchmarks_version\": __version__,\n",
|
||||
" \"arch\": arch,\n",
|
||||
" },\n",
|
||||
" )\n",
|
||||
" break"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -0,0 +1,405 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "60bb467d-861d-4b07-a48d-8e5aa177c969",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"source": [
|
||||
"# Evaluating OSS Models"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8e3b729e-b851-4ab8-a3a9-be34b329b985",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"source": [
|
||||
"For this code to work, please configure LangSmith environment variables with your credentials.\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ[\"LANGCHAIN_API_KEY\"] = \"ls_..\" # Your LangSmith API key\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"id": "666a8246-b1a9-47ce-b159-d950692fc06b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import getpass\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"keys = [\"LANGCHAIN_API_KEY\", \"FIREWORKS_API_KEY\"]\n",
|
||||
"for key in keys:\n",
|
||||
" if not os.environ.get(key):\n",
|
||||
" os.environ[key] = getpass(f\"Set {key}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "92d65770-6a4f-4029-beba-5fa9aeb18809",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Agent Factory\n",
|
||||
"\n",
|
||||
"For evaluation, we need an agent factory that will create a new instance of an agent executor for every evaluation run.\n",
|
||||
"\n",
|
||||
"We'll use an custom AgentFactory provided with LangChain Benchmarks -- look at the `intro` section to see how to define your own.\n",
|
||||
"\n",
|
||||
"We will use the Fireworks API for this."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"id": "a35cbf20-7632-4116-9c6c-cee6e4a98068",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"from copy import deepcopy\n",
|
||||
"from functools import partial\n",
|
||||
"from typing import Sequence, Tuple\n",
|
||||
"\n",
|
||||
"from langchain.agents import AgentExecutor, AgentType, Tool, initialize_agent\n",
|
||||
"from langchain.agents.format_scratchpad import format_to_openai_function_messages\n",
|
||||
"from langchain.agents.structured_chat.output_parser import (\n",
|
||||
" AgentAction,\n",
|
||||
" AgentFinish,\n",
|
||||
" StructuredChatOutputParser,\n",
|
||||
")\n",
|
||||
"from langchain.chains.openai_functions.base import convert_to_openai_function\n",
|
||||
"from langchain.chat_models import ChatOpenAI\n",
|
||||
"from langchain.output_parsers.json import SimpleJsonOutputParser, parse_json_markdown\n",
|
||||
"from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder\n",
|
||||
"from langchain.schema.output_parser import StrOutputParser\n",
|
||||
"from langchain.tools import tool\n",
|
||||
"from langchain_core.runnables import RunnableLambda\n",
|
||||
"\n",
|
||||
"from langchain_benchmarks import clone_public_dataset, registry\n",
|
||||
"from langchain_benchmarks.schema import BaseTask, RegisteredModel\n",
|
||||
"from langchain_benchmarks.tool_usage import apply_agent_executor_adapter\n",
|
||||
"from langchain_benchmarks.tool_usage.agents import apply_agent_executor_adapter\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@tool\n",
|
||||
"def final_answer(answer: str) -> str:\n",
|
||||
" \"\"\"The final answer to the question.\"\"\"\n",
|
||||
" return answer\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def extract_first_json_object(text):\n",
|
||||
" # A hacky FSM to get the first JSON object across newlines\n",
|
||||
" OUTSIDE, INSIDE, IN_STRING = range(3)\n",
|
||||
"\n",
|
||||
" state = OUTSIDE\n",
|
||||
" nested_level = 0\n",
|
||||
" start_index = None\n",
|
||||
"\n",
|
||||
" def is_escaped(index):\n",
|
||||
" escape = False\n",
|
||||
" while index > 0 and text[index - 1] == \"\\\\\":\n",
|
||||
" escape = not escape\n",
|
||||
" index -= 1\n",
|
||||
" return escape\n",
|
||||
"\n",
|
||||
" for i, char in enumerate(text):\n",
|
||||
" if state == OUTSIDE:\n",
|
||||
" if char == \"{\":\n",
|
||||
" state = INSIDE\n",
|
||||
" nested_level = 1\n",
|
||||
" start_index = i\n",
|
||||
"\n",
|
||||
" elif state == INSIDE:\n",
|
||||
" if char == '\"' and not is_escaped(i):\n",
|
||||
" state = IN_STRING\n",
|
||||
" elif char == \"{\":\n",
|
||||
" nested_level += 1\n",
|
||||
" elif char == \"}\":\n",
|
||||
" nested_level -= 1\n",
|
||||
" if nested_level == 0:\n",
|
||||
" return text[start_index : i + 1]\n",
|
||||
"\n",
|
||||
" elif state == IN_STRING:\n",
|
||||
" if char == '\"' and not is_escaped(i):\n",
|
||||
" state = INSIDE\n",
|
||||
"\n",
|
||||
" return None\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def parse(message, prefix: str = \"\") -> dict:\n",
|
||||
" content = prefix + message.content.replace(\"\\_\", \"_\")\n",
|
||||
" content = extract_first_json_object(content)\n",
|
||||
" try:\n",
|
||||
" response = json.loads(content)\n",
|
||||
" except json.JSONDecodeError:\n",
|
||||
" response = parse_json_markdown(content)\n",
|
||||
" if response[\"action\"] == \"final_answer\":\n",
|
||||
" return AgentFinish({\"output\": response[\"action_input\"]}, content)\n",
|
||||
" else:\n",
|
||||
" return AgentAction(\n",
|
||||
" response[\"action\"],\n",
|
||||
" response.get(\"action_input\", {}),\n",
|
||||
" content,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def format_intermediate_steps(\n",
|
||||
" intermediate_steps: Sequence[Tuple[AgentAction, str]],\n",
|
||||
") -> str:\n",
|
||||
" if not intermediate_steps:\n",
|
||||
" return \"\"\n",
|
||||
"\n",
|
||||
" # response_tmpl = \"{action}\\n{{\\\"response\\\": \\\"{observation}\\\"}}\"\n",
|
||||
" response_tmpl = \"{action}\\n# Returned {observation}\"\n",
|
||||
" serialized = \"\\n\".join(\n",
|
||||
" [\n",
|
||||
" # f\"{agent_action.log.strip()}\\n{{\\\"response\\\": \\\"{observation}\\\"}}\"\n",
|
||||
" response_tmpl.format(\n",
|
||||
" action=agent_action.log.strip(), observation=observation\n",
|
||||
" )\n",
|
||||
" for agent_action, observation in intermediate_steps\n",
|
||||
" ]\n",
|
||||
" )\n",
|
||||
" return f\"\"\"\n",
|
||||
"```log.txt\n",
|
||||
"{serialized}\n",
|
||||
"```\n",
|
||||
"Consider previous steps above. What's your next step?\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def format_scratchpad(x):\n",
|
||||
" intermediate_steps = x[\"intermediate_steps\"]\n",
|
||||
" return format_intermediate_steps(intermediate_steps)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class AgentFactory:\n",
|
||||
" def __init__(\n",
|
||||
" self, task: BaseTask, model: RegisteredModel, num_retries: int = 5\n",
|
||||
" ) -> None:\n",
|
||||
" self.task = task\n",
|
||||
" self.model = model\n",
|
||||
" self.num_retries = num_retries\n",
|
||||
"\n",
|
||||
" def create_this_ugly_thing(self, env):\n",
|
||||
" tools = env.tools\n",
|
||||
"\n",
|
||||
" # schemas = []\n",
|
||||
" # for tool in tools + [final_answer]:\n",
|
||||
" # function_def = convert_to_openai_function(tool.args_schema)\n",
|
||||
" # function_def[\"name\"] = tool.name\n",
|
||||
" # schemas.append(function_def)\n",
|
||||
" # tools_str = \"\\n\".join([json.dumps(sc) for sc in schemas])\n",
|
||||
" tools_str = \"\\n\".join([tool.description for tool in tools + [final_answer]])\n",
|
||||
" messages = [\n",
|
||||
" (\n",
|
||||
" \"system\",\n",
|
||||
" f\"Task Instructions: {self.task.instructions}\\n\\n\"\n",
|
||||
" \"The following tools are exposed via an API:\\n\"\n",
|
||||
" \"{tools}\\n\\n\"\n",
|
||||
" \"Respond with one JSONL line to make your next action and call the API of a single tool.\"\n",
|
||||
" \"\"\" Format invocations like this:\n",
|
||||
"{{\"action\": \"tool name\",\"action_input\": {{TOOL BODY}}}}\n",
|
||||
"\\n\\nUse the final_answer tool only once you know the correct answer and have called the tools required for the task.\"\"\",\n",
|
||||
" ),\n",
|
||||
" (\n",
|
||||
" \"user\",\n",
|
||||
" \"{input}{agent_scratchpad}\\n\\nNote: Remember to respond in 1 JSONL line.\",\n",
|
||||
" ),\n",
|
||||
" ]\n",
|
||||
" parse_fn = parse\n",
|
||||
" if self.model.type == \"llm\":\n",
|
||||
" messages += [(\"assistant\", \"{{\")]\n",
|
||||
" # Fill it back in\n",
|
||||
" parse_fn = partial(parse_fn, prefix=\"{\")\n",
|
||||
" prompt = ChatPromptTemplate.from_messages(messages)\n",
|
||||
" prompt = prompt.partial(tools=tools_str)\n",
|
||||
"\n",
|
||||
" llm = self.model.get_model(model_params={\"temperature\": 0}).bind(stop=[\"\\n\\n\"])\n",
|
||||
" if self.num_retries:\n",
|
||||
" llm = llm.with_retry(stop_after_attempt=self.num_retries)\n",
|
||||
"\n",
|
||||
" @RunnableLambda\n",
|
||||
" def empty_fallback(x):\n",
|
||||
" \"\"\"Return an empty response to avoid misleading metrics.\"\"\"\n",
|
||||
" return {\n",
|
||||
" \"intermediate_steps\": [],\n",
|
||||
" \"state\": None,\n",
|
||||
" \"output\": \"ERROR\",\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" agent = (\n",
|
||||
" {\n",
|
||||
" \"input\": lambda x: x[\"input\"],\n",
|
||||
" \"agent_scratchpad\": format_scratchpad,\n",
|
||||
" }\n",
|
||||
" | prompt\n",
|
||||
" | llm\n",
|
||||
" | parse_fn\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" return AgentExecutor(\n",
|
||||
" agent=agent, tools=tools, return_intermediate_steps=True\n",
|
||||
" ).with_fallbacks([empty_fallback])\n",
|
||||
"\n",
|
||||
" def __call__(self):\n",
|
||||
" # This factory creates a new environment for every agent run.\n",
|
||||
" # The reason is that the environment may be associated with an environment state (e.g., typewriter)\n",
|
||||
" # which is changed by the actions of the agent.\n",
|
||||
" # At the end of the run, the environment state will be read.\n",
|
||||
" env = self.task.create_environment()\n",
|
||||
" executor = self.create_this_ugly_thing(env)\n",
|
||||
" # Apply the adapters so that inputs and outputs match dataset schema\n",
|
||||
" # state_reader automatically adds the state of the environment at the end of the run.\n",
|
||||
" return apply_agent_executor_adapter(executor, state_reader=env.read_state)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3821e4b0-8e67-418a-840c-470fcde42df0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Eval\n",
|
||||
"\n",
|
||||
"Let's evaluate an agent now"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"id": "fd6cead0-3c37-4a73-8795-7819220797ee",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_benchmarks.model_registration import model_registry"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"id": "fb32763c-79ab-426a-8fc6-bf8ebb0dd432",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"[-------> ] 3/20View the evaluation results for project 'mixtral-8x7b-fw-chat-ece3-Tool Usage - Typewriter (1 tool)' at:\n",
|
||||
"https://smith.langchain.com/o/ebbaf2eb-769b-4505-aca2-d11de10372a4/datasets/82ca6840-cf23-4bb0-a9be-55237ebbe9d3/compare?selectedSessions=2b92de52-2830-40cb-a396-4c08e0bf1c9b\n",
|
||||
"\n",
|
||||
"View all tests for Dataset Tool Usage - Typewriter (1 tool) at:\n",
|
||||
"https://smith.langchain.com/o/ebbaf2eb-769b-4505-aca2-d11de10372a4/datasets/82ca6840-cf23-4bb0-a9be-55237ebbe9d3\n",
|
||||
"[------------------------------------------------->] 20/20\n",
|
||||
"View the evaluation results for project 'mixtral-8x7b-ece3-Tool Usage - Typewriter (1 tool)' at:\n",
|
||||
"https://smith.langchain.com/o/ebbaf2eb-769b-4505-aca2-d11de10372a4/datasets/82ca6840-cf23-4bb0-a9be-55237ebbe9d3/compare?selectedSessions=ff797831-aee8-43db-a814-7727f9240006\n",
|
||||
"\n",
|
||||
"View all tests for Dataset Tool Usage - Typewriter (1 tool) at:\n",
|
||||
"https://smith.langchain.com/o/ebbaf2eb-769b-4505-aca2-d11de10372a4/datasets/82ca6840-cf23-4bb0-a9be-55237ebbe9d3\n",
|
||||
"[------------------------------------------------->] 20/20\n",
|
||||
"View the evaluation results for project 'mixtral-8x7b-fw-chat-ece3-Tool Usage - Typewriter (26 tools)' at:\n",
|
||||
"https://smith.langchain.com/o/ebbaf2eb-769b-4505-aca2-d11de10372a4/datasets/2f462c7a-f9b9-46e7-b96b-7469e965f478/compare?selectedSessions=1adbc135-93d9-46b2-a33a-e5470eded263\n",
|
||||
"\n",
|
||||
"View all tests for Dataset Tool Usage - Typewriter (26 tools) at:\n",
|
||||
"https://smith.langchain.com/o/ebbaf2eb-769b-4505-aca2-d11de10372a4/datasets/2f462c7a-f9b9-46e7-b96b-7469e965f478\n",
|
||||
"[------------------------------------------------->] 20/20\n",
|
||||
"View the evaluation results for project 'mixtral-8x7b-ece3-Tool Usage - Typewriter (26 tools)' at:\n",
|
||||
"https://smith.langchain.com/o/ebbaf2eb-769b-4505-aca2-d11de10372a4/datasets/2f462c7a-f9b9-46e7-b96b-7469e965f478/compare?selectedSessions=a8548cef-4afd-4f7e-9d21-7bd2fb3f9033\n",
|
||||
"\n",
|
||||
"View all tests for Dataset Tool Usage - Typewriter (26 tools) at:\n",
|
||||
"https://smith.langchain.com/o/ebbaf2eb-769b-4505-aca2-d11de10372a4/datasets/2f462c7a-f9b9-46e7-b96b-7469e965f478\n",
|
||||
"[------------------------------------------------->] 20/20\n",
|
||||
"View the evaluation results for project 'mixtral-8x7b-fw-chat-ece3-Tool Usage - Relational Data' at:\n",
|
||||
"https://smith.langchain.com/o/ebbaf2eb-769b-4505-aca2-d11de10372a4/datasets/df6be6c9-05b3-445e-8836-ebb4aba63826/compare?selectedSessions=685df1fb-605d-40e3-b645-ae132a0a6229\n",
|
||||
"\n",
|
||||
"View all tests for Dataset Tool Usage - Relational Data at:\n",
|
||||
"https://smith.langchain.com/o/ebbaf2eb-769b-4505-aca2-d11de10372a4/datasets/df6be6c9-05b3-445e-8836-ebb4aba63826\n",
|
||||
"[------------------------------------------------->] 21/21\n",
|
||||
"View the evaluation results for project 'mixtral-8x7b-ece3-Tool Usage - Relational Data' at:\n",
|
||||
"https://smith.langchain.com/o/ebbaf2eb-769b-4505-aca2-d11de10372a4/datasets/df6be6c9-05b3-445e-8836-ebb4aba63826/compare?selectedSessions=bb4d1ee4-bbc8-4969-a4f0-2b0732444785\n",
|
||||
"\n",
|
||||
"View all tests for Dataset Tool Usage - Relational Data at:\n",
|
||||
"https://smith.langchain.com/o/ebbaf2eb-769b-4505-aca2-d11de10372a4/datasets/df6be6c9-05b3-445e-8836-ebb4aba63826\n",
|
||||
"[------------------------------------------------->] 21/21\n",
|
||||
"View the evaluation results for project 'mixtral-8x7b-fw-chat-ece3-Multiverse Math' at:\n",
|
||||
"https://smith.langchain.com/o/ebbaf2eb-769b-4505-aca2-d11de10372a4/datasets/108bdc68-1808-4b60-92ef-fbd9bd7e1ad0/compare?selectedSessions=ac7ec5aa-108d-4c5b-9c30-8e954fa132aa\n",
|
||||
"\n",
|
||||
"View all tests for Dataset Multiverse Math at:\n",
|
||||
"https://smith.langchain.com/o/ebbaf2eb-769b-4505-aca2-d11de10372a4/datasets/108bdc68-1808-4b60-92ef-fbd9bd7e1ad0\n",
|
||||
"[------------------------------------------------->] 10/10\n",
|
||||
"View the evaluation results for project 'mixtral-8x7b-ece3-Multiverse Math' at:\n",
|
||||
"https://smith.langchain.com/o/ebbaf2eb-769b-4505-aca2-d11de10372a4/datasets/108bdc68-1808-4b60-92ef-fbd9bd7e1ad0/compare?selectedSessions=9d8573ee-847f-400a-8894-2e77c62e76ab\n",
|
||||
"\n",
|
||||
"View all tests for Dataset Multiverse Math at:\n",
|
||||
"https://smith.langchain.com/o/ebbaf2eb-769b-4505-aca2-d11de10372a4/datasets/108bdc68-1808-4b60-92ef-fbd9bd7e1ad0\n",
|
||||
"[------------------------------------------------->] 10/10"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import uuid\n",
|
||||
"\n",
|
||||
"from langsmith.client import Client\n",
|
||||
"\n",
|
||||
"from langchain_benchmarks.tool_usage import get_eval_config\n",
|
||||
"\n",
|
||||
"experiment_uuid = uuid.uuid4().hex[:4]\n",
|
||||
"\n",
|
||||
"client = Client()\n",
|
||||
"\n",
|
||||
"task_names = [task.name for task in registry.filter(Type=\"ToolUsageTask\")]\n",
|
||||
"models = [\"mixtral-8x7b-fw-chat\", \"mixtral-8x7b\"]\n",
|
||||
"\n",
|
||||
"for task_name in task_names:\n",
|
||||
" for model_name in models:\n",
|
||||
" print()\n",
|
||||
" model = model_registry[model_name]\n",
|
||||
" task = registry[task_name]\n",
|
||||
" clone_public_dataset(task.dataset_id, dataset_name=task.name)\n",
|
||||
" eval_config = task.get_eval_config()\n",
|
||||
" test_run = client.run_on_dataset(\n",
|
||||
" dataset_name=task.name,\n",
|
||||
" llm_or_chain_factory=AgentFactory(task, model),\n",
|
||||
" evaluation=eval_config,\n",
|
||||
" project_name=f\"{model.name}-{experiment_uuid}-{task.name}\",\n",
|
||||
" tags=[model.name],\n",
|
||||
" project_metadata={\"id\": experiment_uuid, **model.params},\n",
|
||||
" verbose=True,\n",
|
||||
" )"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -3,6 +3,7 @@
|
||||
:caption: Introduction
|
||||
|
||||
./notebooks/getting_started
|
||||
./notebooks/models
|
||||
./notebooks/datasets
|
||||
```
|
||||
|
||||
@@ -33,6 +34,11 @@
|
||||
|
||||
./notebooks/retrieval/intro
|
||||
./notebooks/retrieval/langchain_docs_qa
|
||||
./notebooks/retrieval/semi_structured
|
||||
./notebooks/retrieval/semi_structured_benchmarking/semi_structured
|
||||
./notebooks/retrieval/semi_structured_benchmarking/ss_eval_chunk_sizes
|
||||
./notebooks/retrieval/semi_structured_benchmarking/ss_eval_long_context
|
||||
./notebooks/retrieval/semi_structured_benchmarking/ss_eval_multi_vector
|
||||
./notebooks/retrieval/multi_modal_benchmarking/multi_modal_eval_baseline
|
||||
./notebooks/retrieval/multi_modal_benchmarking/multi_modal_eval
|
||||
./notebooks/retrieval/comparing_techniques
|
||||
```
|
||||
|
||||
@@ -1,8 +1,27 @@
|
||||
from importlib import metadata
|
||||
|
||||
from langchain_benchmarks.model_registration import model_registry
|
||||
from langchain_benchmarks.rate_limiting import RateLimiter
|
||||
from langchain_benchmarks.registration import registry
|
||||
from langchain_benchmarks.utils._langsmith import (
|
||||
clone_public_dataset,
|
||||
download_public_dataset,
|
||||
)
|
||||
|
||||
try:
|
||||
__version__ = metadata.version(__package__)
|
||||
except metadata.PackageNotFoundError:
|
||||
# Case where package metadata is not available.
|
||||
__version__ = ""
|
||||
del metadata # optional, avoids polluting the results of dir(__package__)
|
||||
|
||||
|
||||
# Please keep this list sorted!
|
||||
__all__ = ["clone_public_dataset", "download_public_dataset", "registry"]
|
||||
__all__ = [
|
||||
"__version__",
|
||||
"clone_public_dataset",
|
||||
"download_public_dataset",
|
||||
"model_registry",
|
||||
"RateLimiter",
|
||||
"registry",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,258 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain_benchmarks.schema import ModelRegistry, RegisteredModel
|
||||
|
||||
_OPEN_AI_MODELS = [
|
||||
RegisteredModel(
|
||||
provider="openai",
|
||||
name="gpt-3.5-turbo-1106",
|
||||
type="chat",
|
||||
description=(
|
||||
"The latest GPT-3.5 Turbo model with improved instruction following, "
|
||||
"JSON mode, reproducible outputs, parallel function calling, and more. "
|
||||
"Returns a maximum of 4,096 output tokens."
|
||||
),
|
||||
params={
|
||||
"model": "gpt-3.5-turbo-1106",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="openai",
|
||||
name="gpt-3.5-turbo",
|
||||
type="chat",
|
||||
description="Currently points to gpt-3.5-turbo-0613.",
|
||||
params={
|
||||
"model": "gpt-3.5-turbo",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="openai",
|
||||
name="gpt-3.5-turbo-16k",
|
||||
type="chat",
|
||||
description="Currently points to gpt-3.5-turbo-0613.",
|
||||
params={
|
||||
"model": "gpt-3.5-turbo-16k",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="openai",
|
||||
name="gpt-3.5-turbo-instruct",
|
||||
type="llm",
|
||||
description=(
|
||||
"Similar capabilities as text-davinci-003 but compatible with legacy "
|
||||
"Completions endpoint and not Chat Completions."
|
||||
),
|
||||
params={
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="openai",
|
||||
name="gpt-3.5-turbo-0613",
|
||||
type="chat",
|
||||
description=(
|
||||
"Legacy Snapshot of gpt-3.5-turbo from June 13th 2023. "
|
||||
"Will be deprecated on June 13, 2024."
|
||||
),
|
||||
params={
|
||||
"model": "gpt-3.5-turbo-0613",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="openai",
|
||||
name="gpt-3.5-turbo-16k-0613",
|
||||
type="chat",
|
||||
description=(
|
||||
"Legacy Snapshot of gpt-3.5-16k-turbo from June 13th 2023. "
|
||||
"Will be deprecated on June 13, 2024."
|
||||
),
|
||||
params={
|
||||
"model": "gpt-3.5-turbo-16k-0613",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="openai",
|
||||
name="gpt-3.5-turbo-0301",
|
||||
type="chat",
|
||||
description=(
|
||||
"Legacy Snapshot of gpt-3.5-turbo from March 1st 2023. "
|
||||
"Will be deprecated on June 13th 2024."
|
||||
),
|
||||
params={
|
||||
"model": "gpt-3.5-turbo-0301",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="openai",
|
||||
name="text-davinci-003",
|
||||
type="llm",
|
||||
description=(
|
||||
"Legacy Can do language tasks with better quality and consistency than "
|
||||
"the curie, babbage, or ada models. Will be deprecated on Jan 4th 2024."
|
||||
),
|
||||
params={
|
||||
"model": "text-davinci-003",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="openai",
|
||||
name="text-davinci-002",
|
||||
type="llm",
|
||||
description=(
|
||||
"Legacy Similar capabilities to text-davinci-003 but trained with "
|
||||
"supervised fine-tuning instead of reinforcement learning. "
|
||||
"Will be deprecated on Jan 4th 2024."
|
||||
),
|
||||
params={
|
||||
"model": "text-davinci-002",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="openai",
|
||||
name="code-davinci-002",
|
||||
type="llm",
|
||||
description="Legacy Optimized for code-completion tasks. Will be deprecated "
|
||||
"on Jan 4th 2024.",
|
||||
params={
|
||||
"model": "code-davinci-002",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="openai",
|
||||
name="gpt-4-1106-preview",
|
||||
type="chat",
|
||||
description="GPT-4 TurboNew - The latest GPT-4 model with improved instruction following, JSON mode, reproducible outputs, parallel function calling, and more. Returns a maximum of 4,096 output tokens. This preview model is not yet suited for production traffic.",
|
||||
params={
|
||||
"model": "gpt-4-1106-preview",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="openai",
|
||||
name="gpt-4-0613",
|
||||
type="chat",
|
||||
description="Snapshot of gpt-4 from June 13th 2023 with improved function calling support.",
|
||||
params={
|
||||
"model": "gpt-4-0613",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="openai",
|
||||
name="gpt-4-32k-0613",
|
||||
type="chat",
|
||||
description="Snapshot of gpt-4-32k from June 13th 2023 with improved function calling support.",
|
||||
params={
|
||||
"model": "gpt-4-32k-0613",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="openai",
|
||||
name="gpt-4-0314",
|
||||
description="Snapshot of gpt-4 from March 14th 2023 with function calling support. This model version will be deprecated on June 13th 2024.",
|
||||
type="chat",
|
||||
params={
|
||||
"model": "gpt-4-0314",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="openai",
|
||||
name="gpt-4-32k-0314",
|
||||
description="Snapshot of gpt-4-32k from March 14th 2023 with function calling support. This model version will be deprecated on June 13th 2024.",
|
||||
type="chat",
|
||||
params={
|
||||
"model": "gpt-4-32k-0314",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
_FIREWORKS_MODELS = [
|
||||
RegisteredModel(
|
||||
provider="fireworks",
|
||||
name="llama-v2-7b-chat-fw",
|
||||
type="chat",
|
||||
description="7b parameter LlamaChat model",
|
||||
params={
|
||||
"model": "accounts/fireworks/models/llama-v2-7b-chat",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="fireworks",
|
||||
name="llama-v2-13b-chat-fw",
|
||||
type="chat",
|
||||
description="13b parameter LlamaChat model",
|
||||
params={
|
||||
"model": "accounts/fireworks/models/llama-v2-13b-chat",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="fireworks",
|
||||
name="llama-v2-70b-chat-fw",
|
||||
type="chat",
|
||||
description="70b parameter LlamaChat model",
|
||||
params={
|
||||
"model": "accounts/fireworks/models/llama-v2-70b-chat",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="fireworks",
|
||||
name="yi-34b-200k-fw",
|
||||
type="llm",
|
||||
description=" 4B LLM model from 01.ai, with context window 200k.",
|
||||
params={
|
||||
"model": "accounts/fireworks/models/yi-34b-200k",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="fireworks",
|
||||
name="mixtral-8x7b-instruct-fw",
|
||||
description="Mistral MoE 8x7B Instruct v0.1 model with Sparse "
|
||||
"Mixture of Experts. Fine tuned for instruction following",
|
||||
type="llm",
|
||||
params={"model": "accounts/fireworks/models/mixtral-8x7b-instruct"},
|
||||
),
|
||||
]
|
||||
|
||||
_ANTHROPIC_MODELS = [
|
||||
RegisteredModel(
|
||||
provider="anthropic",
|
||||
name="claude-2",
|
||||
description=("Superior performance on tasks that require complex reasoning"),
|
||||
type="chat",
|
||||
params={
|
||||
"model": "claude-2",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="anthropic",
|
||||
name="claude-2.1",
|
||||
description=(
|
||||
"Same performance as Claude 2, plus significant reduction in model "
|
||||
"hallucination rates"
|
||||
),
|
||||
type="chat",
|
||||
params={
|
||||
"model": "claude-2.1",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="anthropic",
|
||||
name="claude-instant-1.2",
|
||||
description="low-latency, high throughput.",
|
||||
type="chat",
|
||||
params={
|
||||
"model": "claude-instant-1.2",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="anthropic",
|
||||
name="claude-instant-1",
|
||||
description="low-latency, high throughput.",
|
||||
type="chat",
|
||||
params={
|
||||
"model": "claude-instant-1",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
model_registry = ModelRegistry(
|
||||
registered_models=_OPEN_AI_MODELS + _FIREWORKS_MODELS + _ANTHROPIC_MODELS
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
pdfs/
|
||||
@@ -1,7 +1,14 @@
|
||||
from langchain_benchmarks.rag.tasks.langchain_docs.task import LANGCHAIN_DOCS_TASK
|
||||
from langchain_benchmarks.rag.tasks.multi_modal_slide_decks.task import (
|
||||
MULTI_MODAL_SLIDE_DECKS_TASK,
|
||||
)
|
||||
from langchain_benchmarks.rag.tasks.semi_structured_reports.task import (
|
||||
SEMI_STRUCTURED_REPORTS_TASK,
|
||||
)
|
||||
|
||||
# Please keep this sorted
|
||||
__all__ = ["LANGCHAIN_DOCS_TASK", "SEMI_STRUCTURED_REPORTS_TASK"]
|
||||
__all__ = [
|
||||
"LANGCHAIN_DOCS_TASK",
|
||||
"SEMI_STRUCTURED_REPORTS_TASK",
|
||||
"MULTI_MODAL_SLIDE_DECKS_TASK",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
from langchain_benchmarks.rag.tasks.multi_modal_slide_decks.indexing.retriever_registry import (
|
||||
get_file_names,
|
||||
)
|
||||
|
||||
__all__ = ["get_file_names"]
|
||||
@@ -0,0 +1,5 @@
|
||||
from langchain_benchmarks.rag.tasks.multi_modal_slide_decks.indexing.retriever_registry import (
|
||||
get_file_names,
|
||||
)
|
||||
|
||||
__all__ = ["get_file_names"]
|
||||
@@ -0,0 +1,39 @@
|
||||
import logging
|
||||
import os
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Optional
|
||||
|
||||
from langchain_benchmarks.rag.utils._downloading import (
|
||||
fetch_remote_file,
|
||||
is_folder_populated,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_DIRECTORY = Path(os.path.abspath(__file__)).parent
|
||||
# Stores the zipped pdfs for this dataset
|
||||
REMOTE_DOCS_FILE = "https://storage.googleapis.com/benchmarks-artifacts/langchain-docs-benchmarking/multi_modal_slide_decks.zip"
|
||||
DOCS_DIR = _DIRECTORY / "pdfs"
|
||||
|
||||
|
||||
def fetch_raw_docs(
|
||||
filename: Optional[str] = None, docs_dir: Optional[str] = None
|
||||
) -> None:
|
||||
filename = filename or _DIRECTORY / Path(REMOTE_DOCS_FILE).name
|
||||
docs_dir = docs_dir or DOCS_DIR
|
||||
if not is_folder_populated(docs_dir):
|
||||
fetch_remote_file(REMOTE_DOCS_FILE, filename)
|
||||
with zipfile.ZipFile(filename, "r") as zip_ref:
|
||||
zip_ref.extractall(docs_dir)
|
||||
|
||||
os.remove(filename)
|
||||
|
||||
|
||||
def get_file_names() -> Iterable[Path]:
|
||||
fetch_raw_docs()
|
||||
# Traverse the directory and partition the pdfs
|
||||
for path in DOCS_DIR.rglob("*.pdf"):
|
||||
# Ignore __MACOSX
|
||||
if "__MACOSX" in str(path):
|
||||
continue
|
||||
yield path
|
||||
@@ -0,0 +1,23 @@
|
||||
from langchain_benchmarks.schema import RetrievalTask
|
||||
|
||||
# ID of public Multi Modal Slide Decks dataset
|
||||
DATASET_ID = "https://smith.langchain.com/public/40afc8e7-9d7e-44ed-8971-2cae1eb59731/d"
|
||||
|
||||
MULTI_MODAL_SLIDE_DECKS_TASK = RetrievalTask(
|
||||
name="Multi-modal slide decks",
|
||||
dataset_id=DATASET_ID,
|
||||
retriever_factories={},
|
||||
architecture_factories={},
|
||||
get_docs={},
|
||||
description=(
|
||||
"""\
|
||||
This public dataset is a work-in-progress and will be extended over time.
|
||||
|
||||
Questions and answers based on slide decks containing visual tables and charts.
|
||||
|
||||
Each example is composed of a question and reference answer.
|
||||
|
||||
Success is measured based on the accuracy of the answer relative to the reference answer.
|
||||
""" # noqa: E501
|
||||
),
|
||||
)
|
||||
+3
-4
@@ -24,7 +24,6 @@ _DIRECTORY = Path(os.path.abspath(__file__)).parent
|
||||
# Stores the zipped pdfs for this dataset
|
||||
REMOTE_DOCS_FILE = "https://storage.googleapis.com/benchmarks-artifacts/langchain-docs-benchmarking/semi_structured_earnings.zip"
|
||||
DOCS_DIR = _DIRECTORY / "pdfs"
|
||||
LOCAL_FILE = _DIRECTORY / "chroma_db.zip"
|
||||
|
||||
_DEFAULT_SEARCH_KWARGS = {"k": 6}
|
||||
|
||||
@@ -32,17 +31,17 @@ _DEFAULT_SEARCH_KWARGS = {"k": 6}
|
||||
def fetch_raw_docs(
|
||||
filename: Optional[str] = None, docs_dir: Optional[str] = None
|
||||
) -> None:
|
||||
filename = filename or LOCAL_FILE
|
||||
filename = filename or _DIRECTORY / Path(REMOTE_DOCS_FILE).name
|
||||
docs_dir = docs_dir or DOCS_DIR
|
||||
if not is_folder_populated(docs_dir):
|
||||
fetch_remote_file(REMOTE_DOCS_FILE, filename)
|
||||
with zipfile.ZipFile(filename, "r") as zip_ref:
|
||||
zip_ref.extractall(docs_dir)
|
||||
|
||||
os.remove(LOCAL_FILE)
|
||||
os.remove(filename)
|
||||
|
||||
|
||||
def get_file_names():
|
||||
def get_file_names() -> Iterable[Path]:
|
||||
fetch_raw_docs()
|
||||
# Traverse the directory and partition the pdfs
|
||||
for path in DOCS_DIR.glob("*.pdf"):
|
||||
|
||||
@@ -0,0 +1,109 @@
|
||||
"""Implementation of a rate limiter based on a token bucket."""
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain.schema.runnable import Runnable, RunnableLambda
|
||||
from langchain.schema.runnable.utils import Input, Output
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
requests_per_second: float = 1,
|
||||
check_every_n_seconds: float = 0.1,
|
||||
max_bucket_size: float = 1,
|
||||
) -> None:
|
||||
"""A rate limiter based on a token bucket.
|
||||
|
||||
These *tokens* have NOTHING to do with LLM tokens. They are just
|
||||
a way to keep track of how many requests can be made at a given time.
|
||||
|
||||
This rate limiter is designed to work in a threaded environment.
|
||||
|
||||
It works by filling up a bucket with tokens at a given rate. Each
|
||||
request consumes a given number of tokens. If there are not enough
|
||||
tokens in the bucket, the request is blocked until there are enough
|
||||
tokens.
|
||||
|
||||
Args:
|
||||
requests_per_second: The number of tokens to add per second to the bucket.
|
||||
Must be at least 1. The tokens represent "credit" that can be used
|
||||
to make requests.
|
||||
check_every_n_seconds: check whether the tokens are available
|
||||
every this many seconds. Can be a float to represent
|
||||
fractions of a second.
|
||||
max_bucket_size: The maximum number of tokens that can be in the bucket.
|
||||
This is used to prevent bursts of requests.
|
||||
"""
|
||||
# Number of requests that we can make per second.
|
||||
self.requests_per_second = requests_per_second
|
||||
# Number of tokens in the bucket.
|
||||
self.available_tokens = 0.0
|
||||
self.max_bucket_size = max_bucket_size
|
||||
# A lock to ensure that tokens can only be consumed by one thread
|
||||
# at a given time.
|
||||
self._consume_lock = threading.Lock()
|
||||
# The last time we tried to consume tokens.
|
||||
self.last: Optional[time.time] = None
|
||||
self.check_every_n_seconds = check_every_n_seconds
|
||||
|
||||
def _consume(self) -> bool:
|
||||
"""Consume the given amount of tokens if possible.
|
||||
|
||||
Returns:
|
||||
True means that the tokens were consumed, and the caller can proceed to
|
||||
make the request. A False means that the tokens were not consumed, and
|
||||
the caller should try again later.
|
||||
"""
|
||||
with self._consume_lock:
|
||||
now = time.time()
|
||||
|
||||
# initialize on first call to avoid a burst
|
||||
if self.last is None:
|
||||
self.last = now
|
||||
|
||||
elapsed = now - self.last
|
||||
|
||||
if elapsed * self.requests_per_second >= 1:
|
||||
self.available_tokens += elapsed * self.requests_per_second
|
||||
self.last = now
|
||||
|
||||
# Make sure that we don't exceed the bucket size.
|
||||
# This is used to prevent bursts of requests.
|
||||
self.available_tokens = min(self.available_tokens, self.max_bucket_size)
|
||||
|
||||
# As long as we have at least one token, we can proceed.
|
||||
if self.available_tokens >= 1:
|
||||
self.available_tokens -= 1
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def wait(self) -> None:
|
||||
"""Blocking call to wait until the given number of tokens are available."""
|
||||
while not self._consume():
|
||||
time.sleep(self.check_every_n_seconds)
|
||||
|
||||
|
||||
def with_rate_limit(
|
||||
runnable: Runnable[Input, Output],
|
||||
rate_limiter: RateLimiter,
|
||||
) -> Runnable[Input, Output]:
|
||||
"""Add a rate limiter to the runnable.
|
||||
|
||||
Args:
|
||||
runnable: The runnable to throttle.
|
||||
rate_limiter: The throttle to use.
|
||||
|
||||
Returns:
|
||||
A runnable lambda that acts as a throttled passthrough.
|
||||
"""
|
||||
|
||||
def _wait(input: dict, **kwargs: Any) -> dict:
|
||||
"""Wait for the rate limiter to allow the request to proceed."""
|
||||
rate_limiter.wait()
|
||||
return input
|
||||
|
||||
return RunnableLambda(_wait).with_config({"name": "Wait"}) | runnable
|
||||
@@ -3,6 +3,7 @@
|
||||
from langchain_benchmarks.extraction.tasks import chat_extraction, email_task
|
||||
from langchain_benchmarks.rag.tasks import (
|
||||
LANGCHAIN_DOCS_TASK,
|
||||
MULTI_MODAL_SLIDE_DECKS_TASK,
|
||||
SEMI_STRUCTURED_REPORTS_TASK,
|
||||
)
|
||||
from langchain_benchmarks.schema import Registry
|
||||
@@ -24,5 +25,6 @@ registry = Registry(
|
||||
chat_extraction.CHAT_EXTRACTION_TASK,
|
||||
LANGCHAIN_DOCS_TASK,
|
||||
SEMI_STRUCTURED_REPORTS_TASK,
|
||||
MULTI_MODAL_SLIDE_DECKS_TASK,
|
||||
]
|
||||
)
|
||||
|
||||
@@ -2,16 +2,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import importlib
|
||||
import urllib
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Type, Union
|
||||
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.schema import BaseRetriever
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.smith import RunEvalConfig
|
||||
from langchain.tools import BaseTool
|
||||
from langchain_core.language_models import BaseChatModel, BaseLanguageModel
|
||||
from pydantic import BaseModel
|
||||
from tabulate import tabulate
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@@ -67,7 +71,7 @@ class BaseTask:
|
||||
"""Return a table representation of the environment."""
|
||||
return [
|
||||
["Name", self.name],
|
||||
["Type", self.__class__.__name__],
|
||||
["Type", self.type],
|
||||
["Dataset ID", self._dataset_link],
|
||||
["Description", self.description],
|
||||
]
|
||||
@@ -79,6 +83,11 @@ class BaseTask:
|
||||
tablefmt="unsafehtml",
|
||||
)
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
"""Return the type of the task."""
|
||||
return self.__class__.__name__
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ToolUsageTask(BaseTask):
|
||||
@@ -90,6 +99,27 @@ class ToolUsageTask(BaseTask):
|
||||
instructions: str
|
||||
"""Instructions for the agent/chain/llm."""
|
||||
|
||||
eval_params: Dict[str, Any]
|
||||
"""Used to parameterize differences in the evaluation of the task.
|
||||
|
||||
These are passed to the standard factory method for creating an evaluator
|
||||
for tool usage.
|
||||
|
||||
An example, for MultiVerse math the `output_evaluation` parameter is set to
|
||||
`qa_math` to use a different prompt for evaluating the output of the agent.
|
||||
|
||||
This prompt performs better at comparing the output of the agent against
|
||||
the reference output.
|
||||
"""
|
||||
|
||||
def get_eval_config(self, **params: Any) -> RunEvalConfig:
|
||||
"""Get the default evaluator for the environment."""
|
||||
# Import locally to avoid potential circular imports in the future.
|
||||
from langchain_benchmarks.tool_usage.evaluators import get_eval_config
|
||||
|
||||
finalized_params = {**self.eval_params, **params}
|
||||
return get_eval_config(**finalized_params)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ExtractionTask(BaseTask):
|
||||
@@ -109,12 +139,16 @@ class ExtractionTask(BaseTask):
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class RetrievalTask(BaseTask):
|
||||
retriever_factories: Dict[str, Callable[[Embeddings], BaseRetriever]] # noqa: F821
|
||||
"""Factories that index the docs using the specified strategy."""
|
||||
architecture_factories: Dict[str, Callable[[Embeddings], BaseRetriever]] # noqa: F821
|
||||
"""Factories methods that help build some off-the-shelf architectures。"""
|
||||
get_docs: Callable[..., Iterable[Document]]
|
||||
get_docs: Optional[Callable[..., Iterable[Document]]] = None
|
||||
"""A function that returns the documents to be indexed."""
|
||||
retriever_factories: Dict[
|
||||
str, Callable[[Embeddings], BaseRetriever]
|
||||
] = dataclasses.field(default_factory=dict) # noqa: F821
|
||||
"""Factories that index the docs using the specified strategy."""
|
||||
architecture_factories: Dict[
|
||||
str, Callable[[Embeddings], BaseRetriever]
|
||||
] = dataclasses.field(default_factory=dict) # noqa: F821
|
||||
"""Factories methods that help build some off-the-shelf architectures。"""
|
||||
|
||||
@property
|
||||
def _table(self) -> List[List[str]]:
|
||||
@@ -149,6 +183,15 @@ class Registry:
|
||||
raise ValueError(
|
||||
f"Duplicate task name {task.name}. " f"Task names must be unique."
|
||||
)
|
||||
seen_names.add(task.name)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of tasks in the registry."""
|
||||
return len(self.tasks)
|
||||
|
||||
def __iter__(self) -> Iterable[BaseTask]:
|
||||
"""Iterate over the tasks in the registry."""
|
||||
return iter(self.tasks)
|
||||
|
||||
def _repr_html_(self) -> str:
|
||||
"""Return an HTML representation of the registry."""
|
||||
@@ -192,10 +235,10 @@ class Registry:
|
||||
]
|
||||
return Registry(tasks=tasks)
|
||||
|
||||
def __getitem__(self, key: Union[int, str]) -> BaseTask:
|
||||
def __getitem__(self, key: Union[int, str, slice]) -> Union[BaseTask, Registry]:
|
||||
"""Get an environment from the registry."""
|
||||
if isinstance(key, slice):
|
||||
raise NotImplementedError("Slicing is not supported.")
|
||||
return Registry(tasks=self.tasks[key])
|
||||
elif isinstance(key, (int, str)):
|
||||
# If key is an integer, return the corresponding environment
|
||||
return self.get_task(key)
|
||||
@@ -206,3 +249,222 @@ class Registry:
|
||||
if not isinstance(task, BaseTask):
|
||||
raise TypeError("Only tasks can be added to the registry.")
|
||||
self.tasks.append(task)
|
||||
|
||||
|
||||
Provider = Literal["fireworks", "openai", "anthropic"]
|
||||
ModelType = Literal["chat", "llm"]
|
||||
AUTHORIZED_NAMESPACES = {"langchain"}
|
||||
|
||||
|
||||
def _get_model_class_from_path(
|
||||
path: str,
|
||||
) -> Union[Type[BaseChatModel], Type[BaseLanguageModel]]:
|
||||
"""Get the class of the model."""
|
||||
module_name, attribute_name = path.rsplit(".", 1)
|
||||
top_namespace = path.split(".")[0]
|
||||
|
||||
if top_namespace not in AUTHORIZED_NAMESPACES:
|
||||
raise ValueError(
|
||||
f"Unauthorized namespace {top_namespace}. "
|
||||
f"Authorized namespaces are: {AUTHORIZED_NAMESPACES}"
|
||||
)
|
||||
|
||||
# Import the module dynamically
|
||||
module = importlib.import_module(module_name)
|
||||
model_class = getattr(module, attribute_name)
|
||||
if not issubclass(model_class, (BaseLanguageModel, BaseChatModel)):
|
||||
raise ValueError(
|
||||
f"Model class {model_class} is not a subclass of BaseLanguageModel"
|
||||
)
|
||||
return model_class
|
||||
|
||||
|
||||
def _get_default_path(provider: str, type_: ModelType) -> str:
|
||||
"""Get the default path for a model."""
|
||||
paths = {
|
||||
("fireworks", "chat"): "langchain.chat_models.fireworks.ChatFireworks",
|
||||
("fireworks", "llm"): "langchain.llms.fireworks.Fireworks",
|
||||
("openai", "chat"): "langchain.chat_models.openai.ChatOpenAI",
|
||||
("openai", "llm"): "langchain.llms.openai.OpenAI",
|
||||
("anthropic", "chat"): "langchain.chat_models.anthropic.ChatAnthropic",
|
||||
}
|
||||
|
||||
if (provider, type_) not in paths:
|
||||
raise ValueError(f"Unknown provider {provider} and type {type_}")
|
||||
|
||||
return paths[(provider, type_)]
|
||||
|
||||
|
||||
def _get_default_url(provider: str, type_: ModelType) -> Optional[str]:
|
||||
"""Get default URL to API page for model."""
|
||||
if provider == "fireworks":
|
||||
return "https://app.fireworks.ai/models"
|
||||
elif provider == "openai":
|
||||
return "https://platform.openai.com/docs/models"
|
||||
elif provider == "anthropic":
|
||||
return "https://docs.anthropic.com/claude/reference/selecting-a-model"
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class RegisteredModel:
|
||||
"""Descriptive information about a model.
|
||||
|
||||
This information can be used to instantiate the underlying model.
|
||||
"""
|
||||
|
||||
name: str
|
||||
provider: Provider
|
||||
description: str
|
||||
params: Dict[str, Any]
|
||||
type: ModelType
|
||||
# Path to the model class.
|
||||
# For example, "langchain.chat_models.anthropic import ChatAnthropicModel"
|
||||
path: Optional[str] = None # If not provided, will use default path
|
||||
url: Optional[str] = None # If not provided, will use default URL
|
||||
|
||||
def get_model(
|
||||
self, *, model_params: Optional[Dict[str, Any]] = None
|
||||
) -> Union[BaseChatModel, BaseLanguageModel]:
|
||||
"""Get the class of the model."""
|
||||
all_params = {**self.params, **(model_params or {})}
|
||||
model_class = _get_model_class_from_path(self.model_path)
|
||||
return model_class(**all_params)
|
||||
|
||||
@property
|
||||
def model_path(self) -> str:
|
||||
"""Get the path of the model."""
|
||||
return self.path or _get_default_path(self.provider, self.type)
|
||||
|
||||
@property
|
||||
def model_url(self) -> Optional[str]:
|
||||
"""Get the URL of the model."""
|
||||
return self.url or _get_default_url(self.provider, self.type)
|
||||
|
||||
@property
|
||||
def _table(self) -> List[List[str]]:
|
||||
"""Return a table representation of the environment."""
|
||||
if self.model_path:
|
||||
url = (
|
||||
f'<a href="{self.model_path}" target="_blank" rel="noopener">'
|
||||
"ModelPage"
|
||||
"</a>"
|
||||
)
|
||||
else:
|
||||
url = ""
|
||||
return [
|
||||
["name", self.name],
|
||||
["type", self.type],
|
||||
["provider", self.provider],
|
||||
["description", self.description],
|
||||
["model_path", self.model_path],
|
||||
["url", url],
|
||||
]
|
||||
|
||||
def _repr_html_(self) -> str:
|
||||
"""Return an HTML representation of the environment."""
|
||||
return tabulate(
|
||||
self._table,
|
||||
tablefmt="unsafehtml",
|
||||
)
|
||||
|
||||
|
||||
StrFilter = Union[None, str, Sequence[str]]
|
||||
|
||||
|
||||
def _is_in_filter(actual_value: str, filter_value: StrFilter) -> bool:
|
||||
"""Filter for a string attribute."""
|
||||
if filter_value is None:
|
||||
return True
|
||||
|
||||
if isinstance(filter_value, str):
|
||||
return actual_value == filter_value
|
||||
|
||||
return actual_value in filter_value
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=False)
|
||||
class ModelRegistry:
|
||||
registered_models: Sequence[RegisteredModel]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate that all the tasks have unique names and IDs."""
|
||||
seen_names = set()
|
||||
for model in self.registered_models:
|
||||
if model.name in seen_names:
|
||||
raise ValueError(
|
||||
f"Duplicate model name {model.name}. " f"Task names must be unique."
|
||||
)
|
||||
seen_names.add(model.name)
|
||||
|
||||
def get_model(self, name: str) -> Optional[RegisteredModel]:
|
||||
"""Get model info."""
|
||||
return next(model for model in self.registered_models if model.name == name)
|
||||
|
||||
def filter(
|
||||
self,
|
||||
*,
|
||||
type: StrFilter = None,
|
||||
name: StrFilter = None,
|
||||
provider: StrFilter = None,
|
||||
) -> ModelRegistry:
|
||||
"""Filter the tasks in the registry."""
|
||||
models = self.registered_models
|
||||
selected_models = []
|
||||
|
||||
for model in models:
|
||||
if not _is_in_filter(model.type, type):
|
||||
continue
|
||||
if not _is_in_filter(model.name, name):
|
||||
continue
|
||||
if not _is_in_filter(model.provider, provider):
|
||||
continue
|
||||
selected_models.append(model)
|
||||
return ModelRegistry(registered_models=selected_models)
|
||||
|
||||
def _repr_html_(self) -> str:
|
||||
"""Return an HTML representation of the registry."""
|
||||
headers = [
|
||||
"Name",
|
||||
"Type",
|
||||
"Provider",
|
||||
"Description",
|
||||
]
|
||||
table = [
|
||||
[
|
||||
model.name,
|
||||
model.type,
|
||||
model.provider,
|
||||
model.description,
|
||||
]
|
||||
for model in self.registered_models
|
||||
]
|
||||
return tabulate(table, headers=headers, tablefmt="unsafehtml")
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of tasks in the registry."""
|
||||
return len(self.registered_models)
|
||||
|
||||
def __contains__(self, item: Any) -> bool:
|
||||
"""Return whether the registry contains the given model."""
|
||||
return self.get_model(item) is not None
|
||||
|
||||
def __iter__(self) -> Iterable[RegisteredModel]:
|
||||
"""Iterate over the tasks in the registry."""
|
||||
return iter(self.registered_models)
|
||||
|
||||
def __getitem__(
|
||||
self, key: Union[int, str, slice]
|
||||
) -> Union[RegisteredModel, ModelRegistry]:
|
||||
"""Get an environment from the registry."""
|
||||
if isinstance(key, slice):
|
||||
return ModelRegistry(registered_models=self.registered_models[key])
|
||||
elif isinstance(key, (int, str)):
|
||||
# If key is an integer, return the corresponding environment
|
||||
if isinstance(key, str):
|
||||
return self.get_model(key)
|
||||
else:
|
||||
return self.registered_models[key]
|
||||
else:
|
||||
raise TypeError("Key must be an integer or a slice.")
|
||||
|
||||
@@ -1,149 +0,0 @@
|
||||
"""Code for creating an agent factory for evaluating tool usage tasks."""
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain.agents.format_scratchpad import format_to_openai_functions
|
||||
from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain.schema.runnable import Runnable, RunnableLambda, RunnablePassthrough
|
||||
from langchain.tools.render import format_tool_to_openai_function
|
||||
|
||||
from langchain_benchmarks.schema import ToolUsageTask
|
||||
|
||||
|
||||
def _ensure_output_exists(inputs: dict) -> dict:
|
||||
"""Make sure that the output key is always present."""
|
||||
if "output" not in inputs:
|
||||
return {"output": "", **inputs}
|
||||
return inputs
|
||||
|
||||
|
||||
# PUBLIC API
|
||||
|
||||
|
||||
class OpenAIAgentFactory:
|
||||
def __init__(
|
||||
self, task: ToolUsageTask, *, model: str = "gpt-3.5-turbo-16k"
|
||||
) -> None:
|
||||
"""Create an OpenAI agent factory for the given task.
|
||||
|
||||
Args:
|
||||
task: The task to create an agent factory for.
|
||||
model: The model to use -- this must be an open AI model.
|
||||
"""
|
||||
self.task = task
|
||||
self.model = model
|
||||
|
||||
def create(self) -> Runnable:
|
||||
"""Agent Executor"""
|
||||
# For backwards compatibility
|
||||
return self()
|
||||
|
||||
def __call__(self) -> Runnable:
|
||||
llm = ChatOpenAI(
|
||||
model=self.model,
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
env = self.task.create_environment()
|
||||
|
||||
llm_with_tools = llm.bind(
|
||||
functions=[format_tool_to_openai_function(t) for t in env.tools]
|
||||
)
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
self.task.instructions,
|
||||
),
|
||||
("user", "{input}"),
|
||||
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||
]
|
||||
)
|
||||
|
||||
runnable_agent = (
|
||||
{
|
||||
"input": lambda x: x["input"],
|
||||
"agent_scratchpad": lambda x: format_to_openai_functions(
|
||||
x["intermediate_steps"]
|
||||
),
|
||||
}
|
||||
| prompt
|
||||
| llm_with_tools
|
||||
| OpenAIFunctionsAgentOutputParser()
|
||||
)
|
||||
|
||||
runnable = AgentExecutor(
|
||||
agent=runnable_agent,
|
||||
tools=env.tools,
|
||||
handle_parsing_errors=True,
|
||||
return_intermediate_steps=True,
|
||||
)
|
||||
|
||||
# Returns `state` in the output if the environment has a state reader
|
||||
# makes sure that `output` is always in the output
|
||||
return apply_agent_executor_adapter(runnable, state_reader=env.read_state)
|
||||
|
||||
|
||||
# PUBLIC API
|
||||
|
||||
|
||||
def apply_agent_executor_adapter(
|
||||
agent_executor: AgentExecutor,
|
||||
*,
|
||||
state_reader: Optional[Callable[[], Any]] = None,
|
||||
) -> Runnable:
|
||||
"""An adapter for the agent executor to standardize its input and output.
|
||||
|
||||
1) Map `question` to `input` (`question` is used in the datasets,
|
||||
but `input` is used in the agent executor)
|
||||
2) Ensure that `output` is always returned (will be set to "" if missing) --
|
||||
note that this may be relaxed after more updates in the eval config.
|
||||
3) Populate `state` key in the response of the agent with the system state
|
||||
if a state reader is provided.
|
||||
|
||||
Args:
|
||||
agent_executor: the agent executor
|
||||
state_reader: A callable without parameters that if invoked will return
|
||||
the state of the environment. Used to populate the 'state' key.
|
||||
|
||||
Returns:
|
||||
a new runnable with a standardized output.
|
||||
"""
|
||||
|
||||
def _read_state(*args: Any, **kwargs: Any) -> Any:
|
||||
"""Read the state of the environment."""
|
||||
if state_reader is not None:
|
||||
return state_reader()
|
||||
else:
|
||||
return None
|
||||
|
||||
def _format_input(inputs: dict) -> dict:
|
||||
"""Make sure that the input is always called `input`."""
|
||||
|
||||
if "question" not in inputs:
|
||||
raise ValueError(
|
||||
"Expected 'question' to be in the inputs. Found only the following "
|
||||
f"keys {sorted(inputs.keys())}."
|
||||
)
|
||||
|
||||
inputs = inputs.copy() # Because 'question' is popped below
|
||||
|
||||
if "input" not in inputs:
|
||||
return {"input": inputs.pop("question"), **inputs}
|
||||
return inputs
|
||||
|
||||
runnable = (
|
||||
RunnableLambda(_format_input).with_config({"run_name": "Format Input"})
|
||||
| agent_executor
|
||||
| RunnableLambda(_ensure_output_exists).with_config(
|
||||
{"run_name": "Ensure Output"}
|
||||
)
|
||||
)
|
||||
|
||||
if state_reader is not None:
|
||||
runnable = runnable | RunnablePassthrough.assign(state=_read_state).with_config(
|
||||
{"run_name": "Read Env State"}
|
||||
)
|
||||
return runnable
|
||||
@@ -0,0 +1,7 @@
|
||||
from langchain_benchmarks.tool_usage.agents.adapters import apply_agent_executor_adapter
|
||||
from langchain_benchmarks.tool_usage.agents.experimental.factory import (
|
||||
CustomAgentFactory,
|
||||
)
|
||||
from langchain_benchmarks.tool_usage.agents.openai_functions import OpenAIAgentFactory
|
||||
|
||||
__all__ = ["OpenAIAgentFactory", "apply_agent_executor_adapter", "CustomAgentFactory"]
|
||||
@@ -0,0 +1,71 @@
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
|
||||
|
||||
|
||||
def _ensure_output_exists(inputs: dict) -> dict:
|
||||
"""Make sure that the output key is always present."""
|
||||
if "output" not in inputs:
|
||||
return {"output": "", **inputs}
|
||||
return inputs
|
||||
|
||||
|
||||
def apply_agent_executor_adapter(
|
||||
agent_executor: AgentExecutor,
|
||||
*,
|
||||
state_reader: Optional[Callable[[], Any]] = None,
|
||||
) -> Runnable:
|
||||
"""An adapter for the agent executor to standardize its input and output.
|
||||
|
||||
1) Map `question` to `input` (`question` is used in the datasets,
|
||||
but `input` is used in the agent executor)
|
||||
2) Ensure that `output` is always returned (will be set to "" if missing) --
|
||||
note that this may be relaxed after more updates in the eval config.
|
||||
3) Populate `state` key in the response of the agent with the system state
|
||||
if a state reader is provided.
|
||||
|
||||
Args:
|
||||
agent_executor: the agent executor
|
||||
state_reader: A callable without parameters that if invoked will return
|
||||
the state of the environment. Used to populate the 'state' key.
|
||||
|
||||
Returns:
|
||||
a new runnable with a standardized output.
|
||||
"""
|
||||
|
||||
def _read_state(*args: Any, **kwargs: Any) -> Any:
|
||||
"""Read the state of the environment."""
|
||||
if state_reader is not None:
|
||||
return state_reader()
|
||||
else:
|
||||
return None
|
||||
|
||||
def _format_input(inputs: dict) -> dict:
|
||||
"""Make sure that the input is always called `input`."""
|
||||
|
||||
if "question" not in inputs:
|
||||
raise ValueError(
|
||||
"Expected 'question' to be in the inputs. Found only the following "
|
||||
f"keys {sorted(inputs.keys())}."
|
||||
)
|
||||
|
||||
inputs = inputs.copy() # Because 'question' is popped below
|
||||
|
||||
if "input" not in inputs:
|
||||
return {"input": inputs.pop("question"), **inputs}
|
||||
return inputs
|
||||
|
||||
runnable = (
|
||||
RunnableLambda(_format_input).with_config({"run_name": "Format Input"})
|
||||
| agent_executor
|
||||
| RunnableLambda(_ensure_output_exists).with_config(
|
||||
{"run_name": "Ensure Output"}
|
||||
)
|
||||
)
|
||||
|
||||
if state_reader is not None:
|
||||
runnable = runnable | RunnablePassthrough.assign(state=_read_state).with_config(
|
||||
{"run_name": "Read Env State"}
|
||||
)
|
||||
return runnable
|
||||
@@ -0,0 +1,133 @@
|
||||
from typing import List, Literal, Optional, Sequence, Tuple, Union
|
||||
|
||||
from langchain.agents import AgentOutputParser
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from langchain.schema.runnable import Runnable
|
||||
from langchain.tools import StructuredTool
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.language_models import BaseChatModel, BaseLanguageModel
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from langchain_core.prompts import MessagesPlaceholder
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from langchain_benchmarks import RateLimiter
|
||||
from langchain_benchmarks.rate_limiting import with_rate_limit
|
||||
from langchain_benchmarks.tool_usage.agents.experimental.encoder import (
|
||||
AstPrinter,
|
||||
FunctionResult,
|
||||
TypeScriptEncoder,
|
||||
XMLEncoder,
|
||||
)
|
||||
from langchain_benchmarks.tool_usage.agents.experimental.prompts import (
|
||||
_AGENT_INSTRUCTIONS_BLOB_STYLE,
|
||||
)
|
||||
from langchain_benchmarks.tool_usage.agents.experimental.tool_utils import (
|
||||
convert_tool_to_function_definition,
|
||||
)
|
||||
|
||||
|
||||
def format_steps_for_chat(
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
ast_printer: AstPrinter,
|
||||
) -> List[BaseMessage]:
|
||||
"""Format the steps."""
|
||||
messages = []
|
||||
for action, observation in intermediate_steps:
|
||||
# Action messages contains the tool invocation request from the LLM
|
||||
# Now add the result of the tool invocation.
|
||||
|
||||
if action.tool == "_Exception":
|
||||
messages.append(
|
||||
AIMessage(
|
||||
content=action.log,
|
||||
)
|
||||
)
|
||||
messages.append(
|
||||
# Tool input is the error message for the exception
|
||||
HumanMessage(content=action.tool_input)
|
||||
)
|
||||
else:
|
||||
messages.extend(action.messages)
|
||||
function_result: FunctionResult = {
|
||||
"name": action.tool,
|
||||
"error": None,
|
||||
"result": observation,
|
||||
}
|
||||
messages.append(
|
||||
HumanMessage(
|
||||
content=ast_printer.visit_function_result(function_result),
|
||||
)
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
# PUBLIC API
|
||||
|
||||
|
||||
class AgentInput(TypedDict):
|
||||
"""The input to the agent."""
|
||||
|
||||
input: str
|
||||
"""The input to the agent."""
|
||||
intermediate_steps: List[Tuple[AgentAction, str]]
|
||||
"""The intermediate steps taken by the agent."""
|
||||
examples: NotRequired[List[BaseMessage]]
|
||||
"""A list of messages that can be used to form example traces."""
|
||||
|
||||
|
||||
def create_agent(
|
||||
model: Union[BaseChatModel, BaseLanguageModel],
|
||||
tools: Sequence[StructuredTool],
|
||||
parser: AgentOutputParser,
|
||||
*,
|
||||
ast_printer: Union[AstPrinter, Literal["xml"]] = "xml",
|
||||
rate_limiter: Optional[RateLimiter] = None,
|
||||
) -> Runnable[AgentInput, Union[AgentAction, AgentFinish]]:
|
||||
"""Create an agent for a chat model."""
|
||||
if isinstance(ast_printer, str):
|
||||
if ast_printer == "xml":
|
||||
ast_printer_ = XMLEncoder()
|
||||
elif ast_printer == "typescript":
|
||||
ast_printer_ = TypeScriptEncoder()
|
||||
else:
|
||||
raise ValueError(f"Unknown ast printer: {ast_printer}")
|
||||
elif isinstance(ast_printer, AstPrinter):
|
||||
ast_printer_ = ast_printer
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected AstPrinter or str, got {type(ast_printer)} for `ast_printer`"
|
||||
)
|
||||
|
||||
function_definitions = [convert_tool_to_function_definition(tool) for tool in tools]
|
||||
tool_description = ast_printer_.visit_function_definitions(function_definitions)
|
||||
|
||||
template = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", _AGENT_INSTRUCTIONS_BLOB_STYLE),
|
||||
MessagesPlaceholder("examples"), # Can use to add example traces
|
||||
("human", "{input}"),
|
||||
MessagesPlaceholder("history"),
|
||||
]
|
||||
).partial(tool_description=tool_description)
|
||||
|
||||
# For the time being, hard-coding the fact that we're using a <tool> tag.
|
||||
model = model.bind(stop=["</tool>"])
|
||||
|
||||
if rate_limiter:
|
||||
# Apply a rate limiter if it was provided
|
||||
model = with_rate_limit(model, rate_limiter)
|
||||
|
||||
agent = (
|
||||
{
|
||||
"input": lambda x: x["input"],
|
||||
"history": lambda x: format_steps_for_chat(
|
||||
x["intermediate_steps"], ast_printer_
|
||||
),
|
||||
"examples": lambda x: x.get("examples", []),
|
||||
}
|
||||
| template
|
||||
| model
|
||||
| parser
|
||||
)
|
||||
return agent
|
||||
@@ -0,0 +1,240 @@
|
||||
"""Prototyping code for rendering function definitions, invocations, and results.
|
||||
|
||||
Types are simplified for now to `str`.
|
||||
|
||||
We should actually support something like pydantic or jsonschema for the types, so
|
||||
we can expand them recursively for nested types.
|
||||
"""
|
||||
import abc
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
|
||||
class Parameter(TypedDict):
|
||||
"""Representation for a parameter."""
|
||||
|
||||
name: str
|
||||
type: str
|
||||
description: str
|
||||
|
||||
|
||||
class Arguments(TypedDict):
|
||||
"""Arguments are passed to a function during function invocation."""
|
||||
|
||||
name: Optional[str]
|
||||
value: Any
|
||||
|
||||
|
||||
class ReturnValue(TypedDict):
|
||||
"""Representation for a return value of a function call."""
|
||||
|
||||
type: str
|
||||
description: NotRequired[str]
|
||||
|
||||
|
||||
class FunctionDefinition(TypedDict):
|
||||
"""Representation for a function."""
|
||||
|
||||
name: str
|
||||
description: str # Function description
|
||||
parameters: List[Parameter]
|
||||
return_value: ReturnValue
|
||||
|
||||
|
||||
class FunctionInvocation(TypedDict):
|
||||
"""Representation for a function invocation."""
|
||||
|
||||
id: NotRequired[str]
|
||||
name: str
|
||||
arguments: List[Arguments]
|
||||
|
||||
|
||||
class FunctionResult(TypedDict):
|
||||
"""Representation for a function result."""
|
||||
|
||||
id: NotRequired[str]
|
||||
name: str
|
||||
result: Optional[str]
|
||||
error: Optional[str]
|
||||
|
||||
|
||||
class Visitor(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def visit_function_definition(self, function_definition: FunctionDefinition) -> str:
|
||||
"""Render a function."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def visit_function_definitions(
|
||||
self, function_definitions: List[FunctionDefinition]
|
||||
) -> str:
|
||||
"""Render a function."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def visit_function_invocation(self, function_invocation: FunctionInvocation) -> str:
|
||||
"""Render a function invocation."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def visit_function_result(self, function_result: FunctionResult) -> str:
|
||||
"""Render a function result."""
|
||||
|
||||
|
||||
class AstPrinter(Visitor):
|
||||
"""Print the AST."""
|
||||
|
||||
|
||||
class XMLEncoder(AstPrinter):
|
||||
def visit_function_definition(self, function_definition: FunctionDefinition) -> str:
|
||||
"""Render a function."""
|
||||
parameters_lines = []
|
||||
|
||||
for parameter in function_definition["parameters"]:
|
||||
parameters_lines.extend(
|
||||
[
|
||||
"<parameter>",
|
||||
f"<name>{parameter['name']}</name>",
|
||||
f"<type>{parameter['type']}</type>",
|
||||
f"<description>{parameter['description']}</description>",
|
||||
"</parameter>",
|
||||
]
|
||||
)
|
||||
lines = [
|
||||
"<function>",
|
||||
f"<function_name>{function_definition['name']}</function_name>",
|
||||
"<description>",
|
||||
f"{function_definition['description']}",
|
||||
"</description>",
|
||||
"<parameters>",
|
||||
*parameters_lines,
|
||||
"</parameters>",
|
||||
"<return_value>",
|
||||
f"<type>{function_definition['return_value']['type']}</type>",
|
||||
]
|
||||
if function_definition["return_value"].get("description"):
|
||||
lines.append(
|
||||
f"<description>{function_definition['return_value']['description']}"
|
||||
f"</description>"
|
||||
)
|
||||
|
||||
lines.extend(["</return_value>", "</function>"])
|
||||
return "\n".join(lines)
|
||||
|
||||
def visit_function_definitions(
|
||||
self, function_definitions: List[FunctionDefinition]
|
||||
) -> str:
|
||||
"""Render a function."""
|
||||
strs = [
|
||||
self.visit_function_definition(function_definition)
|
||||
for function_definition in function_definitions
|
||||
]
|
||||
return "<functions>\n" + "\n".join(strs) + "\n</functions>"
|
||||
|
||||
def visit_function_invocation(self, invocation: FunctionInvocation) -> str:
|
||||
"""Render a function invocation."""
|
||||
arguments_as_strings = [
|
||||
"<argument>\n"
|
||||
f"<name>{argument['name']}</name>\n"
|
||||
f"<value>{argument['value']}</value>\n"
|
||||
"</argument>\n"
|
||||
for argument in invocation["arguments"]
|
||||
]
|
||||
lines = ["<function_invocation>"]
|
||||
|
||||
if invocation.get("id"):
|
||||
lines.append(f"<id>{invocation['id']}</id>")
|
||||
|
||||
lines.extend(
|
||||
[
|
||||
f"<function_name>{invocation['name']}</function_name>\n"
|
||||
"<arguments>\n"
|
||||
f"{''.join(arguments_as_strings)}" # Already includes trailing newline
|
||||
"</arguments>\n"
|
||||
"</function_invocation>"
|
||||
]
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
def visit_function_result(self, function_result: FunctionResult) -> str:
|
||||
"""Render a function result."""
|
||||
lines = [
|
||||
"<function_result>",
|
||||
]
|
||||
|
||||
if function_result.get("id"):
|
||||
lines.append(f"<id>{function_result['id']}</id>")
|
||||
|
||||
lines.append(f"<function_name>{function_result['name']}</function_name>")
|
||||
|
||||
if function_result["error"]:
|
||||
lines.extend(
|
||||
[
|
||||
f"<error>{function_result['error']}</error>",
|
||||
]
|
||||
)
|
||||
else:
|
||||
lines.append(
|
||||
f"<result>{function_result['result']}</result>",
|
||||
)
|
||||
|
||||
lines.append("</function_result>")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class TypeScriptEncoder(AstPrinter):
|
||||
def visit_function_definition(self, function_definition: FunctionDefinition) -> str:
|
||||
"""Render a function."""
|
||||
parameters_as_strings = [
|
||||
f"{parameter['name']}: {parameter['type']}"
|
||||
for parameter in function_definition["parameters"]
|
||||
]
|
||||
# Let's use JSdoc style comments
|
||||
# First the function description
|
||||
lines = [
|
||||
f"// {function_definition['description']}",
|
||||
# Then the parameter descriptions
|
||||
*[
|
||||
f"// @param {parameter['name']} {parameter['description']}"
|
||||
for parameter in function_definition["parameters"]
|
||||
],
|
||||
# Then the return value description
|
||||
f"// @returns {function_definition['return_value']['description']}",
|
||||
# Then the function definition
|
||||
f"function {function_definition['name']}("
|
||||
f"{', '.join(parameters_as_strings)}): "
|
||||
f"{function_definition['return_value']['type']};",
|
||||
]
|
||||
|
||||
# finally join
|
||||
function = "\n".join(lines)
|
||||
return function
|
||||
|
||||
def visit_function_definitions(
|
||||
self, function_definitions: List[FunctionDefinition]
|
||||
) -> str:
|
||||
"""Render a function."""
|
||||
strs = [
|
||||
self.visit_function_definition(function_definition)
|
||||
for function_definition in function_definitions
|
||||
]
|
||||
return "\n\n".join(strs)
|
||||
|
||||
def visit_function_invocation(self, invocation: FunctionInvocation) -> str:
|
||||
"""Render a function invocation."""
|
||||
arguments_as_strings = [
|
||||
f"{argument['name']}: {argument['value']}"
|
||||
for argument in invocation["arguments"]
|
||||
]
|
||||
lines = [f"{invocation['name']}(" f"{', '.join(arguments_as_strings)});"]
|
||||
return "\n".join(lines)
|
||||
|
||||
def visit_function_result(self, function_result: FunctionResult) -> str:
|
||||
"""Render a function result."""
|
||||
lines = []
|
||||
if function_result["error"]:
|
||||
lines.append(f"ERROR: {function_result['error']}")
|
||||
else:
|
||||
lines.append(f"> {function_result['result']}")
|
||||
if function_result.get("id"):
|
||||
lines.append(f"// ID: {function_result['id']}")
|
||||
return "\n".join(lines)
|
||||
@@ -0,0 +1,86 @@
|
||||
"""Factory for creating agents for the tool usage task."""
|
||||
from typing import Optional
|
||||
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain_core.runnables import Runnable, RunnableConfig
|
||||
|
||||
from langchain_benchmarks import RateLimiter, model_registry
|
||||
from langchain_benchmarks.schema import ToolUsageTask
|
||||
from langchain_benchmarks.tool_usage.agents.adapters import apply_agent_executor_adapter
|
||||
from langchain_benchmarks.tool_usage.agents.experimental.agent import create_agent
|
||||
from langchain_benchmarks.tool_usage.agents.experimental.parser import (
|
||||
GenericAgentParser,
|
||||
)
|
||||
|
||||
|
||||
class CustomAgentFactory:
|
||||
"""A factory for creating tool using agents.
|
||||
|
||||
A factory for agents that do not leverage any special JSON mode for
|
||||
function usage; instead all function invocation behavior is implemented solely
|
||||
through prompt engineering and parsing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task: ToolUsageTask,
|
||||
model: str,
|
||||
*,
|
||||
rate_limiter: Optional[RateLimiter] = None,
|
||||
) -> None:
|
||||
"""Create an agent factory for the given tool usage task.
|
||||
|
||||
Args:
|
||||
task: The task to create an agent factory for
|
||||
model: model name (check model_registry)
|
||||
rate_limiter: The rate limiter to use if provided
|
||||
"""
|
||||
if model not in model_registry:
|
||||
raise ValueError(f"Unknown model: {model}")
|
||||
self.task = task
|
||||
self.model = model
|
||||
self.rate_limiter = rate_limiter
|
||||
|
||||
def __call__(self) -> Runnable:
|
||||
if isinstance(self.model, str):
|
||||
registered_model = model_registry.get_model(self.model)
|
||||
if registered_model is None:
|
||||
raise ValueError(f"Unknown model: {self.model}")
|
||||
model = registered_model.get_model(model_params={"temperature": 0})
|
||||
else:
|
||||
model = self.model
|
||||
|
||||
def _add_task_instructions(
|
||||
input: dict, config: Optional[RunnableConfig] = None, **kwargs
|
||||
) -> dict:
|
||||
"""Add task instructions to the question."""
|
||||
if not isinstance(input, dict):
|
||||
raise ValueError(
|
||||
f"Expected input to be a dict with key `question`. "
|
||||
f"Found {type(input)}."
|
||||
)
|
||||
input = input.copy()
|
||||
input["question"] = (
|
||||
f"{self.task.instructions}\nWrite down your answer, "
|
||||
f"but do not explain it. Input: `{input['question']}`"
|
||||
)
|
||||
return input
|
||||
|
||||
env = self.task.create_environment()
|
||||
|
||||
agent = create_agent(
|
||||
model,
|
||||
env.tools,
|
||||
GenericAgentParser(wrapping_xml_tag="tool", require_closing_xml_tag=False),
|
||||
rate_limiter=self.rate_limiter,
|
||||
)
|
||||
executor = AgentExecutor(
|
||||
agent=agent,
|
||||
tools=env.tools,
|
||||
handle_parsing_errors=True,
|
||||
return_intermediate_steps=True,
|
||||
)
|
||||
|
||||
return _add_task_instructions | apply_agent_executor_adapter(
|
||||
executor, state_reader=env.read_state
|
||||
)
|
||||
@@ -0,0 +1,122 @@
|
||||
import ast
|
||||
import re
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from langchain.agents import AgentOutputParser
|
||||
from langchain.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
|
||||
class _ToolInvocationRequest(BaseModel):
|
||||
"""Light-weight pydantic model for validating the raw tool invocation request.
|
||||
|
||||
The purpose of this model, is to make sure that whatever as parsed from
|
||||
the raw llm output has `tool_name` and potential `arguments` fields, and
|
||||
nothing else.
|
||||
"""
|
||||
|
||||
tool_name: str
|
||||
# OK parameterless tools which do not take arguments
|
||||
arguments: Optional[Dict] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class GenericAgentParser(AgentOutputParser):
|
||||
"""A generalized parser that makes it easier to parameterize different parsing."""
|
||||
|
||||
wrapping_xml_tag: str
|
||||
"""The tag that wraps the function invocation request.
|
||||
|
||||
For example, if "tool", then the function invocation request should be wrapped
|
||||
in <tool>...</tool>.
|
||||
"""
|
||||
require_closing_xml_tag: bool = False
|
||||
"""Whether we should require a closing tag for the wrapping_xml_tag.
|
||||
|
||||
For example, if True, then the function invocation request should be wrapped
|
||||
"""
|
||||
|
||||
def parse(self, text: str) -> Union[AgentFinish, AgentAction]:
|
||||
"""Parse the output of the agent."""
|
||||
open_tag = f"<{self.wrapping_xml_tag}>"
|
||||
close_tag = f"</{self.wrapping_xml_tag}>"
|
||||
if open_tag in text:
|
||||
# This is a hack to make sure that </tool> is always present
|
||||
# in the output if <tool>. </tool> may be a stop sequence for the
|
||||
# language model, so depending on implementation
|
||||
# the stop sequence may be cut off.
|
||||
# There might be a better way to do this, but this works and
|
||||
# is simple.
|
||||
if not self.require_closing_xml_tag:
|
||||
text += close_tag
|
||||
|
||||
pattern = rf"{open_tag}(?P<invocation>.*?){close_tag}"
|
||||
match = re.search(pattern, text, re.DOTALL)
|
||||
if match:
|
||||
content = match.group("invocation").strip()
|
||||
return parse_invocation(content, self.wrapping_xml_tag)
|
||||
|
||||
return AgentFinish(
|
||||
log=text,
|
||||
return_values={
|
||||
"output": text,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def parse_invocation(text: str, tag: str) -> AgentAction:
|
||||
"""Parse the content of the function invocation.
|
||||
|
||||
Args:
|
||||
text: The text to parse.
|
||||
tag: The tag that wraps the function invocation request.
|
||||
|
||||
Returns:
|
||||
An AgentAction that corresponds to the function invocation.
|
||||
|
||||
Raises:
|
||||
OutputParserException: If the parsing fails.
|
||||
|
||||
This exception is meant to be caught by the agent executor and
|
||||
handled appropriately to provide feedback to the LLM.
|
||||
"""
|
||||
ai_content = f"<{tag}>{text}</{tag}>\n"
|
||||
|
||||
try:
|
||||
result = ast.literal_eval(text)
|
||||
except BaseException as e:
|
||||
# Convert this to something controllable by the user.
|
||||
err_msg = (
|
||||
f"ERROR: Please use the format "
|
||||
f'<{tag}>{{"tool_name": $TOOL_NAME, "arguments": $ARGUMENTS}}</{tag}>\n'
|
||||
)
|
||||
|
||||
raise OutputParserException(
|
||||
error=e,
|
||||
llm_output=ai_content,
|
||||
observation=err_msg,
|
||||
send_to_llm=True,
|
||||
)
|
||||
|
||||
try:
|
||||
request = _ToolInvocationRequest.validate(result)
|
||||
except Exception as e: # Using broad exception since it's not just ValidationError
|
||||
# Can also raise DictError if result is not a dict.
|
||||
err_msg = (
|
||||
f"ERROR: Please use the format "
|
||||
f'<{tag}>{{"tool_name": $TOOL_NAME, "arguments": $ARGUMENTS}}</{tag}>\n'
|
||||
)
|
||||
raise OutputParserException(
|
||||
error=e,
|
||||
llm_output=ai_content,
|
||||
send_to_llm=True,
|
||||
observation=err_msg,
|
||||
)
|
||||
|
||||
return AgentActionMessageLog(
|
||||
message_log=[AIMessage(content=ai_content)],
|
||||
tool=request.tool_name,
|
||||
tool_input=request.arguments,
|
||||
log=f"\nInvoking {request.tool_name}: {request.arguments}\n\t",
|
||||
)
|
||||
@@ -0,0 +1,42 @@
|
||||
AGENT_INSTRUCTIONS_XML_FORMAT = """\
|
||||
In this environment you have access to a set of tools you can use to answer the user's question.
|
||||
|
||||
You may call them like this:
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>$TOOL_NAME</tool_name>
|
||||
<parameters>
|
||||
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
|
||||
...
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
|
||||
Here are the tools available:
|
||||
|
||||
{tool_description}
|
||||
""" # noqa: E501
|
||||
|
||||
_AGENT_INSTRUCTIONS_BLOB_STYLE = """\
|
||||
In this environment you have access to a set of tools you can use to answer the user's question.
|
||||
|
||||
Here are the tools available:
|
||||
|
||||
{tool_description}
|
||||
|
||||
You may call one tool at a time using a format that includes <tool> and </tool> tag.
|
||||
|
||||
Inside the tag the content is a python dictionary that uses python literals (e.g., numbers, strings, lists, dictionaries, etc.) to specify the tool invocation.
|
||||
|
||||
It must match the schema of the function as described in the tool description.
|
||||
"arguments" is a dictionary of the arguments to the function.
|
||||
|
||||
<tool>
|
||||
{{
|
||||
"tool_name": $TOOL_NAME,
|
||||
"arguments": $ARGUMENTS
|
||||
}}
|
||||
</tool>
|
||||
|
||||
If you do not know the answer use more tools. You can only take a single action at a time.\
|
||||
""" # noqa: E501
|
||||
@@ -0,0 +1,57 @@
|
||||
"""Utilities to extract information from langchain tools for use in prompts."""
|
||||
import inspect
|
||||
from textwrap import dedent
|
||||
from typing import List
|
||||
|
||||
from langchain.tools.base import StructuredTool
|
||||
|
||||
from langchain_benchmarks.tool_usage.agents.experimental.encoder import (
|
||||
FunctionDefinition,
|
||||
Parameter,
|
||||
)
|
||||
|
||||
# PUBLIC API
|
||||
|
||||
|
||||
def get_parameters_from_tool(tool: StructuredTool) -> List[Parameter]:
|
||||
"""Convert a langchain tool to a tool user tool."""
|
||||
schema = tool.args_schema.schema()
|
||||
|
||||
properties = schema["properties"]
|
||||
parameters = []
|
||||
# Is this needed or is string OK?
|
||||
type_adapter = {
|
||||
"string": "str", # str or string?
|
||||
"integer": "int",
|
||||
"number": "float",
|
||||
"boolean": "bool",
|
||||
}
|
||||
for key, value in properties.items():
|
||||
parameters.append(
|
||||
{
|
||||
"name": key,
|
||||
"type": type_adapter.get(value["type"], value["type"]),
|
||||
"description": value.get("description", ""),
|
||||
}
|
||||
)
|
||||
|
||||
return parameters
|
||||
|
||||
|
||||
#
|
||||
def convert_tool_to_function_definition(tool: StructuredTool) -> FunctionDefinition:
|
||||
"""Convert a langchain tool to a tool user tool."""
|
||||
# Here we re-inspect the underlying function to get the doc-string
|
||||
# since StructuredTool modifies it, but we want the raw one for maximum
|
||||
# flexibility.
|
||||
description = inspect.getdoc(tool.func)
|
||||
|
||||
parameters = get_parameters_from_tool(tool)
|
||||
return {
|
||||
"name": tool.name,
|
||||
"description": dedent(description),
|
||||
"parameters": parameters,
|
||||
"return_value": {
|
||||
"type": "Any",
|
||||
},
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
"""Code for creating an agent factory for evaluating tool usage tasks."""
|
||||
from typing import Optional
|
||||
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain.agents.format_scratchpad import format_to_openai_functions
|
||||
from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain.schema.runnable import Runnable
|
||||
from langchain.tools.render import format_tool_to_openai_function
|
||||
|
||||
from langchain_benchmarks import rate_limiting
|
||||
from langchain_benchmarks.schema import ToolUsageTask
|
||||
from langchain_benchmarks.tool_usage.agents.adapters import apply_agent_executor_adapter
|
||||
|
||||
# PUBLIC API
|
||||
|
||||
|
||||
class OpenAIAgentFactory:
|
||||
def __init__(
|
||||
self,
|
||||
task: ToolUsageTask,
|
||||
*,
|
||||
model: str = "gpt-3.5-turbo-16k",
|
||||
rate_limiter: Optional[rate_limiting.RateLimiter] = None,
|
||||
) -> None:
|
||||
"""Create an OpenAI agent factory for the given task.
|
||||
|
||||
Args:
|
||||
task: The task to create an agent factory for.
|
||||
model: The model to use -- this must be an open AI model.
|
||||
rate_limiter: The rate limiter to use
|
||||
"""
|
||||
self.task = task
|
||||
self.model = model
|
||||
self.rate_limiter = rate_limiter
|
||||
|
||||
def create(self) -> Runnable:
|
||||
"""Agent Executor"""
|
||||
# For backwards compatibility
|
||||
return self()
|
||||
|
||||
def __call__(self) -> Runnable:
|
||||
model = ChatOpenAI(
|
||||
model=self.model,
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
env = self.task.create_environment()
|
||||
|
||||
model = model.bind(
|
||||
functions=[format_tool_to_openai_function(t) for t in env.tools]
|
||||
)
|
||||
|
||||
if rate_limiting:
|
||||
# Rate limited model
|
||||
model = rate_limiting.with_rate_limit(model, self.rate_limiter)
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
self.task.instructions,
|
||||
),
|
||||
("user", "{input}"),
|
||||
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||
]
|
||||
)
|
||||
|
||||
runnable_agent = (
|
||||
{
|
||||
"input": lambda x: x["input"],
|
||||
"agent_scratchpad": lambda x: format_to_openai_functions(
|
||||
x["intermediate_steps"]
|
||||
),
|
||||
}
|
||||
| prompt
|
||||
| model
|
||||
| OpenAIFunctionsAgentOutputParser()
|
||||
)
|
||||
|
||||
runnable = AgentExecutor(
|
||||
agent=runnable_agent,
|
||||
tools=env.tools,
|
||||
handle_parsing_errors=True,
|
||||
return_intermediate_steps=True,
|
||||
)
|
||||
|
||||
# Returns `state` in the output if the environment has a state reader
|
||||
# makes sure that `output` is always in the output
|
||||
return apply_agent_executor_adapter(runnable, state_reader=env.read_state)
|
||||
@@ -77,7 +77,7 @@ def compare_outputs(
|
||||
|
||||
# Evaluate state score
|
||||
# This will need to be evolved it's too simple.
|
||||
if "state" in run_outputs:
|
||||
if "state" in run_outputs and "state" in example_outputs:
|
||||
state = run_outputs["state"]
|
||||
example_state = example_outputs["state"]
|
||||
results.append(
|
||||
@@ -112,7 +112,7 @@ class AgentTrajectoryEvaluator(RunEvaluator):
|
||||
def __init__(
|
||||
self,
|
||||
eval_llm: Union[BaseLanguageModel, BaseChatModel, None] = None,
|
||||
output_evaluation: Literal["qa", "none"] = "qa",
|
||||
output_evaluation: Literal["qa", "none", "qa_math"] = "qa",
|
||||
) -> None:
|
||||
"""Initialize the evaluator."""
|
||||
if output_evaluation == "none":
|
||||
|
||||
@@ -152,6 +152,9 @@ The objective of this task is to evaluate the ability to use the provided tools
|
||||
solve simple math questions and ignore any innate knowledge about math.
|
||||
"""
|
||||
),
|
||||
eval_params={
|
||||
"output_evaluation": "qa_math",
|
||||
},
|
||||
)
|
||||
|
||||
# Source dataset used to create the public dataset in LangSmith
|
||||
|
||||
@@ -438,4 +438,5 @@ the question.
|
||||
Success is measured by the ability to answer the question correctly, and efficiently.
|
||||
"""
|
||||
),
|
||||
eval_params={}, # No special evaluation parameters
|
||||
)
|
||||
|
||||
@@ -18,7 +18,7 @@ class Paper:
|
||||
content: str
|
||||
|
||||
|
||||
def create_typer(paper: Paper) -> Callable[[], str]:
|
||||
def create_typer(paper: Paper) -> Callable[[str], str]:
|
||||
"""Create a function that types the given letter."""
|
||||
|
||||
def type_letter(letter: str) -> str:
|
||||
@@ -82,6 +82,12 @@ The dataset includes examples of varying difficulty. The difficulty is measured
|
||||
by the length of the string.
|
||||
"""
|
||||
),
|
||||
eval_params={
|
||||
# For this task, the agent's output is irrelevant
|
||||
# what we care about is the final state of the environment
|
||||
# (i.e., what's written on the virtual paper)
|
||||
"output_evaluation": "none",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -92,6 +92,12 @@ This is a variation of the typer writer task, where 26 parameterless tools are
|
||||
given instead of a single tool that takes a letter as an argument.
|
||||
"""
|
||||
),
|
||||
eval_params={
|
||||
# For this task, the agent's output is irrelevant
|
||||
# what we care about is the final state of the environment
|
||||
# (i.e., what's written on the virtual paper)
|
||||
"output_evaluation": "none",
|
||||
},
|
||||
)
|
||||
|
||||
STRINGS_TO_TYPE = [
|
||||
|
||||
Generated
+33
-19
@@ -816,6 +816,20 @@ files = [
|
||||
{file = "fqdn-1.5.1.tar.gz", hash = "sha256:105ed3677e767fb5ca086a0c1f4bb66ebc3c100be518f0e0d755d9eae164d89f"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "freezegun"
|
||||
version = "1.3.1"
|
||||
description = "Let your Python tests travel through time"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "freezegun-1.3.1-py3-none-any.whl", hash = "sha256:065e77a12624d05531afa87ade12a0b9bdb53495c4573893252a055b545ce3ea"},
|
||||
{file = "freezegun-1.3.1.tar.gz", hash = "sha256:48984397b3b58ef5dfc645d6a304b0060f612bcecfdaaf45ce8aff0077a6cb6a"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
python-dateutil = ">=2.7"
|
||||
|
||||
[[package]]
|
||||
name = "frozenlist"
|
||||
version = "1.4.0"
|
||||
@@ -3145,28 +3159,28 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "ruff"
|
||||
version = "0.1.6"
|
||||
version = "0.1.7"
|
||||
description = "An extremely fast Python linter and code formatter, written in Rust."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "ruff-0.1.6-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:88b8cdf6abf98130991cbc9f6438f35f6e8d41a02622cc5ee130a02a0ed28703"},
|
||||
{file = "ruff-0.1.6-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:5c549ed437680b6105a1299d2cd30e4964211606eeb48a0ff7a93ef70b902248"},
|
||||
{file = "ruff-0.1.6-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1cf5f701062e294f2167e66d11b092bba7af6a057668ed618a9253e1e90cfd76"},
|
||||
{file = "ruff-0.1.6-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:05991ee20d4ac4bb78385360c684e4b417edd971030ab12a4fbd075ff535050e"},
|
||||
{file = "ruff-0.1.6-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:87455a0c1f739b3c069e2f4c43b66479a54dea0276dd5d4d67b091265f6fd1dc"},
|
||||
{file = "ruff-0.1.6-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:683aa5bdda5a48cb8266fcde8eea2a6af4e5700a392c56ea5fb5f0d4bfdc0240"},
|
||||
{file = "ruff-0.1.6-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:137852105586dcbf80c1717facb6781555c4e99f520c9c827bd414fac67ddfb6"},
|
||||
{file = "ruff-0.1.6-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bd98138a98d48a1c36c394fd6b84cd943ac92a08278aa8ac8c0fdefcf7138f35"},
|
||||
{file = "ruff-0.1.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a0cd909d25f227ac5c36d4e7e681577275fb74ba3b11d288aff7ec47e3ae745"},
|
||||
{file = "ruff-0.1.6-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:e8fd1c62a47aa88a02707b5dd20c5ff20d035d634aa74826b42a1da77861b5ff"},
|
||||
{file = "ruff-0.1.6-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:fd89b45d374935829134a082617954120d7a1470a9f0ec0e7f3ead983edc48cc"},
|
||||
{file = "ruff-0.1.6-py3-none-musllinux_1_2_i686.whl", hash = "sha256:491262006e92f825b145cd1e52948073c56560243b55fb3b4ecb142f6f0e9543"},
|
||||
{file = "ruff-0.1.6-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:ea284789861b8b5ca9d5443591a92a397ac183d4351882ab52f6296b4fdd5462"},
|
||||
{file = "ruff-0.1.6-py3-none-win32.whl", hash = "sha256:1610e14750826dfc207ccbcdd7331b6bd285607d4181df9c1c6ae26646d6848a"},
|
||||
{file = "ruff-0.1.6-py3-none-win_amd64.whl", hash = "sha256:4558b3e178145491e9bc3b2ee3c4b42f19d19384eaa5c59d10acf6e8f8b57e33"},
|
||||
{file = "ruff-0.1.6-py3-none-win_arm64.whl", hash = "sha256:03910e81df0d8db0e30050725a5802441c2022ea3ae4fe0609b76081731accbc"},
|
||||
{file = "ruff-0.1.6.tar.gz", hash = "sha256:1b09f29b16c6ead5ea6b097ef2764b42372aebe363722f1605ecbcd2b9207184"},
|
||||
{file = "ruff-0.1.7-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:7f80496854fdc65b6659c271d2c26e90d4d401e6a4a31908e7e334fab4645aac"},
|
||||
{file = "ruff-0.1.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:1ea109bdb23c2a4413f397ebd8ac32cb498bee234d4191ae1a310af760e5d287"},
|
||||
{file = "ruff-0.1.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8b0c2de9dd9daf5e07624c24add25c3a490dbf74b0e9bca4145c632457b3b42a"},
|
||||
{file = "ruff-0.1.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:69a4bed13bc1d5dabf3902522b5a2aadfebe28226c6269694283c3b0cecb45fd"},
|
||||
{file = "ruff-0.1.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:de02ca331f2143195a712983a57137c5ec0f10acc4aa81f7c1f86519e52b92a1"},
|
||||
{file = "ruff-0.1.7-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:45b38c3f8788a65e6a2cab02e0f7adfa88872696839d9882c13b7e2f35d64c5f"},
|
||||
{file = "ruff-0.1.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6c64cb67b2025b1ac6d58e5ffca8f7b3f7fd921f35e78198411237e4f0db8e73"},
|
||||
{file = "ruff-0.1.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9dcc6bb2f4df59cb5b4b40ff14be7d57012179d69c6565c1da0d1f013d29951b"},
|
||||
{file = "ruff-0.1.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:df2bb4bb6bbe921f6b4f5b6fdd8d8468c940731cb9406f274ae8c5ed7a78c478"},
|
||||
{file = "ruff-0.1.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:276a89bcb149b3d8c1b11d91aa81898fe698900ed553a08129b38d9d6570e717"},
|
||||
{file = "ruff-0.1.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:90c958fe950735041f1c80d21b42184f1072cc3975d05e736e8d66fc377119ea"},
|
||||
{file = "ruff-0.1.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:6b05e3b123f93bb4146a761b7a7d57af8cb7384ccb2502d29d736eaade0db519"},
|
||||
{file = "ruff-0.1.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:290ecab680dce94affebefe0bbca2322a6277e83d4f29234627e0f8f6b4fa9ce"},
|
||||
{file = "ruff-0.1.7-py3-none-win32.whl", hash = "sha256:416dfd0bd45d1a2baa3b1b07b1b9758e7d993c256d3e51dc6e03a5e7901c7d80"},
|
||||
{file = "ruff-0.1.7-py3-none-win_amd64.whl", hash = "sha256:4af95fd1d3b001fc41325064336db36e3d27d2004cdb6d21fd617d45a172dd96"},
|
||||
{file = "ruff-0.1.7-py3-none-win_arm64.whl", hash = "sha256:0683b7bfbb95e6df3c7c04fe9d78f631f8e8ba4868dfc932d43d690698057e2e"},
|
||||
{file = "ruff-0.1.7.tar.gz", hash = "sha256:dffd699d07abf54833e5f6cc50b85a6ff043715da8788c4a79bcd4ab4734d306"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4069,4 +4083,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.8.1"
|
||||
content-hash = "d360fa5fb031270329ee36b00f5c2ebb2b73d92df8e97041ce53a14338488b91"
|
||||
content-hash = "8ff4713dcb59c3e0d796659a7feaf21a5bacd9ce9995426f56cce4b5af9e5e1b"
|
||||
|
||||
+3
-5
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "langchain-benchmarks"
|
||||
version = "0.0.6"
|
||||
version = "0.0.9"
|
||||
description = "🦜💪 Flex those feathers!"
|
||||
authors = ["LangChain AI"]
|
||||
license = "MIT"
|
||||
@@ -41,6 +41,7 @@ pytest-mock = "^3.11.1"
|
||||
pytest-socket = "^0.6.0"
|
||||
pytest-watch = "^4.2.0"
|
||||
pytest-timeout = "^2.2.0"
|
||||
freezegun = "^1.3.1"
|
||||
|
||||
|
||||
[tool.ruff]
|
||||
@@ -55,10 +56,7 @@ extend-include = ["*.ipynb"]
|
||||
line-length = 88
|
||||
|
||||
[tool.ruff.isort]
|
||||
# TODO(Team): Temporary to make isort work with examples.
|
||||
# examples assume langserve is available as a 3rd party package
|
||||
# For simplicity we'll define it as first party for now can update later.
|
||||
known-first-party = ["langserve"]
|
||||
known-first-party = ["langchain-benchmarks"]
|
||||
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = "True"
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
import pytest
|
||||
from langchain_core.agents import AgentActionMessageLog, AgentFinish
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from langchain_benchmarks.tool_usage.agents.experimental.parser import (
|
||||
GenericAgentParser,
|
||||
)
|
||||
|
||||
|
||||
def test_parser() -> None:
|
||||
"""Test parser."""
|
||||
parser = GenericAgentParser(require_closing_tag=False, wrapping_xml_tag="tool")
|
||||
|
||||
# If <tool> tag not found then it's an agent finish
|
||||
assert isinstance(parser.invoke("goodbye"), AgentFinish)
|
||||
|
||||
with pytest.raises(OutputParserException):
|
||||
# Invocation content is missing tool name and arguments
|
||||
parser.invoke("<tool>'hello'</tool>")
|
||||
|
||||
with pytest.raises(OutputParserException):
|
||||
parser.invoke("<tool>hello")
|
||||
|
||||
# Full invocation
|
||||
text = (
|
||||
'<tool>{\n "tool_name": "type_letter",\n '
|
||||
'"arguments": {\n '
|
||||
'"letter": "h"\n }\n}</tool>\n'
|
||||
)
|
||||
|
||||
assert parser.invoke(text) == AgentActionMessageLog(
|
||||
tool="type_letter",
|
||||
tool_input={"letter": "h"},
|
||||
log="\nInvoking type_letter: {'letter': 'h'}\n\t",
|
||||
message_log=[AIMessage(content=text)],
|
||||
)
|
||||
|
||||
# Test more cases
|
||||
parsed = parser.invoke('<tool>{"tool_name": "hello"}</tool>')
|
||||
assert parsed.tool == "hello"
|
||||
# Assumes that it's a structured tool by default!
|
||||
assert parsed.tool_input == {}
|
||||
|
||||
with pytest.raises(OutputParserException):
|
||||
# Arguments need to be a dict
|
||||
parser.invoke('<tool>{"tool_name": "hello", "arguments": [1, 2]}</tool>')
|
||||
|
||||
parsed = parser.invoke(
|
||||
'<tool>{"tool_name": "hello", "arguments": {"a": "b"}}</tool>'
|
||||
)
|
||||
assert parsed.tool == "hello"
|
||||
# Assumes that it's a structured tool by default!
|
||||
assert parsed.tool_input == {"a": "b"}
|
||||
@@ -0,0 +1,25 @@
|
||||
"""Test typescript encoding."""
|
||||
from langchain_benchmarks.tool_usage.agents.experimental.encoder import (
|
||||
FunctionDefinition,
|
||||
TypeScriptEncoder,
|
||||
)
|
||||
|
||||
|
||||
def test_function_definition() -> None:
|
||||
"""Test encoding a function definition."""
|
||||
function_definition = FunctionDefinition(
|
||||
name="test_function",
|
||||
description="A test function",
|
||||
parameters=[
|
||||
{"name": "test_parameter", "type": "str", "description": "A test parameter"}
|
||||
],
|
||||
return_value={"type": "str", "description": "A test return value"},
|
||||
)
|
||||
encoder = TypeScriptEncoder()
|
||||
xml = encoder.visit_function_definition(function_definition)
|
||||
assert xml == (
|
||||
"// A test function\n"
|
||||
"// @param test_parameter A test parameter\n"
|
||||
"// @returns A test return value\n"
|
||||
"function test_function(test_parameter: str): str;"
|
||||
)
|
||||
@@ -0,0 +1,90 @@
|
||||
"""Test XML encoding and decoding of function definitions, invocation, and results."""
|
||||
from langchain_benchmarks.tool_usage.agents.experimental.encoder import (
|
||||
FunctionDefinition,
|
||||
FunctionInvocation,
|
||||
FunctionResult,
|
||||
XMLEncoder,
|
||||
)
|
||||
|
||||
|
||||
def test_function_definition_encoding() -> None:
|
||||
"""Test encoding a function definition."""
|
||||
function_definition = FunctionDefinition(
|
||||
name="test_function",
|
||||
description="A test function",
|
||||
parameters=[
|
||||
{"name": "test_parameter", "type": "str", "description": "A test parameter"}
|
||||
],
|
||||
return_value={"type": "str", "description": "A test return value"},
|
||||
)
|
||||
encoder = XMLEncoder()
|
||||
xml = encoder.visit_function_definition(function_definition)
|
||||
assert xml == (
|
||||
"<function>\n"
|
||||
"<function_name>test_function</function_name>\n"
|
||||
"<description>\n"
|
||||
"A test function\n"
|
||||
"</description>\n"
|
||||
"<parameters>\n"
|
||||
"<parameter>\n"
|
||||
"<name>test_parameter</name>\n"
|
||||
"<type>str</type>\n"
|
||||
"<description>A test parameter</description>\n"
|
||||
"</parameter>\n"
|
||||
"</parameters>\n"
|
||||
"<return_value>\n"
|
||||
"<type>str</type>\n"
|
||||
"<description>A test return value</description>\n"
|
||||
"</return_value>\n"
|
||||
"</function>"
|
||||
)
|
||||
|
||||
|
||||
def test_function_result_encoding() -> None:
|
||||
"""Test encoding a function result."""
|
||||
encoder = XMLEncoder()
|
||||
function_result = FunctionResult(
|
||||
name="test_function",
|
||||
result="test_result",
|
||||
error=None,
|
||||
)
|
||||
xml = encoder.visit_function_result(function_result)
|
||||
assert xml == (
|
||||
"<function_result>\n"
|
||||
"<function_name>test_function</function_name>\n"
|
||||
"<result>test_result</result>\n"
|
||||
"</function_result>"
|
||||
)
|
||||
|
||||
function_result = FunctionResult(
|
||||
name="test_function",
|
||||
error="error",
|
||||
)
|
||||
xml = encoder.visit_function_result(function_result)
|
||||
assert xml == (
|
||||
"<function_result>\n"
|
||||
"<function_name>test_function</function_name>\n"
|
||||
"<error>error</error>\n"
|
||||
"</function_result>"
|
||||
)
|
||||
|
||||
|
||||
def test_function_invocation() -> None:
|
||||
"""Test function invocation."""
|
||||
function_invocation = FunctionInvocation(
|
||||
name="test_function",
|
||||
arguments=[{"name": "test_argument", "value": "test_value"}],
|
||||
)
|
||||
encoder = XMLEncoder()
|
||||
xml = encoder.visit_function_invocation(function_invocation)
|
||||
assert xml == (
|
||||
"<function_invocation>\n"
|
||||
"<function_name>test_function</function_name>\n"
|
||||
"<arguments>\n"
|
||||
"<argument>\n"
|
||||
"<name>test_argument</name>\n"
|
||||
"<value>test_value</value>\n"
|
||||
"</argument>\n"
|
||||
"</arguments>\n"
|
||||
"</function_invocation>"
|
||||
)
|
||||
@@ -0,0 +1,59 @@
|
||||
import pytest
|
||||
from langchain.tools import tool
|
||||
|
||||
from langchain_benchmarks.tool_usage.agents.experimental.tool_utils import (
|
||||
convert_tool_to_function_definition,
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
def get_hello() -> str:
|
||||
"""Get hello."""
|
||||
return "hello"
|
||||
|
||||
|
||||
@tool
|
||||
def repeat(x: str) -> str:
|
||||
"""Repeat x.
|
||||
|
||||
Args:
|
||||
x: The string to repeat.
|
||||
|
||||
Returns:
|
||||
The repeated string.
|
||||
"""
|
||||
return x
|
||||
|
||||
|
||||
def test_parameterless_function() -> None:
|
||||
"""Test foo."""
|
||||
function_definition = convert_tool_to_function_definition(get_hello)
|
||||
assert function_definition == {
|
||||
"name": "get_hello",
|
||||
"description": "Get hello.",
|
||||
"parameters": [],
|
||||
"return_value": {
|
||||
"type": "Any",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.skip("Need to fix handling of leading whitespace")
|
||||
def test_function_with_parameters() -> None:
|
||||
import textwrap
|
||||
|
||||
doc = textwrap.dedent(repeat.func.__doc__)
|
||||
assert convert_tool_to_function_definition(repeat) == {
|
||||
"name": "repeat",
|
||||
"description": doc,
|
||||
"parameters": [
|
||||
{
|
||||
"name": "x",
|
||||
"type": "str",
|
||||
"description": "", # Need to fix this
|
||||
}
|
||||
],
|
||||
"return_value": {
|
||||
"type": "Any",
|
||||
},
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
import pytest
|
||||
|
||||
from langchain_benchmarks.schema import ModelRegistry, RegisteredModel
|
||||
|
||||
# Create some sample RegisteredModel instances for testing
|
||||
SAMPLE_MODELS = [
|
||||
RegisteredModel(
|
||||
"model1", "fireworks", "Description 1", {"param1": "value1"}, "chat"
|
||||
),
|
||||
RegisteredModel("model2", "openai", "Description 2", {"param2": "value2"}, "llm"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_registry() -> ModelRegistry:
|
||||
return ModelRegistry(SAMPLE_MODELS)
|
||||
|
||||
|
||||
def test_init() -> None:
|
||||
# Test the constructor of ModelRegistry
|
||||
registry = ModelRegistry(SAMPLE_MODELS)
|
||||
assert len(registry.registered_models) == 2
|
||||
|
||||
|
||||
def test_get_model(sample_registry: ModelRegistry) -> None:
|
||||
# Test the get_model method
|
||||
model = sample_registry.get_model("model1")
|
||||
assert model.name == "model1"
|
||||
|
||||
|
||||
def test_filter(sample_registry: ModelRegistry) -> None:
|
||||
# Test the filter method
|
||||
filtered_registry = sample_registry.filter(type="chat")
|
||||
assert len(filtered_registry.registered_models) == 1
|
||||
assert filtered_registry.registered_models[0].type == "chat"
|
||||
|
||||
|
||||
def test_repr_html(sample_registry: ModelRegistry) -> None:
|
||||
# Test the _repr_html_ method
|
||||
html_representation = sample_registry._repr_html_()
|
||||
assert "<table>" in html_representation
|
||||
|
||||
|
||||
def test_len(sample_registry: ModelRegistry) -> None:
|
||||
# Test the __len__ method
|
||||
assert len(sample_registry) == 2
|
||||
|
||||
|
||||
def test_iter(sample_registry: ModelRegistry) -> None:
|
||||
# Test the __iter__ method
|
||||
models = list(iter(sample_registry))
|
||||
assert len(models) == 2
|
||||
assert isinstance(models[0], RegisteredModel)
|
||||
|
||||
|
||||
def test_getitem(sample_registry: ModelRegistry) -> None:
|
||||
# Test the __getitem__ method for integer and string keys
|
||||
model = sample_registry[0]
|
||||
assert model.name == "model1"
|
||||
model = sample_registry["model2"]
|
||||
assert model.name == "model2"
|
||||
|
||||
|
||||
def test_getitem_slice(sample_registry: ModelRegistry) -> None:
|
||||
# Test the __getitem__ method for slices
|
||||
sliced_registry = sample_registry[:1]
|
||||
assert len(sliced_registry.registered_models) == 1
|
||||
assert sliced_registry.registered_models[0].name == "model1"
|
||||
@@ -6,5 +6,13 @@ def test_public_api() -> None:
|
||||
# This test will also fail if __all__ is not sorted.
|
||||
# Please keep it sorted!
|
||||
assert __all__ == sorted(
|
||||
["clone_public_dataset", "download_public_dataset", "registry"]
|
||||
[
|
||||
"__version__",
|
||||
"clone_public_dataset",
|
||||
"download_public_dataset",
|
||||
"model_registry",
|
||||
"RateLimiter",
|
||||
"registry",
|
||||
],
|
||||
key=lambda x: x.lower(),
|
||||
)
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
|
||||
from langchain_benchmarks.rate_limiting import RateLimiter
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"delta_time, requests_per_second, max_bucket_size, expected_result",
|
||||
[
|
||||
(
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
True,
|
||||
),
|
||||
(
|
||||
0.5,
|
||||
1,
|
||||
1,
|
||||
False,
|
||||
),
|
||||
(
|
||||
0.5,
|
||||
2,
|
||||
1,
|
||||
True,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_consume(
|
||||
delta_time: float,
|
||||
requests_per_second: float,
|
||||
max_bucket_size: float,
|
||||
expected_result: bool,
|
||||
) -> None:
|
||||
"""Test the consumption of tokens over time.
|
||||
|
||||
Args:
|
||||
delta_time: The time in seconds to add to the initial time.
|
||||
requests_per_second: The rate at which tokens are added per second.
|
||||
max_bucket_size: The maximum size of the token bucket.
|
||||
expected_result: The expected result of the consume operation.
|
||||
"""
|
||||
rate_limiter = RateLimiter(
|
||||
requests_per_second=requests_per_second, max_bucket_size=max_bucket_size
|
||||
)
|
||||
|
||||
with freeze_time(auto_tick_seconds=delta_time):
|
||||
assert rate_limiter._consume() is False
|
||||
assert rate_limiter._consume() is expected_result
|
||||
|
||||
|
||||
def test_consume_count_tokens() -> None:
|
||||
"""Test to check that the bucket size is used correctly."""
|
||||
rate_limiter = RateLimiter(
|
||||
requests_per_second=60,
|
||||
max_bucket_size=10,
|
||||
)
|
||||
|
||||
with freeze_time(auto_tick_seconds=100):
|
||||
assert rate_limiter._consume() is False
|
||||
assert rate_limiter._consume() is True
|
||||
assert (
|
||||
rate_limiter.available_tokens == 9
|
||||
) # Max bucket size is 10, so 10 - 1 = 9
|
||||
Reference in New Issue
Block a user