chore: DRY-up the loadQAChain code so it's just using the same functions (#1122)

* chore: DRY-up the load code so it's just using the same functions

* Lint

---------

Co-authored-by: Nuno Campos <nuno@boringbits.io>
This commit is contained in:
Justin Rahardjo
2023-05-05 03:51:55 -07:00
committed by GitHub
parent bd27d17152
commit 085fcf4f92
+18 -54
View File
@@ -6,10 +6,8 @@ import {
RefineDocumentsChain,
MapReduceDocumentsChainInput,
} from "../combine_docs_chain.js";
import { QA_PROMPT_SELECTOR, DEFAULT_QA_PROMPT } from "./stuff_prompts.js";
import { QA_PROMPT_SELECTOR } from "./stuff_prompts.js";
import {
COMBINE_PROMPT,
DEFAULT_COMBINE_QA_PROMPT,
COMBINE_PROMPT_SELECTOR,
COMBINE_QA_PROMPT_SELECTOR,
} from "./map_reduce_prompts.js";
@@ -34,52 +32,15 @@ export const loadQAChain = (
llm: BaseLanguageModel,
params: QAChainParams = { type: "stuff" }
) => {
const { type, verbose } = params;
const { type } = params;
if (type === "stuff") {
const { prompt = DEFAULT_QA_PROMPT } = params;
const llmChain = new LLMChain({ prompt, llm, verbose });
const chain = new StuffDocumentsChain({ llmChain, verbose });
return chain;
return loadQAStuffChain(llm, params);
}
if (type === "map_reduce") {
const {
combineMapPrompt = DEFAULT_COMBINE_QA_PROMPT,
combinePrompt = COMBINE_PROMPT,
returnIntermediateSteps,
} = params;
const llmChain = new LLMChain({ prompt: combineMapPrompt, llm, verbose });
const combineLLMChain = new LLMChain({
prompt: combinePrompt,
llm,
verbose,
});
const combineDocumentChain = new StuffDocumentsChain({
llmChain: combineLLMChain,
documentVariableName: "summaries",
verbose,
});
const chain = new MapReduceDocumentsChain({
llmChain,
combineDocumentChain,
returnIntermediateSteps,
verbose,
});
return chain;
return loadQAMapReduceChain(llm, params);
}
if (type === "refine") {
const {
questionPrompt = QUESTION_PROMPT_SELECTOR.getPrompt(llm),
refinePrompt = REFINE_PROMPT_SELECTOR.getPrompt(llm),
} = params;
const llmChain = new LLMChain({ prompt: questionPrompt, llm, verbose });
const refineLLMChain = new LLMChain({ prompt: refinePrompt, llm, verbose });
const chain = new RefineDocumentsChain({
llmChain,
refineLLMChain,
verbose,
});
return chain;
return loadQARefineChain(llm, params);
}
throw new Error(`Invalid _type: ${type}`);
};
@@ -89,15 +50,15 @@ export interface StuffQAChainParams {
verbose?: boolean;
}
export const loadQAStuffChain = (
export function loadQAStuffChain(
llm: BaseLanguageModel,
params: StuffQAChainParams = {}
) => {
) {
const { prompt = QA_PROMPT_SELECTOR.getPrompt(llm), verbose } = params;
const llmChain = new LLMChain({ prompt, llm, verbose });
const chain = new StuffDocumentsChain({ llmChain });
const chain = new StuffDocumentsChain({ llmChain, verbose });
return chain;
};
}
export interface MapReduceQAChainParams {
returnIntermediateSteps?: MapReduceDocumentsChainInput["returnIntermediateSteps"];
@@ -106,10 +67,10 @@ export interface MapReduceQAChainParams {
verbose?: boolean;
}
export const loadQAMapReduceChain = (
export function loadQAMapReduceChain(
llm: BaseLanguageModel,
params: MapReduceQAChainParams = {}
) => {
) {
const {
combineMapPrompt = COMBINE_QA_PROMPT_SELECTOR.getPrompt(llm),
combinePrompt = COMBINE_PROMPT_SELECTOR.getPrompt(llm),
@@ -121,14 +82,16 @@ export const loadQAMapReduceChain = (
const combineDocumentChain = new StuffDocumentsChain({
llmChain: combineLLMChain,
documentVariableName: "summaries",
verbose,
});
const chain = new MapReduceDocumentsChain({
llmChain,
combineDocumentChain,
returnIntermediateSteps,
verbose,
});
return chain;
};
}
export interface RefineQAChainParams {
questionPrompt?: BasePromptTemplate;
@@ -136,10 +99,10 @@ export interface RefineQAChainParams {
verbose?: boolean;
}
export const loadQARefineChain = (
export function loadQARefineChain(
llm: BaseLanguageModel,
params: RefineQAChainParams = {}
) => {
) {
const {
questionPrompt = QUESTION_PROMPT_SELECTOR.getPrompt(llm),
refinePrompt = REFINE_PROMPT_SELECTOR.getPrompt(llm),
@@ -151,6 +114,7 @@ export const loadQARefineChain = (
const chain = new RefineDocumentsChain({
llmChain,
refineLLMChain,
verbose,
});
return chain;
};
}