mirror of
https://github.com/langchain-ai/langchain-benchmarks.git
synced 2026-07-01 22:34:02 -04:00
Compare commits
20 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a7545285c3 | |||
| 295fb12f8c | |||
| dd1f33ce02 | |||
| 0ddb74036b | |||
| cf10528b99 | |||
| dbc254ebc0 | |||
| 4c3de31e6f | |||
| f8e5d32bc9 | |||
| 40250950db | |||
| 9adc5b96f5 | |||
| 85c8d29342 | |||
| a42c70f315 | |||
| 9c6ec01219 | |||
| d2defc95e3 | |||
| 8b0e52fdc4 | |||
| 22b90df0ad | |||
| 8ba4340730 | |||
| 2f0b4e9bed | |||
| 910bd60832 | |||
| 5b4672a33b |
@@ -0,0 +1,52 @@
|
||||
import inspect
|
||||
from textwrap import dedent
|
||||
from typing import List
|
||||
|
||||
from langchain.tools.base import StructuredTool
|
||||
|
||||
from agents.encoder import FunctionDefinition, Parameter
|
||||
|
||||
|
||||
# This is temporary until we have a better way to represent parameters
|
||||
def get_parameters_from_tool(tool: StructuredTool) -> List[Parameter]:
|
||||
"""Convert a langchain tool to a tool user tool."""
|
||||
schema = tool.args_schema.schema()
|
||||
|
||||
properties = schema["properties"]
|
||||
parameters = []
|
||||
# Is this needed or is string OK?
|
||||
type_adapter = {
|
||||
"string": "str", # str or string?
|
||||
"integer": "int",
|
||||
"number": "float",
|
||||
"boolean": "bool",
|
||||
}
|
||||
for key, value in properties.items():
|
||||
parameters.append(
|
||||
{
|
||||
"name": key,
|
||||
"type": type_adapter.get(value["type"], value["type"]),
|
||||
"description": value.get("description", ""),
|
||||
}
|
||||
)
|
||||
|
||||
return parameters
|
||||
|
||||
|
||||
#
|
||||
def convert_tool_to_function_definition(tool: StructuredTool) -> FunctionDefinition:
|
||||
"""Convert a langchain tool to a tool user tool."""
|
||||
# Here we re-inspect the underlying function to get the doc-string
|
||||
# since StructuredTool modifies it, but we want the raw one for maximum
|
||||
# flexibility.
|
||||
description = inspect.getdoc(tool.func)
|
||||
|
||||
parameters = get_parameters_from_tool(tool)
|
||||
return {
|
||||
"name": tool.name,
|
||||
"description": dedent(description),
|
||||
"parameters": parameters,
|
||||
"return_value": {
|
||||
"type": "Any",
|
||||
},
|
||||
}
|
||||
+105
@@ -0,0 +1,105 @@
|
||||
from typing import List, Literal, Sequence, Tuple, Union
|
||||
|
||||
from langchain.agents import AgentOutputParser
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.runnable import Runnable
|
||||
from langchain.tools import StructuredTool
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.language_models import BaseChatModel, BaseLanguageModel
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.prompts import MessagesPlaceholder
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from agents.adapters import convert_tool_to_function_definition
|
||||
from agents.encoder import AstPrinter, TypeScriptEncoder
|
||||
from agents.prompts import AGENT_INSTRUCTIONS_BLOB_STYLE
|
||||
|
||||
|
||||
def format_observation(tool_name: str, observation: str) -> BaseMessage:
|
||||
"""Format the observation."""
|
||||
result = (
|
||||
"<tool_output>\n"
|
||||
f"<tool_name>{tool_name}</tool_name>\n"
|
||||
f"<output>{observation}</output>\n"
|
||||
"</tool_output>"
|
||||
)
|
||||
|
||||
return HumanMessage(content=result)
|
||||
|
||||
|
||||
def format_steps_for_chat(
|
||||
intermediate_steps: List[Tuple[AgentAction, str]]
|
||||
) -> List[BaseMessage]:
|
||||
"""Format the steps."""
|
||||
messages = []
|
||||
for action, observation in intermediate_steps:
|
||||
if not isinstance(action, AgentAction):
|
||||
if action.tool != "_Exception":
|
||||
raise AssertionError(f"Unexpected step: {action}. type: {type(action)}")
|
||||
|
||||
messages.append(HumanMessage(content=observation))
|
||||
messages.extend(action.messages)
|
||||
messages.append(format_observation(action.tool, observation))
|
||||
return messages
|
||||
|
||||
|
||||
# PUBLIC API
|
||||
|
||||
|
||||
class AgentInput(TypedDict):
|
||||
"""The input to the agent."""
|
||||
|
||||
input: str
|
||||
"""The input to the agent."""
|
||||
intermediate_steps: List[Tuple[AgentAction, str]]
|
||||
"""The intermediate steps taken by the agent."""
|
||||
examples: NotRequired[List[BaseMessage]]
|
||||
"""A list of messages that can be used to form example traces."""
|
||||
|
||||
|
||||
def create_agent(
|
||||
model: Union[BaseChatModel, BaseLanguageModel],
|
||||
tools: Sequence[StructuredTool],
|
||||
parser: AgentOutputParser,
|
||||
*,
|
||||
ast_printer: Union[AstPrinter, Literal["xml"]] = "xml",
|
||||
) -> Runnable[AgentInput, Union[AgentAction, AgentFinish]]:
|
||||
"""Create an agent for a chat model."""
|
||||
if isinstance(ast_printer, str):
|
||||
if ast_printer == "xml":
|
||||
ast_printer = AstPrinter()
|
||||
elif ast_printer == "typescript":
|
||||
ast_printer = TypeScriptEncoder()
|
||||
else:
|
||||
raise ValueError(f"Unknown ast printer: {ast_printer}")
|
||||
elif isinstance(ast_printer, AstPrinter):
|
||||
pass
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected AstPrinter or str, got {type(ast_printer)} for `ast_printer`"
|
||||
)
|
||||
|
||||
function_definitions = [convert_tool_to_function_definition(tool) for tool in tools]
|
||||
tool_description = ast_printer.visit_function_definitions(function_definitions)
|
||||
|
||||
template = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", AGENT_INSTRUCTIONS_BLOB_STYLE),
|
||||
MessagesPlaceholder("examples"), # Can use to add example traces
|
||||
("human", "{input}"),
|
||||
MessagesPlaceholder("history"),
|
||||
]
|
||||
).partial(tool_description=tool_description)
|
||||
|
||||
agent = (
|
||||
{
|
||||
"input": lambda x: x["input"],
|
||||
"history": lambda x: format_steps_for_chat(x["intermediate_steps"]),
|
||||
"examples": lambda x: x.get("examples", []),
|
||||
}
|
||||
| template
|
||||
| model.bind(stop=["</tool>"])
|
||||
| parser
|
||||
)
|
||||
return agent
|
||||
@@ -0,0 +1,226 @@
|
||||
"""Prototyping code for rendering function definitions, invocations, and results.
|
||||
|
||||
Types are simplified for now to `str`.
|
||||
|
||||
We should actually support something like pydantic or jsonschema for the types, so
|
||||
we can expand them recursively for nested types.
|
||||
"""
|
||||
import abc
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
|
||||
class Parameter(TypedDict):
|
||||
"""Representation for a parameter."""
|
||||
|
||||
name: str
|
||||
type: str
|
||||
description: str
|
||||
|
||||
|
||||
class Arguments(TypedDict):
|
||||
"""Arguments are passed to a function during function invocation."""
|
||||
|
||||
name: Optional[str]
|
||||
value: Any
|
||||
|
||||
|
||||
class ReturnValue(TypedDict):
|
||||
"""Representation for a return value of a function call."""
|
||||
|
||||
type: str
|
||||
description: NotRequired[str]
|
||||
|
||||
|
||||
class FunctionDefinition(TypedDict):
|
||||
"""Representation for a function."""
|
||||
|
||||
name: str
|
||||
description: str # Function description
|
||||
parameters: List[Parameter]
|
||||
return_value: ReturnValue
|
||||
|
||||
|
||||
class FunctionInvocation(TypedDict):
|
||||
"""Representation for a function invocation."""
|
||||
|
||||
id: NotRequired[str]
|
||||
name: str
|
||||
arguments: List[Arguments]
|
||||
|
||||
|
||||
class FunctionResult(TypedDict):
|
||||
"""Representation for a function result."""
|
||||
|
||||
id: NotRequired[str]
|
||||
name: str
|
||||
result: Optional[str]
|
||||
error: Optional[str]
|
||||
|
||||
|
||||
class Visitor(abc.ABC):
|
||||
def visit_function_definition(self, function_definition: FunctionDefinition) -> str:
|
||||
"""Render a function."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def visit_function_definitions(
|
||||
self, function_definitions: List[FunctionDefinition]
|
||||
) -> str:
|
||||
"""Render a function."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def visit_function_invocation(self, function_invocation: FunctionInvocation) -> str:
|
||||
"""Render a function invocation."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def visit_function_result(self, function_result: FunctionResult) -> str:
|
||||
"""Render a function result."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class AstPrinter(Visitor):
|
||||
"""Print the AST."""
|
||||
|
||||
|
||||
class XMLEncoder(AstPrinter):
|
||||
def visit_function_definition(self, function_definition: FunctionDefinition) -> str:
|
||||
"""Render a function."""
|
||||
parameters_as_strings = [
|
||||
"<parameter>\n"
|
||||
f"<name>{parameter['name']}</name>\n"
|
||||
f"<type>{parameter['type']}</type>\n"
|
||||
f"<description>{parameter['description']}</description>\n"
|
||||
"</parameter>\n"
|
||||
for parameter in function_definition["parameters"]
|
||||
]
|
||||
function = (
|
||||
"<function>\n"
|
||||
f"<function_name>{function_definition['name']}</function_name>\n"
|
||||
"<description>\n"
|
||||
f"{function_definition['description']}\n"
|
||||
"</description>\n"
|
||||
"<parameters>\n"
|
||||
f"{''.join(parameters_as_strings)}" # Already includes trailing newline
|
||||
"</parameters>\n"
|
||||
"<return_value>\n"
|
||||
f"<type>{function_definition['return_value']['type']}</type>\n"
|
||||
f"<description>{function_definition['return_value']['description']}</description>\n"
|
||||
"</return_value>\n"
|
||||
"</function>"
|
||||
)
|
||||
return function
|
||||
|
||||
def visit_function_definitions(
|
||||
self, function_definitions: List[FunctionDefinition]
|
||||
) -> str:
|
||||
"""Render a function."""
|
||||
strs = [
|
||||
self.visit_function_definition(function_definition)
|
||||
for function_definition in function_definitions
|
||||
]
|
||||
return "<functions>\n" + "\n".join(strs) + "\n</functions>"
|
||||
|
||||
def visit_function_invocation(self, invocation: FunctionInvocation) -> str:
|
||||
"""Render a function invocation."""
|
||||
arguments_as_strings = [
|
||||
"<argument>\n"
|
||||
f"<name>{argument['name']}</name>\n"
|
||||
f"<value>{argument['value']}</value>\n"
|
||||
"</argument>\n"
|
||||
for argument in invocation["arguments"]
|
||||
]
|
||||
lines = ["<function_invocation>"]
|
||||
|
||||
if invocation.get("id"):
|
||||
lines.append(f"<id>{invocation['id']}</id>")
|
||||
|
||||
lines.extend(
|
||||
[
|
||||
f"<function_name>{invocation['name']}</function_name>\n"
|
||||
"<arguments>\n"
|
||||
f"{''.join(arguments_as_strings)}" # Already includes trailing newline
|
||||
"</arguments>\n"
|
||||
"</function_invocation>"
|
||||
]
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
def visit_function_result(self, function_result: FunctionResult) -> str:
|
||||
"""Render a function result."""
|
||||
lines = [
|
||||
"<function_result>",
|
||||
]
|
||||
|
||||
if function_result.get("id"):
|
||||
lines.append(f"<id>{function_result['id']}</id>")
|
||||
|
||||
lines.extend(
|
||||
[
|
||||
f"<function_name>{function_result['name']}</function_name>",
|
||||
f"<result>{function_result['result']}</result>",
|
||||
f"<error>{function_result['error']}</error>",
|
||||
"</function_result>",
|
||||
]
|
||||
)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class TypeScriptEncoder(AstPrinter):
|
||||
def visit_function_definition(self, function_definition: FunctionDefinition) -> str:
|
||||
"""Render a function."""
|
||||
parameters_as_strings = [
|
||||
f"{parameter['name']}: {parameter['type']}"
|
||||
for parameter in function_definition["parameters"]
|
||||
]
|
||||
# Let's use JSdoc style comments
|
||||
# First the function description
|
||||
lines = [
|
||||
f"// {function_definition['description']}",
|
||||
# Then the parameter descriptions
|
||||
*[
|
||||
f"// @param {parameter['name']} {parameter['description']}"
|
||||
for parameter in function_definition["parameters"]
|
||||
],
|
||||
# Then the return value description
|
||||
f"// @returns {function_definition['return_value']['description']}",
|
||||
# Then the function definition
|
||||
f"function {function_definition['name']}("
|
||||
f"{', '.join(parameters_as_strings)}): "
|
||||
f"{function_definition['return_value']['type']};",
|
||||
]
|
||||
|
||||
# finally join
|
||||
function = "\n".join(lines)
|
||||
return function
|
||||
|
||||
def visit_function_definitions(
|
||||
self, function_definitions: List[FunctionDefinition]
|
||||
) -> str:
|
||||
"""Render a function."""
|
||||
strs = [
|
||||
self.visit_function_definition(function_definition)
|
||||
for function_definition in function_definitions
|
||||
]
|
||||
return "\n\n".join(strs)
|
||||
|
||||
def visit_function_invocation(self, invocation: FunctionInvocation) -> str:
|
||||
"""Render a function invocation."""
|
||||
arguments_as_strings = [
|
||||
f"{argument['name']}: {argument['value']}"
|
||||
for argument in invocation["arguments"]
|
||||
]
|
||||
lines = [f"{invocation['name']}(" f"{', '.join(arguments_as_strings)});"]
|
||||
return "\n".join(lines)
|
||||
|
||||
def visit_function_result(self, function_result: FunctionResult) -> str:
|
||||
"""Render a function result."""
|
||||
lines = []
|
||||
if function_result["error"]:
|
||||
lines.append(f"ERROR: {function_result['error']}")
|
||||
else:
|
||||
lines.append(f"> {function_result['result']}")
|
||||
if function_result.get("id"):
|
||||
lines.append(f"// ID: {function_result['id']}")
|
||||
return "\n".join(lines)
|
||||
@@ -0,0 +1,25 @@
|
||||
# EXAMPLE_TRACE = [
|
||||
# HumanMessage(content="type the letter 'o'"),
|
||||
# AIMessage(
|
||||
# content="""
|
||||
# <tool>
|
||||
# {
|
||||
# "tool_name": "type_letter",
|
||||
# "arguments": {
|
||||
# "letter": "o"
|
||||
# }
|
||||
# }
|
||||
# </tool>\
|
||||
# """
|
||||
# ),
|
||||
# HumanMessage(
|
||||
# content="""\
|
||||
# <tool_outputs>
|
||||
# <tool_name>type_letter</tool_name>
|
||||
# <output>o</output>
|
||||
# </tool_outputs>\
|
||||
# """
|
||||
# ),
|
||||
# ]
|
||||
#
|
||||
#
|
||||
@@ -0,0 +1,75 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain.chat_models import ChatAnthropic, ChatFireworks
|
||||
from langchain_core.runnables import Runnable, RunnableConfig
|
||||
|
||||
from agents.agent import create_agent
|
||||
from agents.parser import ParameterizedAgentParser
|
||||
from langchain_benchmarks.model_registration import FIREWORK_NAME_TO_MODEL
|
||||
from langchain_benchmarks.schema import ToolUsageTask
|
||||
from langchain_benchmarks.tool_usage import apply_agent_executor_adapter
|
||||
|
||||
|
||||
class CustomAgentFactory:
|
||||
def __init__(self, task: ToolUsageTask, model: str) -> None:
|
||||
"""Create an OpenAI agent factory for the given task.
|
||||
|
||||
Args:
|
||||
task: The task to create an agent factory for.
|
||||
"""
|
||||
if model not in self.list_models():
|
||||
raise ValueError(f"Unknown model: {model}")
|
||||
self.task = task
|
||||
self.model = model
|
||||
|
||||
@staticmethod
|
||||
def list_models() -> List[str]:
|
||||
"""List all models."""
|
||||
return sorted(
|
||||
[
|
||||
"claude-2.1",
|
||||
"claude-2",
|
||||
*FIREWORK_NAME_TO_MODEL.keys(),
|
||||
]
|
||||
)
|
||||
|
||||
def __call__(self) -> Runnable:
|
||||
env = self.task.create_environment()
|
||||
if self.model in {"claude-2.1", "claude-2"}:
|
||||
model = ChatAnthropic(model=self.model, temperature=0)
|
||||
elif self.model in FIREWORK_NAME_TO_MODEL:
|
||||
model = ChatFireworks(
|
||||
model=FIREWORK_NAME_TO_MODEL[self.model], temperature=0
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown model: {self.model}")
|
||||
|
||||
def _add_task_instructions(
|
||||
input: dict, config: Optional[RunnableConfig] = None, **kwargs
|
||||
) -> dict:
|
||||
"""Add task instructions to the question."""
|
||||
input = input.copy()
|
||||
input["question"] = (
|
||||
f"{self.task.instructions}\nWrite down your answer, "
|
||||
f"but do not explain it. Input: `{input['question']}`"
|
||||
)
|
||||
return input
|
||||
|
||||
agent = create_agent(
|
||||
model,
|
||||
env.tools,
|
||||
ParameterizedAgentParser(
|
||||
wrapping_xml_tag="tool", require_closing_xml_tag=False
|
||||
),
|
||||
)
|
||||
executor = AgentExecutor(
|
||||
agent=agent,
|
||||
tools=env.tools,
|
||||
handle_parsing_errors=True,
|
||||
return_intermediate_steps=True,
|
||||
)
|
||||
|
||||
return _add_task_instructions | apply_agent_executor_adapter(
|
||||
executor, state_reader=env.read_state
|
||||
)
|
||||
@@ -0,0 +1,120 @@
|
||||
import ast
|
||||
import re
|
||||
from typing import Any, Union
|
||||
|
||||
from langchain.agents import AgentOutputParser
|
||||
from langchain.pydantic_v1 import BaseModel, Field, ValidationError
|
||||
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
|
||||
class _ToolInvocationRequest(BaseModel):
|
||||
"""Light-weight pydantic model for validating the raw tool invocation request.
|
||||
|
||||
The purpose of this model, is to make sure that whatever as parsed from
|
||||
the raw llm output has `tool_name` and potential `arguments` fields, and
|
||||
nothing else.
|
||||
"""
|
||||
|
||||
tool_name: str
|
||||
# OK parameterless tools which do not take arguments
|
||||
named_arguments: Any = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ParameterizedAgentParser(AgentOutputParser):
|
||||
"""A generalized parser that makes it easier to parameterize different parsing."""
|
||||
|
||||
wrapping_xml_tag: str
|
||||
"""The tag that wraps the function invocation request.
|
||||
|
||||
For example, if "tool", then the function invocation request should be wrapped
|
||||
in <tool>...</tool>.
|
||||
"""
|
||||
require_closing_xml_tag: bool = False
|
||||
"""Whether we should require a closing tag for the wrapping_xml_tag.
|
||||
|
||||
For example, if True, then the function invocation request should be wrapped
|
||||
"""
|
||||
|
||||
def parse(self, text: str) -> Union[AgentFinish, AgentAction]:
|
||||
"""Parse the output of the agent."""
|
||||
open_tag = f"<{self.wrapping_xml_tag}>"
|
||||
close_tag = f"</{self.wrapping_xml_tag}>"
|
||||
if open_tag in text:
|
||||
# This is a hack to make sure that </tool> is always present
|
||||
# in the output if <tool>. </tool> may be a stop sequence for the
|
||||
# language model, so depending on implementation
|
||||
# the stop sequence may be cut off.
|
||||
# There might be a better way to do this, but this works and
|
||||
# is simple.
|
||||
if not self.require_closing_xml_tag:
|
||||
text += close_tag
|
||||
|
||||
pattern = rf"{open_tag}(?P<invocation>.*?){close_tag}"
|
||||
match = re.search(pattern, text, re.DOTALL)
|
||||
if match:
|
||||
content = match.group("invocation").strip()
|
||||
return parse_invocation(content, self.wrapping_xml_tag)
|
||||
|
||||
return AgentFinish(
|
||||
log=text,
|
||||
return_values={
|
||||
"output": text,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def parse_invocation(text: str, tag: str) -> AgentAction:
|
||||
"""Parse the content of the function invocation.
|
||||
|
||||
Args:
|
||||
text: The text to parse.
|
||||
tag: The tag that wraps the function invocation request.
|
||||
|
||||
Returns:
|
||||
An AgentAction that corresponds to the function invocation.
|
||||
|
||||
Raises:
|
||||
OutputParserException: If the parsing fails.
|
||||
|
||||
This exception is meant to be caught by the agent executor and
|
||||
handled appropriately to provide feedback to the LLM.
|
||||
"""
|
||||
ai_content = f"<{tag}>{text}</{tag}>"
|
||||
|
||||
try:
|
||||
result = ast.literal_eval(text)
|
||||
except Exception as e:
|
||||
# Convert this to something controllable by the user.
|
||||
err_msg = (
|
||||
f"ERROR: Please use the format "
|
||||
f'<{tag}>{{"tool_name": $TOOL_NAME, "arguments": $ARGUMENTS}}</{tag}>'
|
||||
)
|
||||
raise OutputParserException(
|
||||
error=e,
|
||||
llm_output=ai_content,
|
||||
observation=err_msg,
|
||||
send_to_llm=True,
|
||||
)
|
||||
|
||||
try:
|
||||
request = _ToolInvocationRequest(**result)
|
||||
except ValidationError as e:
|
||||
err_msg = (
|
||||
f"ERROR: Please use the format "
|
||||
f'<{tag}>{{"tool_name": $TOOL_NAME, "arguments": $ARGUMENTS}}</{tag}>'
|
||||
)
|
||||
raise OutputParserException(
|
||||
error=e,
|
||||
llm_output=ai_content,
|
||||
send_to_llm=True,
|
||||
observation=err_msg,
|
||||
)
|
||||
|
||||
return AgentActionMessageLog(
|
||||
message_log=[AIMessage(content=ai_content)],
|
||||
tool=request.tool_name,
|
||||
tool_input=request.named_arguments,
|
||||
log=f"\nInvoking {request.tool_name}: {request.named_arguments}\n\t",
|
||||
)
|
||||
@@ -0,0 +1,42 @@
|
||||
AGENT_INSTRUCTIONS_XML_FORMAT = """\
|
||||
In this environment you have access to a set of tools you can use to answer the user's question.
|
||||
|
||||
You may call them like this:
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>$TOOL_NAME</tool_name>
|
||||
<parameters>
|
||||
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
|
||||
...
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
|
||||
Here are the tools available:
|
||||
|
||||
{tool_description}
|
||||
""" # noqa: E501
|
||||
|
||||
AGENT_INSTRUCTIONS_BLOB_STYLE = """\
|
||||
In this environment you have access to a set of tools you can use to answer the user's question.
|
||||
|
||||
Here are the tools available:
|
||||
|
||||
{tool_description}
|
||||
|
||||
You may call one tool at a time using a format that includes <tool> and </tool> tag.
|
||||
|
||||
Inside the tag the content is a python dictionary that uses python literals (e.g., numbers, strings, lists, dictionaries, etc.) to specify the tool invocation.
|
||||
|
||||
It must match the schema of the function as described in the tool description.
|
||||
"arguments" is a dictionary of the arguments to the function.
|
||||
|
||||
<tool>
|
||||
{{
|
||||
"tool_name": $TOOL_NAME,
|
||||
"arguments": $ARGUMENTS
|
||||
}}
|
||||
</tool>
|
||||
|
||||
If you do not know the answer use more tools. You can only take a single action at a time.\
|
||||
""" # noqa: E501
|
||||
@@ -0,0 +1,29 @@
|
||||
"""Test XML encoding and decoding of function definitions, invocation, and results."""
|
||||
from agents.encoder import (
|
||||
FunctionDefinition,
|
||||
TypeScriptEncoder,
|
||||
)
|
||||
|
||||
|
||||
def test_function_definition() -> None:
|
||||
"""Test encoding a function definition."""
|
||||
function_definition = FunctionDefinition(
|
||||
name="test_function",
|
||||
description="A test function",
|
||||
parameters=[
|
||||
{"name": "test_parameter", "type": "str", "description": "A test parameter"}
|
||||
],
|
||||
return_value={"type": "str", "description": "A test return value"},
|
||||
)
|
||||
encoder = TypeScriptEncoder()
|
||||
xml = encoder.visit_function_definition(function_definition)
|
||||
assert xml == (
|
||||
"// A test function\n"
|
||||
"// @param test_parameter A test parameter\n"
|
||||
"// @returns A test return value\n"
|
||||
"function test_function(test_parameter: str): str;"
|
||||
)
|
||||
|
||||
|
||||
# Not important to test other ones right now since we can't parse / interpret
|
||||
# typescript anyway.
|
||||
@@ -0,0 +1,79 @@
|
||||
"""Test XML encoding and decoding of function definitions, invocation, and results."""
|
||||
from agents.encoder import (
|
||||
FunctionDefinition,
|
||||
FunctionInvocation,
|
||||
FunctionResult,
|
||||
XMLEncoder,
|
||||
)
|
||||
|
||||
|
||||
def test_function_definition_encoding() -> None:
|
||||
"""Test encoding a function definition."""
|
||||
function_definition = FunctionDefinition(
|
||||
name="test_function",
|
||||
description="A test function",
|
||||
parameters=[
|
||||
{"name": "test_parameter", "type": "str", "description": "A test parameter"}
|
||||
],
|
||||
return_value={"type": "str", "description": "A test return value"},
|
||||
)
|
||||
encoder = XMLEncoder()
|
||||
xml = encoder.visit_function_definition(function_definition)
|
||||
assert xml == (
|
||||
"<function>\n"
|
||||
"<function_name>test_function</function_name>\n"
|
||||
"<description>\n"
|
||||
"A test function\n"
|
||||
"</description>\n"
|
||||
"<parameters>\n"
|
||||
"<parameter>\n"
|
||||
"<name>test_parameter</name>\n"
|
||||
"<type>str</type>\n"
|
||||
"<description>A test parameter</description>\n"
|
||||
"</parameter>\n"
|
||||
"</parameters>\n"
|
||||
"<return_value>\n"
|
||||
"<type>str</type>\n"
|
||||
"<description>A test return value</description>\n"
|
||||
"</return_value>\n"
|
||||
"</function>"
|
||||
)
|
||||
|
||||
|
||||
def test_function_result_encoding() -> None:
|
||||
"""Test encoding a function result."""
|
||||
function_result = FunctionResult(
|
||||
name="test_function",
|
||||
result="test_result",
|
||||
error="test_error",
|
||||
)
|
||||
encoder = XMLEncoder()
|
||||
xml = encoder.visit_function_result(function_result)
|
||||
assert xml == (
|
||||
"<function_result>\n"
|
||||
"<function_name>test_function</function_name>\n"
|
||||
"<result>test_result</result>\n"
|
||||
"<error>test_error</error>\n"
|
||||
"</function_result>"
|
||||
)
|
||||
|
||||
|
||||
def test_function_invocation() -> None:
|
||||
"""Test function invocation."""
|
||||
function_invocation = FunctionInvocation(
|
||||
name="test_function",
|
||||
arguments=[{"name": "test_argument", "value": "test_value"}],
|
||||
)
|
||||
encoder = XMLEncoder()
|
||||
xml = encoder.visit_function_invocation(function_invocation)
|
||||
assert xml == (
|
||||
"<function_invocation>\n"
|
||||
"<function_name>test_function</function_name>\n"
|
||||
"<arguments>\n"
|
||||
"<argument>\n"
|
||||
"<name>test_argument</name>\n"
|
||||
"<value>test_value</value>\n"
|
||||
"</argument>\n"
|
||||
"</arguments>\n"
|
||||
"</function_invocation>"
|
||||
)
|
||||
@@ -0,0 +1,58 @@
|
||||
import pytest
|
||||
from langchain.tools import tool
|
||||
|
||||
from agents.adapters import convert_tool_to_function_definition
|
||||
from agents.encoder import XMLEncoder
|
||||
|
||||
|
||||
@tool
|
||||
def get_hello() -> str:
|
||||
"""Get hello."""
|
||||
return "hello"
|
||||
|
||||
|
||||
@tool
|
||||
def repeat(x: str) -> str:
|
||||
"""Repeat x.
|
||||
|
||||
Args:
|
||||
x: The string to repeat.
|
||||
|
||||
Returns:
|
||||
The repeated string.
|
||||
"""
|
||||
return x
|
||||
|
||||
|
||||
def test_parameterless_function() -> None:
|
||||
"""Test foo."""
|
||||
function_definition = convert_tool_to_function_definition(get_hello)
|
||||
assert function_definition == {
|
||||
"name": "get_hello",
|
||||
"description": "Get hello.",
|
||||
"parameters": [],
|
||||
"return_value": {
|
||||
"type": "Any",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.skip("Need to fix handling of leading whitespace")
|
||||
def test_function_with_parameters() -> None:
|
||||
import textwrap
|
||||
|
||||
doc = textwrap.dedent(repeat.func.__doc__)
|
||||
assert convert_tool_to_function_definition(repeat) == {
|
||||
"name": "repeat",
|
||||
"description": doc,
|
||||
"parameters": [
|
||||
{
|
||||
"name": "x",
|
||||
"type": "str",
|
||||
"description": "", # Need to fix this
|
||||
}
|
||||
],
|
||||
"return_value": {
|
||||
"type": "Any",
|
||||
},
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
from langchain_adapters.alternative import AgentOutputParser
|
||||
from langchain_core.agents import AgentFinish
|
||||
|
||||
|
||||
def test_parser() -> None:
|
||||
"""Test parser."""
|
||||
parser = AgentOutputParser(require_closing_tag=False, tag="tool")
|
||||
assert isinstance(parser.invoke("goodbye"), AgentFinish)
|
||||
assert parser.invoke("<tool>hello</tool>") == "hello"
|
||||
assert parser.invoke("<tool>hello") == "hello"
|
||||
# assert isinstance(parser.invoke("<tag>hello</tag>"), AgentAction)
|
||||
@@ -0,0 +1,35 @@
|
||||
"""Throttle using a token bucket."""
|
||||
import threading
|
||||
import time
|
||||
|
||||
|
||||
class Throttle:
|
||||
def __init__(self, rate: int) -> None:
|
||||
"""Initialize the throttle."""
|
||||
self.rate = rate
|
||||
self.tokens = 0
|
||||
self._consume_lock = threading.Lock()
|
||||
self.last = None
|
||||
|
||||
def consume(self, amount: int = 0) -> int:
|
||||
"""Consume the given amount of tokens."""
|
||||
with self._consume_lock:
|
||||
now = time.time()
|
||||
|
||||
# initialize on first call to avoid a burst
|
||||
if self.last is None:
|
||||
self.last = now
|
||||
|
||||
elapsed = now - self.last
|
||||
|
||||
if elapsed * self.rate > 1:
|
||||
self.tokens += elapsed * self.rate
|
||||
self.last = now
|
||||
|
||||
self.tokens = min(self.tokens, self.rate)
|
||||
|
||||
if self.tokens >= amount:
|
||||
self.tokens -= amount
|
||||
return amount
|
||||
|
||||
return 0
|
||||
@@ -1,225 +1,226 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "033684fb-65b2-4586-a959-68c614741ca2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Datasets\n",
|
||||
"[](https://colab.research.google.com/github/langchain-ai/langchain-benchmarks/blob/main/docs/source/notebooks/datasets.ipynb)\n",
|
||||
"\n",
|
||||
"Here, we'll see how to work with LangSmith datasets."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install -U langchain-benchmarks"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "6d272fbf-710e-4a49-a0da-67e010541905",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_benchmarks import clone_public_dataset, download_public_dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "18ee0f96-e5c4-4ae9-aebf-7d8b88c51662",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's first download the dataset to the local file system"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "58b94f6d-0c91-4361-9b22-f758ffaa150a",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Fetching examples...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "5a2fad8c0c3549ec96a3b38fe8a002b0",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/21 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Done fetching examples.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"download_public_dataset(\n",
|
||||
" \"https://smith.langchain.com/public/452ccafc-18e1-4314-885b-edd735f17b9d/examples\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "841db832-b0d3-4fd1-8531-1154ec9b3caa",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"we can take a look at the first two examples"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "664e90fc-af84-4c5f-a3dd-5d9ffe649650",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[\n",
|
||||
" {\n",
|
||||
" \"created_at\": \"2023-11-15T15:26:53.511629\",\n",
|
||||
" \"dataset_id\": \"9f73165c-d333-4d14-8f59-bd7eede5db08\",\n",
|
||||
" \"id\": \"0703a989-2693-4039-a1f6-7281fc1b4cb0\",\n",
|
||||
" \"inputs\": {\n",
|
||||
" \"question\": \"do bob and alice live in the same city?\"\n",
|
||||
" },\n",
|
||||
" \"modified_at\": \"2023-11-15T15:26:53.511629\",\n",
|
||||
" \"outputs\": {\n",
|
||||
" \"expected_steps\": [\n",
|
||||
" \"find_users_by_name\",\n",
|
||||
" \"get_user_location\",\n",
|
||||
" \"get_city_for_location\",\n",
|
||||
" \"get_user_location\",\n",
|
||||
" \"get_city_for_location\"\n",
|
||||
" ],\n",
|
||||
" \"order_matters\": false,\n",
|
||||
" \"reference\": \"no\"\n",
|
||||
" },\n",
|
||||
" \"runs\": []\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"created_at\": \"2023-11-15T15:26:53.491359\",\n",
|
||||
" \"dataset_id\": \"9f73165c-d333-4d14-8f59-bd7eede5db08\",\n",
|
||||
" \"id\": \"b258b95a-9524-4da7-b758-c5481109322d\",\n",
|
||||
" \"inputs\": {\n",
|
||||
" \"question\": \"Is it likely that Donna is outside with an umbrella at this time?\"\n",
|
||||
" },\n",
|
||||
" \"modified_at\": \"2023-11-15T15:26:53.491359\",\n",
|
||||
" \"outputs\": {\n",
|
||||
" \"expected_steps\": [\n",
|
||||
" \"find_users_by_name\",\n",
|
||||
" \"get_user_location\",\n",
|
||||
" \"get_current_time_for_location\",\n",
|
||||
" \"get_current_weather_for_location\"\n",
|
||||
" ],\n",
|
||||
" \"order_matters\": false,\n",
|
||||
" \"reference\": \"yes\"\n",
|
||||
" },\n",
|
||||
" \"runs\": []\n",
|
||||
" }\n",
|
||||
"]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"\n",
|
||||
"with open(\"./e95d45da-aaa3-44b3-ba2b-7c15ff6e46f5.json\", \"r\", encoding=\"utf-8\") as f:\n",
|
||||
" print(json.dumps(json.load(f)[:2], indent=2, sort_keys=True))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2c6cf01f-466b-406d-b4c7-2395747780fd",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can also clone the dataset to our local tenant"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e4dea4df-2f1c-436b-a71c-49ffb2295ccc",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Executing this command will clone the dataset to your own LangSmith tenant. \n",
|
||||
"For this to work you must have a [LangSmith account](https://smith.langchain.com/) set up."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"# Get from https://smith.langchain.com/settings\n",
|
||||
"os.environ[\"LANGCHAIN_API_KEY\"] = \"ls_...\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "18d0b905-2a6a-4752-a7cb-8653bd9049e3",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"clone_public_dataset(\n",
|
||||
" \"https://smith.langchain.com/public/452ccafc-18e1-4314-885b-edd735f17b9d/examples\",\n",
|
||||
" dataset_name=\"Agent Dataset\",\n",
|
||||
")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.2"
|
||||
}
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "033684fb-65b2-4586-a959-68c614741ca2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Datasets\n",
|
||||
"\n",
|
||||
"Here, we'll see how to work with LangSmith datasets."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d9aa20db",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install -U langchain-benchmarks"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "6d272fbf-710e-4a49-a0da-67e010541905",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_benchmarks import clone_public_dataset, download_public_dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "18ee0f96-e5c4-4ae9-aebf-7d8b88c51662",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's first download the dataset to the local file system"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "58b94f6d-0c91-4361-9b22-f758ffaa150a",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Fetching examples...\n"
|
||||
]
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "5a2fad8c0c3549ec96a3b38fe8a002b0",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/21 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Done fetching examples.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"download_public_dataset(\n",
|
||||
" \"https://smith.langchain.com/public/452ccafc-18e1-4314-885b-edd735f17b9d/examples\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "841db832-b0d3-4fd1-8531-1154ec9b3caa",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"we can take a look at the first two examples"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "664e90fc-af84-4c5f-a3dd-5d9ffe649650",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[\n",
|
||||
" {\n",
|
||||
" \"created_at\": \"2023-11-15T15:26:53.511629\",\n",
|
||||
" \"dataset_id\": \"9f73165c-d333-4d14-8f59-bd7eede5db08\",\n",
|
||||
" \"id\": \"0703a989-2693-4039-a1f6-7281fc1b4cb0\",\n",
|
||||
" \"inputs\": {\n",
|
||||
" \"question\": \"do bob and alice live in the same city?\"\n",
|
||||
" },\n",
|
||||
" \"modified_at\": \"2023-11-15T15:26:53.511629\",\n",
|
||||
" \"outputs\": {\n",
|
||||
" \"expected_steps\": [\n",
|
||||
" \"find_users_by_name\",\n",
|
||||
" \"get_user_location\",\n",
|
||||
" \"get_city_for_location\",\n",
|
||||
" \"get_user_location\",\n",
|
||||
" \"get_city_for_location\"\n",
|
||||
" ],\n",
|
||||
" \"order_matters\": false,\n",
|
||||
" \"reference\": \"no\"\n",
|
||||
" },\n",
|
||||
" \"runs\": []\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"created_at\": \"2023-11-15T15:26:53.491359\",\n",
|
||||
" \"dataset_id\": \"9f73165c-d333-4d14-8f59-bd7eede5db08\",\n",
|
||||
" \"id\": \"b258b95a-9524-4da7-b758-c5481109322d\",\n",
|
||||
" \"inputs\": {\n",
|
||||
" \"question\": \"Is it likely that Donna is outside with an umbrella at this time?\"\n",
|
||||
" },\n",
|
||||
" \"modified_at\": \"2023-11-15T15:26:53.491359\",\n",
|
||||
" \"outputs\": {\n",
|
||||
" \"expected_steps\": [\n",
|
||||
" \"find_users_by_name\",\n",
|
||||
" \"get_user_location\",\n",
|
||||
" \"get_current_time_for_location\",\n",
|
||||
" \"get_current_weather_for_location\"\n",
|
||||
" ],\n",
|
||||
" \"order_matters\": false,\n",
|
||||
" \"reference\": \"yes\"\n",
|
||||
" },\n",
|
||||
" \"runs\": []\n",
|
||||
" }\n",
|
||||
"]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"\n",
|
||||
"with open(\"./e95d45da-aaa3-44b3-ba2b-7c15ff6e46f5.json\", \"r\", encoding=\"utf-8\") as f:\n",
|
||||
" print(json.dumps(json.load(f)[:2], indent=2, sort_keys=True))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2c6cf01f-466b-406d-b4c7-2395747780fd",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can also clone the dataset to our local tenant"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e4dea4df-2f1c-436b-a71c-49ffb2295ccc",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Executing this command will clone the dataset to your own LangSmith tenant. \n",
|
||||
"For this to work you must have a [LangSmith account](https://smith.langchain.com/) set up."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b20ba9a6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"# Get from https://smith.langchain.com/settings\n",
|
||||
"os.environ[\"LANGCHAIN_API_KEY\"] = \"ls_...\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "18d0b905-2a6a-4752-a7cb-8653bd9049e3",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"clone_public_dataset(\n",
|
||||
" \"https://smith.langchain.com/public/452ccafc-18e1-4314-885b-edd735f17b9d/examples\",\n",
|
||||
" dataset_name=\"Agent Dataset\",\n",
|
||||
")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
|
||||
@@ -0,0 +1,267 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d9cb90d8-e6a1-4c89-9cde-0e6c0a28f5c0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Model Registry\n",
|
||||
"\n",
|
||||
"Here, we'll see how to access the model registry.\n",
|
||||
"\n",
|
||||
"If you see a model that you want to use and it's missing, please open a PR to add it!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "31831289-51fb-4ee5-98f3-0476cf11b187",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_benchmarks import model_registry"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "aaed190d-fa4b-4445-9bfb-0e784e2a083b",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<table>\n",
|
||||
"<thead>\n",
|
||||
"<tr><th>Name </th><th>Type </th><th>Provider </th><th>Description </th></tr>\n",
|
||||
"</thead>\n",
|
||||
"<tbody>\n",
|
||||
"<tr><td>gpt-3.5-turbo-1106 </td><td>chat </td><td>openai </td><td>The latest GPT-3.5 Turbo model with improved instruction following, JSON mode, reproducible outputs, parallel function calling, and more. Returns a maximum of 4,096 output tokens. Learn more.</td></tr>\n",
|
||||
"<tr><td>gpt-3.5-turbo </td><td>chat </td><td>openai </td><td>Currently points to gpt-3.5-turbo-0613. </td></tr>\n",
|
||||
"<tr><td>gpt-3.5-turbo-16k </td><td>chat </td><td>openai </td><td>Currently points to gpt-3.5-turbo-0613. </td></tr>\n",
|
||||
"<tr><td>gpt-3.5-turbo-instruct</td><td>llm </td><td>openai </td><td>Similar capabilities as text-davinci-003 but compatible with legacy Completions endpoint and not Chat Completions. </td></tr>\n",
|
||||
"<tr><td>gpt-3.5-turbo-0613 </td><td>chat </td><td>openai </td><td>Legacy Snapshot of gpt-3.5-turbo from June 13th 2023. Will be deprecated on June 13, 2024. </td></tr>\n",
|
||||
"<tr><td>gpt-3.5-turbo-16k-0613</td><td>chat </td><td>openai </td><td>Legacy Snapshot of gpt-3.5-16k-turbo from June 13th 2023. Will be deprecated on June 13, 2024. </td></tr>\n",
|
||||
"<tr><td>gpt-3.5-turbo-0301 </td><td>chat </td><td>openai </td><td>Legacy Snapshot of gpt-3.5-turbo from March 1st 2023. Will be deprecated on June 13th 2024. </td></tr>\n",
|
||||
"<tr><td>text-davinci-003 </td><td>llm </td><td>openai </td><td>Legacy Can do language tasks with better quality and consistency than the curie, babbage, or ada models. Will be deprecated on Jan 4th 2024. </td></tr>\n",
|
||||
"<tr><td>text-davinci-002 </td><td>llm </td><td>openai </td><td>Legacy Similar capabilities to text-davinci-003 but trained with supervised fine-tuning instead of reinforcement learning. Will be deprecated on Jan 4th 2024. </td></tr>\n",
|
||||
"<tr><td>code-davinci-002 </td><td>llm </td><td>openai </td><td>Legacy Optimized for code-completion tasks. Will be deprecated on Jan 4th 2024. </td></tr>\n",
|
||||
"<tr><td>llama-v2-7b-chat </td><td>chat </td><td>fireworks </td><td>7b parameter LlamaChat model </td></tr>\n",
|
||||
"<tr><td>llama-v2-13b-chat </td><td>chat </td><td>fireworks </td><td>13b parameter LlamaChat model </td></tr>\n",
|
||||
"<tr><td>llama-v2-70b-chat </td><td>chat </td><td>fireworks </td><td>70b parameter LlamaChat model </td></tr>\n",
|
||||
"</tbody>\n",
|
||||
"</table>"
|
||||
],
|
||||
"text/plain": [
|
||||
"ModelRegistry(registered_models=[RegisteredModel(name='gpt-3.5-turbo-1106', provider='openai', description='The latest GPT-3.5 Turbo model with improved instruction following, JSON mode, reproducible outputs, parallel function calling, and more. Returns a maximum of 4,096 output tokens. Learn more.', params={'model': 'gpt-3.5-turbo-1106'}, type='chat', path=None), RegisteredModel(name='gpt-3.5-turbo', provider='openai', description='Currently points to gpt-3.5-turbo-0613.', params={'model': 'gpt-3.5-turbo'}, type='chat', path=None), RegisteredModel(name='gpt-3.5-turbo-16k', provider='openai', description='Currently points to gpt-3.5-turbo-0613.', params={'model': 'gpt-3.5-turbo-16k'}, type='chat', path=None), RegisteredModel(name='gpt-3.5-turbo-instruct', provider='openai', description='Similar capabilities as text-davinci-003 but compatible with legacy Completions endpoint and not Chat Completions.', params={'model': 'gpt-3.5-turbo-instruct'}, type='llm', path=None), RegisteredModel(name='gpt-3.5-turbo-0613', provider='openai', description='Legacy Snapshot of gpt-3.5-turbo from June 13th 2023. Will be deprecated on June 13, 2024.', params={'model': 'gpt-3.5-turbo-0613'}, type='chat', path=None), RegisteredModel(name='gpt-3.5-turbo-16k-0613', provider='openai', description='Legacy Snapshot of gpt-3.5-16k-turbo from June 13th 2023. Will be deprecated on June 13, 2024.', params={'model': 'gpt-3.5-turbo-16k-0613'}, type='chat', path=None), RegisteredModel(name='gpt-3.5-turbo-0301', provider='openai', description='Legacy Snapshot of gpt-3.5-turbo from March 1st 2023. Will be deprecated on June 13th 2024.', params={'model': 'gpt-3.5-turbo-0301'}, type='chat', path=None), RegisteredModel(name='text-davinci-003', provider='openai', description='Legacy Can do language tasks with better quality and consistency than the curie, babbage, or ada models. Will be deprecated on Jan 4th 2024.', params={'model': 'text-davinci-003'}, type='llm', path=None), RegisteredModel(name='text-davinci-002', provider='openai', description='Legacy Similar capabilities to text-davinci-003 but trained with supervised fine-tuning instead of reinforcement learning. Will be deprecated on Jan 4th 2024.', params={'model': 'text-davinci-002'}, type='llm', path=None), RegisteredModel(name='code-davinci-002', provider='openai', description='Legacy Optimized for code-completion tasks. Will be deprecated on Jan 4th 2024.', params={'model': 'code-davinci-002'}, type='llm', path=None), RegisteredModel(name='llama-v2-7b-chat', provider='fireworks', description='7b parameter LlamaChat model', params={'model': 'accounts/fireworks/models/llama-v2-7b-chat'}, type='chat', path=None), RegisteredModel(name='llama-v2-13b-chat', provider='fireworks', description='13b parameter LlamaChat model', params={'model': 'accounts/fireworks/models/llama-v2-13b-chat'}, type='chat', path=None), RegisteredModel(name='llama-v2-70b-chat', provider='fireworks', description='70b parameter LlamaChat model', params={'model': 'accounts/fireworks/models/llama-v2-70b-chat'}, type='chat', path=None)])"
|
||||
]
|
||||
},
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model_registry"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "64bfc631-1f1e-4cf4-8636-b8be7b46fef8",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<table>\n",
|
||||
"<tbody>\n",
|
||||
"<tr><td>name </td><td>gpt-3.5-turbo-1106 </td></tr>\n",
|
||||
"<tr><td>type </td><td>chat </td></tr>\n",
|
||||
"<tr><td>provider </td><td>openai </td></tr>\n",
|
||||
"<tr><td>description</td><td>The latest GPT-3.5 Turbo model with improved instruction following, JSON mode, reproducible outputs, parallel function calling, and more. Returns a maximum of 4,096 output tokens. Learn more.</td></tr>\n",
|
||||
"<tr><td>model_path </td><td>langchain.chat_models.openai.ChatOpenAI </td></tr>\n",
|
||||
"</tbody>\n",
|
||||
"</table>"
|
||||
],
|
||||
"text/plain": [
|
||||
"RegisteredModel(name='gpt-3.5-turbo-1106', provider='openai', description='The latest GPT-3.5 Turbo model with improved instruction following, JSON mode, reproducible outputs, parallel function calling, and more. Returns a maximum of 4,096 output tokens. Learn more.', params={'model': 'gpt-3.5-turbo-1106'}, type='chat', path=None)"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"registered_model = model_registry[0]\n",
|
||||
"registered_model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "3604d49e-afbe-48ad-ac10-1e538b1ad376",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = registered_model.get_model()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "bdece532-9843-427a-a10b-4545ed4ec151",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content='Hello there! How can I assist you today?')"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.invoke('hello!')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "db40d4da-dc70-4e6d-b7e8-61de1e15ed2e",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<table>\n",
|
||||
"<thead>\n",
|
||||
"<tr><th>Name </th><th>Type </th><th>Provider </th><th>Description </th></tr>\n",
|
||||
"</thead>\n",
|
||||
"<tbody>\n",
|
||||
"<tr><td>gpt-3.5-turbo-1106</td><td>chat </td><td>openai </td><td>The latest GPT-3.5 Turbo model with improved instruction following, JSON mode, reproducible outputs, parallel function calling, and more. Returns a maximum of 4,096 output tokens. Learn more.</td></tr>\n",
|
||||
"<tr><td>gpt-3.5-turbo </td><td>chat </td><td>openai </td><td>Currently points to gpt-3.5-turbo-0613. </td></tr>\n",
|
||||
"<tr><td>gpt-3.5-turbo-16k </td><td>chat </td><td>openai </td><td>Currently points to gpt-3.5-turbo-0613. </td></tr>\n",
|
||||
"</tbody>\n",
|
||||
"</table>"
|
||||
],
|
||||
"text/plain": [
|
||||
"ModelRegistry(registered_models=[RegisteredModel(name='gpt-3.5-turbo-1106', provider='openai', description='The latest GPT-3.5 Turbo model with improved instruction following, JSON mode, reproducible outputs, parallel function calling, and more. Returns a maximum of 4,096 output tokens. Learn more.', params={'model': 'gpt-3.5-turbo-1106'}, type='chat', path=None), RegisteredModel(name='gpt-3.5-turbo', provider='openai', description='Currently points to gpt-3.5-turbo-0613.', params={'model': 'gpt-3.5-turbo'}, type='chat', path=None), RegisteredModel(name='gpt-3.5-turbo-16k', provider='openai', description='Currently points to gpt-3.5-turbo-0613.', params={'model': 'gpt-3.5-turbo-16k'}, type='chat', path=None)])"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model_registry[:3]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "9874846a-52f3-4921-b1ed-0858521bb9a9",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<table>\n",
|
||||
"<thead>\n",
|
||||
"<tr><th>Name </th><th>Type </th><th>Provider </th><th>Description </th></tr>\n",
|
||||
"</thead>\n",
|
||||
"<tbody>\n",
|
||||
"<tr><td>llama-v2-7b-chat </td><td>chat </td><td>fireworks </td><td>7b parameter LlamaChat model </td></tr>\n",
|
||||
"<tr><td>llama-v2-13b-chat</td><td>chat </td><td>fireworks </td><td>13b parameter LlamaChat model</td></tr>\n",
|
||||
"<tr><td>llama-v2-70b-chat</td><td>chat </td><td>fireworks </td><td>70b parameter LlamaChat model</td></tr>\n",
|
||||
"</tbody>\n",
|
||||
"</table>"
|
||||
],
|
||||
"text/plain": [
|
||||
"ModelRegistry(registered_models=[RegisteredModel(name='llama-v2-7b-chat', provider='fireworks', description='7b parameter LlamaChat model', params={'model': 'accounts/fireworks/models/llama-v2-7b-chat'}, type='chat', path=None), RegisteredModel(name='llama-v2-13b-chat', provider='fireworks', description='13b parameter LlamaChat model', params={'model': 'accounts/fireworks/models/llama-v2-13b-chat'}, type='chat', path=None), RegisteredModel(name='llama-v2-70b-chat', provider='fireworks', description='70b parameter LlamaChat model', params={'model': 'accounts/fireworks/models/llama-v2-70b-chat'}, type='chat', path=None)])"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model_registry.filter(provider='fireworks')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "eb531591-f46b-4745-ae67-4dfd6217ec5f",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"gpt-3.5-turbo-1106\n",
|
||||
"gpt-3.5-turbo\n",
|
||||
"gpt-3.5-turbo-16k\n",
|
||||
"gpt-3.5-turbo-instruct\n",
|
||||
"gpt-3.5-turbo-0613\n",
|
||||
"gpt-3.5-turbo-16k-0613\n",
|
||||
"gpt-3.5-turbo-0301\n",
|
||||
"text-davinci-003\n",
|
||||
"text-davinci-002\n",
|
||||
"code-davinci-002\n",
|
||||
"llama-v2-7b-chat\n",
|
||||
"llama-v2-13b-chat\n",
|
||||
"llama-v2-70b-chat\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for registered_model in model_registry:\n",
|
||||
" print(registered_model.name)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
from langchain_benchmarks.model_registration import model_registry
|
||||
from langchain_benchmarks.registration import registry
|
||||
from langchain_benchmarks.utils._langsmith import (
|
||||
clone_public_dataset,
|
||||
@@ -5,4 +6,10 @@ from langchain_benchmarks.utils._langsmith import (
|
||||
)
|
||||
|
||||
# Please keep this list sorted!
|
||||
__all__ = ["clone_public_dataset", "download_public_dataset", "registry"]
|
||||
__all__ = [
|
||||
"clone_public_dataset",
|
||||
"download_public_dataset",
|
||||
"registry",
|
||||
"model_registry",
|
||||
|
||||
]
|
||||
|
||||
@@ -0,0 +1,152 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain_benchmarks.schema import RegisteredModel, ModelRegistry
|
||||
|
||||
_OpenAIModels = [
|
||||
RegisteredModel(
|
||||
provider="openai",
|
||||
name="gpt-3.5-turbo-1106",
|
||||
type="chat",
|
||||
description=(
|
||||
"The latest GPT-3.5 Turbo model with improved instruction following, "
|
||||
"JSON mode, reproducible outputs, parallel function calling, and more. "
|
||||
"Returns a maximum of 4,096 output tokens. Learn more."
|
||||
),
|
||||
params={
|
||||
"model": "gpt-3.5-turbo-1106",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="openai",
|
||||
name="gpt-3.5-turbo",
|
||||
type="chat",
|
||||
description="Currently points to gpt-3.5-turbo-0613.",
|
||||
params={
|
||||
"model": "gpt-3.5-turbo",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="openai",
|
||||
name="gpt-3.5-turbo-16k",
|
||||
type="chat",
|
||||
description="Currently points to gpt-3.5-turbo-0613.",
|
||||
params={
|
||||
"model": "gpt-3.5-turbo-16k",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="openai",
|
||||
name="gpt-3.5-turbo-instruct",
|
||||
type="llm",
|
||||
description=(
|
||||
"Similar capabilities as text-davinci-003 but compatible with legacy "
|
||||
"Completions endpoint and not Chat Completions."
|
||||
),
|
||||
params={
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="openai",
|
||||
name="gpt-3.5-turbo-0613",
|
||||
type="chat",
|
||||
description=(
|
||||
"Legacy Snapshot of gpt-3.5-turbo from June 13th 2023. "
|
||||
"Will be deprecated on June 13, 2024."
|
||||
),
|
||||
params={
|
||||
"model": "gpt-3.5-turbo-0613",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="openai",
|
||||
name="gpt-3.5-turbo-16k-0613",
|
||||
type="chat",
|
||||
description=(
|
||||
"Legacy Snapshot of gpt-3.5-16k-turbo from June 13th 2023. "
|
||||
"Will be deprecated on June 13, 2024."
|
||||
),
|
||||
params={
|
||||
"model": "gpt-3.5-turbo-16k-0613",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="openai",
|
||||
name="gpt-3.5-turbo-0301",
|
||||
type="chat",
|
||||
description=(
|
||||
"Legacy Snapshot of gpt-3.5-turbo from March 1st 2023. "
|
||||
"Will be deprecated on June 13th 2024."
|
||||
),
|
||||
params={
|
||||
"model": "gpt-3.5-turbo-0301",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="openai",
|
||||
name="text-davinci-003",
|
||||
type="llm",
|
||||
description=(
|
||||
"Legacy Can do language tasks with better quality and consistency than "
|
||||
"the curie, babbage, or ada models. Will be deprecated on Jan 4th 2024."
|
||||
),
|
||||
params={
|
||||
"model": "text-davinci-003",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="openai",
|
||||
name="text-davinci-002",
|
||||
type="llm",
|
||||
description=(
|
||||
"Legacy Similar capabilities to text-davinci-003 but trained with "
|
||||
"supervised fine-tuning instead of reinforcement learning. "
|
||||
"Will be deprecated on Jan 4th 2024."
|
||||
),
|
||||
params={
|
||||
"model": "text-davinci-002",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="openai",
|
||||
name="code-davinci-002",
|
||||
type="llm",
|
||||
description="Legacy Optimized for code-completion tasks. Will be deprecated "
|
||||
"on Jan 4th 2024.",
|
||||
params={
|
||||
"model": "code-davinci-002",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
_FireworksModels = [
|
||||
RegisteredModel(
|
||||
provider="fireworks",
|
||||
name="llama-v2-7b-chat",
|
||||
type="chat",
|
||||
description="7b parameter LlamaChat model",
|
||||
params={
|
||||
"model": "accounts/fireworks/models/llama-v2-7b-chat",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="fireworks",
|
||||
name="llama-v2-13b-chat",
|
||||
type="chat",
|
||||
description="13b parameter LlamaChat model",
|
||||
params={
|
||||
"model": "accounts/fireworks/models/llama-v2-13b-chat",
|
||||
},
|
||||
),
|
||||
RegisteredModel(
|
||||
provider="fireworks",
|
||||
name="llama-v2-70b-chat",
|
||||
type="chat",
|
||||
description="70b parameter LlamaChat model",
|
||||
params={
|
||||
"model": "accounts/fireworks/models/llama-v2-70b-chat",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
model_registry = ModelRegistry(registered_models=_OpenAIModels + _FireworksModels)
|
||||
@@ -2,8 +2,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import importlib
|
||||
import urllib
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union, Sequence
|
||||
|
||||
from langchain_core.language_models import BaseChatModel, BaseLanguageModel
|
||||
from typing_extensions import Literal
|
||||
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.schema import BaseRetriever
|
||||
@@ -153,6 +157,7 @@ class Registry:
|
||||
raise ValueError(
|
||||
f"Duplicate task name {task.name}. " f"Task names must be unique."
|
||||
)
|
||||
seen_names.add(task.name)
|
||||
|
||||
def _repr_html_(self) -> str:
|
||||
"""Return an HTML representation of the registry."""
|
||||
@@ -210,3 +215,190 @@ class Registry:
|
||||
if not isinstance(task, BaseTask):
|
||||
raise TypeError("Only tasks can be added to the registry.")
|
||||
self.tasks.append(task)
|
||||
|
||||
|
||||
Provider = Literal["fireworks", "openai"]
|
||||
ModelType = Literal["chat", "llm"]
|
||||
AUTHORIZED_NAMESPACES = {"langchain"}
|
||||
|
||||
|
||||
def _get_model_class_from_path(
|
||||
path: str
|
||||
) -> Union[Type[BaseChatModel], Type[BaseLanguageModel]]:
|
||||
"""Get the class of the model."""
|
||||
module_name, attribute_name = path.rsplit(".", 1)
|
||||
top_namespace = path.split(".")[0]
|
||||
|
||||
if top_namespace not in AUTHORIZED_NAMESPACES:
|
||||
raise ValueError(
|
||||
f"Unauthorized namespace {top_namespace}. "
|
||||
f"Authorized namespaces are: {AUTHORIZED_NAMESPACES}"
|
||||
)
|
||||
|
||||
# Import the module dynamically
|
||||
module = importlib.import_module(module_name)
|
||||
model_class = getattr(module, attribute_name)
|
||||
if not issubclass(model_class, (BaseLanguageModel, BaseChatModel)):
|
||||
raise ValueError(
|
||||
f"Model class {model_class} is not a subclass of BaseLanguageModel"
|
||||
)
|
||||
return model_class
|
||||
|
||||
|
||||
def _get_default_path(provider: str, type_: ModelType) -> str:
|
||||
"""Get the default path for a model."""
|
||||
paths = {
|
||||
("fireworks", "chat"): "langchain.chat_models.fireworks.ChatFireworks",
|
||||
("fireworks", "llm"): "langchain.language_models.fireworks.Fireworks",
|
||||
("openai", "chat"): "langchain.chat_models.openai.ChatOpenAI",
|
||||
("openai", "llm"): "langchain.language_models.openai.OpenAI",
|
||||
}
|
||||
|
||||
if (provider, type_) not in paths:
|
||||
raise ValueError(f"Unknown provider {provider} and type {type_}")
|
||||
|
||||
return paths[(provider, type_)]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class RegisteredModel:
|
||||
"""Descriptive information about a model.
|
||||
|
||||
This information can be used to instantiate the underlying model.
|
||||
"""
|
||||
|
||||
name: str
|
||||
provider: Provider
|
||||
description: str
|
||||
params: Dict[str, Any]
|
||||
type: ModelType
|
||||
# Path to the model class.
|
||||
# For example, "langchain.chat_models.anthropic import ChatAnthropicModel"
|
||||
path: Optional[str] = None # If not provided, will use default path
|
||||
|
||||
def get_model(
|
||||
self, *, model_params: Optional[Dict[str, Any]] = None
|
||||
) -> Union[BaseChatModel, BaseLanguageModel]:
|
||||
"""Get the class of the model."""
|
||||
all_params = {**self.params, **(model_params or {})}
|
||||
model_class = _get_model_class_from_path(self.model_path)
|
||||
return model_class(**all_params)
|
||||
|
||||
@property
|
||||
def model_path(self) -> str:
|
||||
"""Get the path of the model."""
|
||||
return self.path or _get_default_path(self.provider, self.type)
|
||||
|
||||
@property
|
||||
def _table(self) -> List[List[str]]:
|
||||
"""Return a table representation of the environment."""
|
||||
return [
|
||||
["name", self.name],
|
||||
["type", self.type],
|
||||
["provider", self.provider],
|
||||
["description", self.description],
|
||||
["model_path", self.model_path],
|
||||
]
|
||||
|
||||
def _repr_html_(self) -> str:
|
||||
"""Return an HTML representation of the environment."""
|
||||
return tabulate(
|
||||
self._table,
|
||||
tablefmt="unsafehtml",
|
||||
)
|
||||
|
||||
|
||||
StrFilter = Union[None, str, Sequence[str]]
|
||||
|
||||
|
||||
def _is_in_filter(actual_value: str, filter_value: StrFilter) -> bool:
|
||||
"""Filter for a string attribute."""
|
||||
if filter_value is None:
|
||||
return True
|
||||
|
||||
if isinstance(filter_value, str):
|
||||
return actual_value == filter_value
|
||||
|
||||
return actual_value in filter_value
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=False)
|
||||
class ModelRegistry:
|
||||
registered_models: Sequence[RegisteredModel]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate that all the tasks have unique names and IDs."""
|
||||
seen_names = set()
|
||||
for model in self.registered_models:
|
||||
if model.name in seen_names:
|
||||
raise ValueError(
|
||||
f"Duplicate model name {model.name}. " f"Task names must be unique."
|
||||
)
|
||||
seen_names.add(model.name)
|
||||
|
||||
def get_model(self, name: str) -> Optional[RegisteredModel]:
|
||||
"""Get model info."""
|
||||
return next(model for model in self.registered_models if model.name == name)
|
||||
|
||||
def filter(
|
||||
self,
|
||||
*,
|
||||
type: StrFilter = None,
|
||||
name: StrFilter = None,
|
||||
provider: StrFilter = None,
|
||||
) -> ModelRegistry:
|
||||
"""Filter the tasks in the registry."""
|
||||
models = self.registered_models
|
||||
selected_models = []
|
||||
|
||||
for model in models:
|
||||
if not _is_in_filter(model.type, type):
|
||||
continue
|
||||
if not _is_in_filter(model.name, name):
|
||||
continue
|
||||
if not _is_in_filter(model.provider, provider):
|
||||
continue
|
||||
selected_models.append(model)
|
||||
return ModelRegistry(registered_models=selected_models)
|
||||
|
||||
def _repr_html_(self) -> str:
|
||||
"""Return an HTML representation of the registry."""
|
||||
headers = [
|
||||
"Name",
|
||||
"Type",
|
||||
"Provider",
|
||||
"Description",
|
||||
]
|
||||
table = [
|
||||
[
|
||||
model.name,
|
||||
model.type,
|
||||
model.provider,
|
||||
model.description,
|
||||
]
|
||||
for model in self.registered_models
|
||||
]
|
||||
return tabulate(table, headers=headers, tablefmt="unsafehtml")
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of tasks in the registry."""
|
||||
return len(self.registered_models)
|
||||
|
||||
def __iter__(self) -> Iterable[RegisteredModel]:
|
||||
"""Iterate over the tasks in the registry."""
|
||||
return iter(self.registered_models)
|
||||
|
||||
def __getitem__(
|
||||
self, key: Union[int, str]
|
||||
) -> Union[RegisteredModel, ModelRegistry]:
|
||||
"""Get an environment from the registry."""
|
||||
if isinstance(key, slice):
|
||||
return ModelRegistry(registered_models=self.registered_models[key])
|
||||
elif isinstance(key, (int, str)):
|
||||
# If key is an integer, return the corresponding environment
|
||||
if isinstance(key, str):
|
||||
return self.get_model(key)
|
||||
else:
|
||||
return self.registered_models[key]
|
||||
else:
|
||||
raise TypeError("Key must be an integer or a slice.")
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
import pytest
|
||||
|
||||
from langchain_benchmarks.schema import RegisteredModel, ModelRegistry
|
||||
|
||||
# Create some sample RegisteredModel instances for testing
|
||||
SAMPLE_MODELS = [
|
||||
RegisteredModel(
|
||||
"model1", "fireworks", "Description 1", {"param1": "value1"}, "chat"
|
||||
),
|
||||
RegisteredModel("model2", "openai", "Description 2", {"param2": "value2"}, "llm"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_registry() -> ModelRegistry:
|
||||
return ModelRegistry(SAMPLE_MODELS)
|
||||
|
||||
|
||||
def test_init() -> None:
|
||||
# Test the constructor of ModelRegistry
|
||||
registry = ModelRegistry(SAMPLE_MODELS)
|
||||
assert len(registry.registered_models) == 2
|
||||
|
||||
|
||||
def test_get_model(sample_registry: ModelRegistry) -> None:
|
||||
# Test the get_model method
|
||||
model = sample_registry.get_model("model1")
|
||||
assert model.name == "model1"
|
||||
|
||||
|
||||
def test_filter(sample_registry: ModelRegistry) -> None:
|
||||
# Test the filter method
|
||||
filtered_registry = sample_registry.filter(type="chat")
|
||||
assert len(filtered_registry.registered_models) == 1
|
||||
assert filtered_registry.registered_models[0].type == "chat"
|
||||
|
||||
|
||||
def test_repr_html(sample_registry: ModelRegistry) -> None:
|
||||
# Test the _repr_html_ method
|
||||
html_representation = sample_registry._repr_html_()
|
||||
assert "<table>" in html_representation
|
||||
|
||||
|
||||
def test_len(sample_registry: ModelRegistry) -> None:
|
||||
# Test the __len__ method
|
||||
assert len(sample_registry) == 2
|
||||
|
||||
|
||||
def test_iter(sample_registry: ModelRegistry) -> None:
|
||||
# Test the __iter__ method
|
||||
models = list(iter(sample_registry))
|
||||
assert len(models) == 2
|
||||
assert isinstance(models[0], RegisteredModel)
|
||||
|
||||
|
||||
def test_getitem(sample_registry: ModelRegistry) -> None:
|
||||
# Test the __getitem__ method for integer and string keys
|
||||
model = sample_registry[0]
|
||||
assert model.name == "model1"
|
||||
model = sample_registry["model2"]
|
||||
assert model.name == "model2"
|
||||
|
||||
|
||||
def test_getitem_slice(sample_registry: ModelRegistry) -> None:
|
||||
# Test the __getitem__ method for slices
|
||||
sliced_registry = sample_registry[:1]
|
||||
assert len(sliced_registry.registered_models) == 1
|
||||
assert sliced_registry.registered_models[0].name == "model1"
|
||||
Reference in New Issue
Block a user