mirror of
https://github.com/langchain-ai/langchain-extract.git
synced 2026-07-01 20:24:03 -04:00
incorporate tool calling (#131)
This commit is contained in:
Generated
+1149
-802
File diff suppressed because it is too large
Load Diff
@@ -14,7 +14,7 @@ fastapi = "^0.109.2"
|
||||
langserve = "^0.0.45"
|
||||
uvicorn = "^0.27.1"
|
||||
pydantic = "^1.10"
|
||||
langchain-openai = "^0.0.8"
|
||||
langchain-openai = "^0.1.3"
|
||||
jsonschema = "^4.21.1"
|
||||
sse-starlette = "^2.0.0"
|
||||
alembic = "^1.13.1"
|
||||
@@ -26,6 +26,8 @@ lxml = "^5.1.0"
|
||||
faiss-cpu = "^1.7.4"
|
||||
python-multipart = "^0.0.9"
|
||||
langchain-fireworks = "^0.1.1"
|
||||
langchain-anthropic = "^0.1.11"
|
||||
langchain-groq = "^0.1.3"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
jupyterlab = "^3.6.1"
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
|
||||
from fastapi import HTTPException
|
||||
from jsonschema import Draft202012Validator, exceptions
|
||||
from langchain.text_splitter import TokenTextSplitter
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.runnables import chain
|
||||
from langserve import CustomUserType
|
||||
@@ -97,19 +98,18 @@ def _make_prompt_template(
|
||||
# TODO: We'll need to refactor this at some point to
|
||||
# support other encoding strategies. The function calling logic here
|
||||
# has some hard-coded assumptions (e.g., name of parameters like `data`).
|
||||
function_call = {
|
||||
"arguments": json.dumps(
|
||||
{
|
||||
"data": example.output,
|
||||
}
|
||||
),
|
||||
_id = uuid.uuid4().hex[:]
|
||||
tool_call = {
|
||||
"args": {"data": example.output},
|
||||
"name": function_name,
|
||||
"id": _id,
|
||||
}
|
||||
few_shot_prompt.extend(
|
||||
[
|
||||
HumanMessage(content=example.text),
|
||||
AIMessage(
|
||||
content="", additional_kwargs={"function_call": function_call}
|
||||
AIMessage(content="", tool_calls=[tool_call]),
|
||||
ToolMessage(
|
||||
content="You have correctly called this tool.", tool_call_id=_id
|
||||
),
|
||||
]
|
||||
)
|
||||
@@ -172,10 +172,9 @@ async def extraction_runnable(extraction_request: ExtractRequest) -> ExtractResp
|
||||
schema["title"],
|
||||
)
|
||||
model = get_model(extraction_request.model_name)
|
||||
# N.B. method must be consistent with examples in _make_prompt_template
|
||||
runnable = (
|
||||
prompt | model.with_structured_output(schema=schema, method="function_calling")
|
||||
).with_config({"run_name": "extraction"})
|
||||
runnable = (prompt | model.with_structured_output(schema=schema)).with_config(
|
||||
{"run_name": "extraction"}
|
||||
)
|
||||
|
||||
return await runnable.ainvoke({"text": extraction_request.text})
|
||||
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_fireworks import ChatFireworks
|
||||
from langchain_groq import ChatGroq
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
|
||||
@@ -37,6 +39,21 @@ def get_supported_models():
|
||||
),
|
||||
"description": "Mixtral 8x7B Instruct v0.1 (Together AI)",
|
||||
}
|
||||
if "ANTHROPIC_API_KEY" in os.environ:
|
||||
models["claude-3-sonnet-20240229"] = {
|
||||
"chat_model": ChatAnthropic(
|
||||
model="claude-3-sonnet-20240229", temperature=0
|
||||
),
|
||||
"description": "Claude 3 Sonnet",
|
||||
}
|
||||
if "GROQ_API_KEY" in os.environ:
|
||||
models["groq-llama3-8b-8192"] = {
|
||||
"chat_model": ChatGroq(
|
||||
model="llama3-8b-8192",
|
||||
temperature=0,
|
||||
),
|
||||
"description": "GROQ Llama 3 8B",
|
||||
}
|
||||
|
||||
return models
|
||||
|
||||
|
||||
@@ -6,16 +6,24 @@ from langchain_core.messages import AIMessage
|
||||
from tests.unit_tests.fake.chat_model import GenericFakeChatModel
|
||||
|
||||
|
||||
class AnyStr(str):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, str)
|
||||
|
||||
|
||||
def test_generic_fake_chat_model_invoke() -> None:
|
||||
# Will alternate between responding with hello and goodbye
|
||||
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
response = model.invoke("meow")
|
||||
assert response == AIMessage(content="hello")
|
||||
assert response == AIMessage(content="hello", id=AnyStr())
|
||||
response = model.invoke("kitty")
|
||||
assert response == AIMessage(content="goodbye")
|
||||
assert response == AIMessage(content="goodbye", id=AnyStr())
|
||||
response = model.invoke("meow")
|
||||
assert response == AIMessage(content="hello")
|
||||
assert response == AIMessage(content="hello", id=AnyStr())
|
||||
|
||||
|
||||
async def test_generic_fake_chat_model_ainvoke() -> None:
|
||||
@@ -23,8 +31,8 @@ async def test_generic_fake_chat_model_ainvoke() -> None:
|
||||
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
response = await model.ainvoke("meow")
|
||||
assert response == AIMessage(content="hello")
|
||||
assert response == AIMessage(content="hello", id=AnyStr())
|
||||
response = await model.ainvoke("kitty")
|
||||
assert response == AIMessage(content="goodbye")
|
||||
assert response == AIMessage(content="goodbye", id=AnyStr())
|
||||
response = await model.ainvoke("meow")
|
||||
assert response == AIMessage(content="hello")
|
||||
assert response == AIMessage(content="hello", id=AnyStr())
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from langchain.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from extraction.utils import update_json_schema
|
||||
from server.extraction_runnable import ExtractionExample, _make_prompt_template
|
||||
@@ -82,7 +83,7 @@ def test_make_prompt_template() -> None:
|
||||
)
|
||||
prompt = _make_prompt_template(instructions, examples, "name")
|
||||
messages = prompt.messages
|
||||
assert 4 == len(messages)
|
||||
assert 5 == len(messages)
|
||||
system = messages[0].prompt.template
|
||||
assert system.startswith(prefix)
|
||||
assert system.endswith(instructions)
|
||||
@@ -90,11 +91,13 @@ def test_make_prompt_template() -> None:
|
||||
example_input = messages[1]
|
||||
assert example_input.content == "Test text."
|
||||
example_output = messages[2]
|
||||
assert "function_call" in example_output.additional_kwargs
|
||||
assert example_output.additional_kwargs["function_call"]["name"] == "name"
|
||||
assert isinstance(example_output, AIMessage)
|
||||
assert example_output.tool_calls
|
||||
assert len(example_output.tool_calls) == 1
|
||||
assert example_output.tool_calls[0]["name"] == "name"
|
||||
|
||||
prompt = _make_prompt_template(instructions, None, "name")
|
||||
assert 2 == len(prompt.messages)
|
||||
|
||||
prompt = _make_prompt_template(None, examples, "name")
|
||||
assert 4 == len(prompt.messages)
|
||||
assert 5 == len(prompt.messages)
|
||||
|
||||
Reference in New Issue
Block a user