Fix example for chatanthropic, Fix LLMChain cancellation when using memory, Allow any chain to be cancelled (#1686)

* Fix example for chatanthropic

* Fix llmchain cancellation when using memory

* Implement cancellation for all chains

- SequentialChain and SimpleSequentialChain pass the cancellation tokens to the inner chains (and if those are llmchains the cancellation token is passed to the llm)
- Add support for using chains with memory inside simpleseqchain
This commit is contained in:
Nuno Campos
2023-06-17 21:31:08 +01:00
committed by GitHub
parent 4020d0a8c6
commit 8ddf206998
7 changed files with 120 additions and 16 deletions
@@ -2,5 +2,5 @@ import { ChatAnthropic } from "langchain/chat_models/anthropic";
const model = new ChatAnthropic({
temperature: 0.9,
apiKey: "YOUR-API-KEY", // In Node.js defaults to process.env.ANTHROPIC_API_KEY
anthropicApiKey: "YOUR-API-KEY", // In Node.js defaults to process.env.ANTHROPIC_API_KEY
});
+25 -4
View File
@@ -53,6 +53,15 @@ export abstract class BaseChain extends BaseLangChain implements ChainInputs {
}
}
/** @ignore */
_selectMemoryInputs(values: ChainValues): ChainValues {
const valuesForMemory = { ...values };
if ("signal" in valuesForMemory) {
delete valuesForMemory.signal;
}
return valuesForMemory;
}
/**
* Run the core logic of this chain and return the output
*/
@@ -109,13 +118,15 @@ export abstract class BaseChain extends BaseLangChain implements ChainInputs {
* Wraps _call and handles memory.
*/
async call(
values: ChainValues,
values: ChainValues & { signal?: AbortSignal },
callbacks?: Callbacks,
tags?: string[]
): Promise<ChainValues> {
const fullValues = { ...values } as typeof values;
if (!(this.memory == null)) {
const newValues = await this.memory.loadMemoryVariables(values);
const newValues = await this.memory.loadMemoryVariables(
this._selectMemoryInputs(values)
);
for (const [key, value] of Object.entries(newValues)) {
fullValues[key] = value;
}
@@ -133,13 +144,23 @@ export abstract class BaseChain extends BaseLangChain implements ChainInputs {
);
let outputValues;
try {
outputValues = await this._call(fullValues, runManager);
outputValues = (await Promise.race([
this._call(fullValues, runManager),
new Promise((_, reject) => {
values.signal?.addEventListener("abort", () => {
reject(new Error("AbortError"));
});
}),
])) as ChainValues;
} catch (e) {
await runManager?.handleChainError(e);
throw e;
}
if (!(this.memory == null)) {
await this.memory.saveContext(values, outputValues);
await this.memory.saveContext(
this._selectMemoryInputs(values),
outputValues
);
}
await runManager?.handleChainEnd(outputValues);
// add the runManager's currentRunId to the outputValues
+11
View File
@@ -68,6 +68,17 @@ export class LLMChain<T extends string | object = string>
}
}
/** @ignore */
_selectMemoryInputs(values: ChainValues): ChainValues {
const valuesForMemory = { ...values };
for (const key of this.llm.callKeys) {
if (key in values) {
delete valuesForMemory[key];
}
}
return valuesForMemory;
}
/** @ignore */
async _getFinalOutput(
generations: Generation[],
+11 -2
View File
@@ -253,7 +253,11 @@ export class SimpleSequentialChain
/** @ignore */
_validateChains() {
for (const chain of this.chains) {
if (chain.inputKeys.length !== 1) {
if (
chain.inputKeys.filter(
(k) => !chain.memory?.memoryKeys.includes(k) ?? true
).length !== 1
) {
throw new Error(
`Chains used in SimpleSequentialChain should all have one input, got ${
chain.inputKeys.length
@@ -279,7 +283,12 @@ export class SimpleSequentialChain
let i = 0;
for (const chain of this.chains) {
i += 1;
input = await chain.run(input, runManager?.getChild(`step_${i}`));
input = (
await chain.call(
{ [chain.inputKeys[0]]: input, signal: values.signal },
runManager?.getChild(`step_${i}`)
)
)[chain.outputKeys[0]];
if (this.trimOutputs) {
input = input.trim();
}
@@ -62,6 +62,25 @@ test("Test run method", async () => {
console.log({ res });
});
test("Test memory + cancellation", async () => {
const model = new OpenAI({ modelName: "text-ada-001" });
const prompt = new PromptTemplate({
template: "{history} Print {foo}",
inputVariables: ["foo", "history"],
});
const chain = new LLMChain({
prompt,
llm: model,
memory: new BufferMemory(),
});
await expect(() =>
chain.call({
foo: "my favorite color",
signal: AbortSignal.timeout(20),
})
).rejects.toThrow("Cancel: canceled");
});
test("Test apply", async () => {
const model = new OpenAI({ modelName: "text-ada-001" });
const prompt = new PromptTemplate({
@@ -4,6 +4,7 @@ import { PromptTemplate } from "../../prompts/index.js";
import { LLMChain } from "../llm_chain.js";
import { SimpleSequentialChain } from "../sequential_chain.js";
import { ChatOpenAI } from "../../chat_models/openai.js";
import { BufferMemory } from "../../memory/buffer_memory.js";
test("Test SimpleSequentialChain example usage", async () => {
// This is an LLMChain to write a synopsis given a title of a play.
@@ -44,6 +45,52 @@ test("Test SimpleSequentialChain example usage", async () => {
);
});
test("Test SimpleSequentialChain example usage", async () => {
// This is an LLMChain to write a synopsis given a title of a play.
const llm = new ChatOpenAI({ temperature: 0 });
const template = `You are a playwright. Given the title of play, it is your job to write a synopsis for that title.
{history}
Title: {title}
Playwright: This is a synopsis for the above play:`;
const promptTemplate = new PromptTemplate({
template,
inputVariables: ["title", "history"],
});
const synopsisChain = new LLMChain({
llm,
prompt: promptTemplate,
memory: new BufferMemory(),
});
// This is an LLMChain to write a review of a play given a synopsis.
const reviewLLM = new ChatOpenAI({ temperature: 0 });
const reviewTemplate = `You are a play critic from the New York Times. Given the synopsis of play, it is your job to write a review for that play.
Play Synopsis:
{synopsis}
Review from a New York Times play critic of the above play:`;
const reviewPromptTemplate = new PromptTemplate({
template: reviewTemplate,
inputVariables: ["synopsis"],
});
const reviewChain = new LLMChain({
llm: reviewLLM,
prompt: reviewPromptTemplate,
});
const overallChain = new SimpleSequentialChain({
chains: [synopsisChain, reviewChain],
verbose: true,
});
await expect(() =>
overallChain.call({
input: "Tragedy at sunset on the beach",
signal: AbortSignal.timeout(1000),
})
).rejects.toThrow("AbortError");
});
test("Test SimpleSequentialChain serialize/deserialize", async () => {
const llm1 = new ChatOpenAI();
const template1 = `Echo back "{foo}"`;
+6 -9
View File
@@ -72,7 +72,10 @@ export interface AnthropicInput {
streaming?: boolean;
/** Anthropic API key */
apiKey?: string;
anthropicApiKey?: string;
/** Anthropic API URL */
anthropicApiUrl?: string;
/** Model name to use */
modelName: string;
@@ -109,7 +112,7 @@ export class ChatAnthropic extends BaseChatModel implements AnthropicInput {
get lc_secrets(): { [key: string]: string } | undefined {
return {
apiKey: "ANTHROPIC_API_KEY",
anthropicApiKey: "ANTHROPIC_API_KEY",
};
}
@@ -147,13 +150,7 @@ export class ChatAnthropic extends BaseChatModel implements AnthropicInput {
// Used for streaming requests
private streamingClient: AnthropicApi;
constructor(
fields?: Partial<AnthropicInput> &
BaseChatModelParams & {
anthropicApiKey?: string;
anthropicApiUrl?: string;
}
) {
constructor(fields?: Partial<AnthropicInput> & BaseChatModelParams) {
super(fields ?? {});
this.apiKey =