Compare commits

...

27 Commits

Author SHA1 Message Date
Thuc Pham e10ca4a216 add handle llm tool call step 2024-09-23 21:17:48 +07:00
Thuc Pham 94c623ecfb check toolcall support 2024-09-23 17:45:20 +07:00
Thuc Pham 0ec268cd7f fix: stream response 2024-09-23 17:13:53 +07:00
Thuc Pham 4e6a04ba62 refactor: remove dup code in handleLLMInput 2024-09-23 17:03:23 +07:00
Thuc Pham 891d9fbe65 patch controller express template 2024-09-23 16:41:39 +07:00
Thuc Pham f053da4728 refactor: rename workflow files 2024-09-23 16:41:28 +07:00
Thuc Pham e01ad418e5 fix: lint 2024-09-23 16:25:55 +07:00
Thuc Pham 97eb4dc51b feat: support ts multi-agent 2024-09-23 14:15:01 +07:00
Marcus Schiesser 0bf11a57b0 Update questions.ts 2024-09-23 11:05:51 +07:00
Thuc Pham f7d366b648 feat: support multiagent for nextjs 2024-09-20 17:13:19 +07:00
Thuc Pham d69cd42fa7 refactor: move workflow to components 2024-09-20 17:00:43 +07:00
Thuc Pham 54d74f8237 fix: move settings.ts to setting folder 2024-09-20 15:56:05 +07:00
Thuc Pham f6597213c8 refactor: share settings file for ts templates 2024-09-20 15:44:25 +07:00
Thuc Pham c4041e2de3 refactor: move workflow folder to src 2024-09-20 15:33:23 +07:00
Thuc Pham aff4f0cde4 fix lint 2024-09-20 15:28:51 +07:00
Thuc Pham de5ba29276 fix: let default max attempt 2 2024-09-20 15:28:19 +07:00
Thuc Pham 33ce5934fa feat: funtional calling agent 2024-09-20 15:23:54 +07:00
Thuc Pham b030a3d885 fix: pipe final streaming result 2024-09-19 19:59:54 +07:00
Thuc Pham b8756189cc fix: streaming final result 2024-09-19 19:32:16 +07:00
Thuc Pham 5daf519572 feat: streaming event 2024-09-19 16:58:41 +07:00
Thuc Pham 2c7a53853a update doc 2024-09-19 15:44:46 +07:00
Thuc Pham 6c05872aae remove unused files 2024-09-19 15:44:40 +07:00
Thuc Pham f43f00a4ee create workflow with example agents 2024-09-19 15:05:20 +07:00
Marcus Schiesser 0ebcb9fff7 Create yellow-jokes-protect.md 2024-09-19 09:26:48 +07:00
Thuc Pham f464b40f58 fix: import from agent 2024-09-18 19:29:15 +07:00
Thuc Pham 622b84b97a feat: add express simple multiagent 2024-09-18 19:26:23 +07:00
Thuc Pham 413593b0d9 feat: update question to ask creating multiagent in express 2024-09-18 19:26:02 +07:00
12 changed files with 624 additions and 185 deletions
+5
View File
@@ -0,0 +1,5 @@
---
"create-llama": patch
---
Add multi agents template for Express
+30 -2
View File
@@ -33,8 +33,7 @@ export const installTSTemplate = async ({
* Copy the template files to the target directory.
*/
console.log("\nInitializing project with template:", template, "\n");
const type = template === "multiagent" ? "streaming" : template; // use nextjs streaming template for multiagent
const templatePath = path.join(templatesDir, "types", type, framework);
const templatePath = path.join(templatesDir, "types", "streaming", framework);
const copySource = ["**"];
await copy(copySource, root, {
@@ -124,6 +123,30 @@ export const installTSTemplate = async ({
cwd: path.join(compPath, "vectordbs", "typescript", vectorDb ?? "none"),
});
if (template === "multiagent") {
const multiagentPath = path.join(compPath, "multiagent", "typescript");
// copy workflow code for multiagent template
await copy("**", path.join(root, relativeEngineDestPath, "workflow"), {
parents: true,
cwd: path.join(multiagentPath, "workflow"),
});
if (framework === "nextjs") {
// patch route.ts file
await copy("**", path.join(root, relativeEngineDestPath), {
parents: true,
cwd: path.join(multiagentPath, "nextjs"),
});
} else if (framework === "express") {
// patch chat.controller.ts file
await copy("**", path.join(root, relativeEngineDestPath), {
parents: true,
cwd: path.join(multiagentPath, "express"),
});
}
}
// copy loader component (TS only supports llama_parse and file for now)
const loaderFolder = useLlamaParse ? "llama_parse" : "file";
await copy("**", enginePath, {
@@ -145,6 +168,11 @@ export const installTSTemplate = async ({
cwd: path.join(compPath, "engines", "typescript", engine),
});
// copy settings to engine folder
await copy("**", enginePath, {
cwd: path.join(compPath, "settings", "typescript"),
});
/**
* Copy the selected UI files to the target directory and reference it.
*/
+1 -4
View File
@@ -410,10 +410,7 @@ export const askQuestions = async (
return; // early return - no further questions needed for llamapack projects
}
if (program.template === "multiagent") {
// TODO: multi-agents currently only supports FastAPI
program.framework = preferences.framework = "fastapi";
} else if (program.template === "extractor") {
if (program.template === "extractor") {
// Extractor template only supports FastAPI, empty data sources, and llamacloud
// So we just use example file for extractor template, this allows user to choose vector database later
program.dataSources = [EXAMPLE_FILE];
@@ -0,0 +1,34 @@
import { Message, StreamData, streamToResponse } from "ai";
import { Request, Response } from "express";
import { ChatMessage } from "llamaindex";
import { createStreamTimeout } from "./llamaindex/streaming/events";
import { createWorkflow } from "./workflow/factory";
import { toDataStream } from "./workflow/stream";
export const chat = async (req: Request, res: Response) => {
const vercelStreamData = new StreamData();
const streamTimeout = createStreamTimeout(vercelStreamData);
try {
const { messages }: { messages: Message[] } = req.body;
const userMessage = messages.pop();
if (!messages || !userMessage || userMessage.role !== "user") {
return res.status(400).json({
error:
"messages are required in the request body and the last message must be from the user",
});
}
const chatHistory = messages as ChatMessage[];
const agent = await createWorkflow(chatHistory, vercelStreamData);
agent.run(userMessage.content);
const stream = toDataStream(agent.streamEvents(), vercelStreamData);
return streamToResponse(stream, res, {}, vercelStreamData);
} catch (error) {
console.error("[LlamaIndex]", error);
return res.status(500).json({
detail: (error as Error).message,
});
} finally {
clearTimeout(streamTimeout);
}
};
@@ -0,0 +1,53 @@
import { initObservability } from "@/app/observability";
import { Message, StreamData, StreamingTextResponse } from "ai";
import { ChatMessage } from "llamaindex";
import { NextRequest, NextResponse } from "next/server";
import { initSettings } from "./engine/settings";
import { createStreamTimeout } from "./llamaindex/streaming/events";
import { createWorkflow } from "./workflow/factory";
import { toDataStream } from "./workflow/stream";
initObservability();
initSettings();
export const runtime = "nodejs";
export const dynamic = "force-dynamic";
export async function POST(request: NextRequest) {
// Init Vercel AI StreamData and timeout
const vercelStreamData = new StreamData();
const streamTimeout = createStreamTimeout(vercelStreamData);
try {
const body = await request.json();
const { messages, data }: { messages: Message[]; data?: any } = body;
const userMessage = messages.pop();
if (!messages || !userMessage || userMessage.role !== "user") {
return NextResponse.json(
{
error:
"messages are required in the request body and the last message must be from the user",
},
{ status: 400 },
);
}
const chatHistory = messages as ChatMessage[];
const agent = await createWorkflow(chatHistory, vercelStreamData);
agent.run(userMessage.content);
const stream = toDataStream(agent.streamEvents(), vercelStreamData);
return new StreamingTextResponse(stream, {}, vercelStreamData);
} catch (error) {
console.error("[LlamaIndex]", error);
return NextResponse.json(
{
detail: (error as Error).message,
},
{
status: 500,
},
);
} finally {
clearTimeout(streamTimeout);
}
}
@@ -0,0 +1,49 @@
import { ChatMessage, QueryEngineTool } from "llamaindex";
import { getDataSource } from "../engine";
import { FunctionCallingAgent } from "./single-agent";
const getQueryEngineTool = async () => {
const index = await getDataSource();
if (!index) {
throw new Error("Index not found. Please create an index first.");
}
const topK = process.env.TOP_K ? parseInt(process.env.TOP_K) : undefined;
return new QueryEngineTool({
queryEngine: index.asQueryEngine({
similarityTopK: topK,
}),
metadata: {
name: "query_index",
description: `Use this tool to retrieve information about the text corpus from the index.`,
},
});
};
export const createResearcher = async (chatHistory: ChatMessage[]) => {
return new FunctionCallingAgent({
name: "researcher",
tools: [await getQueryEngineTool()],
systemPrompt:
"You are a researcher agent. You are given a researching task. You must use your tools to complete the research.",
chatHistory,
});
};
export const createWriter = (chatHistory: ChatMessage[]) => {
return new FunctionCallingAgent({
name: "writer",
systemPrompt:
"You are an expert in writing blog posts. You are given a task to write a blog post. Don't make up any information yourself.",
chatHistory,
});
};
export const createReviewer = (chatHistory: ChatMessage[]) => {
return new FunctionCallingAgent({
name: "reviewer",
systemPrompt:
"You are an expert in reviewing blog posts. You are given a task to review a blog post. Review the post for logical inconsistencies, ask critical questions, and provide suggestions for improvement. Furthermore, proofread the post for grammar and spelling errors. Only if the post is good enough for publishing, then you MUST return 'The post is good.'. In all other cases return your review.",
chatHistory,
});
};
@@ -0,0 +1,138 @@
import {
Context,
StartEvent,
StopEvent,
Workflow,
WorkflowEvent,
} from "@llamaindex/core/workflow";
import { StreamData } from "ai";
import { ChatMessage, ChatResponseChunk } from "llamaindex";
import { createResearcher, createReviewer, createWriter } from "./agents";
import { AgentInput, AgentRunEvent, AgentRunResult } from "./type";
const TIMEOUT = 360 * 1000;
const MAX_ATTEMPTS = 2;
class ResearchEvent extends WorkflowEvent<{ input: string }> {}
class WriteEvent extends WorkflowEvent<{
input: string;
isGood: boolean;
}> {}
class ReviewEvent extends WorkflowEvent<{ input: string }> {}
export const createWorkflow = async (
chatHistory: ChatMessage[],
stream: StreamData,
) => {
const appendStream = (agent: string, text: string) => {
stream.appendMessageAnnotation({
type: "agent",
data: { agent, text },
});
};
const runAgent = async (agent: Workflow, input: AgentInput) => {
const run = agent.run(new StartEvent({ input }));
for await (const event of agent.streamEvents()) {
if (event.data instanceof AgentRunEvent) {
const { name, msg } = event.data.data;
// TODO: better using context.writeEventToStream here instead of directly append to stream
// But not sure why it's fail to write to stream from the third event
appendStream(name, msg);
}
}
return await run;
};
const start = async (context: Context, ev: StartEvent) => {
context.set("task", ev.data.input);
return new ResearchEvent({
input: `Research for this task: ${ev.data.input}`,
});
};
const research = async (context: Context, ev: ResearchEvent) => {
const researcher = await createResearcher(chatHistory);
const researchRes = await runAgent(researcher, { message: ev.data.input });
const researchResult = researchRes.data.result;
return new WriteEvent({
input: `Write a blog post given this task: ${context.get("task")} using this research content: ${researchResult}`,
isGood: false,
});
};
const write = async (context: Context, ev: WriteEvent) => {
context.set("attempts", context.get("attempts", 0) + 1);
const tooManyAttempts = context.get("attempts") > MAX_ATTEMPTS;
if (tooManyAttempts) {
appendStream(
"writer",
`Too many attempts (${MAX_ATTEMPTS}) to write the blog post. Proceeding with the current version.`,
);
}
if (ev.data.isGood || tooManyAttempts) {
const writer = createWriter(chatHistory);
const writeRes = (await runAgent(writer, {
message: ev.data.input,
streaming: true,
})) as unknown as StopEvent<AsyncGenerator<ChatResponseChunk>>;
const result = writeRes.data.result;
context.writeEventToStream({
data: new AgentRunResult(result),
});
return new StopEvent({ result }); // stop the workflow
}
const writer = createWriter(chatHistory);
const writeRes = await runAgent(writer, { message: ev.data.input });
const writeResult = writeRes.data.result;
context.set("result", writeResult); // store the last result
return new ReviewEvent({ input: writeResult });
};
const review = async (context: Context, ev: ReviewEvent) => {
const reviewer = createReviewer(chatHistory);
const reviewRes = await reviewer.run(
new StartEvent<AgentInput>({ input: { message: ev.data.input } }),
);
const reviewResult = reviewRes.data.result;
const oldContent = context.get("result");
const postIsGood = reviewResult.toLowerCase().includes("post is good");
appendStream(
"reviewer",
`The post is ${postIsGood ? "" : "not "}good enough for publishing. Sending back to the writer${
postIsGood ? " for publication." : "."
}`,
);
if (postIsGood) {
return new WriteEvent({
input: `You're blog post is ready for publication. Please respond with just the blog post. Blog post: \`\`\`${oldContent}\`\`\``,
isGood: true,
});
}
return new WriteEvent({
input: `Improve the writing of a given blog post by using a given review.
Blog post:
\`\`\`
${oldContent}
\`\`\`
Review:
\`\`\`
${reviewResult}
\`\`\``,
isGood: false,
});
};
const workflow = new Workflow({ timeout: TIMEOUT, validate: true });
workflow.addStep(StartEvent, start, { outputs: ResearchEvent });
workflow.addStep(ResearchEvent, research, { outputs: WriteEvent });
workflow.addStep(WriteEvent, write, { outputs: [ReviewEvent, StopEvent] });
workflow.addStep(ReviewEvent, review, { outputs: WriteEvent });
return workflow;
};
@@ -0,0 +1,252 @@
import {
Context,
StartEvent,
StopEvent,
Workflow,
WorkflowEvent,
} from "@llamaindex/core/workflow";
import {
BaseToolWithCall,
ChatMemoryBuffer,
ChatMessage,
ChatResponse,
ChatResponseChunk,
LLM,
Settings,
ToolCall,
ToolCallLLM,
} from "llamaindex";
import { AgentInput, AgentRunEvent } from "./type";
class InputEvent extends WorkflowEvent<{
input: ChatMessage[];
}> {}
class ToolCallEvent extends WorkflowEvent<{
toolCalls: ToolCall[];
}> {}
export class FunctionCallingAgent extends Workflow {
name: string;
llm: LLM;
memory: ChatMemoryBuffer;
tools: BaseToolWithCall[];
systemPrompt?: string;
writeEvents: boolean;
role?: string;
toolCalled: boolean = false;
constructor(options: {
name: string;
llm?: LLM;
chatHistory?: ChatMessage[];
tools?: BaseToolWithCall[];
systemPrompt?: string;
writeEvents?: boolean;
role?: string;
verbose?: boolean;
timeout?: number;
}) {
super({
verbose: options?.verbose ?? false,
timeout: options?.timeout ?? 360,
});
this.name = options?.name;
this.llm = options.llm ?? Settings.llm;
this.checkToolCallSupport();
this.memory = new ChatMemoryBuffer({
llm: this.llm,
chatHistory: options.chatHistory,
});
this.tools = options?.tools ?? [];
this.systemPrompt = options.systemPrompt;
this.writeEvents = options?.writeEvents ?? true;
this.role = options?.role;
// add steps
this.addStep(StartEvent<AgentInput>, this.prepareChatHistory, {
outputs: InputEvent,
});
this.addStep(InputEvent, this.handleLLMInput, {
outputs: [ToolCallEvent, StopEvent],
});
this.addStep(ToolCallEvent, this.handleToolCalls, {
outputs: InputEvent,
});
}
private get chatHistory() {
return this.memory.getAllMessages();
}
private get toolsByName() {
return this.tools.reduce((acc: Record<string, BaseToolWithCall>, tool) => {
acc[tool.metadata.name] = tool;
return acc;
}, {});
}
private async prepareChatHistory(
ctx: Context,
ev: StartEvent<AgentInput>,
): Promise<InputEvent> {
this.toolCalled = false;
const { message, streaming } = ev.data.input;
ctx.set("streaming", streaming);
this.writeEvent(`Start to work on: ${message}`, ctx);
if (this.systemPrompt) {
this.memory.put({ role: "system", content: this.systemPrompt });
}
this.memory.put({ role: "user", content: message });
return new InputEvent({ input: this.chatHistory });
}
private async handleLLMInput(
ctx: Context,
ev: InputEvent,
): Promise<StopEvent<string | AsyncGenerator> | ToolCallEvent> {
const isStreaming = ctx.get("streaming");
const llmArgs = { messages: this.chatHistory, tools: this.tools };
if (isStreaming) {
return await this.handleLLMInputStream(ctx, ev);
}
const nonStreamingRes = await this.llm.chat({ ...llmArgs });
const toolCalls = this.getToolCallsFromResponse(nonStreamingRes);
if (toolCalls.length && !this.toolCalled) {
return new ToolCallEvent({ toolCalls });
}
this.writeEvent("Finished task", ctx);
const result = nonStreamingRes.message.content.toString();
return new StopEvent({ result });
}
private async handleLLMInputStream(
context: Context,
ev: InputEvent,
): Promise<StopEvent<AsyncGenerator> | ToolCallEvent> {
const { llm, tools, memory } = this;
const llmArgs = { messages: this.chatHistory, tools };
const responseGenerator = async function* () {
const responseStream = await llm.chat({ ...llmArgs, stream: true });
let fullResponse = null;
let yieldedIndicator = false;
for await (const chunk of responseStream) {
const hasToolCalls = chunk.options && "toolCall" in chunk.options;
if (!hasToolCalls) {
if (!yieldedIndicator) {
yield false;
yieldedIndicator = true;
}
yield chunk;
} else if (!yieldedIndicator) {
yield true;
yieldedIndicator = true;
}
fullResponse = chunk;
}
if (fullResponse) {
memory.put({
role: "system",
content: fullResponse.delta,
});
yield fullResponse;
}
};
const generator = responseGenerator();
const isToolCall = await generator.next();
if (isToolCall.value) {
const fullResponse = await generator.next();
const toolCalls = this.getToolCallsFromResponse(
fullResponse.value as ChatResponseChunk<object>,
);
return new ToolCallEvent({ toolCalls });
}
this.writeEvent("Finished task", context);
return new StopEvent({ result: generator });
}
private async handleToolCalls(
ctx: Context,
ev: ToolCallEvent,
): Promise<InputEvent> {
this.toolCalled = true;
const { toolCalls } = ev.data;
const toolMsgs: ChatMessage[] = [];
for (const toolCall of toolCalls) {
const tool = this.toolsByName[toolCall.name];
const options = {
tool_call_id: toolCall.id,
name: tool.metadata.name,
};
if (!tool) {
toolMsgs.push({
role: "system",
content: `Tool ${toolCall.name} does not exist`,
options,
});
continue;
}
try {
const toolInput = JSON.parse(toolCall.input.toString());
const toolOutput = await tool.call(toolInput);
toolMsgs.push({
role: "system",
content: toolOutput.toString(),
options,
});
} catch (e) {
console.error(e);
toolMsgs.push({
role: "system",
content: `Encountered error in tool call: ${e}`,
options,
});
}
}
for (const msg of toolMsgs) {
this.memory.put(msg);
}
return new InputEvent({ input: this.memory.getAllMessages() });
}
private writeEvent(msg: string, context: Context) {
if (!this.writeEvents) return;
context.writeEventToStream({
data: new AgentRunEvent({ name: this.name, msg }),
});
}
private checkToolCallSupport() {
const { supportToolCall } = this.llm as ToolCallLLM;
if (!supportToolCall) throw new Error("LLM does not support tool calls");
}
// TODO: in LITS, llm should have a method to get tool calls from response
// then we don't need to use toolCalled flag
private getToolCallsFromResponse(
response: ChatResponse<object> | ChatResponseChunk<object>,
): ToolCall[] {
let options;
if ("message" in response) {
options = response.message.options;
} else {
options = response.options;
}
if (options && "toolCall" in options) {
return options.toolCall as ToolCall[];
}
return [];
}
}
@@ -0,0 +1,44 @@
import { WorkflowEvent } from "@llamaindex/core/workflow";
import {
createCallbacksTransformer,
createStreamDataTransformer,
StreamData,
trimStartOfStreamHelper,
type AIStreamCallbacksAndOptions,
} from "ai";
import { AgentRunResult } from "./type";
export function toDataStream(
generator: AsyncGenerator<WorkflowEvent, void>,
data: StreamData,
callbacks?: AIStreamCallbacksAndOptions,
) {
return toReadableStream(generator, data)
.pipeThrough(createCallbacksTransformer(callbacks))
.pipeThrough(createStreamDataTransformer());
}
function toReadableStream(
generator: AsyncGenerator<WorkflowEvent, void>,
data: StreamData,
) {
const trimStartOfStream = trimStartOfStreamHelper();
return new ReadableStream<string>({
start(controller) {
controller.enqueue(""); // Kickstart the stream
},
async pull(controller): Promise<void> {
const { value, done } = await generator.next();
if (done) return;
if (value.data instanceof AgentRunResult) {
const finalResultStream = value.data.response;
for await (const event of finalResultStream) {
const text = trimStartOfStream(event.delta ?? "");
if (text) controller.enqueue(text);
}
controller.close();
data.close();
}
},
});
}
@@ -0,0 +1,18 @@
import { WorkflowEvent } from "@llamaindex/core/workflow";
import { ChatResponseChunk } from "llamaindex";
export type AgentInput = {
message: string;
streaming?: boolean;
};
export class AgentRunEvent extends WorkflowEvent<{
name: string;
msg: string;
}> {}
export class AgentRunResult {
constructor(
public response: AsyncGenerator<ChatResponseChunk, any, unknown>,
) {}
}
@@ -1,179 +0,0 @@
import {
ALL_AVAILABLE_MISTRAL_MODELS,
Anthropic,
GEMINI_EMBEDDING_MODEL,
GEMINI_MODEL,
Gemini,
GeminiEmbedding,
Groq,
MistralAI,
MistralAIEmbedding,
MistralAIEmbeddingModelType,
OpenAI,
OpenAIEmbedding,
Settings,
} from "llamaindex";
import { HuggingFaceEmbedding } from "llamaindex/embeddings/HuggingFaceEmbedding";
import { OllamaEmbedding } from "llamaindex/embeddings/OllamaEmbedding";
import { ALL_AVAILABLE_ANTHROPIC_MODELS } from "llamaindex/llm/anthropic";
import { Ollama } from "llamaindex/llm/ollama";
const CHUNK_SIZE = 512;
const CHUNK_OVERLAP = 20;
export const initSettings = async () => {
// HINT: you can delete the initialization code for unused model providers
console.log(`Using '${process.env.MODEL_PROVIDER}' model provider`);
if (!process.env.MODEL || !process.env.EMBEDDING_MODEL) {
throw new Error("'MODEL' and 'EMBEDDING_MODEL' env variables must be set.");
}
switch (process.env.MODEL_PROVIDER) {
case "ollama":
initOllama();
break;
case "groq":
initGroq();
break;
case "anthropic":
initAnthropic();
break;
case "gemini":
initGemini();
break;
case "mistral":
initMistralAI();
break;
case "azure-openai":
initAzureOpenAI();
break;
default:
initOpenAI();
break;
}
Settings.chunkSize = CHUNK_SIZE;
Settings.chunkOverlap = CHUNK_OVERLAP;
};
function initOpenAI() {
Settings.llm = new OpenAI({
model: process.env.MODEL ?? "gpt-4o-mini",
maxTokens: process.env.LLM_MAX_TOKENS
? Number(process.env.LLM_MAX_TOKENS)
: undefined,
});
Settings.embedModel = new OpenAIEmbedding({
model: process.env.EMBEDDING_MODEL,
dimensions: process.env.EMBEDDING_DIM
? parseInt(process.env.EMBEDDING_DIM)
: undefined,
});
}
function initAzureOpenAI() {
// Map Azure OpenAI model names to OpenAI model names (only for TS)
const AZURE_OPENAI_MODEL_MAP: Record<string, string> = {
"gpt-35-turbo": "gpt-3.5-turbo",
"gpt-35-turbo-16k": "gpt-3.5-turbo-16k",
"gpt-4o": "gpt-4o",
"gpt-4": "gpt-4",
"gpt-4-32k": "gpt-4-32k",
"gpt-4-turbo": "gpt-4-turbo",
"gpt-4-turbo-2024-04-09": "gpt-4-turbo",
"gpt-4-vision-preview": "gpt-4-vision-preview",
"gpt-4-1106-preview": "gpt-4-1106-preview",
"gpt-4o-2024-05-13": "gpt-4o-2024-05-13",
};
const azureConfig = {
apiKey: process.env.AZURE_OPENAI_KEY,
endpoint: process.env.AZURE_OPENAI_ENDPOINT,
apiVersion:
process.env.AZURE_OPENAI_API_VERSION || process.env.OPENAI_API_VERSION,
};
Settings.llm = new OpenAI({
model:
AZURE_OPENAI_MODEL_MAP[process.env.MODEL ?? "gpt-35-turbo"] ??
"gpt-3.5-turbo",
maxTokens: process.env.LLM_MAX_TOKENS
? Number(process.env.LLM_MAX_TOKENS)
: undefined,
azure: {
...azureConfig,
deployment: process.env.AZURE_OPENAI_LLM_DEPLOYMENT,
},
});
Settings.embedModel = new OpenAIEmbedding({
model: process.env.EMBEDDING_MODEL,
dimensions: process.env.EMBEDDING_DIM
? parseInt(process.env.EMBEDDING_DIM)
: undefined,
azure: {
...azureConfig,
deployment: process.env.AZURE_OPENAI_EMBEDDING_DEPLOYMENT,
},
});
}
function initOllama() {
const config = {
host: process.env.OLLAMA_BASE_URL ?? "http://127.0.0.1:11434",
};
Settings.llm = new Ollama({
model: process.env.MODEL ?? "",
config,
});
Settings.embedModel = new OllamaEmbedding({
model: process.env.EMBEDDING_MODEL ?? "",
config,
});
}
function initGroq() {
const embedModelMap: Record<string, string> = {
"all-MiniLM-L6-v2": "Xenova/all-MiniLM-L6-v2",
"all-mpnet-base-v2": "Xenova/all-mpnet-base-v2",
};
Settings.llm = new Groq({
model: process.env.MODEL!,
});
Settings.embedModel = new HuggingFaceEmbedding({
modelType: embedModelMap[process.env.EMBEDDING_MODEL!],
});
}
function initAnthropic() {
const embedModelMap: Record<string, string> = {
"all-MiniLM-L6-v2": "Xenova/all-MiniLM-L6-v2",
"all-mpnet-base-v2": "Xenova/all-mpnet-base-v2",
};
Settings.llm = new Anthropic({
model: process.env.MODEL as keyof typeof ALL_AVAILABLE_ANTHROPIC_MODELS,
});
Settings.embedModel = new HuggingFaceEmbedding({
modelType: embedModelMap[process.env.EMBEDDING_MODEL!],
});
}
function initGemini() {
Settings.llm = new Gemini({
model: process.env.MODEL as GEMINI_MODEL,
});
Settings.embedModel = new GeminiEmbedding({
model: process.env.EMBEDDING_MODEL as GEMINI_EMBEDDING_MODEL,
});
}
function initMistralAI() {
Settings.llm = new MistralAI({
model: process.env.MODEL as keyof typeof ALL_AVAILABLE_MISTRAL_MODELS,
});
Settings.embedModel = new MistralAIEmbedding({
model: process.env.EMBEDDING_MODEL as MistralAIEmbeddingModelType,
});
}