incorporate tool calling (#131)

This commit is contained in:
ccurme
2024-04-27 18:10:01 -04:00
committed by GitHub
parent c72c87f278
commit cc01bbd5cf
6 changed files with 1202 additions and 826 deletions
+1149 -802
View File
File diff suppressed because it is too large Load Diff
+3 -1
View File
@@ -14,7 +14,7 @@ fastapi = "^0.109.2"
langserve = "^0.0.45"
uvicorn = "^0.27.1"
pydantic = "^1.10"
langchain-openai = "^0.0.8"
langchain-openai = "^0.1.3"
jsonschema = "^4.21.1"
sse-starlette = "^2.0.0"
alembic = "^1.13.1"
@@ -26,6 +26,8 @@ lxml = "^5.1.0"
faiss-cpu = "^1.7.4"
python-multipart = "^0.0.9"
langchain-fireworks = "^0.1.1"
langchain-anthropic = "^0.1.11"
langchain-groq = "^0.1.3"
[tool.poetry.group.dev.dependencies]
jupyterlab = "^3.6.1"
+12 -13
View File
@@ -1,12 +1,13 @@
from __future__ import annotations
import json
import uuid
from typing import Any, Dict, List, Optional, Sequence
from fastapi import HTTPException
from jsonschema import Draft202012Validator, exceptions
from langchain.text_splitter import TokenTextSplitter
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import chain
from langserve import CustomUserType
@@ -97,19 +98,18 @@ def _make_prompt_template(
# TODO: We'll need to refactor this at some point to
# support other encoding strategies. The function calling logic here
# has some hard-coded assumptions (e.g., name of parameters like `data`).
function_call = {
"arguments": json.dumps(
{
"data": example.output,
}
),
_id = uuid.uuid4().hex[:]
tool_call = {
"args": {"data": example.output},
"name": function_name,
"id": _id,
}
few_shot_prompt.extend(
[
HumanMessage(content=example.text),
AIMessage(
content="", additional_kwargs={"function_call": function_call}
AIMessage(content="", tool_calls=[tool_call]),
ToolMessage(
content="You have correctly called this tool.", tool_call_id=_id
),
]
)
@@ -172,10 +172,9 @@ async def extraction_runnable(extraction_request: ExtractRequest) -> ExtractResp
schema["title"],
)
model = get_model(extraction_request.model_name)
# N.B. method must be consistent with examples in _make_prompt_template
runnable = (
prompt | model.with_structured_output(schema=schema, method="function_calling")
).with_config({"run_name": "extraction"})
runnable = (prompt | model.with_structured_output(schema=schema)).with_config(
{"run_name": "extraction"}
)
return await runnable.ainvoke({"text": extraction_request.text})
+17
View File
@@ -1,8 +1,10 @@
import os
from typing import Optional
from langchain_anthropic import ChatAnthropic
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_fireworks import ChatFireworks
from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI
@@ -37,6 +39,21 @@ def get_supported_models():
),
"description": "Mixtral 8x7B Instruct v0.1 (Together AI)",
}
if "ANTHROPIC_API_KEY" in os.environ:
models["claude-3-sonnet-20240229"] = {
"chat_model": ChatAnthropic(
model="claude-3-sonnet-20240229", temperature=0
),
"description": "Claude 3 Sonnet",
}
if "GROQ_API_KEY" in os.environ:
models["groq-llama3-8b-8192"] = {
"chat_model": ChatGroq(
model="llama3-8b-8192",
temperature=0,
),
"description": "GROQ Llama 3 8B",
}
return models
@@ -6,16 +6,24 @@ from langchain_core.messages import AIMessage
from tests.unit_tests.fake.chat_model import GenericFakeChatModel
class AnyStr(str):
def __init__(self) -> None:
super().__init__()
def __eq__(self, other: object) -> bool:
return isinstance(other, str)
def test_generic_fake_chat_model_invoke() -> None:
# Will alternate between responding with hello and goodbye
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
model = GenericFakeChatModel(messages=infinite_cycle)
response = model.invoke("meow")
assert response == AIMessage(content="hello")
assert response == AIMessage(content="hello", id=AnyStr())
response = model.invoke("kitty")
assert response == AIMessage(content="goodbye")
assert response == AIMessage(content="goodbye", id=AnyStr())
response = model.invoke("meow")
assert response == AIMessage(content="hello")
assert response == AIMessage(content="hello", id=AnyStr())
async def test_generic_fake_chat_model_ainvoke() -> None:
@@ -23,8 +31,8 @@ async def test_generic_fake_chat_model_ainvoke() -> None:
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
model = GenericFakeChatModel(messages=infinite_cycle)
response = await model.ainvoke("meow")
assert response == AIMessage(content="hello")
assert response == AIMessage(content="hello", id=AnyStr())
response = await model.ainvoke("kitty")
assert response == AIMessage(content="goodbye")
assert response == AIMessage(content="goodbye", id=AnyStr())
response = await model.ainvoke("meow")
assert response == AIMessage(content="hello")
assert response == AIMessage(content="hello", id=AnyStr())
+7 -4
View File
@@ -1,4 +1,5 @@
from langchain.pydantic_v1 import BaseModel, Field
from langchain_core.messages import AIMessage
from extraction.utils import update_json_schema
from server.extraction_runnable import ExtractionExample, _make_prompt_template
@@ -82,7 +83,7 @@ def test_make_prompt_template() -> None:
)
prompt = _make_prompt_template(instructions, examples, "name")
messages = prompt.messages
assert 4 == len(messages)
assert 5 == len(messages)
system = messages[0].prompt.template
assert system.startswith(prefix)
assert system.endswith(instructions)
@@ -90,11 +91,13 @@ def test_make_prompt_template() -> None:
example_input = messages[1]
assert example_input.content == "Test text."
example_output = messages[2]
assert "function_call" in example_output.additional_kwargs
assert example_output.additional_kwargs["function_call"]["name"] == "name"
assert isinstance(example_output, AIMessage)
assert example_output.tool_calls
assert len(example_output.tool_calls) == 1
assert example_output.tool_calls[0]["name"] == "name"
prompt = _make_prompt_template(instructions, None, "name")
assert 2 == len(prompt.messages)
prompt = _make_prompt_template(None, examples, "name")
assert 4 == len(prompt.messages)
assert 5 == len(prompt.messages)