Fix calling chain.run() with memory (#1642)

This commit is contained in:
Nuno Campos
2023-06-14 10:29:11 +01:00
committed by GitHub
parent b696686744
commit 9b17b5e2fc
2 changed files with 21 additions and 2 deletions
+5 -2
View File
@@ -82,13 +82,16 @@ export abstract class BaseChain extends BaseLangChain implements ChainInputs {
input: any,
callbacks?: Callbacks
): Promise<string> {
const isKeylessInput = this.inputKeys.length <= 1;
const inputKeys = this.inputKeys.filter(
(k) => !this.memory?.memoryKeys.includes(k) ?? true
);
const isKeylessInput = inputKeys.length <= 1;
if (!isKeylessInput) {
throw new Error(
`Chain ${this._chainType()} expects multiple inputs, cannot use 'run' `
);
}
const values = this.inputKeys.length ? { [this.inputKeys[0]]: input } : {};
const values = inputKeys.length ? { [inputKeys[0]]: input } : {};
const returnValues = await this.call(values, callbacks);
const keys = Object.keys(returnValues);
@@ -8,6 +8,7 @@ import {
} from "../../prompts/index.js";
import { LLMChain } from "../llm_chain.js";
import { loadChain } from "../load.js";
import { BufferMemory } from "../../memory/buffer_memory.js";
test("Test OpenAI", async () => {
const model = new OpenAI({ modelName: "text-ada-001" });
@@ -46,6 +47,21 @@ test("Test run method", async () => {
console.log({ res });
});
test("Test run method", async () => {
const model = new OpenAI({ modelName: "text-ada-001" });
const prompt = new PromptTemplate({
template: "{history} Print {foo}",
inputVariables: ["foo", "history"],
});
const chain = new LLMChain({
prompt,
llm: model,
memory: new BufferMemory(),
});
const res = await chain.run("my favorite color");
console.log({ res });
});
test("Test apply", async () => {
const model = new OpenAI({ modelName: "text-ada-001" });
const prompt = new PromptTemplate({