Nc/llmchain functions (#1699)

* Refactor openai functions chains to use output parsers

* Lint

* Rename

* Add docs
This commit is contained in:
Nuno Campos
2023-06-19 19:11:58 +01:00
committed by GitHub
parent 46f8147fa6
commit 29015fbc2c
25 changed files with 298 additions and 152 deletions
@@ -0,0 +1,24 @@
---
hide_table_of_contents: true
sidebar_position: 4
---
import CodeBlock from "@theme/CodeBlock";
import Extraction from "@examples/chains/openai_functions_extraction.ts";
import Tagging from "@examples/chains/openai_functions_tagging.ts";
# OpenAI Functions Chains
These chains are designed to be used with the [OpenAI Functions](https://platform.openai.com/docs/guides/gpt/function-calling) API.
## Extraction
This chain is designed to extract lists of objects from an input text and schema of desired info.
<CodeBlock language="typescript">{Extraction}</CodeBlock>
## Tagging
This chain is designed to tag an input text according to properties defined in a schema.
<CodeBlock language="typescript">{Tagging}</CodeBlock>
@@ -0,0 +1,37 @@
import { z } from "zod";
import { ChatOpenAI } from "langchain/chat_models/openai";
import { createExtractionChainFromZod } from "langchain/chains";
const chain = createExtractionChainFromZod(
z.object({
"person-name": z.string().optional(),
"person-age": z.number().optional(),
"person-hair_color": z.string().optional(),
"dog-name": z.string().optional(),
"dog-breed": z.string().optional(),
}),
new ChatOpenAI({ modelName: "gpt-3.5-turbo-0613", temperature: 0 })
);
console.log(
await chain.run(`Alex is 5 feet tall. Claudia is 4 feet taller Alex and jumps higher than him. Claudia is a brunette and Alex is blonde.
Alex's dog Frosty is a labrador and likes to play hide and seek.`)
);
/*
[
{
'person-name': 'Alex',
'person-age': 0,
'person-hair_color': 'blonde',
'dog-name': 'Frosty',
'dog-breed': 'labrador'
},
{
'person-name': 'Claudia',
'person-age': 0,
'person-hair_color': 'brunette',
'dog-name': '',
'dog-breed': ''
}
]
*/
@@ -0,0 +1,24 @@
import { createTaggingChain } from "langchain/chains";
import { ChatOpenAI } from "langchain/chat_models/openai";
const chain = createTaggingChain(
{
type: "object",
properties: {
sentiment: { type: "string" },
tone: { type: "string" },
language: { type: "string" },
},
required: ["tone"],
},
new ChatOpenAI({ modelName: "gpt-4-0613", temperature: 0 })
);
console.log(
await chain.run(
`Estoy increiblemente contento de haberte conocido! Creo que seremos muy buenos amigos!`
)
);
/*
{ tone: 'positive', language: 'Spanish' }
*/
+7 -5
View File
@@ -1,3 +1,4 @@
import { awaitAllCallbacks } from "langchain/callbacks";
import path from "path";
import url from "url";
@@ -44,10 +45,11 @@ if (runExample) {
const maybePromise = runExample(args);
if (maybePromise instanceof Promise) {
maybePromise.catch((e) => {
console.error(`Example failed with:`);
console.error(e);
process.exit(1);
});
maybePromise
.catch((e) => {
console.error(`Example failed with:`);
console.error(e);
})
.finally(() => awaitAllCallbacks());
}
}
+3
View File
@@ -394,6 +394,9 @@ experimental/plan_and_execute.d.ts
client.cjs
client.js
client.d.ts
evaluation.cjs
evaluation.js
evaluation.d.ts
index.cjs
index.js
index.d.ts
+8
View File
@@ -406,6 +406,9 @@
"client.cjs",
"client.js",
"client.d.ts",
"evaluation.cjs",
"evaluation.js",
"evaluation.d.ts",
"index.cjs",
"index.js",
"index.d.ts"
@@ -1426,6 +1429,11 @@
"import": "./client.js",
"require": "./client.cjs"
},
"./evaluation": {
"types": "./evaluation.d.ts",
"import": "./evaluation.js",
"require": "./evaluation.cjs"
},
"./package.json": "./package.json"
}
}
-6
View File
@@ -84,12 +84,6 @@ export {
export { MultiPromptChain } from "./router/multi_prompt.js";
export { MultiRetrievalQAChain } from "./router/multi_retrieval_qa.js";
export { TransformChain, TransformChainFields } from "./transform.js";
export {
OpenAIFunctionsChain,
OpenAIFunctionsChainFields,
parseToArguments,
parseToNamedArgument,
} from "./openai_functions/index.js";
export {
createExtractionChain,
createExtractionChainFromZod,
+32 -16
View File
@@ -2,19 +2,27 @@ import { BaseChain, ChainInputs } from "./base.js";
import { BasePromptTemplate } from "../prompts/base.js";
import { BaseLanguageModel } from "../base_language/index.js";
import { ChainValues, Generation, BasePromptValue } from "../schema/index.js";
import { BaseOutputParser } from "../schema/output_parser.js";
import {
BaseLLMOutputParser,
BaseOutputParser,
} from "../schema/output_parser.js";
import { SerializedLLMChain } from "./serde.js";
import { CallbackManager } from "../callbacks/index.js";
import { CallbackManagerForChainRun, Callbacks } from "../callbacks/manager.js";
import { NoOpOutputParser } from "../output_parsers/noop.js";
export interface LLMChainInput<T extends string | object = string>
extends ChainInputs {
export interface LLMChainInput<
T extends string | object = string,
L extends BaseLanguageModel = BaseLanguageModel
> extends ChainInputs {
/** Prompt object to use */
prompt: BasePromptTemplate;
/** LLM Wrapper to use */
llm: BaseLanguageModel;
llm: L;
/** Kwargs to pass to LLM */
llmKwargs?: this["llm"]["CallOptions"];
/** OutputParser to use */
outputParser?: BaseOutputParser<T>;
outputParser?: BaseLLMOutputParser<T>;
/** Key to use for output, defaults to `text` */
outputKey?: string;
}
@@ -32,7 +40,10 @@ export interface LLMChainInput<T extends string | object = string>
* const llm = new LLMChain({ llm: new OpenAI(), prompt });
* ```
*/
export class LLMChain<T extends string | object = string>
export class LLMChain<
T extends string | object = string,
L extends BaseLanguageModel = BaseLanguageModel
>
extends BaseChain
implements LLMChainInput<T>
{
@@ -40,11 +51,13 @@ export class LLMChain<T extends string | object = string>
prompt: BasePromptTemplate;
llm: BaseLanguageModel;
llm: L;
llmKwargs?: this["llm"]["CallOptions"];
outputKey = "text";
outputParser?: BaseOutputParser<T>;
outputParser?: BaseLLMOutputParser<T>;
get inputKeys() {
return this.prompt.inputVariables;
@@ -54,14 +67,16 @@ export class LLMChain<T extends string | object = string>
return [this.outputKey];
}
constructor(fields: LLMChainInput<T>) {
constructor(fields: LLMChainInput<T, L>) {
super(fields);
this.prompt = fields.prompt;
this.llm = fields.llm;
this.llmKwargs = fields.llmKwargs;
this.outputKey = fields.outputKey ?? this.outputKey;
this.outputParser = fields.outputParser ?? this.outputParser;
this.outputParser =
fields.outputParser ?? (new NoOpOutputParser() as BaseOutputParser<T>);
if (this.prompt.outputParser) {
if (this.outputParser) {
if (fields.outputParser) {
throw new Error("Cannot set both outputParser and prompt.outputParser");
}
this.outputParser = this.prompt.outputParser as BaseOutputParser<T>;
@@ -85,16 +100,15 @@ export class LLMChain<T extends string | object = string>
promptValue: BasePromptValue,
runManager?: CallbackManagerForChainRun
): Promise<unknown> {
const completion = generations[0].text;
let finalCompletion: unknown;
if (this.outputParser) {
finalCompletion = await this.outputParser.parseWithPrompt(
completion,
finalCompletion = await this.outputParser.parseResultWithPrompt(
generations,
promptValue,
runManager?.getChild()
);
} else {
finalCompletion = completion;
finalCompletion = generations[0].text;
}
return finalCompletion;
}
@@ -117,7 +131,9 @@ export class LLMChain<T extends string | object = string>
runManager?: CallbackManagerForChainRun
): Promise<ChainValues> {
const valuesForPrompt = { ...values };
const valuesForLLM: this["llm"]["CallOptions"] = {};
const valuesForLLM: this["llm"]["CallOptions"] = {
...this.llmKwargs,
};
for (const key of this.llm.callKeys) {
if (key in values) {
valuesForLLM[key as keyof this["llm"]["CallOptions"]] = values[key];
@@ -4,13 +4,11 @@ import { JsonSchema7ObjectType } from "zod-to-json-schema/src/parsers/object.js"
import { ChatOpenAI } from "../../chat_models/openai.js";
import { PromptTemplate } from "../../prompts/prompt.js";
import { TransformChain } from "../transform.js";
import { SimpleSequentialChain } from "../sequential_chain.js";
import {
FunctionParameters,
OpenAIFunctionsChain,
parseToNamedArgument,
} from "./index.js";
JsonKeyOutputFunctionsParser,
} from "../../output_parsers/openai_functions.js";
import { LLMChain } from "../llm_chain.js";
function getExtractionFunctions(schema: FunctionParameters) {
return [
@@ -47,13 +45,14 @@ export function createExtractionChain(
) {
const functions = getExtractionFunctions(schema);
const prompt = PromptTemplate.fromTemplate(_EXTRACTION_TEMPLATE);
const chain = new OpenAIFunctionsChain({ llm, prompt, functions });
const parsing_chain = new TransformChain({
transform: parseToNamedArgument.bind(null, "info"),
inputVariables: ["input"],
outputVariables: ["output"],
const outputParser = new JsonKeyOutputFunctionsParser({ attrName: "info" });
return new LLMChain({
llm,
prompt,
llmKwargs: { functions },
outputParser,
tags: ["openai_functions", "extraction"],
});
return new SimpleSequentialChain({ chains: [chain, parsing_chain] });
}
export function createExtractionChainFromZod(
@@ -1,98 +0,0 @@
import {
ChatCompletionFunctions,
ChatCompletionRequestMessageFunctionCall,
} from "openai";
import { JsonSchema7ObjectType } from "zod-to-json-schema/src/parsers/object.js";
import { BaseChain, ChainInputs } from "../base.js";
import { BasePromptTemplate } from "../../prompts/base.js";
import { ChatOpenAI } from "../../chat_models/openai.js";
import { CallbackManagerForChainRun } from "../../callbacks/manager.js";
import { AIChatMessage, ChainValues } from "../../schema/index.js";
import { Optional } from "../../types/type-utils.js";
export type FunctionParameters = Optional<
JsonSchema7ObjectType,
"additionalProperties"
>;
export interface OpenAIFunctionsChainFields extends ChainInputs {
llm: ChatOpenAI;
prompt: BasePromptTemplate;
functions: ChatCompletionFunctions[];
outputKey?: string;
}
export class OpenAIFunctionsChain
extends BaseChain
implements OpenAIFunctionsChainFields
{
llm: ChatOpenAI;
prompt: BasePromptTemplate;
functions: ChatCompletionFunctions[];
outputKey = "output";
_chainType() {
return "openai_functions" as const;
}
get inputKeys() {
return this.prompt.inputVariables;
}
get outputKeys() {
return [this.outputKey];
}
constructor(fields: OpenAIFunctionsChainFields) {
super(fields);
this.llm = fields.llm;
this.prompt = fields.prompt;
this.functions = fields.functions;
this.outputKey = fields.outputKey ?? this.outputKey;
}
async _call(
values: ChainValues,
runManager?: CallbackManagerForChainRun
): Promise<ChainValues> {
const valuesForPrompt = { ...values };
const valuesForLLM: this["llm"]["CallOptions"] = {
functions: this.functions,
};
for (const key of this.llm.callKeys) {
if (key in values) {
valuesForLLM[key as keyof this["llm"]["CallOptions"]] = values[key];
delete valuesForPrompt[key];
}
}
const promptValue = await this.prompt.formatPromptValue(valuesForPrompt);
const message = await this.llm.predictMessages(
promptValue.toChatMessages(),
valuesForLLM,
runManager?.getChild()
);
return { output: message };
}
}
export function parseToArguments({ input }: { input: AIChatMessage }) {
const function_call = input?.additional_kwargs
?.function_call as ChatCompletionRequestMessageFunctionCall;
return {
output: function_call?.arguments
? JSON.parse(function_call?.arguments)
: undefined,
};
}
export function parseToNamedArgument(
key: string,
inputs: { input: AIChatMessage }
) {
const { output } = parseToArguments(inputs);
return { output: output?.[key] };
}
@@ -4,13 +4,11 @@ import { JsonSchema7ObjectType } from "zod-to-json-schema/src/parsers/object.js"
import { ChatOpenAI } from "../../chat_models/openai.js";
import { PromptTemplate } from "../../prompts/prompt.js";
import { TransformChain } from "../transform.js";
import { SimpleSequentialChain } from "../sequential_chain.js";
import {
FunctionParameters,
OpenAIFunctionsChain,
parseToArguments,
} from "./index.js";
JsonOutputFunctionsParser,
} from "../../output_parsers/openai_functions.js";
import { LLMChain } from "../llm_chain.js";
function getTaggingFunctions(schema: FunctionParameters) {
return [
@@ -34,13 +32,14 @@ export function createTaggingChain(
) {
const functions = getTaggingFunctions(schema);
const prompt = PromptTemplate.fromTemplate(TAGGING_TEMPLATE);
const chain = new OpenAIFunctionsChain({ llm, prompt, functions });
const parsing_chain = new TransformChain({
transform: parseToArguments,
inputVariables: ["input"],
outputVariables: ["output"],
const outputParser = new JsonOutputFunctionsParser();
return new LLMChain({
llm,
prompt,
llmKwargs: { functions },
outputParser,
tags: ["openai_functions", "tagging"],
});
return new SimpleSequentialChain({ chains: [chain, parsing_chain] });
}
export function createTaggingChainFromZod(
+3 -1
View File
@@ -294,7 +294,9 @@ test.skip("serialize + deserialize llmchain with output parser", async () => {
});
expect(chain2).toBeInstanceOf(LLMChain);
expect(JSON.stringify(chain2, null, 2)).toBe(str);
expect(await chain2.outputParser?.parse("a, b, c")).toEqual(["a", "b", "c"]);
expect(await chain2.outputParser?.parseResult([{ text: "a, b, c" }])).toEqual(
["a", "b", "c"]
);
});
test("serialize + deserialize llmchain with struct output parser throws", async () => {
+5
View File
@@ -11,3 +11,8 @@ export { OutputFixingParser } from "./fix.js";
export { CombiningOutputParser } from "./combining.js";
export { RouterOutputParser, RouterOutputParserInput } from "./router.js";
export { CustomListOutputParser } from "./list.js";
export {
OutputFunctionsParser,
JsonOutputFunctionsParser,
JsonKeyOutputFunctionsParser,
} from "../output_parsers/openai_functions.js";
+15
View File
@@ -0,0 +1,15 @@
import { BaseOutputParser } from "../schema/output_parser.js";
export class NoOpOutputParser extends BaseOutputParser<string> {
lc_namespace = ["langchain", "output_parsers", "default"];
lc_serializable = true;
parse(text: string): Promise<string> {
return Promise.resolve(text);
}
getFormatInstructions(): string {
return "";
}
}
@@ -0,0 +1,80 @@
import { JsonSchema7ObjectType } from "zod-to-json-schema/src/parsers/object.js";
import { ChatGeneration, Generation } from "../schema/index.js";
import { Optional } from "../types/type-utils.js";
import { BaseLLMOutputParser } from "../schema/output_parser.js";
export type FunctionParameters = Optional<
JsonSchema7ObjectType,
"additionalProperties"
>;
export class OutputFunctionsParser extends BaseLLMOutputParser<string> {
lc_namespace = ["langchain", "chains", "openai_functions"];
lc_serializable = true;
async parseResult(
generations: Generation[] | ChatGeneration[]
): Promise<string> {
if ("message" in generations[0]) {
const gen = generations[0] as ChatGeneration;
if (!gen.message.additional_kwargs.function_call) {
throw new Error(
`No function_call in message ${JSON.stringify(generations)}`
);
}
if (!gen.message.additional_kwargs.function_call.arguments) {
throw new Error(
`No arguments in function_call ${JSON.stringify(generations)}`
);
}
return gen.message.additional_kwargs.function_call.arguments;
} else {
throw new Error(
`No message in generations ${JSON.stringify(generations)}`
);
}
}
}
export class JsonOutputFunctionsParser extends BaseLLMOutputParser<object> {
lc_namespace = ["langchain", "chains", "openai_functions"];
lc_serializable = true;
outputParser = new OutputFunctionsParser();
async parseResult(
generations: Generation[] | ChatGeneration[]
): Promise<object> {
const result = await this.outputParser.parseResult(generations);
if (!result) {
throw new Error(
`No result from "OutputFunctionsParser" ${JSON.stringify(generations)}`
);
}
return JSON.parse(result);
}
}
export class JsonKeyOutputFunctionsParser<
T = object
> extends BaseLLMOutputParser<T> {
lc_namespace = ["langchain", "chains", "openai_functions"];
lc_serializable = true;
outputParser = new JsonOutputFunctionsParser();
attrName: string;
constructor(fields: { attrName: string }) {
super(fields);
this.attrName = fields.attrName;
}
async parseResult(generations: Generation[] | ChatGeneration[]): Promise<T> {
const result = await this.outputParser.parseResult(generations);
return result[this.attrName as keyof typeof result] as T;
}
}
+5 -1
View File
@@ -1,3 +1,4 @@
import { ChatCompletionRequestMessageFunctionCall } from "openai";
import { Document } from "../document.js";
import { Serializable } from "../load/serializable.js";
@@ -72,7 +73,10 @@ export abstract class BaseChatMessage {
name?: string;
/** Additional keyword arguments */
additional_kwargs: Record<string, unknown> = {};
additional_kwargs: {
function_call?: ChatCompletionRequestMessageFunctionCall;
[key: string]: unknown;
} = {};
/** The type of the message. */
abstract _getType(): MessageType;
+26 -2
View File
@@ -1,5 +1,5 @@
import { Callbacks } from "../callbacks/manager.js";
import { BasePromptValue } from "./index.js";
import { BasePromptValue, Generation, ChatGeneration } from "./index.js";
import { Serializable } from "../load/serializable.js";
/**
@@ -7,9 +7,33 @@ import { Serializable } from "../load/serializable.js";
*/
export interface FormatInstructionsOptions {}
export abstract class BaseLLMOutputParser<T = unknown> extends Serializable {
abstract parseResult(
generations: Generation[] | ChatGeneration[],
callbacks?: Callbacks
): Promise<T>;
parseResultWithPrompt(
generations: Generation[] | ChatGeneration[],
_prompt: BasePromptValue,
callbacks?: Callbacks
): Promise<T> {
return this.parseResult(generations, callbacks);
}
}
/** Class to parse the output of an LLM call.
*/
export abstract class BaseOutputParser<T = unknown> extends Serializable {
export abstract class BaseOutputParser<
T = unknown
> extends BaseLLMOutputParser<T> {
parseResult(
generations: Generation[] | ChatGeneration[],
callbacks?: Callbacks
): Promise<T> {
return this.parse(generations[0].text, callbacks);
}
/**
* Parse the output of an LLM call.
*
+2 -1
View File
@@ -157,7 +157,8 @@
"src/experimental/babyagi/index.ts",
"src/experimental/generative_agents/index.ts",
"src/experimental/plan_and_execute/index.ts",
"src/client/index.ts"
"src/client/index.ts",
"src/evaluation/index.ts"
],
"sort": [
"kind",
+1
View File
@@ -44,3 +44,4 @@ export * from "langchain/experimental/babyagi";
export * from "langchain/experimental/generative_agents";
export * from "langchain/experimental/plan_and_execute";
export * from "langchain/client";
export * from "langchain/evaluation";
+1
View File
@@ -44,3 +44,4 @@ const experimental_babyagi = require("langchain/experimental/babyagi");
const experimental_generative_agents = require("langchain/experimental/generative_agents");
const experimental_plan_and_execute = require("langchain/experimental/plan_and_execute");
const client = require("langchain/client");
const evaluation = require("langchain/evaluation");
+1
View File
@@ -44,3 +44,4 @@ export * from "langchain/experimental/babyagi";
export * from "langchain/experimental/generative_agents";
export * from "langchain/experimental/plan_and_execute";
export * from "langchain/client";
export * from "langchain/evaluation";
+1
View File
@@ -44,3 +44,4 @@ import * as experimental_babyagi from "langchain/experimental/babyagi";
import * as experimental_generative_agents from "langchain/experimental/generative_agents";
import * as experimental_plan_and_execute from "langchain/experimental/plan_and_execute";
import * as client from "langchain/client";
import * as evaluation from "langchain/evaluation";
+1
View File
@@ -44,3 +44,4 @@ import * as experimental_babyagi from "langchain/experimental/babyagi";
import * as experimental_generative_agents from "langchain/experimental/generative_agents";
import * as experimental_plan_and_execute from "langchain/experimental/plan_and_execute";
import * as client from "langchain/client";
import * as evaluation from "langchain/evaluation";
+1
View File
@@ -44,3 +44,4 @@ export * from "langchain/experimental/babyagi";
export * from "langchain/experimental/generative_agents";
export * from "langchain/experimental/plan_and_execute";
export * from "langchain/client";
export * from "langchain/evaluation";
+1
View File
@@ -44,3 +44,4 @@ export * from "langchain/experimental/babyagi";
export * from "langchain/experimental/generative_agents";
export * from "langchain/experimental/plan_and_execute";
export * from "langchain/client";
export * from "langchain/evaluation";