mirror of
https://github.com/langchain-ai/langchain-benchmarks.git
synced 2026-07-01 22:34:02 -04:00
Compare commits
25 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a7545285c3 | |||
| 295fb12f8c | |||
| dd1f33ce02 | |||
| 0ddb74036b | |||
| cf10528b99 | |||
| dbc254ebc0 | |||
| 4c3de31e6f | |||
| f8e5d32bc9 | |||
| 40250950db | |||
| 9adc5b96f5 | |||
| 85c8d29342 | |||
| a42c70f315 | |||
| 9c6ec01219 | |||
| d2defc95e3 | |||
| 8b0e52fdc4 | |||
| 22b90df0ad | |||
| 8ba4340730 | |||
| 2f0b4e9bed | |||
| 910bd60832 | |||
| 5b4672a33b | |||
| 9f827eaca5 | |||
| d9fc08b05c | |||
| 8a5ba6d575 | |||
| 8204930f2b | |||
| 013fe6a153 |
@@ -1,6 +1,4 @@
|
||||
🚧 Under Active Development 🚧
|
||||
|
||||
# 🦜💪 LangChain Benchmarks
|
||||
# 🦜💯 LangChain Benchmarks
|
||||
|
||||
[](https://github.com/langchain-ai/langchain-benchmarks/releases)
|
||||
[](https://github.com/langchain-ai/langchain-benchmarks/actions/workflows/ci.yml)
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
import inspect
|
||||
from textwrap import dedent
|
||||
from typing import List
|
||||
|
||||
from langchain.tools.base import StructuredTool
|
||||
|
||||
from agents.encoder import FunctionDefinition, Parameter
|
||||
|
||||
|
||||
# This is temporary until we have a better way to represent parameters
|
||||
def get_parameters_from_tool(tool: StructuredTool) -> List[Parameter]:
|
||||
"""Convert a langchain tool to a tool user tool."""
|
||||
schema = tool.args_schema.schema()
|
||||
|
||||
properties = schema["properties"]
|
||||
parameters = []
|
||||
# Is this needed or is string OK?
|
||||
type_adapter = {
|
||||
"string": "str", # str or string?
|
||||
"integer": "int",
|
||||
"number": "float",
|
||||
"boolean": "bool",
|
||||
}
|
||||
for key, value in properties.items():
|
||||
parameters.append(
|
||||
{
|
||||
"name": key,
|
||||
"type": type_adapter.get(value["type"], value["type"]),
|
||||
"description": value.get("description", ""),
|
||||
}
|
||||
)
|
||||
|
||||
return parameters
|
||||
|
||||
|
||||
#
|
||||
def convert_tool_to_function_definition(tool: StructuredTool) -> FunctionDefinition:
|
||||
"""Convert a langchain tool to a tool user tool."""
|
||||
# Here we re-inspect the underlying function to get the doc-string
|
||||
# since StructuredTool modifies it, but we want the raw one for maximum
|
||||
# flexibility.
|
||||
description = inspect.getdoc(tool.func)
|
||||
|
||||
parameters = get_parameters_from_tool(tool)
|
||||
return {
|
||||
"name": tool.name,
|
||||
"description": dedent(description),
|
||||
"parameters": parameters,
|
||||
"return_value": {
|
||||
"type": "Any",
|
||||
},
|
||||
}
|
||||
+105
@@ -0,0 +1,105 @@
|
||||
from typing import List, Literal, Sequence, Tuple, Union
|
||||
|
||||
from langchain.agents import AgentOutputParser
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.runnable import Runnable
|
||||
from langchain.tools import StructuredTool
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.language_models import BaseChatModel, BaseLanguageModel
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.prompts import MessagesPlaceholder
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from agents.adapters import convert_tool_to_function_definition
|
||||
from agents.encoder import AstPrinter, TypeScriptEncoder
|
||||
from agents.prompts import AGENT_INSTRUCTIONS_BLOB_STYLE
|
||||
|
||||
|
||||
def format_observation(tool_name: str, observation: str) -> BaseMessage:
|
||||
"""Format the observation."""
|
||||
result = (
|
||||
"<tool_output>\n"
|
||||
f"<tool_name>{tool_name}</tool_name>\n"
|
||||
f"<output>{observation}</output>\n"
|
||||
"</tool_output>"
|
||||
)
|
||||
|
||||
return HumanMessage(content=result)
|
||||
|
||||
|
||||
def format_steps_for_chat(
|
||||
intermediate_steps: List[Tuple[AgentAction, str]]
|
||||
) -> List[BaseMessage]:
|
||||
"""Format the steps."""
|
||||
messages = []
|
||||
for action, observation in intermediate_steps:
|
||||
if not isinstance(action, AgentAction):
|
||||
if action.tool != "_Exception":
|
||||
raise AssertionError(f"Unexpected step: {action}. type: {type(action)}")
|
||||
|
||||
messages.append(HumanMessage(content=observation))
|
||||
messages.extend(action.messages)
|
||||
messages.append(format_observation(action.tool, observation))
|
||||
return messages
|
||||
|
||||
|
||||
# PUBLIC API
|
||||
|
||||
|
||||
class AgentInput(TypedDict):
|
||||
"""The input to the agent."""
|
||||
|
||||
input: str
|
||||
"""The input to the agent."""
|
||||
intermediate_steps: List[Tuple[AgentAction, str]]
|
||||
"""The intermediate steps taken by the agent."""
|
||||
examples: NotRequired[List[BaseMessage]]
|
||||
"""A list of messages that can be used to form example traces."""
|
||||
|
||||
|
||||
def create_agent(
|
||||
model: Union[BaseChatModel, BaseLanguageModel],
|
||||
tools: Sequence[StructuredTool],
|
||||
parser: AgentOutputParser,
|
||||
*,
|
||||
ast_printer: Union[AstPrinter, Literal["xml"]] = "xml",
|
||||
) -> Runnable[AgentInput, Union[AgentAction, AgentFinish]]:
|
||||
"""Create an agent for a chat model."""
|
||||
if isinstance(ast_printer, str):
|
||||
if ast_printer == "xml":
|
||||
ast_printer = AstPrinter()
|
||||
elif ast_printer == "typescript":
|
||||
ast_printer = TypeScriptEncoder()
|
||||
else:
|
||||
raise ValueError(f"Unknown ast printer: {ast_printer}")
|
||||
elif isinstance(ast_printer, AstPrinter):
|
||||
pass
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected AstPrinter or str, got {type(ast_printer)} for `ast_printer`"
|
||||
)
|
||||
|
||||
function_definitions = [convert_tool_to_function_definition(tool) for tool in tools]
|
||||
tool_description = ast_printer.visit_function_definitions(function_definitions)
|
||||
|
||||
template = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", AGENT_INSTRUCTIONS_BLOB_STYLE),
|
||||
MessagesPlaceholder("examples"), # Can use to add example traces
|
||||
("human", "{input}"),
|
||||
MessagesPlaceholder("history"),
|
||||
]
|
||||
).partial(tool_description=tool_description)
|
||||
|
||||
agent = (
|
||||
{
|
||||
"input": lambda x: x["input"],
|
||||
"history": lambda x: format_steps_for_chat(x["intermediate_steps"]),
|
||||
"examples": lambda x: x.get("examples", []),
|
||||
}
|
||||
| template
|
||||
| model.bind(stop=["</tool>"])
|
||||
| parser
|
||||
)
|
||||
return agent
|
||||
@@ -0,0 +1,226 @@
|
||||
"""Prototyping code for rendering function definitions, invocations, and results.
|
||||
|
||||
Types are simplified for now to `str`.
|
||||
|
||||
We should actually support something like pydantic or jsonschema for the types, so
|
||||
we can expand them recursively for nested types.
|
||||
"""
|
||||
import abc
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
|
||||
class Parameter(TypedDict):
|
||||
"""Representation for a parameter."""
|
||||
|
||||
name: str
|
||||
type: str
|
||||
description: str
|
||||
|
||||
|
||||
class Arguments(TypedDict):
|
||||
"""Arguments are passed to a function during function invocation."""
|
||||
|
||||
name: Optional[str]
|
||||
value: Any
|
||||
|
||||
|
||||
class ReturnValue(TypedDict):
|
||||
"""Representation for a return value of a function call."""
|
||||
|
||||
type: str
|
||||
description: NotRequired[str]
|
||||
|
||||
|
||||
class FunctionDefinition(TypedDict):
|
||||
"""Representation for a function."""
|
||||
|
||||
name: str
|
||||
description: str # Function description
|
||||
parameters: List[Parameter]
|
||||
return_value: ReturnValue
|
||||
|
||||
|
||||
class FunctionInvocation(TypedDict):
|
||||
"""Representation for a function invocation."""
|
||||
|
||||
id: NotRequired[str]
|
||||
name: str
|
||||
arguments: List[Arguments]
|
||||
|
||||
|
||||
class FunctionResult(TypedDict):
|
||||
"""Representation for a function result."""
|
||||
|
||||
id: NotRequired[str]
|
||||
name: str
|
||||
result: Optional[str]
|
||||
error: Optional[str]
|
||||
|
||||
|
||||
class Visitor(abc.ABC):
|
||||
def visit_function_definition(self, function_definition: FunctionDefinition) -> str:
|
||||
"""Render a function."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def visit_function_definitions(
|
||||
self, function_definitions: List[FunctionDefinition]
|
||||
) -> str:
|
||||
"""Render a function."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def visit_function_invocation(self, function_invocation: FunctionInvocation) -> str:
|
||||
"""Render a function invocation."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def visit_function_result(self, function_result: FunctionResult) -> str:
|
||||
"""Render a function result."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class AstPrinter(Visitor):
|
||||
"""Print the AST."""
|
||||
|
||||
|
||||
class XMLEncoder(AstPrinter):
|
||||
def visit_function_definition(self, function_definition: FunctionDefinition) -> str:
|
||||
"""Render a function."""
|
||||
parameters_as_strings = [
|
||||
"<parameter>\n"
|
||||
f"<name>{parameter['name']}</name>\n"
|
||||
f"<type>{parameter['type']}</type>\n"
|
||||
f"<description>{parameter['description']}</description>\n"
|
||||
"</parameter>\n"
|
||||
for parameter in function_definition["parameters"]
|
||||
]
|
||||
function = (
|
||||
"<function>\n"
|
||||
f"<function_name>{function_definition['name']}</function_name>\n"
|
||||
"<description>\n"
|
||||
f"{function_definition['description']}\n"
|
||||
"</description>\n"
|
||||
"<parameters>\n"
|
||||
f"{''.join(parameters_as_strings)}" # Already includes trailing newline
|
||||
"</parameters>\n"
|
||||
"<return_value>\n"
|
||||
f"<type>{function_definition['return_value']['type']}</type>\n"
|
||||
f"<description>{function_definition['return_value']['description']}</description>\n"
|
||||
"</return_value>\n"
|
||||
"</function>"
|
||||
)
|
||||
return function
|
||||
|
||||
def visit_function_definitions(
|
||||
self, function_definitions: List[FunctionDefinition]
|
||||
) -> str:
|
||||
"""Render a function."""
|
||||
strs = [
|
||||
self.visit_function_definition(function_definition)
|
||||
for function_definition in function_definitions
|
||||
]
|
||||
return "<functions>\n" + "\n".join(strs) + "\n</functions>"
|
||||
|
||||
def visit_function_invocation(self, invocation: FunctionInvocation) -> str:
|
||||
"""Render a function invocation."""
|
||||
arguments_as_strings = [
|
||||
"<argument>\n"
|
||||
f"<name>{argument['name']}</name>\n"
|
||||
f"<value>{argument['value']}</value>\n"
|
||||
"</argument>\n"
|
||||
for argument in invocation["arguments"]
|
||||
]
|
||||
lines = ["<function_invocation>"]
|
||||
|
||||
if invocation.get("id"):
|
||||
lines.append(f"<id>{invocation['id']}</id>")
|
||||
|
||||
lines.extend(
|
||||
[
|
||||
f"<function_name>{invocation['name']}</function_name>\n"
|
||||
"<arguments>\n"
|
||||
f"{''.join(arguments_as_strings)}" # Already includes trailing newline
|
||||
"</arguments>\n"
|
||||
"</function_invocation>"
|
||||
]
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
def visit_function_result(self, function_result: FunctionResult) -> str:
|
||||
"""Render a function result."""
|
||||
lines = [
|
||||
"<function_result>",
|
||||
]
|
||||
|
||||
if function_result.get("id"):
|
||||
lines.append(f"<id>{function_result['id']}</id>")
|
||||
|
||||
lines.extend(
|
||||
[
|
||||
f"<function_name>{function_result['name']}</function_name>",
|
||||
f"<result>{function_result['result']}</result>",
|
||||
f"<error>{function_result['error']}</error>",
|
||||
"</function_result>",
|
||||
]
|
||||
)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class TypeScriptEncoder(AstPrinter):
|
||||
def visit_function_definition(self, function_definition: FunctionDefinition) -> str:
|
||||
"""Render a function."""
|
||||
parameters_as_strings = [
|
||||
f"{parameter['name']}: {parameter['type']}"
|
||||
for parameter in function_definition["parameters"]
|
||||
]
|
||||
# Let's use JSdoc style comments
|
||||
# First the function description
|
||||
lines = [
|
||||
f"// {function_definition['description']}",
|
||||
# Then the parameter descriptions
|
||||
*[
|
||||
f"// @param {parameter['name']} {parameter['description']}"
|
||||
for parameter in function_definition["parameters"]
|
||||
],
|
||||
# Then the return value description
|
||||
f"// @returns {function_definition['return_value']['description']}",
|
||||
# Then the function definition
|
||||
f"function {function_definition['name']}("
|
||||
f"{', '.join(parameters_as_strings)}): "
|
||||
f"{function_definition['return_value']['type']};",
|
||||
]
|
||||
|
||||
# finally join
|
||||
function = "\n".join(lines)
|
||||
return function
|
||||
|
||||
def visit_function_definitions(
|
||||
self, function_definitions: List[FunctionDefinition]
|
||||
) -> str:
|
||||
"""Render a function."""
|
||||
strs = [
|
||||
self.visit_function_definition(function_definition)
|
||||
for function_definition in function_definitions
|
||||
]
|
||||
return "\n\n".join(strs)
|
||||
|
||||
def visit_function_invocation(self, invocation: FunctionInvocation) -> str:
|
||||
"""Render a function invocation."""
|
||||
arguments_as_strings = [
|
||||
f"{argument['name']}: {argument['value']}"
|
||||
for argument in invocation["arguments"]
|
||||
]
|
||||
lines = [f"{invocation['name']}(" f"{', '.join(arguments_as_strings)});"]
|
||||
return "\n".join(lines)
|
||||
|
||||
def visit_function_result(self, function_result: FunctionResult) -> str:
|
||||
"""Render a function result."""
|
||||
lines = []
|
||||
if function_result["error"]:
|
||||
lines.append(f"ERROR: {function_result['error']}")
|
||||
else:
|
||||
lines.append(f"> {function_result['result']}")
|
||||
if function_result.get("id"):
|
||||
lines.append(f"// ID: {function_result['id']}")
|
||||
return "\n".join(lines)
|
||||
@@ -0,0 +1,25 @@
|
||||
# EXAMPLE_TRACE = [
|
||||
# HumanMessage(content="type the letter 'o'"),
|
||||
# AIMessage(
|
||||
# content="""
|
||||
# <tool>
|
||||
# {
|
||||
# "tool_name": "type_letter",
|
||||
# "arguments": {
|
||||
# "letter": "o"
|
||||
# }
|
||||
# }
|
||||
# </tool>\
|
||||
# """
|
||||
# ),
|
||||
# HumanMessage(
|
||||
# content="""\
|
||||
# <tool_outputs>
|
||||
# <tool_name>type_letter</tool_name>
|
||||
# <output>o</output>
|
||||
# </tool_outputs>\
|
||||
# """
|
||||
# ),
|
||||
# ]
|
||||
#
|
||||
#
|
||||
@@ -0,0 +1,75 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain.chat_models import ChatAnthropic, ChatFireworks
|
||||
from langchain_core.runnables import Runnable, RunnableConfig
|
||||
|
||||
from agents.agent import create_agent
|
||||
from agents.parser import ParameterizedAgentParser
|
||||
from langchain_benchmarks.model_registration import FIREWORK_NAME_TO_MODEL
|
||||
from langchain_benchmarks.schema import ToolUsageTask
|
||||
from langchain_benchmarks.tool_usage import apply_agent_executor_adapter
|
||||
|
||||
|
||||
class CustomAgentFactory:
|
||||
def __init__(self, task: ToolUsageTask, model: str) -> None:
|
||||
"""Create an OpenAI agent factory for the given task.
|
||||
|
||||
Args:
|
||||
task: The task to create an agent factory for.
|
||||
"""
|
||||
if model not in self.list_models():
|
||||
raise ValueError(f"Unknown model: {model}")
|
||||
self.task = task
|
||||
self.model = model
|
||||
|
||||
@staticmethod
|
||||
def list_models() -> List[str]:
|
||||
"""List all models."""
|
||||
return sorted(
|
||||
[
|
||||
"claude-2.1",
|
||||
"claude-2",
|
||||
*FIREWORK_NAME_TO_MODEL.keys(),
|
||||
]
|
||||
)
|
||||
|
||||
def __call__(self) -> Runnable:
|
||||
env = self.task.create_environment()
|
||||
if self.model in {"claude-2.1", "claude-2"}:
|
||||
model = ChatAnthropic(model=self.model, temperature=0)
|
||||
elif self.model in FIREWORK_NAME_TO_MODEL:
|
||||
model = ChatFireworks(
|
||||
model=FIREWORK_NAME_TO_MODEL[self.model], temperature=0
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown model: {self.model}")
|
||||
|
||||
def _add_task_instructions(
|
||||
input: dict, config: Optional[RunnableConfig] = None, **kwargs
|
||||
) -> dict:
|
||||
"""Add task instructions to the question."""
|
||||
input = input.copy()
|
||||
input["question"] = (
|
||||
f"{self.task.instructions}\nWrite down your answer, "
|
||||
f"but do not explain it. Input: `{input['question']}`"
|
||||
)
|
||||
return input
|
||||
|
||||
agent = create_agent(
|
||||
model,
|
||||
env.tools,
|
||||
ParameterizedAgentParser(
|
||||
wrapping_xml_tag="tool", require_closing_xml_tag=False
|
||||
),
|
||||
)
|
||||
executor = AgentExecutor(
|
||||
agent=agent,
|
||||
tools=env.tools,
|
||||
handle_parsing_errors=True,
|
||||
return_intermediate_steps=True,
|
||||
)
|
||||
|
||||
return _add_task_instructions | apply_agent_executor_adapter(
|
||||
executor, state_reader=env.read_state
|
||||
)
|
||||
@@ -0,0 +1,120 @@
|
||||
import ast
|
||||
import re
|
||||
from typing import Any, Union
|
||||
|
||||
from langchain.agents import AgentOutputParser
|
||||
from langchain.pydantic_v1 import BaseModel, Field, ValidationError
|
||||
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
|
||||
class _ToolInvocationRequest(BaseModel):
|
||||
"""Light-weight pydantic model for validating the raw tool invocation request.
|
||||
|
||||
The purpose of this model, is to make sure that whatever as parsed from
|
||||
the raw llm output has `tool_name` and potential `arguments` fields, and
|
||||
nothing else.
|
||||
"""
|
||||
|
||||
tool_name: str
|
||||
# OK parameterless tools which do not take arguments
|
||||
named_arguments: Any = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ParameterizedAgentParser(AgentOutputParser):
|
||||
"""A generalized parser that makes it easier to parameterize different parsing."""
|
||||
|
||||
wrapping_xml_tag: str
|
||||
"""The tag that wraps the function invocation request.
|
||||
|
||||
For example, if "tool", then the function invocation request should be wrapped
|
||||
in <tool>...</tool>.
|
||||
"""
|
||||
require_closing_xml_tag: bool = False
|
||||
"""Whether we should require a closing tag for the wrapping_xml_tag.
|
||||
|
||||
For example, if True, then the function invocation request should be wrapped
|
||||
"""
|
||||
|
||||
def parse(self, text: str) -> Union[AgentFinish, AgentAction]:
|
||||
"""Parse the output of the agent."""
|
||||
open_tag = f"<{self.wrapping_xml_tag}>"
|
||||
close_tag = f"</{self.wrapping_xml_tag}>"
|
||||
if open_tag in text:
|
||||
# This is a hack to make sure that </tool> is always present
|
||||
# in the output if <tool>. </tool> may be a stop sequence for the
|
||||
# language model, so depending on implementation
|
||||
# the stop sequence may be cut off.
|
||||
# There might be a better way to do this, but this works and
|
||||
# is simple.
|
||||
if not self.require_closing_xml_tag:
|
||||
text += close_tag
|
||||
|
||||
pattern = rf"{open_tag}(?P<invocation>.*?){close_tag}"
|
||||
match = re.search(pattern, text, re.DOTALL)
|
||||
if match:
|
||||
content = match.group("invocation").strip()
|
||||
return parse_invocation(content, self.wrapping_xml_tag)
|
||||
|
||||
return AgentFinish(
|
||||
log=text,
|
||||
return_values={
|
||||
"output": text,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def parse_invocation(text: str, tag: str) -> AgentAction:
|
||||
"""Parse the content of the function invocation.
|
||||
|
||||
Args:
|
||||
text: The text to parse.
|
||||
tag: The tag that wraps the function invocation request.
|
||||
|
||||
Returns:
|
||||
An AgentAction that corresponds to the function invocation.
|
||||
|
||||
Raises:
|
||||
OutputParserException: If the parsing fails.
|
||||
|
||||
This exception is meant to be caught by the agent executor and
|
||||
handled appropriately to provide feedback to the LLM.
|
||||
"""
|
||||
ai_content = f"<{tag}>{text}</{tag}>"
|
||||
|
||||
try:
|
||||
result = ast.literal_eval(text)
|
||||
except Exception as e:
|
||||
# Convert this to something controllable by the user.
|
||||
err_msg = (
|
||||
f"ERROR: Please use the format "
|
||||
f'<{tag}>{{"tool_name": $TOOL_NAME, "arguments": $ARGUMENTS}}</{tag}>'
|
||||
)
|
||||
raise OutputParserException(
|
||||
error=e,
|
||||
llm_output=ai_content,
|
||||
observation=err_msg,
|
||||
send_to_llm=True,
|
||||
)
|
||||
|
||||
try:
|
||||
request = _ToolInvocationRequest(**result)
|
||||
except ValidationError as e:
|
||||
err_msg = (
|
||||
f"ERROR: Please use the format "
|
||||
f'<{tag}>{{"tool_name": $TOOL_NAME, "arguments": $ARGUMENTS}}</{tag}>'
|
||||
)
|
||||
raise OutputParserException(
|
||||
error=e,
|
||||
llm_output=ai_content,
|
||||
send_to_llm=True,
|
||||
observation=err_msg,
|
||||
)
|
||||
|
||||
return AgentActionMessageLog(
|
||||
message_log=[AIMessage(content=ai_content)],
|
||||
tool=request.tool_name,
|
||||
tool_input=request.named_arguments,
|
||||
log=f"\nInvoking {request.tool_name}: {request.named_arguments}\n\t",
|
||||
)
|
||||
@@ -0,0 +1,42 @@
|
||||
AGENT_INSTRUCTIONS_XML_FORMAT = """\
|
||||
In this environment you have access to a set of tools you can use to answer the user's question.
|
||||
|
||||
You may call them like this:
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>$TOOL_NAME</tool_name>
|
||||
<parameters>
|
||||
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
|
||||
...
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
|
||||
Here are the tools available:
|
||||
|
||||
{tool_description}
|
||||
""" # noqa: E501
|
||||
|
||||
AGENT_INSTRUCTIONS_BLOB_STYLE = """\
|
||||
In this environment you have access to a set of tools you can use to answer the user's question.
|
||||
|
||||
Here are the tools available:
|
||||
|
||||
{tool_description}
|
||||
|
||||
You may call one tool at a time using a format that includes <tool> and </tool> tag.
|
||||
|
||||
Inside the tag the content is a python dictionary that uses python literals (e.g., numbers, strings, lists, dictionaries, etc.) to specify the tool invocation.
|
||||
|
||||
It must match the schema of the function as described in the tool description.
|
||||
"arguments" is a dictionary of the arguments to the function.
|
||||
|
||||
<tool>
|
||||
{{
|
||||
"tool_name": $TOOL_NAME,
|
||||
"arguments": $ARGUMENTS
|
||||
}}
|
||||
</tool>
|
||||
|
||||
If you do not know the answer use more tools. You can only take a single action at a time.\
|
||||
""" # noqa: E501
|
||||
@@ -0,0 +1,29 @@
|
||||
"""Test XML encoding and decoding of function definitions, invocation, and results."""
|
||||
from agents.encoder import (
|
||||
FunctionDefinition,
|
||||
TypeScriptEncoder,
|
||||
)
|
||||
|
||||
|
||||
def test_function_definition() -> None:
|
||||
"""Test encoding a function definition."""
|
||||
function_definition = FunctionDefinition(
|
||||
name="test_function",
|
||||
description="A test function",
|
||||
parameters=[
|
||||
{"name": "test_parameter", "type": "str", "description": "A test parameter"}
|
||||
],
|
||||
return_value={"type": "str", "description": "A test return value"},
|
||||
)
|
||||
encoder = TypeScriptEncoder()
|
||||
xml = encoder.visit_function_definition(function_definition)
|
||||
assert xml == (
|
||||
"// A test function\n"
|
||||
"// @param test_parameter A test parameter\n"
|
||||
"// @returns A test return value\n"
|
||||
"function test_function(test_parameter: str): str;"
|
||||
)
|
||||
|
||||
|
||||
# Not important to test other ones right now since we can't parse / interpret
|
||||
# typescript anyway.
|
||||
@@ -0,0 +1,79 @@
|
||||
"""Test XML encoding and decoding of function definitions, invocation, and results."""
|
||||
from agents.encoder import (
|
||||
FunctionDefinition,
|
||||
FunctionInvocation,
|
||||
FunctionResult,
|
||||
XMLEncoder,
|
||||
)
|
||||
|
||||
|
||||
def test_function_definition_encoding() -> None:
|
||||
"""Test encoding a function definition."""
|
||||
function_definition = FunctionDefinition(
|
||||
name="test_function",
|
||||
description="A test function",
|
||||
parameters=[
|
||||
{"name": "test_parameter", "type": "str", "description": "A test parameter"}
|
||||
],
|
||||
return_value={"type": "str", "description": "A test return value"},
|
||||
)
|
||||
encoder = XMLEncoder()
|
||||
xml = encoder.visit_function_definition(function_definition)
|
||||
assert xml == (
|
||||
"<function>\n"
|
||||
"<function_name>test_function</function_name>\n"
|
||||
"<description>\n"
|
||||
"A test function\n"
|
||||
"</description>\n"
|
||||
"<parameters>\n"
|
||||
"<parameter>\n"
|
||||
"<name>test_parameter</name>\n"
|
||||
"<type>str</type>\n"
|
||||
"<description>A test parameter</description>\n"
|
||||
"</parameter>\n"
|
||||
"</parameters>\n"
|
||||
"<return_value>\n"
|
||||
"<type>str</type>\n"
|
||||
"<description>A test return value</description>\n"
|
||||
"</return_value>\n"
|
||||
"</function>"
|
||||
)
|
||||
|
||||
|
||||
def test_function_result_encoding() -> None:
|
||||
"""Test encoding a function result."""
|
||||
function_result = FunctionResult(
|
||||
name="test_function",
|
||||
result="test_result",
|
||||
error="test_error",
|
||||
)
|
||||
encoder = XMLEncoder()
|
||||
xml = encoder.visit_function_result(function_result)
|
||||
assert xml == (
|
||||
"<function_result>\n"
|
||||
"<function_name>test_function</function_name>\n"
|
||||
"<result>test_result</result>\n"
|
||||
"<error>test_error</error>\n"
|
||||
"</function_result>"
|
||||
)
|
||||
|
||||
|
||||
def test_function_invocation() -> None:
|
||||
"""Test function invocation."""
|
||||
function_invocation = FunctionInvocation(
|
||||
name="test_function",
|
||||
arguments=[{"name": "test_argument", "value": "test_value"}],
|
||||
)
|
||||
encoder = XMLEncoder()
|
||||
xml = encoder.visit_function_invocation(function_invocation)
|
||||
assert xml == (
|
||||
"<function_invocation>\n"
|
||||
"<function_name>test_function</function_name>\n"
|
||||
"<arguments>\n"
|
||||
"<argument>\n"
|
||||
"<name>test_argument</name>\n"
|
||||
"<value>test_value</value>\n"
|
||||
"</argument>\n"
|
||||
"</arguments>\n"
|
||||
"</function_invocation>"
|
||||
)
|
||||
@@ -0,0 +1,58 @@
|
||||
import pytest
|
||||
from langchain.tools import tool
|
||||
|
||||
from agents.adapters import convert_tool_to_function_definition
|
||||
from agents.encoder import XMLEncoder
|
||||
|
||||
|
||||
@tool
|
||||
def get_hello() -> str:
|
||||
"""Get hello."""
|
||||
return "hello"
|
||||
|
||||
|
||||
@tool
|
||||
def repeat(x: str) -> str:
|
||||
"""Repeat x.
|
||||
|
||||
Args:
|
||||
x: The string to repeat.
|
||||
|
||||
Returns:
|
||||
The repeated string.
|
||||
"""
|
||||
return x
|
||||
|
||||
|
||||
def test_parameterless_function() -> None:
|
||||
"""Test foo."""
|
||||
function_definition = convert_tool_to_function_definition(get_hello)
|
||||
assert function_definition == {
|
||||
"name": "get_hello",
|
||||
"description": "Get hello.",
|
||||
"parameters": [],
|
||||
"return_value": {
|
||||
"type": "Any",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.skip("Need to fix handling of leading whitespace")
|
||||
def test_function_with_parameters() -> None:
|
||||
import textwrap
|
||||
|
||||
doc = textwrap.dedent(repeat.func.__doc__)
|
||||
assert convert_tool_to_function_definition(repeat) == {
|
||||
"name": "repeat",
|
||||
"description": doc,
|
||||
"parameters": [
|
||||
{
|
||||
"name": "x",
|
||||
"type": "str",
|
||||
"description": "", # Need to fix this
|
||||
}
|
||||
],
|
||||
"return_value": {
|
||||
"type": "Any",
|
||||
},
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
from langchain_adapters.alternative import AgentOutputParser
|
||||
from langchain_core.agents import AgentFinish
|
||||
|
||||
|
||||
def test_parser() -> None:
|
||||
"""Test parser."""
|
||||
parser = AgentOutputParser(require_closing_tag=False, tag="tool")
|
||||
assert isinstance(parser.invoke("goodbye"), AgentFinish)
|
||||
assert parser.invoke("<tool>hello</tool>") == "hello"
|
||||
assert parser.invoke("<tool>hello") == "hello"
|
||||
# assert isinstance(parser.invoke("<tag>hello</tag>"), AgentAction)
|
||||
@@ -0,0 +1,35 @@
|
||||
"""Throttle using a token bucket."""
|
||||
import threading
|
||||
import time
|
||||
|
||||
|
||||
class Throttle:
|
||||
def __init__(self, rate: int) -> None:
|
||||
"""Initialize the throttle."""
|
||||
self.rate = rate
|
||||
self.tokens = 0
|
||||
self._consume_lock = threading.Lock()
|
||||
self.last = None
|
||||
|
||||
def consume(self, amount: int = 0) -> int:
|
||||
"""Consume the given amount of tokens."""
|
||||
with self._consume_lock:
|
||||
now = time.time()
|
||||
|
||||
# initialize on first call to avoid a burst
|
||||
if self.last is None:
|
||||
self.last = now
|
||||
|
||||
elapsed = now - self.last
|
||||
|
||||
if elapsed * self.rate > 1:
|
||||
self.tokens += elapsed * self.rate
|
||||
self.last = now
|
||||
|
||||
self.tokens = min(self.tokens, self.rate)
|
||||
|
||||
if self.tokens >= amount:
|
||||
self.tokens -= amount
|
||||
return amount
|
||||
|
||||
return 0
|
||||
@@ -1 +1,3 @@
|
||||
chromadb/
|
||||
index.md
|
||||
Untitled.ipynb
|
||||
|
||||
@@ -1,225 +1,226 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "033684fb-65b2-4586-a959-68c614741ca2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Datasets\n",
|
||||
"[](https://colab.research.google.com/github/langchain-ai/langchain-benchmarks/blob/main/docs/source/notebooks/datasets.ipynb)\n",
|
||||
"\n",
|
||||
"Here, we'll see how to work with LangSmith datasets."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install -U langchain-benchmarks"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "6d272fbf-710e-4a49-a0da-67e010541905",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_benchmarks import clone_public_dataset, download_public_dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "18ee0f96-e5c4-4ae9-aebf-7d8b88c51662",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's first download the dataset to the local file system"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "58b94f6d-0c91-4361-9b22-f758ffaa150a",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Fetching examples...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "5a2fad8c0c3549ec96a3b38fe8a002b0",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/21 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Done fetching examples.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"download_public_dataset(\n",
|
||||
" \"https://smith.langchain.com/public/452ccafc-18e1-4314-885b-edd735f17b9d/examples\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "841db832-b0d3-4fd1-8531-1154ec9b3caa",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"we can take a look at the first two examples"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "664e90fc-af84-4c5f-a3dd-5d9ffe649650",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[\n",
|
||||
" {\n",
|
||||
" \"created_at\": \"2023-11-15T15:26:53.511629\",\n",
|
||||
" \"dataset_id\": \"9f73165c-d333-4d14-8f59-bd7eede5db08\",\n",
|
||||
" \"id\": \"0703a989-2693-4039-a1f6-7281fc1b4cb0\",\n",
|
||||
" \"inputs\": {\n",
|
||||
" \"question\": \"do bob and alice live in the same city?\"\n",
|
||||
" },\n",
|
||||
" \"modified_at\": \"2023-11-15T15:26:53.511629\",\n",
|
||||
" \"outputs\": {\n",
|
||||
" \"expected_steps\": [\n",
|
||||
" \"find_users_by_name\",\n",
|
||||
" \"get_user_location\",\n",
|
||||
" \"get_city_for_location\",\n",
|
||||
" \"get_user_location\",\n",
|
||||
" \"get_city_for_location\"\n",
|
||||
" ],\n",
|
||||
" \"order_matters\": false,\n",
|
||||
" \"reference\": \"no\"\n",
|
||||
" },\n",
|
||||
" \"runs\": []\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"created_at\": \"2023-11-15T15:26:53.491359\",\n",
|
||||
" \"dataset_id\": \"9f73165c-d333-4d14-8f59-bd7eede5db08\",\n",
|
||||
" \"id\": \"b258b95a-9524-4da7-b758-c5481109322d\",\n",
|
||||
" \"inputs\": {\n",
|
||||
" \"question\": \"Is it likely that Donna is outside with an umbrella at this time?\"\n",
|
||||
" },\n",
|
||||
" \"modified_at\": \"2023-11-15T15:26:53.491359\",\n",
|
||||
" \"outputs\": {\n",
|
||||
" \"expected_steps\": [\n",
|
||||
" \"find_users_by_name\",\n",
|
||||
" \"get_user_location\",\n",
|
||||
" \"get_current_time_for_location\",\n",
|
||||
" \"get_current_weather_for_location\"\n",
|
||||
" ],\n",
|
||||
" \"order_matters\": false,\n",
|
||||
" \"reference\": \"yes\"\n",
|
||||
" },\n",
|
||||
" \"runs\": []\n",
|
||||
" }\n",
|
||||
"]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"\n",
|
||||
"with open(\"./e95d45da-aaa3-44b3-ba2b-7c15ff6e46f5.json\", \"r\", encoding=\"utf-8\") as f:\n",
|
||||
" print(json.dumps(json.load(f)[:2], indent=2, sort_keys=True))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2c6cf01f-466b-406d-b4c7-2395747780fd",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can also clone the dataset to our local tenant"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e4dea4df-2f1c-436b-a71c-49ffb2295ccc",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Executing this command will clone the dataset to your own LangSmith tenant. \n",
|
||||
"For this to work you must have a [LangSmith account](https://smith.langchain.com/) set up."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"# Get from https://smith.langchain.com/settings\n",
|
||||
"os.environ[\"LANGCHAIN_API_KEY\"] = \"ls_...\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "18d0b905-2a6a-4752-a7cb-8653bd9049e3",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"clone_public_dataset(\n",
|
||||
" \"https://smith.langchain.com/public/452ccafc-18e1-4314-885b-edd735f17b9d/examples\",\n",
|
||||
" dataset_name=\"Agent Dataset\",\n",
|
||||
")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.2"
|
||||
}
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "033684fb-65b2-4586-a959-68c614741ca2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Datasets\n",
|
||||
"\n",
|
||||
"Here, we'll see how to work with LangSmith datasets."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d9aa20db",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install -U langchain-benchmarks"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "6d272fbf-710e-4a49-a0da-67e010541905",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_benchmarks import clone_public_dataset, download_public_dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "18ee0f96-e5c4-4ae9-aebf-7d8b88c51662",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's first download the dataset to the local file system"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "58b94f6d-0c91-4361-9b22-f758ffaa150a",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Fetching examples...\n"
|
||||
]
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "5a2fad8c0c3549ec96a3b38fe8a002b0",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/21 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Done fetching examples.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"download_public_dataset(\n",
|
||||
" \"https://smith.langchain.com/public/452ccafc-18e1-4314-885b-edd735f17b9d/examples\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "841db832-b0d3-4fd1-8531-1154ec9b3caa",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"we can take a look at the first two examples"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "664e90fc-af84-4c5f-a3dd-5d9ffe649650",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[\n",
|
||||
" {\n",
|
||||
" \"created_at\": \"2023-11-15T15:26:53.511629\",\n",
|
||||
" \"dataset_id\": \"9f73165c-d333-4d14-8f59-bd7eede5db08\",\n",
|
||||
" \"id\": \"0703a989-2693-4039-a1f6-7281fc1b4cb0\",\n",
|
||||
" \"inputs\": {\n",
|
||||
" \"question\": \"do bob and alice live in the same city?\"\n",
|
||||
" },\n",
|
||||
" \"modified_at\": \"2023-11-15T15:26:53.511629\",\n",
|
||||
" \"outputs\": {\n",
|
||||
" \"expected_steps\": [\n",
|
||||
" \"find_users_by_name\",\n",
|
||||
" \"get_user_location\",\n",
|
||||
" \"get_city_for_location\",\n",
|
||||
" \"get_user_location\",\n",
|
||||
" \"get_city_for_location\"\n",
|
||||
" ],\n",
|
||||
" \"order_matters\": false,\n",
|
||||
" \"reference\": \"no\"\n",
|
||||
" },\n",
|
||||
" \"runs\": []\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"created_at\": \"2023-11-15T15:26:53.491359\",\n",
|
||||
" \"dataset_id\": \"9f73165c-d333-4d14-8f59-bd7eede5db08\",\n",
|
||||
" \"id\": \"b258b95a-9524-4da7-b758-c5481109322d\",\n",
|
||||
" \"inputs\": {\n",
|
||||
" \"question\": \"Is it likely that Donna is outside with an umbrella at this time?\"\n",
|
||||
" },\n",
|
||||
" \"modified_at\": \"2023-11-15T15:26:53.491359\",\n",
|
||||
" \"outputs\": {\n",
|
||||
" \"expected_steps\": [\n",
|
||||
" \"find_users_by_name\",\n",
|
||||
" \"get_user_location\",\n",
|
||||
" \"get_current_time_for_location\",\n",
|
||||
" \"get_current_weather_for_location\"\n",
|
||||
" ],\n",
|
||||
" \"order_matters\": false,\n",
|
||||
" \"reference\": \"yes\"\n",
|
||||
" },\n",
|
||||
" \"runs\": []\n",
|
||||
" }\n",
|
||||
"]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"\n",
|
||||
"with open(\"./e95d45da-aaa3-44b3-ba2b-7c15ff6e46f5.json\", \"r\", encoding=\"utf-8\") as f:\n",
|
||||
" print(json.dumps(json.load(f)[:2], indent=2, sort_keys=True))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2c6cf01f-466b-406d-b4c7-2395747780fd",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can also clone the dataset to our local tenant"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e4dea4df-2f1c-436b-a71c-49ffb2295ccc",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Executing this command will clone the dataset to your own LangSmith tenant. \n",
|
||||
"For this to work you must have a [LangSmith account](https://smith.langchain.com/) set up."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b20ba9a6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"# Get from https://smith.langchain.com/settings\n",
|
||||
"os.environ[\"LANGCHAIN_API_KEY\"] = \"ls_...\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "18d0b905-2a6a-4752-a7cb-8653bd9049e3",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"clone_public_dataset(\n",
|
||||
" \"https://smith.langchain.com/public/452ccafc-18e1-4314-885b-edd735f17b9d/examples\",\n",
|
||||
" dataset_name=\"Agent Dataset\",\n",
|
||||
")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
|
||||
@@ -0,0 +1,267 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d9cb90d8-e6a1-4c89-9cde-0e6c0a28f5c0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Model Registry\n",
|
||||
"\n",
|
||||
"Here, we'll see how to access the model registry.\n",
|
||||
"\n",
|
||||
"If you see a model that you want to use and it's missing, please open a PR to add it!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "31831289-51fb-4ee5-98f3-0476cf11b187",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_benchmarks import model_registry"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "aaed190d-fa4b-4445-9bfb-0e784e2a083b",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<table>\n",
|
||||
"<thead>\n",
|
||||
"<tr><th>Name </th><th>Type </th><th>Provider </th><th>Description </th></tr>\n",
|
||||
"</thead>\n",
|
||||
"<tbody>\n",
|
||||
"<tr><td>gpt-3.5-turbo-1106 </td><td>chat </td><td>openai </td><td>The latest GPT-3.5 Turbo model with improved instruction following, JSON mode, reproducible outputs, parallel function calling, and more. Returns a maximum of 4,096 output tokens. Learn more.</td></tr>\n",
|
||||
"<tr><td>gpt-3.5-turbo </td><td>chat </td><td>openai </td><td>Currently points to gpt-3.5-turbo-0613. </td></tr>\n",
|
||||
"<tr><td>gpt-3.5-turbo-16k </td><td>chat </td><td>openai </td><td>Currently points to gpt-3.5-turbo-0613. </td></tr>\n",
|
||||
"<tr><td>gpt-3.5-turbo-instruct</td><td>llm </td><td>openai </td><td>Similar capabilities as text-davinci-003 but compatible with legacy Completions endpoint and not Chat Completions. </td></tr>\n",
|
||||
"<tr><td>gpt-3.5-turbo-0613 </td><td>chat </td><td>openai </td><td>Legacy Snapshot of gpt-3.5-turbo from June 13th 2023. Will be deprecated on June 13, 2024. </td></tr>\n",
|
||||
"<tr><td>gpt-3.5-turbo-16k-0613</td><td>chat </td><td>openai </td><td>Legacy Snapshot of gpt-3.5-16k-turbo from June 13th 2023. Will be deprecated on June 13, 2024. </td></tr>\n",
|
||||
"<tr><td>gpt-3.5-turbo-0301 </td><td>chat </td><td>openai </td><td>Legacy Snapshot of gpt-3.5-turbo from March 1st 2023. Will be deprecated on June 13th 2024. </td></tr>\n",
|
||||
"<tr><td>text-davinci-003 </td><td>llm </td><td>openai </td><td>Legacy Can do language tasks with better quality and consistency than the curie, babbage, or ada models. Will be deprecated on Jan 4th 2024. </td></tr>\n",
|
||||
"<tr><td>text-davinci-002 </td><td>llm </td><td>openai </td><td>Legacy Similar capabilities to text-davinci-003 but trained with supervised fine-tuning instead of reinforcement learning. Will be deprecated on Jan 4th 2024. </td></tr>\n",
|
||||
"<tr><td>code-davinci-002 </td><td>llm </td><td>openai </td><td>Legacy Optimized for code-completion tasks. Will be deprecated on Jan 4th 2024. </td></tr>\n",
|
||||
"<tr><td>llama-v2-7b-chat </td><td>chat </td><td>fireworks </td><td>7b parameter LlamaChat model </td></tr>\n",
|
||||
"<tr><td>llama-v2-13b-chat </td><td>chat </td><td>fireworks </td><td>13b parameter LlamaChat model </td></tr>\n",
|
||||
"<tr><td>llama-v2-70b-chat </td><td>chat </td><td>fireworks </td><td>70b parameter LlamaChat model </td></tr>\n",
|
||||
"</tbody>\n",
|
||||
"</table>"
|
||||
],
|
||||
"text/plain": [
|
||||
"ModelRegistry(registered_models=[RegisteredModel(name='gpt-3.5-turbo-1106', provider='openai', description='The latest GPT-3.5 Turbo model with improved instruction following, JSON mode, reproducible outputs, parallel function calling, and more. Returns a maximum of 4,096 output tokens. Learn more.', params={'model': 'gpt-3.5-turbo-1106'}, type='chat', path=None), RegisteredModel(name='gpt-3.5-turbo', provider='openai', description='Currently points to gpt-3.5-turbo-0613.', params={'model': 'gpt-3.5-turbo'}, type='chat', path=None), RegisteredModel(name='gpt-3.5-turbo-16k', provider='openai', description='Currently points to gpt-3.5-turbo-0613.', params={'model': 'gpt-3.5-turbo-16k'}, type='chat', path=None), RegisteredModel(name='gpt-3.5-turbo-instruct', provider='openai', description='Similar capabilities as text-davinci-003 but compatible with legacy Completions endpoint and not Chat Completions.', params={'model': 'gpt-3.5-turbo-instruct'}, type='llm', path=None), RegisteredModel(name='gpt-3.5-turbo-0613', provider='openai', description='Legacy Snapshot of gpt-3.5-turbo from June 13th 2023. Will be deprecated on June 13, 2024.', params={'model': 'gpt-3.5-turbo-0613'}, type='chat', path=None), RegisteredModel(name='gpt-3.5-turbo-16k-0613', provider='openai', description='Legacy Snapshot of gpt-3.5-16k-turbo from June 13th 2023. Will be deprecated on June 13, 2024.', params={'model': 'gpt-3.5-turbo-16k-0613'}, type='chat', path=None), RegisteredModel(name='gpt-3.5-turbo-0301', provider='openai', description='Legacy Snapshot of gpt-3.5-turbo from March 1st 2023. Will be deprecated on June 13th 2024.', params={'model': 'gpt-3.5-turbo-0301'}, type='chat', path=None), RegisteredModel(name='text-davinci-003', provider='openai', description='Legacy Can do language tasks with better quality and consistency than the curie, babbage, or ada models. Will be deprecated on Jan 4th 2024.', params={'model': 'text-davinci-003'}, type='llm', path=None), RegisteredModel(name='text-davinci-002', provider='openai', description='Legacy Similar capabilities to text-davinci-003 but trained with supervised fine-tuning instead of reinforcement learning. Will be deprecated on Jan 4th 2024.', params={'model': 'text-davinci-002'}, type='llm', path=None), RegisteredModel(name='code-davinci-002', provider='openai', description='Legacy Optimized for code-completion tasks. Will be deprecated on Jan 4th 2024.', params={'model': 'code-davinci-002'}, type='llm', path=None), RegisteredModel(name='llama-v2-7b-chat', provider='fireworks', description='7b parameter LlamaChat model', params={'model': 'accounts/fireworks/models/llama-v2-7b-chat'}, type='chat', path=None), RegisteredModel(name='llama-v2-13b-chat', provider='fireworks', description='13b parameter LlamaChat model', params={'model': 'accounts/fireworks/models/llama-v2-13b-chat'}, type='chat', path=None), RegisteredModel(name='llama-v2-70b-chat', provider='fireworks', description='70b parameter LlamaChat model', params={'model': 'accounts/fireworks/models/llama-v2-70b-chat'}, type='chat', path=None)])"
|
||||
]
|
||||
},
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model_registry"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "64bfc631-1f1e-4cf4-8636-b8be7b46fef8",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<table>\n",
|
||||
"<tbody>\n",
|
||||
"<tr><td>name </td><td>gpt-3.5-turbo-1106 </td></tr>\n",
|
||||
"<tr><td>type </td><td>chat </td></tr>\n",
|
||||
"<tr><td>provider </td><td>openai </td></tr>\n",
|
||||
"<tr><td>description</td><td>The latest GPT-3.5 Turbo model with improved instruction following, JSON mode, reproducible outputs, parallel function calling, and more. Returns a maximum of 4,096 output tokens. Learn more.</td></tr>\n",
|
||||
"<tr><td>model_path </td><td>langchain.chat_models.openai.ChatOpenAI </td></tr>\n",
|
||||
"</tbody>\n",
|
||||
"</table>"
|
||||
],
|
||||
"text/plain": [
|
||||
"RegisteredModel(name='gpt-3.5-turbo-1106', provider='openai', description='The latest GPT-3.5 Turbo model with improved instruction following, JSON mode, reproducible outputs, parallel function calling, and more. Returns a maximum of 4,096 output tokens. Learn more.', params={'model': 'gpt-3.5-turbo-1106'}, type='chat', path=None)"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"registered_model = model_registry[0]\n",
|
||||
"registered_model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "3604d49e-afbe-48ad-ac10-1e538b1ad376",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = registered_model.get_model()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "bdece532-9843-427a-a10b-4545ed4ec151",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content='Hello there! How can I assist you today?')"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.invoke('hello!')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "db40d4da-dc70-4e6d-b7e8-61de1e15ed2e",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<table>\n",
|
||||
"<thead>\n",
|
||||
"<tr><th>Name </th><th>Type </th><th>Provider </th><th>Description </th></tr>\n",
|
||||
"</thead>\n",
|
||||
"<tbody>\n",
|
||||
"<tr><td>gpt-3.5-turbo-1106</td><td>chat </td><td>openai </td><td>The latest GPT-3.5 Turbo model with improved instruction following, JSON mode, reproducible outputs, parallel function calling, and more. Returns a maximum of 4,096 output tokens. Learn more.</td></tr>\n",
|
||||
"<tr><td>gpt-3.5-turbo </td><td>chat </td><td>openai </td><td>Currently points to gpt-3.5-turbo-0613. </td></tr>\n",
|
||||
"<tr><td>gpt-3.5-turbo-16k </td><td>chat </td><td>openai </td><td>Currently points to gpt-3.5-turbo-0613. </td></tr>\n",
|
||||
"</tbody>\n",
|
||||
"</table>"
|
||||
],
|
||||
"text/plain": [
|
||||
"ModelRegistry(registered_models=[RegisteredModel(name='gpt-3.5-turbo-1106', provider='openai', description='The latest GPT-3.5 Turbo model with improved instruction following, JSON mode, reproducible outputs, parallel function calling, and more. Returns a maximum of 4,096 output tokens. Learn more.', params={'model': 'gpt-3.5-turbo-1106'}, type='chat', path=None), RegisteredModel(name='gpt-3.5-turbo', provider='openai', description='Currently points to gpt-3.5-turbo-0613.', params={'model': 'gpt-3.5-turbo'}, type='chat', path=None), RegisteredModel(name='gpt-3.5-turbo-16k', provider='openai', description='Currently points to gpt-3.5-turbo-0613.', params={'model': 'gpt-3.5-turbo-16k'}, type='chat', path=None)])"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model_registry[:3]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "9874846a-52f3-4921-b1ed-0858521bb9a9",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<table>\n",
|
||||
"<thead>\n",
|
||||
"<tr><th>Name </th><th>Type </th><th>Provider </th><th>Description </th></tr>\n",
|
||||
"</thead>\n",
|
||||
"<tbody>\n",
|
||||
"<tr><td>llama-v2-7b-chat </td><td>chat </td><td>fireworks </td><td>7b parameter LlamaChat model </td></tr>\n",
|
||||
"<tr><td>llama-v2-13b-chat</td><td>chat </td><td>fireworks </td><td>13b parameter LlamaChat model</td></tr>\n",
|
||||
"<tr><td>llama-v2-70b-chat</td><td>chat </td><td>fireworks </td><td>70b parameter LlamaChat model</td></tr>\n",
|
||||
"</tbody>\n",
|
||||
"</table>"
|
||||
],
|
||||
"text/plain": [
|
||||
"ModelRegistry(registered_models=[RegisteredModel(name='llama-v2-7b-chat', provider='fireworks', description='7b parameter LlamaChat model', params={'model': 'accounts/fireworks/models/llama-v2-7b-chat'}, type='chat', path=None), RegisteredModel(name='llama-v2-13b-chat', provider='fireworks', description='13b parameter LlamaChat model', params={'model': 'accounts/fireworks/models/llama-v2-13b-chat'}, type='chat', path=None), RegisteredModel(name='llama-v2-70b-chat', provider='fireworks', description='70b parameter LlamaChat model', params={'model': 'accounts/fireworks/models/llama-v2-70b-chat'}, type='chat', path=None)])"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model_registry.filter(provider='fireworks')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "eb531591-f46b-4745-ae67-4dfd6217ec5f",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"gpt-3.5-turbo-1106\n",
|
||||
"gpt-3.5-turbo\n",
|
||||
"gpt-3.5-turbo-16k\n",
|
||||
"gpt-3.5-turbo-instruct\n",
|
||||
"gpt-3.5-turbo-0613\n",
|
||||
"gpt-3.5-turbo-16k-0613\n",
|
||||
"gpt-3.5-turbo-0301\n",
|
||||
"text-davinci-003\n",
|
||||
"text-davinci-002\n",
|
||||
"code-davinci-002\n",
|
||||
"llama-v2-7b-chat\n",
|
||||
"llama-v2-13b-chat\n",
|
||||
"llama-v2-70b-chat\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for registered_model in model_registry:\n",
|
||||
" print(registered_model.name)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
+610
@@ -0,0 +1,610 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9fa3470d-9448-4792-9f65-6978fc58cf84",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Multi-modal eval: Baseline\n",
|
||||
"\n",
|
||||
"`Multi-modal slide decks` is a public dataset that contains a dataset of question-answer pairs from slide decks with visual content.\n",
|
||||
"\n",
|
||||
"The question-answer pairs are derived from the visual content in the decks, testing the ability of RAG to perform visual reasoning.\n",
|
||||
"\n",
|
||||
"As a baseline, we evaluate this dataset using text-based RAG pipeline, below.\n",
|
||||
"\n",
|
||||
"This will not reason about visual content and will simply load the text from the slides. \n",
|
||||
"\n",
|
||||
"## Pre-requisites"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "47220461-d4e9-4f1d-9c57-672ca947ca0d",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# %pip install -U langchain langsmith langchain_benchmarks\n",
|
||||
"# %pip install --quiet chromadb openai pypdf pandas"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "196de967-6de6-40da-aa75-e836923ab5e3",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import getpass\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ[\"LANGCHAIN_ENDPOINT\"] = \"https://api.smith.langchain.com\"\n",
|
||||
"env_vars = [\"LANGCHAIN_API_KEY\", \"OPENAI_API_KEY\"]\n",
|
||||
"for var in env_vars:\n",
|
||||
" if var not in os.environ:\n",
|
||||
" os.environ[var] = getpass.getpass(prompt=f\"Enter your {var}: \")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "10da8e11-6288-4131-bd60-d5aa86928acc",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Dataset\n",
|
||||
"\n",
|
||||
"We can browse the available LangChain benchmark datasets for retrieval."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "2ff97905-14a6-413c-99be-58b7a9c8d4c1",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<table>\n",
|
||||
"<thead>\n",
|
||||
"<tr><th>Name </th><th>Type </th><th>Dataset ID </th><th>Description </th></tr>\n",
|
||||
"</thead>\n",
|
||||
"<tbody>\n",
|
||||
"<tr><td>LangChain Docs Q&A </td><td>RetrievalTask</td><td><a href=\"https://smith.langchain.com/public/452ccafc-18e1-4314-885b-edd735f17b9d/d\" target=\"_blank\" rel=\"noopener\">452ccafc-18e1-4314-885b-edd735f17b9d</a></td><td>Questions and answers based on a snapshot of the LangChain python docs.\n",
|
||||
"\n",
|
||||
"The environment provides the documents and the retriever information.\n",
|
||||
"\n",
|
||||
"Each example is composed of a question and reference answer.\n",
|
||||
"\n",
|
||||
"Success is measured based on the accuracy of the answer relative to the reference answer.\n",
|
||||
"We also measure the faithfulness of the model's response relative to the retrieved documents (if any). </td></tr>\n",
|
||||
"<tr><td>Semi-structured Reports</td><td>RetrievalTask</td><td><a href=\"https://smith.langchain.com/public/c47d9617-ab99-4d6e-a6e6-92b8daf85a7d/d\" target=\"_blank\" rel=\"noopener\">c47d9617-ab99-4d6e-a6e6-92b8daf85a7d</a></td><td>Questions and answers based on PDFs containing tables and charts.\n",
|
||||
"\n",
|
||||
"The task provides the raw documents as well as factory methods to easily index them\n",
|
||||
"and create a retriever.\n",
|
||||
"\n",
|
||||
"Each example is composed of a question and reference answer.\n",
|
||||
"\n",
|
||||
"Success is measured based on the accuracy of the answer relative to the reference answer.\n",
|
||||
"We also measure the faithfulness of the model's response relative to the retrieved documents (if any). </td></tr>\n",
|
||||
"<tr><td>Multi-modal slide decks</td><td>RetrievalTask</td><td><a href=\"https://smith.langchain.com/public/40afc8e7-9d7e-44ed-8971-2cae1eb59731/d\" target=\"_blank\" rel=\"noopener\">40afc8e7-9d7e-44ed-8971-2cae1eb59731</a></td><td>This public dataset is a work-in-progress and will be extended over time.\n",
|
||||
" \n",
|
||||
"Questions and answers based on slide decks containing visual tables and charts.\n",
|
||||
"\n",
|
||||
"Each example is composed of a question and reference answer.\n",
|
||||
"\n",
|
||||
"Success is measured based on the accuracy of the answer relative to the reference answer. </td></tr>\n",
|
||||
"</tbody>\n",
|
||||
"</table>"
|
||||
],
|
||||
"text/plain": [
|
||||
"Registry(tasks=[RetrievalTask(name='LangChain Docs Q&A', dataset_id='https://smith.langchain.com/public/452ccafc-18e1-4314-885b-edd735f17b9d/d', description=\"Questions and answers based on a snapshot of the LangChain python docs.\\n\\nThe environment provides the documents and the retriever information.\\n\\nEach example is composed of a question and reference answer.\\n\\nSuccess is measured based on the accuracy of the answer relative to the reference answer.\\nWe also measure the faithfulness of the model's response relative to the retrieved documents (if any).\\n\", get_docs=<function load_cached_docs at 0x104485800>, retriever_factories={'basic': <function _chroma_retriever_factory at 0x1360289a0>, 'parent-doc': <function _chroma_parent_document_retriever_factory at 0x136028a40>, 'hyde': <function _chroma_hyde_retriever_factory at 0x136028ae0>}, architecture_factories={'conversational-retrieval-qa': <function default_response_chain at 0x126ba2660>}), RetrievalTask(name='Semi-structured Reports', dataset_id='https://smith.langchain.com/public/c47d9617-ab99-4d6e-a6e6-92b8daf85a7d/d', description=\"Questions and answers based on PDFs containing tables and charts.\\n\\nThe task provides the raw documents as well as factory methods to easily index them\\nand create a retriever.\\n\\nEach example is composed of a question and reference answer.\\n\\nSuccess is measured based on the accuracy of the answer relative to the reference answer.\\nWe also measure the faithfulness of the model's response relative to the retrieved documents (if any).\\n\", get_docs=<function load_docs at 0x136029620>, retriever_factories={'basic': <function _chroma_retriever_factory at 0x1360296c0>, 'parent-doc': <function _chroma_parent_document_retriever_factory at 0x136029760>, 'hyde': <function _chroma_hyde_retriever_factory at 0x136029800>}, architecture_factories={}), RetrievalTask(name='Multi-modal slide decks', dataset_id='https://smith.langchain.com/public/40afc8e7-9d7e-44ed-8971-2cae1eb59731/d', description='This public dataset is a work-in-progress and will be extended over time.\\n \\nQuestions and answers based on slide decks containing visual tables and charts.\\n\\nEach example is composed of a question and reference answer.\\n\\nSuccess is measured based on the accuracy of the answer relative to the reference answer.\\n', get_docs={}, retriever_factories={}, architecture_factories={})])"
|
||||
]
|
||||
},
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain_benchmarks import clone_public_dataset, registry\n",
|
||||
"\n",
|
||||
"registry = registry.filter(Type=\"RetrievalTask\")\n",
|
||||
"registry"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2fb7dc3d-28f1-4c28-b0d0-3784d04b81ce",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"`Multi-modal slide decks` is the relevant dataset for our task."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "219a4141-4a5f-48e4-ae05-5a824e2193fd",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<table>\n",
|
||||
"<tbody>\n",
|
||||
"<tr><td>Name </td><td>Multi-modal slide decks </td></tr>\n",
|
||||
"<tr><td>Type </td><td>RetrievalTask </td></tr>\n",
|
||||
"<tr><td>Dataset ID </td><td><a href=\"https://smith.langchain.com/public/40afc8e7-9d7e-44ed-8971-2cae1eb59731/d\" target=\"_blank\" rel=\"noopener\">40afc8e7-9d7e-44ed-8971-2cae1eb59731</a></td></tr>\n",
|
||||
"<tr><td>Description </td><td>This public dataset is a work-in-progress and will be extended over time.\n",
|
||||
" \n",
|
||||
"Questions and answers based on slide decks containing visual tables and charts.\n",
|
||||
"\n",
|
||||
"Each example is composed of a question and reference answer.\n",
|
||||
"\n",
|
||||
"Success is measured based on the accuracy of the answer relative to the reference answer. </td></tr>\n",
|
||||
"<tr><td>Retriever Factories </td><td> </td></tr>\n",
|
||||
"<tr><td>Architecture Factories</td><td> </td></tr>\n",
|
||||
"<tr><td>get_docs </td><td>{} </td></tr>\n",
|
||||
"</tbody>\n",
|
||||
"</table>"
|
||||
],
|
||||
"text/plain": [
|
||||
"RetrievalTask(name='Multi-modal slide decks', dataset_id='https://smith.langchain.com/public/40afc8e7-9d7e-44ed-8971-2cae1eb59731/d', description='This public dataset is a work-in-progress and will be extended over time.\\n \\nQuestions and answers based on slide decks containing visual tables and charts.\\n\\nEach example is composed of a question and reference answer.\\n\\nSuccess is measured based on the accuracy of the answer relative to the reference answer.\\n', get_docs={}, retriever_factories={}, architecture_factories={})"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"task = registry[\"Multi-modal slide decks\"]\n",
|
||||
"task"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2d6569b5-e79a-41b7-9745-c2f8a1dd704e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Clone the dataset so that it's available in our LangSmith datasets."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "d2caa086-9549-4c74-bba9-ba80d5a7b218",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Dataset Multi-modal slide decks already exists. Skipping.\n",
|
||||
"You can access the dataset at https://smith.langchain.com/o/ebbaf2eb-769b-4505-aca2-d11de10372a4/datasets/08a29acb-5ad6-42ce-a482-574c9e2e5306.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"clone_public_dataset(task.dataset_id, dataset_name=task.name)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bf350917-a1e5-46f4-81cd-c1678ab9220f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Fetch the associated PDFs from remote cache for the dataset so that we can perform ingestion."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "99ce6afb-2317-4bc1-9faf-4f828095ad91",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_benchmarks.rag.tasks.multi_modal_slide_decks import get_file_names\n",
|
||||
"\n",
|
||||
"file_names = list(get_file_names()) # PosixPath"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "848a4cdb-6c08-4c01-81ce-16ab83a7fdff",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Load\n",
|
||||
"\n",
|
||||
"Load and split the files for indexing."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "6ce85810-98a7-406e-b44e-ce860ac35986",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"There are 98 text elements in DDOG_Q3_earnings_deck.pdf\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.document_loaders import PyPDFLoader\n",
|
||||
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def load_and_split(file):\n",
|
||||
" \"\"\"\n",
|
||||
" Load and split PDF files\n",
|
||||
" :param file: PosixPath path for pdf\n",
|
||||
" :return: A list of text chunks\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" loader = PyPDFLoader(str(file))\n",
|
||||
" pdf_pages = loader.load()\n",
|
||||
"\n",
|
||||
" text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(\n",
|
||||
" chunk_size=100, chunk_overlap=50\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Get chunks\n",
|
||||
" docs = text_splitter.split_documents(pdf_pages)\n",
|
||||
" texts = [d.page_content for d in docs]\n",
|
||||
" print(f\"There are {len(texts)} text elements in {file.name}\")\n",
|
||||
" return texts\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"texts = []\n",
|
||||
"for fi in file_names:\n",
|
||||
" texts.extend(load_and_split(fi))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "eb01925d-b7d1-47a1-bd90-805178d3c4a9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Index\n",
|
||||
"\n",
|
||||
"Embed (OpenAIEmbeddings) and store splits in a vectorstore (Chroma)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "ceb31f71-45fb-4b12-bc1c-31981de334bb",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.embeddings import OpenAIEmbeddings\n",
|
||||
"from langchain.vectorstores import Chroma\n",
|
||||
"\n",
|
||||
"vectorstore_baseline = Chroma.from_texts(\n",
|
||||
" texts=texts, collection_name=\"baseline-multi-modal\", embedding=OpenAIEmbeddings()\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"retriever_baseline = vectorstore_baseline.as_retriever()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e6dcbb01-f480-456d-b972-c732eb26c393",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## RAG\n",
|
||||
"\n",
|
||||
"Create a pipeline for retrieval of relevant chunks based on semantic similarity to the input question.\n",
|
||||
"\n",
|
||||
"Pass the images to GPT-4 for answer synthesis."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "ea233664-e527-42f1-a820-0c2271e16c20",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chat_models import ChatOpenAI\n",
|
||||
"from langchain.prompts import ChatPromptTemplate\n",
|
||||
"from langchain.schema.output_parser import StrOutputParser\n",
|
||||
"from langchain.schema.runnable import RunnablePassthrough\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def rag_chain(retriever):\n",
|
||||
" \"\"\"\n",
|
||||
" RAG pipeline for the indexed presentations\n",
|
||||
" :param retriever: PosixPath path for pdf\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" # Prompt template\n",
|
||||
" template = \"\"\"Answer the question based only on the following context, which can include text and tables:\n",
|
||||
" {context}\n",
|
||||
" Question: {question}\n",
|
||||
" \"\"\"\n",
|
||||
" prompt = ChatPromptTemplate.from_template(template)\n",
|
||||
"\n",
|
||||
" # LLM\n",
|
||||
" model = ChatOpenAI(temperature=0, model=\"gpt-4\")\n",
|
||||
"\n",
|
||||
" # RAG pipeline\n",
|
||||
" chain = (\n",
|
||||
" {\n",
|
||||
" \"context\": retriever | (lambda x: \"\\n\\n\".join([i.page_content for i in x])),\n",
|
||||
" \"question\": RunnablePassthrough(),\n",
|
||||
" }\n",
|
||||
" | prompt\n",
|
||||
" | model\n",
|
||||
" | StrOutputParser()\n",
|
||||
" )\n",
|
||||
" return chain\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Create RAG chain\n",
|
||||
"chain = rag_chain(retriever_baseline)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "95df1446-143d-4f4c-a15b-2a379266d8bf",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Eval\n",
|
||||
"\n",
|
||||
"Run evaluation on our dataset:\n",
|
||||
"\n",
|
||||
"* `task.name` is the dataset of QA pairs that we cloned\n",
|
||||
"* `eval_config` specifies the [LangSmith evaluator](https://docs.smith.langchain.com/evaluation/evaluator-implementations#correctness-qa-evaluation) for our dataset, which will use GPT-4 as a grader\n",
|
||||
"* The grader will evaluate the chain-generated answer to each question relative to ground truth"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "479ce09d-642e-4b3b-9e4e-e9c2b7f0e9ca",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"View the evaluation results for project '866f-baseline' at:\n",
|
||||
"https://smith.langchain.com/o/ebbaf2eb-769b-4505-aca2-d11de10372a4/datasets/08a29acb-5ad6-42ce-a482-574c9e2e5306/compare?selectedSessions=30199d47-50d7-4c5c-a55a-e74157e05951\n",
|
||||
"\n",
|
||||
"View all tests for Dataset Multi-modal slide decks at:\n",
|
||||
"https://smith.langchain.com/o/ebbaf2eb-769b-4505-aca2-d11de10372a4/datasets/08a29acb-5ad6-42ce-a482-574c9e2e5306\n",
|
||||
"[------------------------------------------------->] 10/10"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<h3>Experiment Results:</h3>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>output</th>\n",
|
||||
" <th>feedback.COT Contextual Accuracy</th>\n",
|
||||
" <th>error</th>\n",
|
||||
" <th>execution_time</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>count</th>\n",
|
||||
" <td>10</td>\n",
|
||||
" <td>10.000000</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>10.000000</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>unique</th>\n",
|
||||
" <td>10</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>top</th>\n",
|
||||
" <td>Datadog has 20 total customers.</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>freq</th>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>mean</th>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>0.200000</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>4.674478</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>std</th>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>0.421637</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>0.864273</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>min</th>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>0.000000</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>3.307960</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>25%</th>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>0.000000</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>4.113816</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>50%</th>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>0.000000</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>4.700962</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>75%</th>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>0.000000</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>5.018359</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>max</th>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>1.000000</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>6.188082</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" output feedback.COT Contextual Accuracy \\\n",
|
||||
"count 10 10.000000 \n",
|
||||
"unique 10 NaN \n",
|
||||
"top Datadog has 20 total customers. NaN \n",
|
||||
"freq 1 NaN \n",
|
||||
"mean NaN 0.200000 \n",
|
||||
"std NaN 0.421637 \n",
|
||||
"min NaN 0.000000 \n",
|
||||
"25% NaN 0.000000 \n",
|
||||
"50% NaN 0.000000 \n",
|
||||
"75% NaN 0.000000 \n",
|
||||
"max NaN 1.000000 \n",
|
||||
"\n",
|
||||
" error execution_time \n",
|
||||
"count 0 10.000000 \n",
|
||||
"unique 0 NaN \n",
|
||||
"top NaN NaN \n",
|
||||
"freq NaN NaN \n",
|
||||
"mean NaN 4.674478 \n",
|
||||
"std NaN 0.864273 \n",
|
||||
"min NaN 3.307960 \n",
|
||||
"25% NaN 4.113816 \n",
|
||||
"50% NaN 4.700962 \n",
|
||||
"75% NaN 5.018359 \n",
|
||||
"max NaN 6.188082 "
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import uuid\n",
|
||||
"\n",
|
||||
"from langchain.smith import RunEvalConfig\n",
|
||||
"from langsmith.client import Client\n",
|
||||
"\n",
|
||||
"# Evaluator configuration\n",
|
||||
"client = Client()\n",
|
||||
"eval_config = RunEvalConfig(\n",
|
||||
" evaluators=[\"cot_qa\"],\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Experiments\n",
|
||||
"chain_map = {\n",
|
||||
" \"baseline\": chain,\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# Run evaluation\n",
|
||||
"run_id = uuid.uuid4().hex[:4]\n",
|
||||
"test_runs = {}\n",
|
||||
"for project_name, chain in chain_map.items():\n",
|
||||
" test_runs[project_name] = client.run_on_dataset(\n",
|
||||
" dataset_name=task.name,\n",
|
||||
" llm_or_chain_factory=lambda: (lambda x: x[\"Question\"]) | chain,\n",
|
||||
" evaluation=eval_config,\n",
|
||||
" verbose=True,\n",
|
||||
" project_name=f\"{run_id}-{project_name}\",\n",
|
||||
" project_metadata={\"chain\": project_name},\n",
|
||||
" )"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -34,5 +34,7 @@
|
||||
./notebooks/retrieval/intro
|
||||
./notebooks/retrieval/langchain_docs_qa
|
||||
./notebooks/retrieval/semi_structured
|
||||
./notebooks/retrieval/multi_modal_benchmarking/multi_modal_eval_baseline
|
||||
./notebooks/retrieval/multi_modal_benchmarking/multi_modal_eval
|
||||
./notebooks/retrieval/comparing_techniques
|
||||
```
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from langchain_benchmarks.model_registration import model_registry
|
||||
from langchain_benchmarks.registration import registry
|
||||
from langchain_benchmarks.utils._langsmith import (
|
||||
clone_public_dataset,
|
||||
@@ -5,4 +6,10 @@ from langchain_benchmarks.utils._langsmith import (
|
||||
)
|
||||
|
||||
# Please keep this list sorted!
|
||||
__all__ = ["clone_public_dataset", "download_public_dataset", "registry"]
|
||||
__all__ = [
|
||||
"clone_public_dataset",
|
||||
"download_public_dataset",
|
||||
"registry",
|
||||
"model_registry",
|
||||
|
||||
]
|
||||
|
||||
@@ -0,0 +1,152 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain_benchmarks.schema import RegisteredModel, ModelRegistry
|
||||
|
||||
_OpenAIModels = [
|
||||
RegisteredModel(
|
||||
provider="openai",
|
||||
name="gpt-3.5-turbo-1106",
|
||||
type="chat",
|
||||
description=(
|
||||
"The latest GPT-3.5 Turbo model with improved instruction following, "
|
||||
"JSON mode, reproducible outputs, parallel function calling, and more. "
|
||||
"Returns a maximum of 4,096 output tokens. Learn more."
|
||||
),
|
||||
params={
|
||||
"model": "gpt-3.5-turbo-1106",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="openai",
|
||||
name="gpt-3.5-turbo",
|
||||
type="chat",
|
||||
description="Currently points to gpt-3.5-turbo-0613.",
|
||||
params={
|
||||
"model": "gpt-3.5-turbo",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="openai",
|
||||
name="gpt-3.5-turbo-16k",
|
||||
type="chat",
|
||||
description="Currently points to gpt-3.5-turbo-0613.",
|
||||
params={
|
||||
"model": "gpt-3.5-turbo-16k",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="openai",
|
||||
name="gpt-3.5-turbo-instruct",
|
||||
type="llm",
|
||||
description=(
|
||||
"Similar capabilities as text-davinci-003 but compatible with legacy "
|
||||
"Completions endpoint and not Chat Completions."
|
||||
),
|
||||
params={
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="openai",
|
||||
name="gpt-3.5-turbo-0613",
|
||||
type="chat",
|
||||
description=(
|
||||
"Legacy Snapshot of gpt-3.5-turbo from June 13th 2023. "
|
||||
"Will be deprecated on June 13, 2024."
|
||||
),
|
||||
params={
|
||||
"model": "gpt-3.5-turbo-0613",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="openai",
|
||||
name="gpt-3.5-turbo-16k-0613",
|
||||
type="chat",
|
||||
description=(
|
||||
"Legacy Snapshot of gpt-3.5-16k-turbo from June 13th 2023. "
|
||||
"Will be deprecated on June 13, 2024."
|
||||
),
|
||||
params={
|
||||
"model": "gpt-3.5-turbo-16k-0613",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="openai",
|
||||
name="gpt-3.5-turbo-0301",
|
||||
type="chat",
|
||||
description=(
|
||||
"Legacy Snapshot of gpt-3.5-turbo from March 1st 2023. "
|
||||
"Will be deprecated on June 13th 2024."
|
||||
),
|
||||
params={
|
||||
"model": "gpt-3.5-turbo-0301",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="openai",
|
||||
name="text-davinci-003",
|
||||
type="llm",
|
||||
description=(
|
||||
"Legacy Can do language tasks with better quality and consistency than "
|
||||
"the curie, babbage, or ada models. Will be deprecated on Jan 4th 2024."
|
||||
),
|
||||
params={
|
||||
"model": "text-davinci-003",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="openai",
|
||||
name="text-davinci-002",
|
||||
type="llm",
|
||||
description=(
|
||||
"Legacy Similar capabilities to text-davinci-003 but trained with "
|
||||
"supervised fine-tuning instead of reinforcement learning. "
|
||||
"Will be deprecated on Jan 4th 2024."
|
||||
),
|
||||
params={
|
||||
"model": "text-davinci-002",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="openai",
|
||||
name="code-davinci-002",
|
||||
type="llm",
|
||||
description="Legacy Optimized for code-completion tasks. Will be deprecated "
|
||||
"on Jan 4th 2024.",
|
||||
params={
|
||||
"model": "code-davinci-002",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
_FireworksModels = [
|
||||
RegisteredModel(
|
||||
provider="fireworks",
|
||||
name="llama-v2-7b-chat",
|
||||
type="chat",
|
||||
description="7b parameter LlamaChat model",
|
||||
params={
|
||||
"model": "accounts/fireworks/models/llama-v2-7b-chat",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="fireworks",
|
||||
name="llama-v2-13b-chat",
|
||||
type="chat",
|
||||
description="13b parameter LlamaChat model",
|
||||
params={
|
||||
"model": "accounts/fireworks/models/llama-v2-13b-chat",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="fireworks",
|
||||
name="llama-v2-70b-chat",
|
||||
type="chat",
|
||||
description="70b parameter LlamaChat model",
|
||||
params={
|
||||
"model": "accounts/fireworks/models/llama-v2-70b-chat",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
model_registry = ModelRegistry(registered_models=_OpenAIModels + _FireworksModels)
|
||||
@@ -0,0 +1 @@
|
||||
pdfs/
|
||||
@@ -1,7 +1,14 @@
|
||||
from langchain_benchmarks.rag.tasks.langchain_docs.task import LANGCHAIN_DOCS_TASK
|
||||
from langchain_benchmarks.rag.tasks.multi_modal_slide_decks.task import (
|
||||
MULTI_MODAL_SLIDE_DECKS_TASK,
|
||||
)
|
||||
from langchain_benchmarks.rag.tasks.semi_structured_reports.task import (
|
||||
SEMI_STRUCTURED_REPORTS_TASK,
|
||||
)
|
||||
|
||||
# Please keep this sorted
|
||||
__all__ = ["LANGCHAIN_DOCS_TASK", "SEMI_STRUCTURED_REPORTS_TASK"]
|
||||
__all__ = [
|
||||
"LANGCHAIN_DOCS_TASK",
|
||||
"SEMI_STRUCTURED_REPORTS_TASK",
|
||||
"MULTI_MODAL_SLIDE_DECKS_TASK",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
from langchain_benchmarks.rag.tasks.multi_modal_slide_decks.indexing.retriever_registry import (
|
||||
get_file_names,
|
||||
)
|
||||
|
||||
__all__ = ["get_file_names"]
|
||||
@@ -0,0 +1,5 @@
|
||||
from langchain_benchmarks.rag.tasks.multi_modal_slide_decks.indexing.retriever_registry import (
|
||||
get_file_names,
|
||||
)
|
||||
|
||||
__all__ = ["get_file_names"]
|
||||
@@ -0,0 +1,39 @@
|
||||
import logging
|
||||
import os
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Optional
|
||||
|
||||
from langchain_benchmarks.rag.utils._downloading import (
|
||||
fetch_remote_file,
|
||||
is_folder_populated,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_DIRECTORY = Path(os.path.abspath(__file__)).parent
|
||||
# Stores the zipped pdfs for this dataset
|
||||
REMOTE_DOCS_FILE = "https://storage.googleapis.com/benchmarks-artifacts/langchain-docs-benchmarking/multi_modal_slide_decks.zip"
|
||||
DOCS_DIR = _DIRECTORY / "pdfs"
|
||||
|
||||
|
||||
def fetch_raw_docs(
|
||||
filename: Optional[str] = None, docs_dir: Optional[str] = None
|
||||
) -> None:
|
||||
filename = filename or _DIRECTORY / Path(REMOTE_DOCS_FILE).name
|
||||
docs_dir = docs_dir or DOCS_DIR
|
||||
if not is_folder_populated(docs_dir):
|
||||
fetch_remote_file(REMOTE_DOCS_FILE, filename)
|
||||
with zipfile.ZipFile(filename, "r") as zip_ref:
|
||||
zip_ref.extractall(docs_dir)
|
||||
|
||||
os.remove(filename)
|
||||
|
||||
|
||||
def get_file_names() -> Iterable[Path]:
|
||||
fetch_raw_docs()
|
||||
# Traverse the directory and partition the pdfs
|
||||
for path in DOCS_DIR.rglob("*.pdf"):
|
||||
# Ignore __MACOSX
|
||||
if "__MACOSX" in str(path):
|
||||
continue
|
||||
yield path
|
||||
@@ -0,0 +1,23 @@
|
||||
from langchain_benchmarks.schema import RetrievalTask
|
||||
|
||||
# ID of public Multi Modal Slide Decks dataset
|
||||
DATASET_ID = "https://smith.langchain.com/public/40afc8e7-9d7e-44ed-8971-2cae1eb59731/d"
|
||||
|
||||
MULTI_MODAL_SLIDE_DECKS_TASK = RetrievalTask(
|
||||
name="Multi-modal slide decks",
|
||||
dataset_id=DATASET_ID,
|
||||
retriever_factories={},
|
||||
architecture_factories={},
|
||||
get_docs={},
|
||||
description=(
|
||||
"""\
|
||||
This public dataset is a work-in-progress and will be extended over time.
|
||||
|
||||
Questions and answers based on slide decks containing visual tables and charts.
|
||||
|
||||
Each example is composed of a question and reference answer.
|
||||
|
||||
Success is measured based on the accuracy of the answer relative to the reference answer.
|
||||
""" # noqa: E501
|
||||
),
|
||||
)
|
||||
+3
-4
@@ -24,7 +24,6 @@ _DIRECTORY = Path(os.path.abspath(__file__)).parent
|
||||
# Stores the zipped pdfs for this dataset
|
||||
REMOTE_DOCS_FILE = "https://storage.googleapis.com/benchmarks-artifacts/langchain-docs-benchmarking/semi_structured_earnings.zip"
|
||||
DOCS_DIR = _DIRECTORY / "pdfs"
|
||||
LOCAL_FILE = _DIRECTORY / "chroma_db.zip"
|
||||
|
||||
_DEFAULT_SEARCH_KWARGS = {"k": 6}
|
||||
|
||||
@@ -32,17 +31,17 @@ _DEFAULT_SEARCH_KWARGS = {"k": 6}
|
||||
def fetch_raw_docs(
|
||||
filename: Optional[str] = None, docs_dir: Optional[str] = None
|
||||
) -> None:
|
||||
filename = filename or LOCAL_FILE
|
||||
filename = filename or _DIRECTORY / Path(REMOTE_DOCS_FILE).name
|
||||
docs_dir = docs_dir or DOCS_DIR
|
||||
if not is_folder_populated(docs_dir):
|
||||
fetch_remote_file(REMOTE_DOCS_FILE, filename)
|
||||
with zipfile.ZipFile(filename, "r") as zip_ref:
|
||||
zip_ref.extractall(docs_dir)
|
||||
|
||||
os.remove(LOCAL_FILE)
|
||||
os.remove(filename)
|
||||
|
||||
|
||||
def get_file_names():
|
||||
def get_file_names() -> Iterable[Path]:
|
||||
fetch_raw_docs()
|
||||
# Traverse the directory and partition the pdfs
|
||||
for path in DOCS_DIR.glob("*.pdf"):
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from langchain_benchmarks.extraction.tasks import chat_extraction, email_task
|
||||
from langchain_benchmarks.rag.tasks import (
|
||||
LANGCHAIN_DOCS_TASK,
|
||||
MULTI_MODAL_SLIDE_DECKS_TASK,
|
||||
SEMI_STRUCTURED_REPORTS_TASK,
|
||||
)
|
||||
from langchain_benchmarks.schema import Registry
|
||||
@@ -24,5 +25,6 @@ registry = Registry(
|
||||
chat_extraction.CHAT_EXTRACTION_TASK,
|
||||
LANGCHAIN_DOCS_TASK,
|
||||
SEMI_STRUCTURED_REPORTS_TASK,
|
||||
MULTI_MODAL_SLIDE_DECKS_TASK,
|
||||
]
|
||||
)
|
||||
|
||||
@@ -2,8 +2,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import importlib
|
||||
import urllib
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union, Sequence
|
||||
|
||||
from langchain_core.language_models import BaseChatModel, BaseLanguageModel
|
||||
from typing_extensions import Literal
|
||||
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.schema import BaseRetriever
|
||||
@@ -109,12 +113,16 @@ class ExtractionTask(BaseTask):
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class RetrievalTask(BaseTask):
|
||||
retriever_factories: Dict[str, Callable[[Embeddings], BaseRetriever]] # noqa: F821
|
||||
"""Factories that index the docs using the specified strategy."""
|
||||
architecture_factories: Dict[str, Callable[[Embeddings], BaseRetriever]] # noqa: F821
|
||||
"""Factories methods that help build some off-the-shelf architectures。"""
|
||||
get_docs: Callable[..., Iterable[Document]]
|
||||
get_docs: Optional[Callable[..., Iterable[Document]]] = None
|
||||
"""A function that returns the documents to be indexed."""
|
||||
retriever_factories: Dict[
|
||||
str, Callable[[Embeddings], BaseRetriever]
|
||||
] = dataclasses.field(default_factory=dict) # noqa: F821
|
||||
"""Factories that index the docs using the specified strategy."""
|
||||
architecture_factories: Dict[
|
||||
str, Callable[[Embeddings], BaseRetriever]
|
||||
] = dataclasses.field(default_factory=dict) # noqa: F821
|
||||
"""Factories methods that help build some off-the-shelf architectures。"""
|
||||
|
||||
@property
|
||||
def _table(self) -> List[List[str]]:
|
||||
@@ -149,6 +157,7 @@ class Registry:
|
||||
raise ValueError(
|
||||
f"Duplicate task name {task.name}. " f"Task names must be unique."
|
||||
)
|
||||
seen_names.add(task.name)
|
||||
|
||||
def _repr_html_(self) -> str:
|
||||
"""Return an HTML representation of the registry."""
|
||||
@@ -206,3 +215,190 @@ class Registry:
|
||||
if not isinstance(task, BaseTask):
|
||||
raise TypeError("Only tasks can be added to the registry.")
|
||||
self.tasks.append(task)
|
||||
|
||||
|
||||
Provider = Literal["fireworks", "openai"]
|
||||
ModelType = Literal["chat", "llm"]
|
||||
AUTHORIZED_NAMESPACES = {"langchain"}
|
||||
|
||||
|
||||
def _get_model_class_from_path(
|
||||
path: str
|
||||
) -> Union[Type[BaseChatModel], Type[BaseLanguageModel]]:
|
||||
"""Get the class of the model."""
|
||||
module_name, attribute_name = path.rsplit(".", 1)
|
||||
top_namespace = path.split(".")[0]
|
||||
|
||||
if top_namespace not in AUTHORIZED_NAMESPACES:
|
||||
raise ValueError(
|
||||
f"Unauthorized namespace {top_namespace}. "
|
||||
f"Authorized namespaces are: {AUTHORIZED_NAMESPACES}"
|
||||
)
|
||||
|
||||
# Import the module dynamically
|
||||
module = importlib.import_module(module_name)
|
||||
model_class = getattr(module, attribute_name)
|
||||
if not issubclass(model_class, (BaseLanguageModel, BaseChatModel)):
|
||||
raise ValueError(
|
||||
f"Model class {model_class} is not a subclass of BaseLanguageModel"
|
||||
)
|
||||
return model_class
|
||||
|
||||
|
||||
def _get_default_path(provider: str, type_: ModelType) -> str:
|
||||
"""Get the default path for a model."""
|
||||
paths = {
|
||||
("fireworks", "chat"): "langchain.chat_models.fireworks.ChatFireworks",
|
||||
("fireworks", "llm"): "langchain.language_models.fireworks.Fireworks",
|
||||
("openai", "chat"): "langchain.chat_models.openai.ChatOpenAI",
|
||||
("openai", "llm"): "langchain.language_models.openai.OpenAI",
|
||||
}
|
||||
|
||||
if (provider, type_) not in paths:
|
||||
raise ValueError(f"Unknown provider {provider} and type {type_}")
|
||||
|
||||
return paths[(provider, type_)]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class RegisteredModel:
|
||||
"""Descriptive information about a model.
|
||||
|
||||
This information can be used to instantiate the underlying model.
|
||||
"""
|
||||
|
||||
name: str
|
||||
provider: Provider
|
||||
description: str
|
||||
params: Dict[str, Any]
|
||||
type: ModelType
|
||||
# Path to the model class.
|
||||
# For example, "langchain.chat_models.anthropic import ChatAnthropicModel"
|
||||
path: Optional[str] = None # If not provided, will use default path
|
||||
|
||||
def get_model(
|
||||
self, *, model_params: Optional[Dict[str, Any]] = None
|
||||
) -> Union[BaseChatModel, BaseLanguageModel]:
|
||||
"""Get the class of the model."""
|
||||
all_params = {**self.params, **(model_params or {})}
|
||||
model_class = _get_model_class_from_path(self.model_path)
|
||||
return model_class(**all_params)
|
||||
|
||||
@property
|
||||
def model_path(self) -> str:
|
||||
"""Get the path of the model."""
|
||||
return self.path or _get_default_path(self.provider, self.type)
|
||||
|
||||
@property
|
||||
def _table(self) -> List[List[str]]:
|
||||
"""Return a table representation of the environment."""
|
||||
return [
|
||||
["name", self.name],
|
||||
["type", self.type],
|
||||
["provider", self.provider],
|
||||
["description", self.description],
|
||||
["model_path", self.model_path],
|
||||
]
|
||||
|
||||
def _repr_html_(self) -> str:
|
||||
"""Return an HTML representation of the environment."""
|
||||
return tabulate(
|
||||
self._table,
|
||||
tablefmt="unsafehtml",
|
||||
)
|
||||
|
||||
|
||||
StrFilter = Union[None, str, Sequence[str]]
|
||||
|
||||
|
||||
def _is_in_filter(actual_value: str, filter_value: StrFilter) -> bool:
|
||||
"""Filter for a string attribute."""
|
||||
if filter_value is None:
|
||||
return True
|
||||
|
||||
if isinstance(filter_value, str):
|
||||
return actual_value == filter_value
|
||||
|
||||
return actual_value in filter_value
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=False)
|
||||
class ModelRegistry:
|
||||
registered_models: Sequence[RegisteredModel]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate that all the tasks have unique names and IDs."""
|
||||
seen_names = set()
|
||||
for model in self.registered_models:
|
||||
if model.name in seen_names:
|
||||
raise ValueError(
|
||||
f"Duplicate model name {model.name}. " f"Task names must be unique."
|
||||
)
|
||||
seen_names.add(model.name)
|
||||
|
||||
def get_model(self, name: str) -> Optional[RegisteredModel]:
|
||||
"""Get model info."""
|
||||
return next(model for model in self.registered_models if model.name == name)
|
||||
|
||||
def filter(
|
||||
self,
|
||||
*,
|
||||
type: StrFilter = None,
|
||||
name: StrFilter = None,
|
||||
provider: StrFilter = None,
|
||||
) -> ModelRegistry:
|
||||
"""Filter the tasks in the registry."""
|
||||
models = self.registered_models
|
||||
selected_models = []
|
||||
|
||||
for model in models:
|
||||
if not _is_in_filter(model.type, type):
|
||||
continue
|
||||
if not _is_in_filter(model.name, name):
|
||||
continue
|
||||
if not _is_in_filter(model.provider, provider):
|
||||
continue
|
||||
selected_models.append(model)
|
||||
return ModelRegistry(registered_models=selected_models)
|
||||
|
||||
def _repr_html_(self) -> str:
|
||||
"""Return an HTML representation of the registry."""
|
||||
headers = [
|
||||
"Name",
|
||||
"Type",
|
||||
"Provider",
|
||||
"Description",
|
||||
]
|
||||
table = [
|
||||
[
|
||||
model.name,
|
||||
model.type,
|
||||
model.provider,
|
||||
model.description,
|
||||
]
|
||||
for model in self.registered_models
|
||||
]
|
||||
return tabulate(table, headers=headers, tablefmt="unsafehtml")
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of tasks in the registry."""
|
||||
return len(self.registered_models)
|
||||
|
||||
def __iter__(self) -> Iterable[RegisteredModel]:
|
||||
"""Iterate over the tasks in the registry."""
|
||||
return iter(self.registered_models)
|
||||
|
||||
def __getitem__(
|
||||
self, key: Union[int, str]
|
||||
) -> Union[RegisteredModel, ModelRegistry]:
|
||||
"""Get an environment from the registry."""
|
||||
if isinstance(key, slice):
|
||||
return ModelRegistry(registered_models=self.registered_models[key])
|
||||
elif isinstance(key, (int, str)):
|
||||
# If key is an integer, return the corresponding environment
|
||||
if isinstance(key, str):
|
||||
return self.get_model(key)
|
||||
else:
|
||||
return self.registered_models[key]
|
||||
else:
|
||||
raise TypeError("Key must be an integer or a slice.")
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "langchain-benchmarks"
|
||||
version = "0.0.6"
|
||||
version = "0.0.7"
|
||||
description = "🦜💪 Flex those feathers!"
|
||||
authors = ["LangChain AI"]
|
||||
license = "MIT"
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
import pytest
|
||||
|
||||
from langchain_benchmarks.schema import RegisteredModel, ModelRegistry
|
||||
|
||||
# Create some sample RegisteredModel instances for testing
|
||||
SAMPLE_MODELS = [
|
||||
RegisteredModel(
|
||||
"model1", "fireworks", "Description 1", {"param1": "value1"}, "chat"
|
||||
),
|
||||
RegisteredModel("model2", "openai", "Description 2", {"param2": "value2"}, "llm"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_registry() -> ModelRegistry:
|
||||
return ModelRegistry(SAMPLE_MODELS)
|
||||
|
||||
|
||||
def test_init() -> None:
|
||||
# Test the constructor of ModelRegistry
|
||||
registry = ModelRegistry(SAMPLE_MODELS)
|
||||
assert len(registry.registered_models) == 2
|
||||
|
||||
|
||||
def test_get_model(sample_registry: ModelRegistry) -> None:
|
||||
# Test the get_model method
|
||||
model = sample_registry.get_model("model1")
|
||||
assert model.name == "model1"
|
||||
|
||||
|
||||
def test_filter(sample_registry: ModelRegistry) -> None:
|
||||
# Test the filter method
|
||||
filtered_registry = sample_registry.filter(type="chat")
|
||||
assert len(filtered_registry.registered_models) == 1
|
||||
assert filtered_registry.registered_models[0].type == "chat"
|
||||
|
||||
|
||||
def test_repr_html(sample_registry: ModelRegistry) -> None:
|
||||
# Test the _repr_html_ method
|
||||
html_representation = sample_registry._repr_html_()
|
||||
assert "<table>" in html_representation
|
||||
|
||||
|
||||
def test_len(sample_registry: ModelRegistry) -> None:
|
||||
# Test the __len__ method
|
||||
assert len(sample_registry) == 2
|
||||
|
||||
|
||||
def test_iter(sample_registry: ModelRegistry) -> None:
|
||||
# Test the __iter__ method
|
||||
models = list(iter(sample_registry))
|
||||
assert len(models) == 2
|
||||
assert isinstance(models[0], RegisteredModel)
|
||||
|
||||
|
||||
def test_getitem(sample_registry: ModelRegistry) -> None:
|
||||
# Test the __getitem__ method for integer and string keys
|
||||
model = sample_registry[0]
|
||||
assert model.name == "model1"
|
||||
model = sample_registry["model2"]
|
||||
assert model.name == "model2"
|
||||
|
||||
|
||||
def test_getitem_slice(sample_registry: ModelRegistry) -> None:
|
||||
# Test the __getitem__ method for slices
|
||||
sliced_registry = sample_registry[:1]
|
||||
assert len(sliced_registry.registered_models) == 1
|
||||
assert sliced_registry.registered_models[0].name == "model1"
|
||||
Reference in New Issue
Block a user