mirror of
https://github.com/Mintplex-Labs/langchainjs.git
synced 2026-07-01 12:17:38 -04:00
Fix example for chatanthropic, Fix LLMChain cancellation when using memory, Allow any chain to be cancelled (#1686)
* Fix example for chatanthropic * Fix llmchain cancellation when using memory * Implement cancellation for all chains - SequentialChain and SimpleSequentialChain pass the cancellation tokens to the inner chains (and if those are llmchains the cancellation token is passed to the llm) - Add support for using chains with memory inside simpleseqchain
This commit is contained in:
@@ -2,5 +2,5 @@ import { ChatAnthropic } from "langchain/chat_models/anthropic";
|
||||
|
||||
const model = new ChatAnthropic({
|
||||
temperature: 0.9,
|
||||
apiKey: "YOUR-API-KEY", // In Node.js defaults to process.env.ANTHROPIC_API_KEY
|
||||
anthropicApiKey: "YOUR-API-KEY", // In Node.js defaults to process.env.ANTHROPIC_API_KEY
|
||||
});
|
||||
|
||||
@@ -53,6 +53,15 @@ export abstract class BaseChain extends BaseLangChain implements ChainInputs {
|
||||
}
|
||||
}
|
||||
|
||||
/** @ignore */
|
||||
_selectMemoryInputs(values: ChainValues): ChainValues {
|
||||
const valuesForMemory = { ...values };
|
||||
if ("signal" in valuesForMemory) {
|
||||
delete valuesForMemory.signal;
|
||||
}
|
||||
return valuesForMemory;
|
||||
}
|
||||
|
||||
/**
|
||||
* Run the core logic of this chain and return the output
|
||||
*/
|
||||
@@ -109,13 +118,15 @@ export abstract class BaseChain extends BaseLangChain implements ChainInputs {
|
||||
* Wraps _call and handles memory.
|
||||
*/
|
||||
async call(
|
||||
values: ChainValues,
|
||||
values: ChainValues & { signal?: AbortSignal },
|
||||
callbacks?: Callbacks,
|
||||
tags?: string[]
|
||||
): Promise<ChainValues> {
|
||||
const fullValues = { ...values } as typeof values;
|
||||
if (!(this.memory == null)) {
|
||||
const newValues = await this.memory.loadMemoryVariables(values);
|
||||
const newValues = await this.memory.loadMemoryVariables(
|
||||
this._selectMemoryInputs(values)
|
||||
);
|
||||
for (const [key, value] of Object.entries(newValues)) {
|
||||
fullValues[key] = value;
|
||||
}
|
||||
@@ -133,13 +144,23 @@ export abstract class BaseChain extends BaseLangChain implements ChainInputs {
|
||||
);
|
||||
let outputValues;
|
||||
try {
|
||||
outputValues = await this._call(fullValues, runManager);
|
||||
outputValues = (await Promise.race([
|
||||
this._call(fullValues, runManager),
|
||||
new Promise((_, reject) => {
|
||||
values.signal?.addEventListener("abort", () => {
|
||||
reject(new Error("AbortError"));
|
||||
});
|
||||
}),
|
||||
])) as ChainValues;
|
||||
} catch (e) {
|
||||
await runManager?.handleChainError(e);
|
||||
throw e;
|
||||
}
|
||||
if (!(this.memory == null)) {
|
||||
await this.memory.saveContext(values, outputValues);
|
||||
await this.memory.saveContext(
|
||||
this._selectMemoryInputs(values),
|
||||
outputValues
|
||||
);
|
||||
}
|
||||
await runManager?.handleChainEnd(outputValues);
|
||||
// add the runManager's currentRunId to the outputValues
|
||||
|
||||
@@ -68,6 +68,17 @@ export class LLMChain<T extends string | object = string>
|
||||
}
|
||||
}
|
||||
|
||||
/** @ignore */
|
||||
_selectMemoryInputs(values: ChainValues): ChainValues {
|
||||
const valuesForMemory = { ...values };
|
||||
for (const key of this.llm.callKeys) {
|
||||
if (key in values) {
|
||||
delete valuesForMemory[key];
|
||||
}
|
||||
}
|
||||
return valuesForMemory;
|
||||
}
|
||||
|
||||
/** @ignore */
|
||||
async _getFinalOutput(
|
||||
generations: Generation[],
|
||||
|
||||
@@ -253,7 +253,11 @@ export class SimpleSequentialChain
|
||||
/** @ignore */
|
||||
_validateChains() {
|
||||
for (const chain of this.chains) {
|
||||
if (chain.inputKeys.length !== 1) {
|
||||
if (
|
||||
chain.inputKeys.filter(
|
||||
(k) => !chain.memory?.memoryKeys.includes(k) ?? true
|
||||
).length !== 1
|
||||
) {
|
||||
throw new Error(
|
||||
`Chains used in SimpleSequentialChain should all have one input, got ${
|
||||
chain.inputKeys.length
|
||||
@@ -279,7 +283,12 @@ export class SimpleSequentialChain
|
||||
let i = 0;
|
||||
for (const chain of this.chains) {
|
||||
i += 1;
|
||||
input = await chain.run(input, runManager?.getChild(`step_${i}`));
|
||||
input = (
|
||||
await chain.call(
|
||||
{ [chain.inputKeys[0]]: input, signal: values.signal },
|
||||
runManager?.getChild(`step_${i}`)
|
||||
)
|
||||
)[chain.outputKeys[0]];
|
||||
if (this.trimOutputs) {
|
||||
input = input.trim();
|
||||
}
|
||||
|
||||
@@ -62,6 +62,25 @@ test("Test run method", async () => {
|
||||
console.log({ res });
|
||||
});
|
||||
|
||||
test("Test memory + cancellation", async () => {
|
||||
const model = new OpenAI({ modelName: "text-ada-001" });
|
||||
const prompt = new PromptTemplate({
|
||||
template: "{history} Print {foo}",
|
||||
inputVariables: ["foo", "history"],
|
||||
});
|
||||
const chain = new LLMChain({
|
||||
prompt,
|
||||
llm: model,
|
||||
memory: new BufferMemory(),
|
||||
});
|
||||
await expect(() =>
|
||||
chain.call({
|
||||
foo: "my favorite color",
|
||||
signal: AbortSignal.timeout(20),
|
||||
})
|
||||
).rejects.toThrow("Cancel: canceled");
|
||||
});
|
||||
|
||||
test("Test apply", async () => {
|
||||
const model = new OpenAI({ modelName: "text-ada-001" });
|
||||
const prompt = new PromptTemplate({
|
||||
|
||||
@@ -4,6 +4,7 @@ import { PromptTemplate } from "../../prompts/index.js";
|
||||
import { LLMChain } from "../llm_chain.js";
|
||||
import { SimpleSequentialChain } from "../sequential_chain.js";
|
||||
import { ChatOpenAI } from "../../chat_models/openai.js";
|
||||
import { BufferMemory } from "../../memory/buffer_memory.js";
|
||||
|
||||
test("Test SimpleSequentialChain example usage", async () => {
|
||||
// This is an LLMChain to write a synopsis given a title of a play.
|
||||
@@ -44,6 +45,52 @@ test("Test SimpleSequentialChain example usage", async () => {
|
||||
);
|
||||
});
|
||||
|
||||
test("Test SimpleSequentialChain example usage", async () => {
|
||||
// This is an LLMChain to write a synopsis given a title of a play.
|
||||
const llm = new ChatOpenAI({ temperature: 0 });
|
||||
const template = `You are a playwright. Given the title of play, it is your job to write a synopsis for that title.
|
||||
|
||||
{history}
|
||||
Title: {title}
|
||||
Playwright: This is a synopsis for the above play:`;
|
||||
const promptTemplate = new PromptTemplate({
|
||||
template,
|
||||
inputVariables: ["title", "history"],
|
||||
});
|
||||
const synopsisChain = new LLMChain({
|
||||
llm,
|
||||
prompt: promptTemplate,
|
||||
memory: new BufferMemory(),
|
||||
});
|
||||
|
||||
// This is an LLMChain to write a review of a play given a synopsis.
|
||||
const reviewLLM = new ChatOpenAI({ temperature: 0 });
|
||||
const reviewTemplate = `You are a play critic from the New York Times. Given the synopsis of play, it is your job to write a review for that play.
|
||||
|
||||
Play Synopsis:
|
||||
{synopsis}
|
||||
Review from a New York Times play critic of the above play:`;
|
||||
const reviewPromptTemplate = new PromptTemplate({
|
||||
template: reviewTemplate,
|
||||
inputVariables: ["synopsis"],
|
||||
});
|
||||
const reviewChain = new LLMChain({
|
||||
llm: reviewLLM,
|
||||
prompt: reviewPromptTemplate,
|
||||
});
|
||||
|
||||
const overallChain = new SimpleSequentialChain({
|
||||
chains: [synopsisChain, reviewChain],
|
||||
verbose: true,
|
||||
});
|
||||
await expect(() =>
|
||||
overallChain.call({
|
||||
input: "Tragedy at sunset on the beach",
|
||||
signal: AbortSignal.timeout(1000),
|
||||
})
|
||||
).rejects.toThrow("AbortError");
|
||||
});
|
||||
|
||||
test("Test SimpleSequentialChain serialize/deserialize", async () => {
|
||||
const llm1 = new ChatOpenAI();
|
||||
const template1 = `Echo back "{foo}"`;
|
||||
|
||||
@@ -72,7 +72,10 @@ export interface AnthropicInput {
|
||||
streaming?: boolean;
|
||||
|
||||
/** Anthropic API key */
|
||||
apiKey?: string;
|
||||
anthropicApiKey?: string;
|
||||
|
||||
/** Anthropic API URL */
|
||||
anthropicApiUrl?: string;
|
||||
|
||||
/** Model name to use */
|
||||
modelName: string;
|
||||
@@ -109,7 +112,7 @@ export class ChatAnthropic extends BaseChatModel implements AnthropicInput {
|
||||
|
||||
get lc_secrets(): { [key: string]: string } | undefined {
|
||||
return {
|
||||
apiKey: "ANTHROPIC_API_KEY",
|
||||
anthropicApiKey: "ANTHROPIC_API_KEY",
|
||||
};
|
||||
}
|
||||
|
||||
@@ -147,13 +150,7 @@ export class ChatAnthropic extends BaseChatModel implements AnthropicInput {
|
||||
// Used for streaming requests
|
||||
private streamingClient: AnthropicApi;
|
||||
|
||||
constructor(
|
||||
fields?: Partial<AnthropicInput> &
|
||||
BaseChatModelParams & {
|
||||
anthropicApiKey?: string;
|
||||
anthropicApiUrl?: string;
|
||||
}
|
||||
) {
|
||||
constructor(fields?: Partial<AnthropicInput> & BaseChatModelParams) {
|
||||
super(fields ?? {});
|
||||
|
||||
this.apiKey =
|
||||
|
||||
Reference in New Issue
Block a user