mirror of
https://github.com/run-llama/create-llama.git
synced 2026-07-05 00:46:20 -04:00
Compare commits
95 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a8073063c5 | |||
| aeb6fef4da | |||
| 64732f05aa | |||
| 588e0d607b | |||
| f2c3389168 | |||
| 5093b37c05 | |||
| f383f0cbe9 | |||
| b3c969dae5 | |||
| 628e16df7c | |||
| aa69014d04 | |||
| 293557cbb4 | |||
| b46d050fc3 | |||
| 02ed277dd0 | |||
| 48b96ff188 | |||
| 9c9decbb88 | |||
| 0748f2e8d7 | |||
| 3079162806 | |||
| 48c19c6e62 | |||
| d75c08e7d8 | |||
| 8f03f8d4bc | |||
| 19c57d945a | |||
| 9112d0801e | |||
| 93b797c162 | |||
| d53b760fd0 | |||
| a880c7c016 | |||
| 7b116ce7f7 | |||
| d1232fb1d5 | |||
| bedf199236 | |||
| c1510bd3fa | |||
| 69b9ce76bf | |||
| 9ced116e1a | |||
| fae9bcd65a | |||
| 2091fea2b4 | |||
| 563b51d76d | |||
| 88c88bf16d | |||
| cd6ebf7295 | |||
| 50b2ddbbf5 | |||
| 5fe2d519d2 | |||
| 09f1db3b5e | |||
| cb3be7d1d4 | |||
| 5474a1f182 | |||
| 1148ddba53 | |||
| 9e945ed355 | |||
| 6342163df2 | |||
| a42fa53a6b | |||
| 099f626586 | |||
| 956538eeb0 | |||
| 555f6b2905 | |||
| d8bc271a21 | |||
| f29561cde2 | |||
| 442abae8ac | |||
| 0ad2207684 | |||
| bfde30deed | |||
| 96fdb83abf | |||
| b7e0072c9c | |||
| 81bc340dda | |||
| ddf3aef7dc | |||
| 1f5a26f3a8 | |||
| 05748bdf10 | |||
| d60b3c5a96 | |||
| c3e9ed3df4 | |||
| 48188ca3f9 | |||
| 1fde1dc585 | |||
| cd50a33d43 | |||
| ed114856d9 | |||
| 69c2e16c82 | |||
| f5da6623cf | |||
| 0950cb90f2 | |||
| bb53425b4b | |||
| bbd5b8ddd6 | |||
| 260d37a3f1 | |||
| 7873bfb030 | |||
| 0c7c41ee3b | |||
| 56537a1473 | |||
| d8dfc29edd | |||
| 84db798353 | |||
| 67a062af14 | |||
| 0bc8e75c64 | |||
| 6bd5e7b77a | |||
| 38bc1d1350 | |||
| cb1001de95 | |||
| 78776ac51e | |||
| 416073db1d | |||
| 84929de8b2 | |||
| 6fe240b854 | |||
| 8bb1024d0f | |||
| 988bfc2a60 | |||
| 056e376ee0 | |||
| 819cccb11a | |||
| 8a5ece10c2 | |||
| 63bb0505d6 | |||
| 2e80ef47ee | |||
| a1feb524e9 | |||
| 06823da849 | |||
| 7bd3ed551f |
@@ -1,5 +0,0 @@
|
||||
---
|
||||
"create-llama": patch
|
||||
---
|
||||
|
||||
Use ingestion pipeline for Python
|
||||
@@ -1,5 +0,0 @@
|
||||
---
|
||||
"create-llama": patch
|
||||
---
|
||||
|
||||
Display events (e.g. retrieving nodes) per chat message
|
||||
@@ -17,7 +17,7 @@ jobs:
|
||||
matrix:
|
||||
node-version: [18, 20]
|
||||
python-version: ["3.11"]
|
||||
os: [macos-latest, windows-latest]
|
||||
os: [macos-latest, windows-latest, ubuntu-22.04]
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
@@ -26,7 +26,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
|
||||
+107
@@ -1,5 +1,112 @@
|
||||
# create-llama
|
||||
|
||||
## 0.1.15
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 64732f0: Fix the issue of images not showing with the sandbox URL from OpenAI's models
|
||||
- aeb6fef: use llamacloud for chat
|
||||
|
||||
## 0.1.14
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- f2c3389: chore: update to llamaindex 0.4.3
|
||||
- 5093b37: Remove non-working file selectors for Linux
|
||||
|
||||
## 0.1.13
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- b3c969d: Add image generator tool
|
||||
|
||||
## 0.1.12
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- aa69014: Fix NextJS for TS 5.2
|
||||
|
||||
## 0.1.11
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 48b96ff: Add DuckDuckGo search tool
|
||||
- 9c9decb: Reuse function tool instances and improve e2b interpreter tool for Python
|
||||
- 02ed277: Add Groq as a model provider
|
||||
- 0748f2e: Remove hard-coded Gemini supported models
|
||||
|
||||
## 0.1.10
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 9112d08: Add OpenAPI tool for Typescript
|
||||
- 8f03f8d: Add OLLAMA_REQUEST_TIMEOUT variable to config Ollama timeout (Python)
|
||||
- 8f03f8d: Apply nest_asyncio for llama parse
|
||||
|
||||
## 0.1.9
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- a42fa53: Add CSV upload
|
||||
- 563b51d: Fix Vercel streaming (python) to stream data events instantly
|
||||
- d60b3c5: Add E2B code interpreter tool for FastAPI
|
||||
- 956538e: Add OpenAPI action tool for FastAPI
|
||||
|
||||
## 0.1.8
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- cd50a33: Add interpreter tool for TS using e2b.dev
|
||||
|
||||
## 0.1.7
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 260d37a: Add system prompt env variable for TS
|
||||
- bbd5b8d: Fix postgres connection leaking issue
|
||||
- bb53425: Support HTTP proxies by setting the GLOBAL_AGENT_HTTP_PROXY env variable
|
||||
- 69c2e16: Fix streaming for Express
|
||||
- 7873bfb: Update Ollama provider to run with the base URL from the environment variable
|
||||
|
||||
## 0.1.6
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 56537a1: Display PDF files in source nodes
|
||||
|
||||
## 0.1.5
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 84db798: feat: support display latex in chat markdown
|
||||
|
||||
## 0.1.4
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 0bc8e75: Use ingestion pipeline for dedicated vector stores (Python only)
|
||||
- cb1001d: Add ChromaDB vector store
|
||||
|
||||
## 0.1.3
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 416073d: Directly import vector stores to work with NextJS
|
||||
|
||||
## 0.1.2
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 056e376: Add support for displaying tool outputs (including weather widget as example)
|
||||
|
||||
## 0.1.1
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 7bd3ed5: Support Anthropic and Gemini as model providers
|
||||
- 7bd3ed5: Support new agents from LITS 0.3
|
||||
- cfb5257: Display events (e.g. retrieving nodes) per chat message
|
||||
|
||||
## 0.1.0
|
||||
|
||||
### Minor Changes
|
||||
|
||||
@@ -151,5 +151,19 @@ export async function createApp({
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
dataSources.some((dataSource) => dataSource.type === "file") &&
|
||||
process.platform === "linux"
|
||||
) {
|
||||
console.log(
|
||||
yellow(
|
||||
`You can add your own data files to ${terminalLink(
|
||||
"data",
|
||||
`file://${root}/data`,
|
||||
)} folder manually.`,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
console.log();
|
||||
}
|
||||
|
||||
+176
-33
@@ -1,5 +1,6 @@
|
||||
import fs from "fs/promises";
|
||||
import path from "path";
|
||||
import { TOOL_SYSTEM_PROMPT_ENV_VAR, Tool } from "./tools";
|
||||
import {
|
||||
ModelConfig,
|
||||
TemplateDataSource,
|
||||
@@ -7,7 +8,7 @@ import {
|
||||
TemplateVectorDB,
|
||||
} from "./types";
|
||||
|
||||
type EnvVar = {
|
||||
export type EnvVar = {
|
||||
name?: string;
|
||||
description?: string;
|
||||
value?: string;
|
||||
@@ -29,17 +30,20 @@ const renderEnvVar = (envVars: EnvVar[]): string => {
|
||||
);
|
||||
};
|
||||
|
||||
const getVectorDBEnvs = (vectorDb?: TemplateVectorDB): EnvVar[] => {
|
||||
if (!vectorDb) {
|
||||
const getVectorDBEnvs = (
|
||||
vectorDb?: TemplateVectorDB,
|
||||
framework?: TemplateFramework,
|
||||
): EnvVar[] => {
|
||||
if (!vectorDb || !framework) {
|
||||
return [];
|
||||
}
|
||||
switch (vectorDb) {
|
||||
case "mongo":
|
||||
return [
|
||||
{
|
||||
name: "MONGO_URI",
|
||||
name: "MONGODB_URI",
|
||||
description:
|
||||
"For generating a connection URI, see https://docs.timescale.com/use-timescale/latest/services/create-a-service\nThe MongoDB connection URI.",
|
||||
"For generating a connection URI, see https://www.mongodb.com/docs/manual/reference/connection-string/ \nThe MongoDB connection URI.",
|
||||
},
|
||||
{
|
||||
name: "MONGODB_DATABASE",
|
||||
@@ -129,6 +133,51 @@ const getVectorDBEnvs = (vectorDb?: TemplateVectorDB): EnvVar[] => {
|
||||
"Optional API key for authenticating requests to Qdrant.",
|
||||
},
|
||||
];
|
||||
case "llamacloud":
|
||||
return [
|
||||
{
|
||||
name: "LLAMA_CLOUD_INDEX_NAME",
|
||||
description:
|
||||
"The name of the LlamaCloud index to use (part of the LlamaCloud project).",
|
||||
value: "test",
|
||||
},
|
||||
{
|
||||
name: "LLAMA_CLOUD_PROJECT_NAME",
|
||||
description: "The name of the LlamaCloud project.",
|
||||
value: "Default",
|
||||
},
|
||||
{
|
||||
name: "LLAMA_CLOUD_BASE_URL",
|
||||
description:
|
||||
"The base URL for the LlamaCloud API. Only change this for non-production environments",
|
||||
value: "https://api.cloud.llamaindex.ai",
|
||||
},
|
||||
];
|
||||
case "chroma":
|
||||
const envs = [
|
||||
{
|
||||
name: "CHROMA_COLLECTION",
|
||||
description: "The name of the collection in your Chroma database",
|
||||
},
|
||||
{
|
||||
name: "CHROMA_HOST",
|
||||
description: "The API endpoint for your Chroma database",
|
||||
},
|
||||
{
|
||||
name: "CHROMA_PORT",
|
||||
description: "The port for your Chroma database",
|
||||
},
|
||||
];
|
||||
// TS Version doesn't support config local storage path
|
||||
if (framework === "fastapi") {
|
||||
envs.push({
|
||||
name: "CHROMA_PATH",
|
||||
description: `The local path to the Chroma database.
|
||||
Specify this if you are using a local Chroma database.
|
||||
Otherwise, use CHROMA_HOST and CHROMA_PORT config above`,
|
||||
});
|
||||
}
|
||||
return envs;
|
||||
default:
|
||||
return [];
|
||||
}
|
||||
@@ -156,6 +205,10 @@ const getModelEnvs = (modelConfig: ModelConfig): EnvVar[] => {
|
||||
description: "Dimension of the embedding model to use.",
|
||||
value: modelConfig.dimensions.toString(),
|
||||
},
|
||||
{
|
||||
name: "CONVERSATION_STARTERS",
|
||||
description: "The questions to help users get started (multi-line).",
|
||||
},
|
||||
...(modelConfig.provider === "openai"
|
||||
? [
|
||||
{
|
||||
@@ -173,41 +226,79 @@ const getModelEnvs = (modelConfig: ModelConfig): EnvVar[] => {
|
||||
},
|
||||
]
|
||||
: []),
|
||||
...(modelConfig.provider === "anthropic"
|
||||
? [
|
||||
{
|
||||
name: "ANTHROPIC_API_KEY",
|
||||
description: "The Anthropic API key to use.",
|
||||
value: modelConfig.apiKey,
|
||||
},
|
||||
]
|
||||
: []),
|
||||
...(modelConfig.provider === "groq"
|
||||
? [
|
||||
{
|
||||
name: "GROQ_API_KEY",
|
||||
description: "The Groq API key to use.",
|
||||
value: modelConfig.apiKey,
|
||||
},
|
||||
]
|
||||
: []),
|
||||
...(modelConfig.provider === "gemini"
|
||||
? [
|
||||
{
|
||||
name: "GOOGLE_API_KEY",
|
||||
description: "The Google API key to use.",
|
||||
value: modelConfig.apiKey,
|
||||
},
|
||||
]
|
||||
: []),
|
||||
...(modelConfig.provider === "ollama"
|
||||
? [
|
||||
{
|
||||
name: "OLLAMA_BASE_URL",
|
||||
description:
|
||||
"The base URL for the Ollama API. Eg: http://127.0.0.1:11434",
|
||||
},
|
||||
]
|
||||
: []),
|
||||
];
|
||||
};
|
||||
|
||||
const getFrameworkEnvs = (
|
||||
framework?: TemplateFramework,
|
||||
framework: TemplateFramework,
|
||||
port?: number,
|
||||
): EnvVar[] => {
|
||||
if (framework !== "fastapi") {
|
||||
return [];
|
||||
}
|
||||
return [
|
||||
const sPort = port?.toString() || "8000";
|
||||
const result: EnvVar[] = [
|
||||
{
|
||||
name: "APP_HOST",
|
||||
description: "The address to start the backend app.",
|
||||
value: "0.0.0.0",
|
||||
},
|
||||
{
|
||||
name: "APP_PORT",
|
||||
description: "The port to start the backend app.",
|
||||
value: port?.toString() || "8000",
|
||||
},
|
||||
// TODO: Once LlamaIndexTS supports string templates, move this to `getEngineEnvs`
|
||||
{
|
||||
name: "SYSTEM_PROMPT",
|
||||
description: `Custom system prompt.
|
||||
Example:
|
||||
SYSTEM_PROMPT="
|
||||
We have provided context information below.
|
||||
---------------------
|
||||
{context_str}
|
||||
---------------------
|
||||
Given this information, please answer the question: {query_str}
|
||||
"`,
|
||||
name: "FILESERVER_URL_PREFIX",
|
||||
description:
|
||||
"FILESERVER_URL_PREFIX is the URL prefix of the server storing the images generated by the interpreter.",
|
||||
value:
|
||||
framework === "nextjs"
|
||||
? // FIXME: if we are using nextjs, port should be 3000
|
||||
"http://localhost:3000/api/files"
|
||||
: `http://localhost:${sPort}/api/files`,
|
||||
},
|
||||
];
|
||||
if (framework === "fastapi") {
|
||||
result.push(
|
||||
...[
|
||||
{
|
||||
name: "APP_HOST",
|
||||
description: "The address to start the backend app.",
|
||||
value: "0.0.0.0",
|
||||
},
|
||||
{
|
||||
name: "APP_PORT",
|
||||
description: "The port to start the backend app.",
|
||||
value: sPort,
|
||||
},
|
||||
],
|
||||
);
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
const getEngineEnvs = (): EnvVar[] => {
|
||||
@@ -218,18 +309,68 @@ const getEngineEnvs = (): EnvVar[] => {
|
||||
"The number of similar embeddings to return when retrieving documents.",
|
||||
value: "3",
|
||||
},
|
||||
{
|
||||
name: "STREAM_TIMEOUT",
|
||||
description:
|
||||
"The time in milliseconds to wait for the stream to return a response.",
|
||||
value: "60000",
|
||||
},
|
||||
];
|
||||
};
|
||||
|
||||
const getToolEnvs = (tools?: Tool[]): EnvVar[] => {
|
||||
if (!tools?.length) return [];
|
||||
const toolEnvs: EnvVar[] = [];
|
||||
tools.forEach((tool) => {
|
||||
if (tool.envVars?.length) {
|
||||
toolEnvs.push(
|
||||
// Don't include the system prompt env var here
|
||||
// It should be handled separately by merging with the default system prompt
|
||||
...tool.envVars.filter(
|
||||
(env) => env.name !== TOOL_SYSTEM_PROMPT_ENV_VAR,
|
||||
),
|
||||
);
|
||||
}
|
||||
});
|
||||
return toolEnvs;
|
||||
};
|
||||
|
||||
const getSystemPromptEnv = (tools?: Tool[]): EnvVar => {
|
||||
const defaultSystemPrompt =
|
||||
"You are a helpful assistant who helps users with their questions.";
|
||||
|
||||
// build tool system prompt by merging all tool system prompts
|
||||
let toolSystemPrompt = "";
|
||||
tools?.forEach((tool) => {
|
||||
const toolSystemPromptEnv = tool.envVars?.find(
|
||||
(env) => env.name === TOOL_SYSTEM_PROMPT_ENV_VAR,
|
||||
);
|
||||
if (toolSystemPromptEnv) {
|
||||
toolSystemPrompt += toolSystemPromptEnv.value + "\n";
|
||||
}
|
||||
});
|
||||
|
||||
const systemPrompt = toolSystemPrompt
|
||||
? `\"${toolSystemPrompt}\"`
|
||||
: defaultSystemPrompt;
|
||||
|
||||
return {
|
||||
name: "SYSTEM_PROMPT",
|
||||
description: "The system prompt for the AI model.",
|
||||
value: systemPrompt,
|
||||
};
|
||||
};
|
||||
|
||||
export const createBackendEnvFile = async (
|
||||
root: string,
|
||||
opts: {
|
||||
llamaCloudKey?: string;
|
||||
vectorDb?: TemplateVectorDB;
|
||||
modelConfig: ModelConfig;
|
||||
framework?: TemplateFramework;
|
||||
framework: TemplateFramework;
|
||||
dataSources?: TemplateDataSource[];
|
||||
port?: number;
|
||||
tools?: Tool[];
|
||||
},
|
||||
) => {
|
||||
// Init env values
|
||||
@@ -245,8 +386,10 @@ export const createBackendEnvFile = async (
|
||||
// Add engine environment variables
|
||||
...getEngineEnvs(),
|
||||
// Add vector database environment variables
|
||||
...getVectorDBEnvs(opts.vectorDb),
|
||||
...getVectorDBEnvs(opts.vectorDb, opts.framework),
|
||||
...getFrameworkEnvs(opts.framework, opts.port),
|
||||
...getToolEnvs(opts.tools),
|
||||
getSystemPromptEnv(opts.tools),
|
||||
];
|
||||
// Render and write env file
|
||||
const content = renderEnvVar(envVars);
|
||||
|
||||
+7
-2
@@ -9,7 +9,6 @@ import { createBackendEnvFile, createFrontendEnvFile } from "./env-variables";
|
||||
import { PackageManager } from "./get-pkg-manager";
|
||||
import { installLlamapackProject } from "./llama-pack";
|
||||
import { isHavingPoetryLockFile, tryPoetryRun } from "./poetry";
|
||||
import { isModelConfigured } from "./providers";
|
||||
import { installPythonTemplate } from "./python";
|
||||
import { downloadAndExtractRepo } from "./repo";
|
||||
import { ConfigFileType, writeToolsConfig } from "./tools";
|
||||
@@ -38,7 +37,7 @@ async function generateContextData(
|
||||
? "poetry run generate"
|
||||
: `${packageManager} run generate`,
|
||||
)}`;
|
||||
const modelConfigured = isModelConfigured(modelConfig);
|
||||
const modelConfigured = modelConfig.isConfigured();
|
||||
const llamaCloudKeyConfigured = useLlamaParse
|
||||
? llamaCloudKey || process.env["LLAMA_CLOUD_API_KEY"]
|
||||
: true;
|
||||
@@ -149,6 +148,7 @@ export const installTemplate = async (
|
||||
framework: props.framework,
|
||||
dataSources: props.dataSources,
|
||||
port: props.externalPort,
|
||||
tools: props.tools,
|
||||
});
|
||||
|
||||
if (props.dataSources.length > 0) {
|
||||
@@ -171,6 +171,11 @@ export const installTemplate = async (
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Create tool-output directory
|
||||
if (props.tools && props.tools.length > 0) {
|
||||
await fsExtra.mkdir(path.join(props.root, "tool-output"));
|
||||
}
|
||||
} else {
|
||||
// this is a frontend for a full-stack app, create .env file with model information
|
||||
await createFrontendEnvFile(props.root, {
|
||||
|
||||
@@ -0,0 +1,106 @@
|
||||
import ciInfo from "ci-info";
|
||||
import prompts from "prompts";
|
||||
import { ModelConfigParams } from ".";
|
||||
import { questionHandlers, toChoice } from "../../questions";
|
||||
|
||||
const MODELS = [
|
||||
"claude-3-opus",
|
||||
"claude-3-sonnet",
|
||||
"claude-3-haiku",
|
||||
"claude-2.1",
|
||||
"claude-instant-1.2",
|
||||
];
|
||||
const DEFAULT_MODEL = MODELS[0];
|
||||
|
||||
// TODO: get embedding vector dimensions from the anthropic sdk (currently not supported)
|
||||
// Use huggingface embedding models for now
|
||||
enum HuggingFaceEmbeddingModelType {
|
||||
XENOVA_ALL_MINILM_L6_V2 = "all-MiniLM-L6-v2",
|
||||
XENOVA_ALL_MPNET_BASE_V2 = "all-mpnet-base-v2",
|
||||
}
|
||||
type ModelData = {
|
||||
dimensions: number;
|
||||
};
|
||||
const EMBEDDING_MODELS: Record<HuggingFaceEmbeddingModelType, ModelData> = {
|
||||
[HuggingFaceEmbeddingModelType.XENOVA_ALL_MINILM_L6_V2]: {
|
||||
dimensions: 384,
|
||||
},
|
||||
[HuggingFaceEmbeddingModelType.XENOVA_ALL_MPNET_BASE_V2]: {
|
||||
dimensions: 768,
|
||||
},
|
||||
};
|
||||
const DEFAULT_EMBEDDING_MODEL = Object.keys(EMBEDDING_MODELS)[0];
|
||||
const DEFAULT_DIMENSIONS = Object.values(EMBEDDING_MODELS)[0].dimensions;
|
||||
|
||||
type AnthropicQuestionsParams = {
|
||||
apiKey?: string;
|
||||
askModels: boolean;
|
||||
};
|
||||
|
||||
export async function askAnthropicQuestions({
|
||||
askModels,
|
||||
apiKey,
|
||||
}: AnthropicQuestionsParams): Promise<ModelConfigParams> {
|
||||
const config: ModelConfigParams = {
|
||||
apiKey,
|
||||
model: DEFAULT_MODEL,
|
||||
embeddingModel: DEFAULT_EMBEDDING_MODEL,
|
||||
dimensions: DEFAULT_DIMENSIONS,
|
||||
isConfigured(): boolean {
|
||||
if (config.apiKey) {
|
||||
return true;
|
||||
}
|
||||
if (process.env["ANTHROPIC_API_KEY"]) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
},
|
||||
};
|
||||
|
||||
if (!config.apiKey) {
|
||||
const { key } = await prompts(
|
||||
{
|
||||
type: "text",
|
||||
name: "key",
|
||||
message:
|
||||
"Please provide your Anthropic API key (or leave blank to use ANTHROPIC_API_KEY env variable):",
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
config.apiKey = key || process.env.ANTHROPIC_API_KEY;
|
||||
}
|
||||
|
||||
// use default model values in CI or if user should not be asked
|
||||
const useDefaults = ciInfo.isCI || !askModels;
|
||||
if (!useDefaults) {
|
||||
const { model } = await prompts(
|
||||
{
|
||||
type: "select",
|
||||
name: "model",
|
||||
message: "Which LLM model would you like to use?",
|
||||
choices: MODELS.map(toChoice),
|
||||
initial: 0,
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
config.model = model;
|
||||
|
||||
const { embeddingModel } = await prompts(
|
||||
{
|
||||
type: "select",
|
||||
name: "embeddingModel",
|
||||
message: "Which embedding model would you like to use?",
|
||||
choices: Object.keys(EMBEDDING_MODELS).map(toChoice),
|
||||
initial: 0,
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
config.embeddingModel = embeddingModel;
|
||||
config.dimensions =
|
||||
EMBEDDING_MODELS[
|
||||
embeddingModel as HuggingFaceEmbeddingModelType
|
||||
].dimensions;
|
||||
}
|
||||
|
||||
return config;
|
||||
}
|
||||
@@ -0,0 +1,87 @@
|
||||
import ciInfo from "ci-info";
|
||||
import prompts from "prompts";
|
||||
import { ModelConfigParams } from ".";
|
||||
import { questionHandlers, toChoice } from "../../questions";
|
||||
|
||||
const MODELS = ["gemini-1.5-pro-latest", "gemini-pro", "gemini-pro-vision"];
|
||||
type ModelData = {
|
||||
dimensions: number;
|
||||
};
|
||||
const EMBEDDING_MODELS: Record<string, ModelData> = {
|
||||
"embedding-001": { dimensions: 768 },
|
||||
"text-embedding-004": { dimensions: 768 },
|
||||
};
|
||||
|
||||
const DEFAULT_MODEL = MODELS[0];
|
||||
const DEFAULT_EMBEDDING_MODEL = Object.keys(EMBEDDING_MODELS)[0];
|
||||
const DEFAULT_DIMENSIONS = Object.values(EMBEDDING_MODELS)[0].dimensions;
|
||||
|
||||
type GeminiQuestionsParams = {
|
||||
apiKey?: string;
|
||||
askModels: boolean;
|
||||
};
|
||||
|
||||
export async function askGeminiQuestions({
|
||||
askModels,
|
||||
apiKey,
|
||||
}: GeminiQuestionsParams): Promise<ModelConfigParams> {
|
||||
const config: ModelConfigParams = {
|
||||
apiKey,
|
||||
model: DEFAULT_MODEL,
|
||||
embeddingModel: DEFAULT_EMBEDDING_MODEL,
|
||||
dimensions: DEFAULT_DIMENSIONS,
|
||||
isConfigured(): boolean {
|
||||
if (config.apiKey) {
|
||||
return true;
|
||||
}
|
||||
if (process.env["GOOGLE_API_KEY"]) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
},
|
||||
};
|
||||
|
||||
if (!config.apiKey) {
|
||||
const { key } = await prompts(
|
||||
{
|
||||
type: "text",
|
||||
name: "key",
|
||||
message:
|
||||
"Please provide your Google API key (or leave blank to use GOOGLE_API_KEY env variable):",
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
config.apiKey = key || process.env.GOOGLE_API_KEY;
|
||||
}
|
||||
|
||||
// use default model values in CI or if user should not be asked
|
||||
const useDefaults = ciInfo.isCI || !askModels;
|
||||
if (!useDefaults) {
|
||||
const { model } = await prompts(
|
||||
{
|
||||
type: "select",
|
||||
name: "model",
|
||||
message: "Which LLM model would you like to use?",
|
||||
choices: MODELS.map(toChoice),
|
||||
initial: 0,
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
config.model = model;
|
||||
|
||||
const { embeddingModel } = await prompts(
|
||||
{
|
||||
type: "select",
|
||||
name: "embeddingModel",
|
||||
message: "Which embedding model would you like to use?",
|
||||
choices: Object.keys(EMBEDDING_MODELS).map(toChoice),
|
||||
initial: 0,
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
config.embeddingModel = embeddingModel;
|
||||
config.dimensions = EMBEDDING_MODELS[embeddingModel].dimensions;
|
||||
}
|
||||
|
||||
return config;
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
import ciInfo from "ci-info";
|
||||
import prompts from "prompts";
|
||||
import { ModelConfigParams } from ".";
|
||||
import { questionHandlers, toChoice } from "../../questions";
|
||||
|
||||
const MODELS = ["llama3-8b", "llama3-70b", "mixtral-8x7b"];
|
||||
const DEFAULT_MODEL = MODELS[0];
|
||||
|
||||
// Use huggingface embedding models for now as Groq doesn't support embedding models
|
||||
enum HuggingFaceEmbeddingModelType {
|
||||
XENOVA_ALL_MINILM_L6_V2 = "all-MiniLM-L6-v2",
|
||||
XENOVA_ALL_MPNET_BASE_V2 = "all-mpnet-base-v2",
|
||||
}
|
||||
type ModelData = {
|
||||
dimensions: number;
|
||||
};
|
||||
const EMBEDDING_MODELS: Record<HuggingFaceEmbeddingModelType, ModelData> = {
|
||||
[HuggingFaceEmbeddingModelType.XENOVA_ALL_MINILM_L6_V2]: {
|
||||
dimensions: 384,
|
||||
},
|
||||
[HuggingFaceEmbeddingModelType.XENOVA_ALL_MPNET_BASE_V2]: {
|
||||
dimensions: 768,
|
||||
},
|
||||
};
|
||||
const DEFAULT_EMBEDDING_MODEL = Object.keys(EMBEDDING_MODELS)[0];
|
||||
const DEFAULT_DIMENSIONS = Object.values(EMBEDDING_MODELS)[0].dimensions;
|
||||
|
||||
type GroqQuestionsParams = {
|
||||
apiKey?: string;
|
||||
askModels: boolean;
|
||||
};
|
||||
|
||||
export async function askGroqQuestions({
|
||||
askModels,
|
||||
apiKey,
|
||||
}: GroqQuestionsParams): Promise<ModelConfigParams> {
|
||||
const config: ModelConfigParams = {
|
||||
apiKey,
|
||||
model: DEFAULT_MODEL,
|
||||
embeddingModel: DEFAULT_EMBEDDING_MODEL,
|
||||
dimensions: DEFAULT_DIMENSIONS,
|
||||
isConfigured(): boolean {
|
||||
if (config.apiKey) {
|
||||
return true;
|
||||
}
|
||||
if (process.env["GROQ_API_KEY"]) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
},
|
||||
};
|
||||
|
||||
if (!config.apiKey) {
|
||||
const { key } = await prompts(
|
||||
{
|
||||
type: "text",
|
||||
name: "key",
|
||||
message:
|
||||
"Please provide your Groq API key (or leave blank to use GROQ_API_KEY env variable):",
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
config.apiKey = key || process.env.GROQ_API_KEY;
|
||||
}
|
||||
|
||||
// use default model values in CI or if user should not be asked
|
||||
const useDefaults = ciInfo.isCI || !askModels;
|
||||
if (!useDefaults) {
|
||||
const { model } = await prompts(
|
||||
{
|
||||
type: "select",
|
||||
name: "model",
|
||||
message: "Which LLM model would you like to use?",
|
||||
choices: MODELS.map(toChoice),
|
||||
initial: 0,
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
config.model = model;
|
||||
|
||||
const { embeddingModel } = await prompts(
|
||||
{
|
||||
type: "select",
|
||||
name: "embeddingModel",
|
||||
message: "Which embedding model would you like to use?",
|
||||
choices: Object.keys(EMBEDDING_MODELS).map(toChoice),
|
||||
initial: 0,
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
config.embeddingModel = embeddingModel;
|
||||
config.dimensions =
|
||||
EMBEDDING_MODELS[
|
||||
embeddingModel as HuggingFaceEmbeddingModelType
|
||||
].dimensions;
|
||||
}
|
||||
|
||||
return config;
|
||||
}
|
||||
+16
-10
@@ -2,8 +2,11 @@ import ciInfo from "ci-info";
|
||||
import prompts from "prompts";
|
||||
import { questionHandlers } from "../../questions";
|
||||
import { ModelConfig, ModelProvider } from "../types";
|
||||
import { askAnthropicQuestions } from "./anthropic";
|
||||
import { askGeminiQuestions } from "./gemini";
|
||||
import { askGroqQuestions } from "./groq";
|
||||
import { askOllamaQuestions } from "./ollama";
|
||||
import { askOpenAIQuestions, isOpenAIConfigured } from "./openai";
|
||||
import { askOpenAIQuestions } from "./openai";
|
||||
|
||||
const DEFAULT_MODEL_PROVIDER = "openai";
|
||||
|
||||
@@ -30,7 +33,10 @@ export async function askModelConfig({
|
||||
title: "OpenAI",
|
||||
value: "openai",
|
||||
},
|
||||
{ title: "Groq", value: "groq" },
|
||||
{ title: "Ollama", value: "ollama" },
|
||||
{ title: "Anthropic", value: "anthropic" },
|
||||
{ title: "Gemini", value: "gemini" },
|
||||
],
|
||||
initial: 0,
|
||||
},
|
||||
@@ -44,6 +50,15 @@ export async function askModelConfig({
|
||||
case "ollama":
|
||||
modelConfig = await askOllamaQuestions({ askModels });
|
||||
break;
|
||||
case "groq":
|
||||
modelConfig = await askGroqQuestions({ askModels });
|
||||
break;
|
||||
case "anthropic":
|
||||
modelConfig = await askAnthropicQuestions({ askModels });
|
||||
break;
|
||||
case "gemini":
|
||||
modelConfig = await askGeminiQuestions({ askModels });
|
||||
break;
|
||||
default:
|
||||
modelConfig = await askOpenAIQuestions({
|
||||
openAiKey,
|
||||
@@ -55,12 +70,3 @@ export async function askModelConfig({
|
||||
provider: modelProvider,
|
||||
};
|
||||
}
|
||||
|
||||
export function isModelConfigured(modelConfig: ModelConfig): boolean {
|
||||
switch (modelConfig.provider) {
|
||||
case "openai":
|
||||
return isOpenAIConfigured(modelConfig);
|
||||
default:
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,6 +29,9 @@ export async function askOllamaQuestions({
|
||||
model: DEFAULT_MODEL,
|
||||
embeddingModel: DEFAULT_EMBEDDING_MODEL,
|
||||
dimensions: EMBEDDING_MODELS[DEFAULT_EMBEDDING_MODEL].dimensions,
|
||||
isConfigured(): boolean {
|
||||
return true;
|
||||
},
|
||||
};
|
||||
|
||||
// use default model values in CI or if user should not be asked
|
||||
|
||||
+10
-12
@@ -8,7 +8,7 @@ import { questionHandlers } from "../../questions";
|
||||
|
||||
const OPENAI_API_URL = "https://api.openai.com/v1";
|
||||
|
||||
const DEFAULT_MODEL = "gpt-4-turbo";
|
||||
const DEFAULT_MODEL = "gpt-3.5-turbo";
|
||||
const DEFAULT_EMBEDDING_MODEL = "text-embedding-3-large";
|
||||
|
||||
export async function askOpenAIQuestions({
|
||||
@@ -20,6 +20,15 @@ export async function askOpenAIQuestions({
|
||||
model: DEFAULT_MODEL,
|
||||
embeddingModel: DEFAULT_EMBEDDING_MODEL,
|
||||
dimensions: getDimensions(DEFAULT_EMBEDDING_MODEL),
|
||||
isConfigured(): boolean {
|
||||
if (config.apiKey) {
|
||||
return true;
|
||||
}
|
||||
if (process.env["OPENAI_API_KEY"]) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
},
|
||||
};
|
||||
|
||||
if (!config.apiKey) {
|
||||
@@ -31,7 +40,6 @@ export async function askOpenAIQuestions({
|
||||
? "Please provide your OpenAI API key (or leave blank to use OPENAI_API_KEY env variable):"
|
||||
: "Please provide your OpenAI API key (leave blank to skip):",
|
||||
validate: (value: string) => {
|
||||
console.log(value);
|
||||
if (askModels && !value) {
|
||||
if (process.env.OPENAI_API_KEY) {
|
||||
return true;
|
||||
@@ -78,16 +86,6 @@ export async function askOpenAIQuestions({
|
||||
return config;
|
||||
}
|
||||
|
||||
export function isOpenAIConfigured(params: ModelConfigParams): boolean {
|
||||
if (params.apiKey) {
|
||||
return true;
|
||||
}
|
||||
if (process.env["OPENAI_API_KEY"]) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
async function getAvailableModelChoices(
|
||||
selectEmbedding: boolean,
|
||||
apiKey?: string,
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
/* Function to conditionally load the global-agent/bootstrap module */
|
||||
export async function initializeGlobalAgent() {
|
||||
if (process.env.GLOBAL_AGENT_HTTP_PROXY) {
|
||||
/* Dynamically import global-agent/bootstrap */
|
||||
await import("global-agent/bootstrap");
|
||||
console.log("Proxy enabled via global-agent.");
|
||||
}
|
||||
}
|
||||
+85
-34
@@ -24,7 +24,7 @@ interface Dependency {
|
||||
const getAdditionalDependencies = (
|
||||
modelConfig: ModelConfig,
|
||||
vectorDb?: TemplateVectorDB,
|
||||
dataSource?: TemplateDataSource,
|
||||
dataSources?: TemplateDataSource[],
|
||||
tools?: Tool[],
|
||||
) => {
|
||||
const dependencies: Dependency[] = [];
|
||||
@@ -43,6 +43,7 @@ const getAdditionalDependencies = (
|
||||
name: "llama-index-vector-stores-postgres",
|
||||
version: "^0.1.1",
|
||||
});
|
||||
break;
|
||||
}
|
||||
case "pinecone": {
|
||||
dependencies.push({
|
||||
@@ -69,41 +70,66 @@ const getAdditionalDependencies = (
|
||||
});
|
||||
break;
|
||||
}
|
||||
case "qdrant": {
|
||||
dependencies.push({
|
||||
name: "llama-index-vector-stores-qdrant",
|
||||
version: "^0.2.8",
|
||||
});
|
||||
break;
|
||||
}
|
||||
case "chroma": {
|
||||
dependencies.push({
|
||||
name: "llama-index-vector-stores-chroma",
|
||||
version: "^0.1.8",
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Add data source dependencies
|
||||
const dataSourceType = dataSource?.type;
|
||||
switch (dataSourceType) {
|
||||
case "file":
|
||||
dependencies.push({
|
||||
name: "docx2txt",
|
||||
version: "^0.8",
|
||||
});
|
||||
break;
|
||||
case "web":
|
||||
dependencies.push({
|
||||
name: "llama-index-readers-web",
|
||||
version: "^0.1.6",
|
||||
});
|
||||
break;
|
||||
case "db":
|
||||
dependencies.push({
|
||||
name: "llama-index-readers-database",
|
||||
version: "^0.1.3",
|
||||
});
|
||||
dependencies.push({
|
||||
name: "pymysql",
|
||||
version: "^1.1.0",
|
||||
extras: ["rsa"],
|
||||
});
|
||||
dependencies.push({
|
||||
name: "psycopg2",
|
||||
version: "^2.9.9",
|
||||
});
|
||||
break;
|
||||
if (dataSources) {
|
||||
for (const ds of dataSources) {
|
||||
const dsType = ds?.type;
|
||||
switch (dsType) {
|
||||
case "file":
|
||||
dependencies.push({
|
||||
name: "docx2txt",
|
||||
version: "^0.8",
|
||||
});
|
||||
break;
|
||||
case "web":
|
||||
dependencies.push({
|
||||
name: "llama-index-readers-web",
|
||||
version: "^0.1.6",
|
||||
});
|
||||
break;
|
||||
case "db":
|
||||
dependencies.push({
|
||||
name: "llama-index-readers-database",
|
||||
version: "^0.1.3",
|
||||
});
|
||||
dependencies.push({
|
||||
name: "pymysql",
|
||||
version: "^1.1.0",
|
||||
extras: ["rsa"],
|
||||
});
|
||||
dependencies.push({
|
||||
name: "psycopg2",
|
||||
version: "^2.9.9",
|
||||
});
|
||||
break;
|
||||
case "llamacloud":
|
||||
dependencies.push({
|
||||
name: "llama-index-indices-managed-llama-cloud",
|
||||
version: "^0.2.1",
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add tools dependencies
|
||||
console.log("Adding tools dependencies");
|
||||
tools?.forEach((tool) => {
|
||||
tool.dependencies?.forEach((dep) => {
|
||||
dependencies.push(dep);
|
||||
@@ -124,7 +150,27 @@ const getAdditionalDependencies = (
|
||||
case "openai":
|
||||
dependencies.push({
|
||||
name: "llama-index-agent-openai",
|
||||
version: "0.2.2",
|
||||
version: "0.2.6",
|
||||
});
|
||||
break;
|
||||
case "anthropic":
|
||||
dependencies.push({
|
||||
name: "llama-index-llms-anthropic",
|
||||
version: "0.1.10",
|
||||
});
|
||||
dependencies.push({
|
||||
name: "llama-index-embeddings-huggingface",
|
||||
version: "0.2.0",
|
||||
});
|
||||
break;
|
||||
case "gemini":
|
||||
dependencies.push({
|
||||
name: "llama-index-llms-gemini",
|
||||
version: "0.1.10",
|
||||
});
|
||||
dependencies.push({
|
||||
name: "llama-index-embeddings-gemini",
|
||||
version: "0.1.6",
|
||||
});
|
||||
break;
|
||||
}
|
||||
@@ -278,9 +324,14 @@ export const installPythonTemplate = async ({
|
||||
cwd: path.join(compPath, "engines", "python", engine),
|
||||
});
|
||||
|
||||
const addOnDependencies = dataSources
|
||||
.map((ds) => getAdditionalDependencies(modelConfig, vectorDb, ds, tools))
|
||||
.flat();
|
||||
console.log("Adding additional dependencies");
|
||||
|
||||
const addOnDependencies = getAdditionalDependencies(
|
||||
modelConfig,
|
||||
vectorDb,
|
||||
dataSources,
|
||||
tools,
|
||||
);
|
||||
|
||||
if (observability === "opentelemetry") {
|
||||
addOnDependencies.push({
|
||||
|
||||
+161
-5
@@ -2,15 +2,25 @@ import fs from "fs/promises";
|
||||
import path from "path";
|
||||
import { red } from "picocolors";
|
||||
import yaml from "yaml";
|
||||
import { EnvVar } from "./env-variables";
|
||||
import { makeDir } from "./make-dir";
|
||||
import { TemplateFramework } from "./types";
|
||||
|
||||
export const TOOL_SYSTEM_PROMPT_ENV_VAR = "TOOL_SYSTEM_PROMPT";
|
||||
|
||||
export enum ToolType {
|
||||
LLAMAHUB = "llamahub",
|
||||
LOCAL = "local",
|
||||
}
|
||||
|
||||
export type Tool = {
|
||||
display: string;
|
||||
name: string;
|
||||
config?: Record<string, any>;
|
||||
dependencies?: ToolDependencies[];
|
||||
supportedFrameworks?: Array<TemplateFramework>;
|
||||
type: ToolType;
|
||||
envVars?: EnvVar[];
|
||||
};
|
||||
|
||||
export type ToolDependencies = {
|
||||
@@ -20,7 +30,7 @@ export type ToolDependencies = {
|
||||
|
||||
export const supportedTools: Tool[] = [
|
||||
{
|
||||
display: "Google Search (configuration required after installation)",
|
||||
display: "Google Search",
|
||||
name: "google.GoogleSearchToolSpec",
|
||||
config: {
|
||||
engine:
|
||||
@@ -35,6 +45,37 @@ export const supportedTools: Tool[] = [
|
||||
},
|
||||
],
|
||||
supportedFrameworks: ["fastapi"],
|
||||
type: ToolType.LLAMAHUB,
|
||||
envVars: [
|
||||
{
|
||||
name: TOOL_SYSTEM_PROMPT_ENV_VAR,
|
||||
description: "System prompt for google search tool.",
|
||||
value: `You are a Google search agent. You help users to get information from Google search.`,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
// For python app, we will use a local DuckDuckGo search tool (instead of DuckDuckGo search tool in LlamaHub)
|
||||
// to get the same results as the TS app.
|
||||
display: "DuckDuckGo Search",
|
||||
name: "duckduckgo",
|
||||
dependencies: [
|
||||
{
|
||||
name: "duckduckgo-search",
|
||||
version: "6.1.7",
|
||||
},
|
||||
],
|
||||
supportedFrameworks: ["fastapi", "nextjs", "express"],
|
||||
type: ToolType.LOCAL,
|
||||
envVars: [
|
||||
{
|
||||
name: TOOL_SYSTEM_PROMPT_ENV_VAR,
|
||||
description: "System prompt for DuckDuckGo search tool.",
|
||||
value: `You are a DuckDuckGo search agent.
|
||||
You can use the duckduckgo search tool to get information from the web to answer user questions.
|
||||
For better results, you can specify the region parameter to get results from a specific region but it's optional.`,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
display: "Wikipedia",
|
||||
@@ -46,6 +87,106 @@ export const supportedTools: Tool[] = [
|
||||
},
|
||||
],
|
||||
supportedFrameworks: ["fastapi", "express", "nextjs"],
|
||||
type: ToolType.LLAMAHUB,
|
||||
envVars: [
|
||||
{
|
||||
name: TOOL_SYSTEM_PROMPT_ENV_VAR,
|
||||
description: "System prompt for wiki tool.",
|
||||
value: `You are a Wikipedia agent. You help users to get information from Wikipedia.`,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
display: "Weather",
|
||||
name: "weather",
|
||||
dependencies: [],
|
||||
supportedFrameworks: ["fastapi", "express", "nextjs"],
|
||||
type: ToolType.LOCAL,
|
||||
envVars: [
|
||||
{
|
||||
name: TOOL_SYSTEM_PROMPT_ENV_VAR,
|
||||
description: "System prompt for weather tool.",
|
||||
value: `You are a weather forecast agent. You help users to get the weather forecast for a given location.`,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
display: "Code Interpreter",
|
||||
name: "interpreter",
|
||||
dependencies: [
|
||||
{
|
||||
name: "e2b_code_interpreter",
|
||||
version: "0.0.7",
|
||||
},
|
||||
],
|
||||
supportedFrameworks: ["fastapi", "express", "nextjs"],
|
||||
type: ToolType.LOCAL,
|
||||
envVars: [
|
||||
{
|
||||
name: "E2B_API_KEY",
|
||||
description:
|
||||
"E2B_API_KEY key is required to run code interpreter tool. Get it here: https://e2b.dev/docs/getting-started/api-key",
|
||||
},
|
||||
{
|
||||
name: TOOL_SYSTEM_PROMPT_ENV_VAR,
|
||||
description: "System prompt for code interpreter tool.",
|
||||
value: `-You are a Python interpreter that can run any python code in a secure environment.
|
||||
- The python code runs in a Jupyter notebook. Every time you call the 'interpreter' tool, the python code is executed in a separate cell.
|
||||
- You are given tasks to complete and you run python code to solve them.
|
||||
- It's okay to make multiple calls to interpreter tool. If you get an error or the result is not what you expected, you can call the tool again. Don't give up too soon!
|
||||
- Plot visualizations using matplotlib or any other visualization library directly in the notebook.
|
||||
- You can install any pip package (if it exists) by running a cell with pip install.`,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
display: "OpenAPI action",
|
||||
name: "openapi_action.OpenAPIActionToolSpec",
|
||||
dependencies: [
|
||||
{
|
||||
name: "llama-index-tools-openapi",
|
||||
version: "0.1.3",
|
||||
},
|
||||
{
|
||||
name: "jsonschema",
|
||||
version: "^4.22.0",
|
||||
},
|
||||
{
|
||||
name: "llama-index-tools-requests",
|
||||
version: "0.1.3",
|
||||
},
|
||||
],
|
||||
config: {
|
||||
openapi_uri: "The URL or file path of the OpenAPI schema",
|
||||
},
|
||||
supportedFrameworks: ["fastapi", "express", "nextjs"],
|
||||
type: ToolType.LOCAL,
|
||||
envVars: [
|
||||
{
|
||||
name: TOOL_SYSTEM_PROMPT_ENV_VAR,
|
||||
description: "System prompt for openapi action tool.",
|
||||
value:
|
||||
"You are an OpenAPI action agent. You help users to make requests to the provided OpenAPI schema.",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
display: "Image Generator",
|
||||
name: "img_gen",
|
||||
supportedFrameworks: ["fastapi", "express", "nextjs"],
|
||||
type: ToolType.LOCAL,
|
||||
envVars: [
|
||||
{
|
||||
name: "STABILITY_API_KEY",
|
||||
description:
|
||||
"STABILITY_API_KEY key is required to run image generator. Get it here: https://platform.stability.ai/account/keys",
|
||||
},
|
||||
{
|
||||
name: TOOL_SYSTEM_PROMPT_ENV_VAR,
|
||||
description: "System prompt for image generator tool.",
|
||||
value: `You are an image generator agent. You help users to generate images using the Stability API.`,
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
@@ -72,9 +213,15 @@ export const getTools = (toolsName: string[]): Tool[] => {
|
||||
return tools;
|
||||
};
|
||||
|
||||
export const toolRequiresConfig = (tool: Tool): boolean => {
|
||||
const hasConfig = Object.keys(tool.config || {}).length > 0;
|
||||
const hasEmptyEnvVar = tool.envVars?.some((envVar) => !envVar.value) ?? false;
|
||||
return hasConfig || hasEmptyEnvVar;
|
||||
};
|
||||
|
||||
export const toolsRequireConfig = (tools?: Tool[]): boolean => {
|
||||
if (tools) {
|
||||
return tools?.some((tool) => Object.keys(tool.config || {}).length > 0);
|
||||
return tools?.some(toolRequiresConfig);
|
||||
}
|
||||
return false;
|
||||
};
|
||||
@@ -89,10 +236,19 @@ export const writeToolsConfig = async (
|
||||
tools: Tool[] = [],
|
||||
type: ConfigFileType = ConfigFileType.YAML,
|
||||
) => {
|
||||
if (tools.length === 0) return; // no tools selected, no config need
|
||||
const configContent: Record<string, any> = {};
|
||||
const configContent: {
|
||||
[key in ToolType]: Record<string, any>;
|
||||
} = {
|
||||
local: {},
|
||||
llamahub: {},
|
||||
};
|
||||
tools.forEach((tool) => {
|
||||
configContent[tool.name] = tool.config ?? {};
|
||||
if (tool.type === ToolType.LLAMAHUB) {
|
||||
configContent.llamahub[tool.name] = tool.config ?? {};
|
||||
}
|
||||
if (tool.type === ToolType.LOCAL) {
|
||||
configContent.local[tool.name] = tool.config ?? {};
|
||||
}
|
||||
});
|
||||
const configPath = path.join(root, "config");
|
||||
await makeDir(configPath);
|
||||
|
||||
+11
-3
@@ -1,13 +1,19 @@
|
||||
import { PackageManager } from "../helpers/get-pkg-manager";
|
||||
import { Tool } from "./tools";
|
||||
|
||||
export type ModelProvider = "openai" | "ollama";
|
||||
export type ModelProvider =
|
||||
| "openai"
|
||||
| "groq"
|
||||
| "ollama"
|
||||
| "anthropic"
|
||||
| "gemini";
|
||||
export type ModelConfig = {
|
||||
provider: ModelProvider;
|
||||
apiKey?: string;
|
||||
model: string;
|
||||
embeddingModel: string;
|
||||
dimensions: number;
|
||||
isConfigured(): boolean;
|
||||
};
|
||||
export type TemplateType = "streaming" | "community" | "llamapack";
|
||||
export type TemplateFramework = "nextjs" | "express" | "fastapi";
|
||||
@@ -19,7 +25,9 @@ export type TemplateVectorDB =
|
||||
| "pinecone"
|
||||
| "milvus"
|
||||
| "astra"
|
||||
| "qdrant";
|
||||
| "qdrant"
|
||||
| "chroma"
|
||||
| "llamacloud";
|
||||
export type TemplatePostInstallAction =
|
||||
| "none"
|
||||
| "VSCode"
|
||||
@@ -29,7 +37,7 @@ export type TemplateDataSource = {
|
||||
type: TemplateDataSourceType;
|
||||
config: TemplateDataSourceConfig;
|
||||
};
|
||||
export type TemplateDataSourceType = "file" | "web" | "db";
|
||||
export type TemplateDataSourceType = "file" | "web" | "db" | "llamacloud";
|
||||
export type TemplateObservability = "none" | "opentelemetry";
|
||||
// Config for both file and folder
|
||||
export type FileSourceConfig = {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import fs from "fs/promises";
|
||||
import os from "os";
|
||||
import path from "path";
|
||||
import { bold, cyan } from "picocolors";
|
||||
import { bold, cyan, yellow } from "picocolors";
|
||||
import { assetRelocator, copy } from "../helpers/copy";
|
||||
import { callPackageManager } from "../helpers/install";
|
||||
import { templatesDir } from "./dir";
|
||||
@@ -105,7 +105,13 @@ export const installTSTemplate = async ({
|
||||
const enginePath = path.join(root, relativeEngineDestPath, "engine");
|
||||
|
||||
// copy vector db component
|
||||
console.log("\nUsing vector DB:", vectorDb, "\n");
|
||||
if (vectorDb === "llamacloud") {
|
||||
console.log(
|
||||
`\nUsing managed index from LlamaCloud. Ensure the ${yellow("LLAMA_CLOUD_* environment variables are set correctly.")}`,
|
||||
);
|
||||
} else {
|
||||
console.log("\nUsing vector DB:", vectorDb ?? "none");
|
||||
}
|
||||
await copy("**", enginePath, {
|
||||
parents: true,
|
||||
cwd: path.join(compPath, "vectordbs", "typescript", vectorDb ?? "none"),
|
||||
|
||||
@@ -12,12 +12,16 @@ import { createApp } from "./create-app";
|
||||
import { getDataSources } from "./helpers/datasources";
|
||||
import { getPkgManager } from "./helpers/get-pkg-manager";
|
||||
import { isFolderEmpty } from "./helpers/is-folder-empty";
|
||||
import { initializeGlobalAgent } from "./helpers/proxy";
|
||||
import { runApp } from "./helpers/run-app";
|
||||
import { getTools } from "./helpers/tools";
|
||||
import { validateNpmName } from "./helpers/validate-pkg";
|
||||
import packageJson from "./package.json";
|
||||
import { QuestionArgs, askQuestions, onPromptState } from "./questions";
|
||||
|
||||
// Run the initialization function
|
||||
initializeGlobalAgent();
|
||||
|
||||
let projectPath: string = "";
|
||||
|
||||
const handleSigTerm = () => process.exit(0);
|
||||
|
||||
+2
-1
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "create-llama",
|
||||
"version": "0.1.0",
|
||||
"version": "0.1.15",
|
||||
"description": "Create LlamaIndex-powered apps with one command",
|
||||
"keywords": [
|
||||
"rag",
|
||||
@@ -52,6 +52,7 @@
|
||||
"cross-spawn": "7.0.3",
|
||||
"fast-glob": "3.3.1",
|
||||
"fs-extra": "11.2.0",
|
||||
"global-agent": "^3.0.0",
|
||||
"got": "10.7.0",
|
||||
"ollama": "^0.5.0",
|
||||
"ora": "^8.0.1",
|
||||
|
||||
Generated
+265
-145
File diff suppressed because it is too large
Load Diff
+105
-55
@@ -14,9 +14,13 @@ import { COMMUNITY_OWNER, COMMUNITY_REPO } from "./helpers/constant";
|
||||
import { EXAMPLE_FILE } from "./helpers/datasources";
|
||||
import { templatesDir } from "./helpers/dir";
|
||||
import { getAvailableLlamapackOptions } from "./helpers/llama-pack";
|
||||
import { askModelConfig, isModelConfigured } from "./helpers/providers";
|
||||
import { askModelConfig } from "./helpers/providers";
|
||||
import { getProjectOptions } from "./helpers/repo";
|
||||
import { supportedTools, toolsRequireConfig } from "./helpers/tools";
|
||||
import {
|
||||
supportedTools,
|
||||
toolRequiresConfig,
|
||||
toolsRequireConfig,
|
||||
} from "./helpers/tools";
|
||||
|
||||
export type QuestionArgs = Omit<
|
||||
InstallAppArgs,
|
||||
@@ -97,6 +101,7 @@ const getVectorDbChoices = (framework: TemplateFramework) => {
|
||||
{ title: "Milvus", value: "milvus" },
|
||||
{ title: "Astra", value: "astra" },
|
||||
{ title: "Qdrant", value: "qdrant" },
|
||||
{ title: "ChromaDB", value: "chroma" },
|
||||
];
|
||||
|
||||
const vectordbLang = framework === "fastapi" ? "python" : "typescript";
|
||||
@@ -118,7 +123,13 @@ export const getDataSourceChoices = (
|
||||
framework: TemplateFramework,
|
||||
selectedDataSource: TemplateDataSource[],
|
||||
) => {
|
||||
// If LlamaCloud is already selected, don't show any other options
|
||||
if (selectedDataSource.find((s) => s.type === "llamacloud")) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const choices = [];
|
||||
|
||||
if (selectedDataSource.length > 0) {
|
||||
choices.push({
|
||||
title: "No",
|
||||
@@ -131,24 +142,30 @@ export const getDataSourceChoices = (
|
||||
value: "none",
|
||||
});
|
||||
choices.push({
|
||||
title: "Use an example PDF",
|
||||
title:
|
||||
process.platform !== "linux"
|
||||
? "Use an example PDF"
|
||||
: "Use an example PDF (you can add your own data files later)",
|
||||
value: "exampleFile",
|
||||
});
|
||||
}
|
||||
|
||||
choices.push(
|
||||
{
|
||||
title: `Use local files (${supportedContextFileTypes.join(", ")})`,
|
||||
value: "file",
|
||||
},
|
||||
{
|
||||
title:
|
||||
process.platform === "win32"
|
||||
? "Use a local folder"
|
||||
: "Use local folders",
|
||||
value: "folder",
|
||||
},
|
||||
);
|
||||
// Linux has many distros so we won't support file/folder picker for now
|
||||
if (process.platform !== "linux") {
|
||||
choices.push(
|
||||
{
|
||||
title: `Use local files (${supportedContextFileTypes.join(", ")})`,
|
||||
value: "file",
|
||||
},
|
||||
{
|
||||
title:
|
||||
process.platform === "win32"
|
||||
? "Use a local folder"
|
||||
: "Use local folders",
|
||||
value: "folder",
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
if (framework === "fastapi") {
|
||||
choices.push({
|
||||
@@ -160,6 +177,13 @@ export const getDataSourceChoices = (
|
||||
value: "db",
|
||||
});
|
||||
}
|
||||
|
||||
if (!selectedDataSource.length) {
|
||||
choices.push({
|
||||
title: "Use managed index from LlamaCloud",
|
||||
value: "llamacloud",
|
||||
});
|
||||
}
|
||||
return choices;
|
||||
};
|
||||
|
||||
@@ -257,7 +281,8 @@ export const askQuestions = async (
|
||||
},
|
||||
];
|
||||
|
||||
const modelConfigured = isModelConfigured(program.modelConfig);
|
||||
const modelConfigured =
|
||||
!program.llamapack && program.modelConfig.isConfigured();
|
||||
// If using LlamaParse, require LlamaCloud API key
|
||||
const llamaCloudKeyConfigured = program.useLlamaParse
|
||||
? program.llamaCloudKey || process.env["LLAMA_CLOUD_API_KEY"]
|
||||
@@ -268,8 +293,7 @@ export const askQuestions = async (
|
||||
!hasVectorDb &&
|
||||
modelConfigured &&
|
||||
llamaCloudKeyConfigured &&
|
||||
!toolsRequireConfig(program.tools) &&
|
||||
!program.llamapack
|
||||
!toolsRequireConfig(program.tools)
|
||||
) {
|
||||
actionChoices.push({
|
||||
title:
|
||||
@@ -398,7 +422,6 @@ export const askQuestions = async (
|
||||
|
||||
if (program.framework === "express" || program.framework === "fastapi") {
|
||||
// if a backend-only framework is selected, ask whether we should create a frontend
|
||||
// (only for streaming backends)
|
||||
if (program.frontend === undefined) {
|
||||
if (ciInfo.isCI) {
|
||||
program.frontend = getPrefOrDefault("frontend");
|
||||
@@ -474,6 +497,11 @@ export const askQuestions = async (
|
||||
// continue asking user for data sources if none are initially provided
|
||||
while (true) {
|
||||
const firstQuestion = program.dataSources.length === 0;
|
||||
const choices = getDataSourceChoices(
|
||||
program.framework,
|
||||
program.dataSources,
|
||||
);
|
||||
if (choices.length === 0) break;
|
||||
const { selectedSource } = await prompts(
|
||||
{
|
||||
type: "select",
|
||||
@@ -481,10 +509,7 @@ export const askQuestions = async (
|
||||
message: firstQuestion
|
||||
? "Which data source would you like to use?"
|
||||
: "Would you like to add another data source?",
|
||||
choices: getDataSourceChoices(
|
||||
program.framework,
|
||||
program.dataSources,
|
||||
),
|
||||
choices,
|
||||
initial: firstQuestion ? 1 : 0,
|
||||
},
|
||||
questionHandlers,
|
||||
@@ -581,51 +606,76 @@ export const askQuestions = async (
|
||||
config: await prompts(dbPrompts, questionHandlers),
|
||||
});
|
||||
}
|
||||
case "llamacloud": {
|
||||
program.dataSources.push({
|
||||
type: "llamacloud",
|
||||
config: {},
|
||||
});
|
||||
program.dataSources.push(EXAMPLE_FILE);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Asking for LlamaParse if user selected file or folder data source
|
||||
if (
|
||||
program.dataSources.some((ds) => ds.type === "file") &&
|
||||
program.useLlamaParse === undefined
|
||||
) {
|
||||
if (ciInfo.isCI) {
|
||||
program.useLlamaParse = getPrefOrDefault("useLlamaParse");
|
||||
program.llamaCloudKey = getPrefOrDefault("llamaCloudKey");
|
||||
} else {
|
||||
const { useLlamaParse } = await prompts(
|
||||
{
|
||||
type: "toggle",
|
||||
name: "useLlamaParse",
|
||||
message:
|
||||
"Would you like to use LlamaParse (improved parser for RAG - requires API key)?",
|
||||
initial: false,
|
||||
active: "yes",
|
||||
inactive: "no",
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
program.useLlamaParse = useLlamaParse;
|
||||
const isUsingLlamaCloud = program.dataSources.some(
|
||||
(ds) => ds.type === "llamacloud",
|
||||
);
|
||||
|
||||
// Ask for LlamaCloud API key
|
||||
if (useLlamaParse && program.llamaCloudKey === undefined) {
|
||||
const { llamaCloudKey } = await prompts(
|
||||
// Asking for LlamaParse if user selected file data source
|
||||
if (isUsingLlamaCloud) {
|
||||
// default to use LlamaParse if using LlamaCloud
|
||||
program.useLlamaParse = preferences.useLlamaParse = true;
|
||||
} else {
|
||||
if (program.dataSources.some((ds) => ds.type === "file")) {
|
||||
if (ciInfo.isCI) {
|
||||
program.useLlamaParse = getPrefOrDefault("useLlamaParse");
|
||||
} else {
|
||||
const { useLlamaParse } = await prompts(
|
||||
{
|
||||
type: "text",
|
||||
name: "llamaCloudKey",
|
||||
type: "toggle",
|
||||
name: "useLlamaParse",
|
||||
message:
|
||||
"Please provide your LlamaIndex Cloud API key (leave blank to skip):",
|
||||
"Would you like to use LlamaParse (improved parser for RAG - requires API key)?",
|
||||
initial: false,
|
||||
active: "yes",
|
||||
inactive: "no",
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
program.llamaCloudKey = llamaCloudKey;
|
||||
program.useLlamaParse = useLlamaParse;
|
||||
preferences.useLlamaParse = useLlamaParse;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (program.dataSources.length > 0 && !program.vectorDb) {
|
||||
// Ask for LlamaCloud API key when using a LlamaCloud index or LlamaParse
|
||||
if (isUsingLlamaCloud || program.useLlamaParse) {
|
||||
if (ciInfo.isCI) {
|
||||
program.llamaCloudKey = getPrefOrDefault("llamaCloudKey");
|
||||
} else {
|
||||
// Ask for LlamaCloud API key
|
||||
const { llamaCloudKey } = await prompts(
|
||||
{
|
||||
type: "text",
|
||||
name: "llamaCloudKey",
|
||||
message:
|
||||
"Please provide your LlamaCloud API key (leave blank to skip):",
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
program.llamaCloudKey = preferences.llamaCloudKey =
|
||||
llamaCloudKey || process.env.LLAMA_CLOUD_API_KEY;
|
||||
}
|
||||
}
|
||||
|
||||
if (isUsingLlamaCloud) {
|
||||
// When using a LlamaCloud index, don't ask for vector database and use code in `llamacloud` folder for vector database
|
||||
const vectorDb = "llamacloud";
|
||||
program.vectorDb = vectorDb;
|
||||
preferences.vectorDb = vectorDb;
|
||||
} else if (program.dataSources.length > 0 && !program.vectorDb) {
|
||||
if (ciInfo.isCI) {
|
||||
program.vectorDb = getPrefOrDefault("vectorDb");
|
||||
} else {
|
||||
@@ -652,7 +702,7 @@ export const askQuestions = async (
|
||||
t.supportedFrameworks?.includes(program.framework),
|
||||
);
|
||||
const toolChoices = options.map((tool) => ({
|
||||
title: tool.display,
|
||||
title: `${tool.display}${toolRequiresConfig(tool) ? " (needs configuration)" : ""}`,
|
||||
value: tool.name,
|
||||
}));
|
||||
const { toolsName } = await prompts({
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
import os
|
||||
import yaml
|
||||
import importlib
|
||||
|
||||
from llama_index.core.tools.tool_spec.base import BaseToolSpec
|
||||
from llama_index.core.tools.function_tool import FunctionTool
|
||||
|
||||
|
||||
class ToolFactory:
|
||||
|
||||
@staticmethod
|
||||
def create_tool(tool_name: str, **kwargs) -> list[FunctionTool]:
|
||||
try:
|
||||
tool_package, tool_cls_name = tool_name.split(".")
|
||||
module_name = f"llama_index.tools.{tool_package}"
|
||||
module = importlib.import_module(module_name)
|
||||
tool_class = getattr(module, tool_cls_name)
|
||||
tool_spec: BaseToolSpec = tool_class(**kwargs)
|
||||
return tool_spec.to_tool_list()
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise ValueError(f"Unsupported tool: {tool_name}") from e
|
||||
except TypeError as e:
|
||||
raise ValueError(
|
||||
f"Could not create tool: {tool_name}. With config: {kwargs}"
|
||||
) from e
|
||||
|
||||
@staticmethod
|
||||
def from_env() -> list[FunctionTool]:
|
||||
tools = []
|
||||
if os.path.exists("config/tools.yaml"):
|
||||
with open("config/tools.yaml", "r") as f:
|
||||
tool_configs = yaml.safe_load(f)
|
||||
for name, config in tool_configs.items():
|
||||
tools += ToolFactory.create_tool(name, **config)
|
||||
return tools
|
||||
@@ -0,0 +1,56 @@
|
||||
import os
|
||||
import yaml
|
||||
import json
|
||||
import importlib
|
||||
from cachetools import cached, LRUCache
|
||||
from llama_index.core.tools.tool_spec.base import BaseToolSpec
|
||||
from llama_index.core.tools.function_tool import FunctionTool
|
||||
|
||||
|
||||
class ToolType:
|
||||
LLAMAHUB = "llamahub"
|
||||
LOCAL = "local"
|
||||
|
||||
|
||||
class ToolFactory:
|
||||
|
||||
TOOL_SOURCE_PACKAGE_MAP = {
|
||||
ToolType.LLAMAHUB: "llama_index.tools",
|
||||
ToolType.LOCAL: "app.engine.tools",
|
||||
}
|
||||
|
||||
def load_tools(tool_type: str, tool_name: str, config: dict) -> list[FunctionTool]:
|
||||
source_package = ToolFactory.TOOL_SOURCE_PACKAGE_MAP[tool_type]
|
||||
try:
|
||||
if "ToolSpec" in tool_name:
|
||||
tool_package, tool_cls_name = tool_name.split(".")
|
||||
module_name = f"{source_package}.{tool_package}"
|
||||
module = importlib.import_module(module_name)
|
||||
tool_class = getattr(module, tool_cls_name)
|
||||
tool_spec: BaseToolSpec = tool_class(**config)
|
||||
return tool_spec.to_tool_list()
|
||||
else:
|
||||
module = importlib.import_module(f"{source_package}.{tool_name}")
|
||||
tools = module.get_tools(**config)
|
||||
if not all(isinstance(tool, FunctionTool) for tool in tools):
|
||||
raise ValueError(
|
||||
f"The module {module} does not contain valid tools"
|
||||
)
|
||||
return tools
|
||||
except ImportError as e:
|
||||
raise ValueError(f"Failed to import tool {tool_name}: {e}")
|
||||
except AttributeError as e:
|
||||
raise ValueError(f"Failed to load tool {tool_name}: {e}")
|
||||
|
||||
@staticmethod
|
||||
def from_env() -> list[FunctionTool]:
|
||||
tools = []
|
||||
if os.path.exists("config/tools.yaml"):
|
||||
with open("config/tools.yaml", "r") as f:
|
||||
tool_configs = yaml.safe_load(f)
|
||||
for tool_type, config_entries in tool_configs.items():
|
||||
for tool_name, config in config_entries.items():
|
||||
tools.extend(
|
||||
ToolFactory.load_tools(tool_type, tool_name, config)
|
||||
)
|
||||
return tools
|
||||
@@ -0,0 +1,36 @@
|
||||
from llama_index.core.tools.function_tool import FunctionTool
|
||||
|
||||
|
||||
def duckduckgo_search(
|
||||
query: str,
|
||||
region: str = "wt-wt",
|
||||
max_results: int = 10,
|
||||
):
|
||||
"""
|
||||
Use this function to search for any query in DuckDuckGo.
|
||||
Args:
|
||||
query (str): The query to search in DuckDuckGo.
|
||||
region Optional(str): The region to be used for the search in [country-language] convention, ex us-en, uk-en, ru-ru, etc...
|
||||
max_results Optional(int): The maximum number of results to be returned. Default is 10.
|
||||
"""
|
||||
try:
|
||||
from duckduckgo_search import DDGS
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"duckduckgo_search package is required to use this function."
|
||||
"Please install it by running: `poetry add duckduckgo_search` or `pip install duckduckgo_search`"
|
||||
)
|
||||
|
||||
params = {
|
||||
"keywords": query,
|
||||
"region": region,
|
||||
"max_results": max_results,
|
||||
}
|
||||
results = []
|
||||
with DDGS() as ddg:
|
||||
results = list(ddg.text(**params))
|
||||
return results
|
||||
|
||||
|
||||
def get_tools(**kwargs):
|
||||
return [FunctionTool.from_defaults(duckduckgo_search)]
|
||||
@@ -0,0 +1,108 @@
|
||||
import os
|
||||
import uuid
|
||||
import logging
|
||||
import requests
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from llama_index.core.tools import FunctionTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ImageGeneratorToolOutput(BaseModel):
|
||||
is_success: bool = Field(
|
||||
...,
|
||||
description="Whether the image generation was successful.",
|
||||
)
|
||||
image_url: Optional[str] = Field(
|
||||
None,
|
||||
description="The URL of the generated image.",
|
||||
)
|
||||
error_message: Optional[str] = Field(
|
||||
None,
|
||||
description="The error message if the image generation failed.",
|
||||
)
|
||||
|
||||
|
||||
class ImageGeneratorTool:
|
||||
_IMG_OUTPUT_FORMAT = "webp"
|
||||
_IMG_OUTPUT_DIR = "tool-output"
|
||||
_IMG_GEN_API = "https://api.stability.ai/v2beta/stable-image/generate/core"
|
||||
|
||||
def __init__(self, api_key: str = None):
|
||||
if not api_key:
|
||||
api_key = os.getenv("STABILITY_API_KEY")
|
||||
self._api_key = api_key
|
||||
self.fileserver_url_prefix = os.getenv("FILESERVER_URL_PREFIX")
|
||||
if self._api_key is None:
|
||||
raise ValueError(
|
||||
"STABILITY_API_KEY key is required to run image generator. Get it here: https://platform.stability.ai/account/keys"
|
||||
)
|
||||
if self.fileserver_url_prefix is None:
|
||||
raise ValueError("FILESERVER_URL_PREFIX is required.")
|
||||
|
||||
def _prepare_output_dir(self):
|
||||
"""
|
||||
Create the output directory if it doesn't exist
|
||||
"""
|
||||
if not os.path.exists(self._IMG_OUTPUT_DIR):
|
||||
os.makedirs(self._IMG_OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
def _save_image(self, image_data: bytes):
|
||||
self._prepare_output_dir()
|
||||
filename = f"{uuid.uuid4()}.{self._IMG_OUTPUT_FORMAT}"
|
||||
output_path = os.path.join(self._IMG_OUTPUT_DIR, filename)
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(image_data)
|
||||
url = f"{os.getenv('FILESERVER_URL_PREFIX')}/{self._IMG_OUTPUT_DIR}/{filename}"
|
||||
logger.info(f"Saved image to {output_path}.\nURL: {url}")
|
||||
return url
|
||||
|
||||
def _call_stability_api(self, prompt: str):
|
||||
headers = {
|
||||
"authorization": f"Bearer {self._api_key}",
|
||||
"accept": "image/*",
|
||||
}
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"output_format": self._IMG_OUTPUT_FORMAT,
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
self._IMG_GEN_API,
|
||||
headers=headers,
|
||||
files={"none": ""},
|
||||
data=data,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return response
|
||||
|
||||
def generate_image(self, prompt: str) -> ImageGeneratorToolOutput:
|
||||
"""
|
||||
Use this tool to generate an image based on the prompt.
|
||||
Args:
|
||||
prompt (str): The prompt to generate the image from.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Call the Stability API
|
||||
response = self._call_stability_api(prompt)
|
||||
|
||||
# Save the image and get the URL
|
||||
image_url = self._save_image(response.content)
|
||||
|
||||
return ImageGeneratorToolOutput(
|
||||
is_success=True,
|
||||
image_url=image_url,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(e, exc_info=True)
|
||||
return ImageGeneratorToolOutput(
|
||||
is_success=False,
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
|
||||
def get_tools(**kwargs):
|
||||
return [FunctionTool.from_defaults(ImageGeneratorTool(**kwargs).generate_image)]
|
||||
@@ -0,0 +1,143 @@
|
||||
import os
|
||||
import logging
|
||||
import base64
|
||||
import uuid
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Tuple, Dict, Optional
|
||||
from llama_index.core.tools import FunctionTool
|
||||
from e2b_code_interpreter import CodeInterpreter
|
||||
from e2b_code_interpreter.models import Logs
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InterpreterExtraResult(BaseModel):
|
||||
type: str
|
||||
content: Optional[str] = None
|
||||
filename: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
|
||||
|
||||
class E2BToolOutput(BaseModel):
|
||||
is_error: bool
|
||||
logs: Logs
|
||||
results: List[InterpreterExtraResult] = []
|
||||
|
||||
|
||||
class E2BCodeInterpreter:
|
||||
|
||||
output_dir = "tool-output"
|
||||
|
||||
def __init__(self, api_key: str = None):
|
||||
if api_key is None:
|
||||
api_key = os.getenv("E2B_API_KEY")
|
||||
filesever_url_prefix = os.getenv("FILESERVER_URL_PREFIX")
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"E2B_API_KEY key is required to run code interpreter. Get it here: https://e2b.dev/docs/getting-started/api-key"
|
||||
)
|
||||
if not filesever_url_prefix:
|
||||
raise ValueError(
|
||||
"FILESERVER_URL_PREFIX is required to display file output from sandbox"
|
||||
)
|
||||
|
||||
self.filesever_url_prefix = filesever_url_prefix
|
||||
self.interpreter = CodeInterpreter(api_key=api_key)
|
||||
|
||||
def __del__(self):
|
||||
self.interpreter.close()
|
||||
|
||||
def get_output_path(self, filename: str) -> str:
|
||||
# if output directory doesn't exist, create it
|
||||
if not os.path.exists(self.output_dir):
|
||||
os.makedirs(self.output_dir, exist_ok=True)
|
||||
return os.path.join(self.output_dir, filename)
|
||||
|
||||
def save_to_disk(self, base64_data: str, ext: str) -> Dict:
|
||||
filename = f"{uuid.uuid4()}.{ext}" # generate a unique filename
|
||||
buffer = base64.b64decode(base64_data)
|
||||
output_path = self.get_output_path(filename)
|
||||
|
||||
try:
|
||||
with open(output_path, "wb") as file:
|
||||
file.write(buffer)
|
||||
except IOError as e:
|
||||
logger.error(f"Failed to write to file {output_path}: {str(e)}")
|
||||
raise e
|
||||
|
||||
logger.info(f"Saved file to {output_path}")
|
||||
|
||||
return {
|
||||
"outputPath": output_path,
|
||||
"filename": filename,
|
||||
}
|
||||
|
||||
def get_file_url(self, filename: str) -> str:
|
||||
return f"{self.filesever_url_prefix}/{self.output_dir}/{filename}"
|
||||
|
||||
def parse_result(self, result) -> List[InterpreterExtraResult]:
|
||||
"""
|
||||
The result could include multiple formats (e.g. png, svg, etc.) but encoded in base64
|
||||
We save each result to disk and return saved file metadata (extension, filename, url)
|
||||
"""
|
||||
if not result:
|
||||
return []
|
||||
|
||||
output = []
|
||||
|
||||
try:
|
||||
formats = result.formats()
|
||||
results = [result[format] for format in formats]
|
||||
|
||||
for ext, data in zip(formats, results):
|
||||
match ext:
|
||||
case "png" | "svg" | "jpeg" | "pdf":
|
||||
result = self.save_to_disk(data, ext)
|
||||
filename = result["filename"]
|
||||
output.append(
|
||||
InterpreterExtraResult(
|
||||
type=ext,
|
||||
filename=filename,
|
||||
url=self.get_file_url(filename),
|
||||
)
|
||||
)
|
||||
case _:
|
||||
output.append(
|
||||
InterpreterExtraResult(
|
||||
type=ext,
|
||||
content=data,
|
||||
)
|
||||
)
|
||||
except Exception as error:
|
||||
logger.exception(error, exc_info=True)
|
||||
logger.error("Error when parsing output from E2b interpreter tool", error)
|
||||
|
||||
return output
|
||||
|
||||
def interpret(self, code: str) -> E2BToolOutput:
|
||||
"""
|
||||
Execute python code in a Jupyter notebook cell, the toll will return result, stdout, stderr, display_data, and error.
|
||||
|
||||
Parameters:
|
||||
code (str): The python code to be executed in a single cell.
|
||||
"""
|
||||
logger.info(
|
||||
f"\n{'='*50}\n> Running following AI-generated code:\n{code}\n{'='*50}"
|
||||
)
|
||||
exec = self.interpreter.notebook.exec_cell(code)
|
||||
|
||||
if exec.error:
|
||||
logger.error("Error when executing code", exec.error)
|
||||
output = E2BToolOutput(is_error=True, logs=exec.logs, results=[])
|
||||
else:
|
||||
if len(exec.results) == 0:
|
||||
output = E2BToolOutput(is_error=False, logs=exec.logs, results=[])
|
||||
else:
|
||||
results = self.parse_result(exec.results[0])
|
||||
output = E2BToolOutput(is_error=False, logs=exec.logs, results=results)
|
||||
return output
|
||||
|
||||
|
||||
def get_tools(**kwargs):
|
||||
return [FunctionTool.from_defaults(E2BCodeInterpreter(**kwargs).interpret)]
|
||||
@@ -0,0 +1,78 @@
|
||||
from typing import Dict, List, Tuple
|
||||
from llama_index.tools.openapi import OpenAPIToolSpec
|
||||
from llama_index.tools.requests import RequestsToolSpec
|
||||
|
||||
|
||||
class OpenAPIActionToolSpec(OpenAPIToolSpec, RequestsToolSpec):
|
||||
"""
|
||||
A combination of OpenAPI and Requests tool specs that can parse OpenAPI specs and make requests.
|
||||
|
||||
openapi_uri: str: The file path or URL to the OpenAPI spec.
|
||||
domain_headers: dict: Whitelist domains and the headers to use.
|
||||
"""
|
||||
|
||||
spec_functions = OpenAPIToolSpec.spec_functions + RequestsToolSpec.spec_functions
|
||||
# Cached parsed specs by URI
|
||||
_specs: Dict[str, Tuple[Dict, List[str]]] = {}
|
||||
|
||||
def __init__(self, openapi_uri: str, domain_headers: dict = None, **kwargs):
|
||||
if domain_headers is None:
|
||||
domain_headers = {}
|
||||
if openapi_uri not in self._specs:
|
||||
openapi_spec, servers = self._load_openapi_spec(openapi_uri)
|
||||
self._specs[openapi_uri] = (openapi_spec, servers)
|
||||
else:
|
||||
openapi_spec, servers = self._specs[openapi_uri]
|
||||
|
||||
# Add the servers to the domain headers if they are not already present
|
||||
for server in servers:
|
||||
if server not in domain_headers:
|
||||
domain_headers[server] = {}
|
||||
|
||||
OpenAPIToolSpec.__init__(self, spec=openapi_spec)
|
||||
RequestsToolSpec.__init__(self, domain_headers)
|
||||
|
||||
@staticmethod
|
||||
def _load_openapi_spec(uri: str) -> Tuple[Dict, List[str]]:
|
||||
"""
|
||||
Load an OpenAPI spec from a URI.
|
||||
|
||||
Args:
|
||||
uri (str): A file path or URL to the OpenAPI spec.
|
||||
|
||||
Returns:
|
||||
List[Document]: A list of Document objects.
|
||||
"""
|
||||
import yaml
|
||||
from urllib.parse import urlparse
|
||||
|
||||
if uri.startswith("http"):
|
||||
import requests
|
||||
|
||||
response = requests.get(uri)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(
|
||||
"Could not initialize OpenAPIActionToolSpec: "
|
||||
f"Failed to load OpenAPI spec from {uri}, status code: {response.status_code}"
|
||||
)
|
||||
spec = yaml.safe_load(response.text)
|
||||
elif uri.startswith("file"):
|
||||
filepath = urlparse(uri).path
|
||||
with open(filepath, "r") as file:
|
||||
spec = yaml.safe_load(file)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Could not initialize OpenAPIActionToolSpec: Invalid OpenAPI URI provided. "
|
||||
"Only HTTP and file path are supported."
|
||||
)
|
||||
# Add the servers to the whitelist
|
||||
try:
|
||||
servers = [
|
||||
urlparse(server["url"]).netloc for server in spec.get("servers", [])
|
||||
]
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
"Could not initialize OpenAPIActionToolSpec: Invalid OpenAPI spec provided. "
|
||||
"Could not get `servers` from the spec."
|
||||
) from e
|
||||
return spec, servers
|
||||
@@ -0,0 +1,73 @@
|
||||
"""Open Meteo weather map tool spec."""
|
||||
|
||||
import logging
|
||||
import requests
|
||||
import pytz
|
||||
from llama_index.core.tools import FunctionTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenMeteoWeather:
|
||||
geo_api = "https://geocoding-api.open-meteo.com/v1"
|
||||
weather_api = "https://api.open-meteo.com/v1"
|
||||
|
||||
@classmethod
|
||||
def _get_geo_location(cls, location: str) -> dict:
|
||||
"""Get geo location from location name."""
|
||||
params = {"name": location, "count": 10, "language": "en", "format": "json"}
|
||||
response = requests.get(f"{cls.geo_api}/search", params=params)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to fetch geo location: {response.status_code}")
|
||||
else:
|
||||
data = response.json()
|
||||
result = data["results"][0]
|
||||
geo_location = {
|
||||
"id": result["id"],
|
||||
"name": result["name"],
|
||||
"latitude": result["latitude"],
|
||||
"longitude": result["longitude"],
|
||||
}
|
||||
return geo_location
|
||||
|
||||
@classmethod
|
||||
def get_weather_information(cls, location: str) -> dict:
|
||||
"""Use this function to get the weather of any given location.
|
||||
Note that the weather code should follow WMO Weather interpretation codes (WW):
|
||||
0: Clear sky
|
||||
1, 2, 3: Mainly clear, partly cloudy, and overcast
|
||||
45, 48: Fog and depositing rime fog
|
||||
51, 53, 55: Drizzle: Light, moderate, and dense intensity
|
||||
56, 57: Freezing Drizzle: Light and dense intensity
|
||||
61, 63, 65: Rain: Slight, moderate and heavy intensity
|
||||
66, 67: Freezing Rain: Light and heavy intensity
|
||||
71, 73, 75: Snow fall: Slight, moderate, and heavy intensity
|
||||
77: Snow grains
|
||||
80, 81, 82: Rain showers: Slight, moderate, and violent
|
||||
85, 86: Snow showers slight and heavy
|
||||
95: Thunderstorm: Slight or moderate
|
||||
96, 99: Thunderstorm with slight and heavy hail
|
||||
"""
|
||||
logger.info(
|
||||
f"Calling open-meteo api to get weather information of location: {location}"
|
||||
)
|
||||
geo_location = cls._get_geo_location(location)
|
||||
timezone = pytz.timezone("UTC").zone
|
||||
params = {
|
||||
"latitude": geo_location["latitude"],
|
||||
"longitude": geo_location["longitude"],
|
||||
"current": "temperature_2m,weather_code",
|
||||
"hourly": "temperature_2m,weather_code",
|
||||
"daily": "weather_code",
|
||||
"timezone": timezone,
|
||||
}
|
||||
response = requests.get(f"{cls.weather_api}/forecast", params=params)
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Failed to fetch weather information: {response.status_code}"
|
||||
)
|
||||
return response.json()
|
||||
|
||||
|
||||
def get_tools(**kwargs):
|
||||
return [FunctionTool.from_defaults(OpenMeteoWeather.get_weather_information)]
|
||||
@@ -1,12 +1,11 @@
|
||||
import { BaseTool, OpenAIAgent, QueryEngineTool } from "llamaindex";
|
||||
import { ToolsFactory } from "llamaindex/tools/ToolsFactory";
|
||||
import { BaseToolWithCall, OpenAIAgent, QueryEngineTool } from "llamaindex";
|
||||
import fs from "node:fs/promises";
|
||||
import path from "node:path";
|
||||
import { getDataSource } from "./index";
|
||||
import { STORAGE_CACHE_DIR } from "./shared";
|
||||
import { createTools } from "./tools";
|
||||
|
||||
export async function createChatEngine() {
|
||||
let tools: BaseTool[] = [];
|
||||
const tools: BaseToolWithCall[] = [];
|
||||
|
||||
// Add a query engine tool if we have a data source
|
||||
// Delete this code if you don't have a data source
|
||||
@@ -17,21 +16,26 @@ export async function createChatEngine() {
|
||||
queryEngine: index.asQueryEngine(),
|
||||
metadata: {
|
||||
name: "data_query_engine",
|
||||
description: `A query engine for documents in storage folder: ${STORAGE_CACHE_DIR}`,
|
||||
description: `A query engine for documents from your data source.`,
|
||||
},
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
const configFile = path.join("config", "tools.json");
|
||||
let toolConfig: any;
|
||||
try {
|
||||
// add tools from config file if it exists
|
||||
const config = JSON.parse(
|
||||
await fs.readFile(path.join("config", "tools.json"), "utf8"),
|
||||
);
|
||||
tools = tools.concat(await ToolsFactory.createTools(config));
|
||||
} catch {}
|
||||
toolConfig = JSON.parse(await fs.readFile(configFile, "utf8"));
|
||||
} catch (e) {
|
||||
console.info(`Could not read ${configFile} file. Using no tools.`);
|
||||
}
|
||||
if (toolConfig) {
|
||||
tools.push(...(await createTools(toolConfig)));
|
||||
}
|
||||
|
||||
return new OpenAIAgent({
|
||||
tools,
|
||||
systemPrompt: process.env.SYSTEM_PROMPT,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
import { JSONSchemaType } from "ajv";
|
||||
import { search } from "duck-duck-scrape";
|
||||
import { BaseTool, ToolMetadata } from "llamaindex";
|
||||
|
||||
export type DuckDuckGoParameter = {
|
||||
query: string;
|
||||
region?: string;
|
||||
};
|
||||
|
||||
export type DuckDuckGoToolParams = {
|
||||
metadata?: ToolMetadata<JSONSchemaType<DuckDuckGoParameter>>;
|
||||
};
|
||||
|
||||
const DEFAULT_META_DATA: ToolMetadata<JSONSchemaType<DuckDuckGoParameter>> = {
|
||||
name: "duckduckgo",
|
||||
description: "Use this function to search for any query in DuckDuckGo.",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
query: {
|
||||
type: "string",
|
||||
description: "The query to search in DuckDuckGo.",
|
||||
},
|
||||
region: {
|
||||
type: "string",
|
||||
description:
|
||||
"Optional, The region to be used for the search in [country-language] convention, ex us-en, uk-en, ru-ru, etc...",
|
||||
nullable: true,
|
||||
},
|
||||
},
|
||||
required: ["query"],
|
||||
},
|
||||
};
|
||||
|
||||
type DuckDuckGoSearchResult = {
|
||||
title: string;
|
||||
description: string;
|
||||
url: string;
|
||||
};
|
||||
|
||||
export class DuckDuckGoSearchTool implements BaseTool<DuckDuckGoParameter> {
|
||||
metadata: ToolMetadata<JSONSchemaType<DuckDuckGoParameter>>;
|
||||
|
||||
constructor(params: DuckDuckGoToolParams) {
|
||||
this.metadata = params.metadata ?? DEFAULT_META_DATA;
|
||||
}
|
||||
|
||||
async call(input: DuckDuckGoParameter) {
|
||||
const { query, region } = input;
|
||||
const options = region ? { region } : {};
|
||||
const searchResults = await search(query, options);
|
||||
|
||||
return searchResults.results.map((result) => {
|
||||
return {
|
||||
title: result.title,
|
||||
description: result.description,
|
||||
url: result.url,
|
||||
} as DuckDuckGoSearchResult;
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,112 @@
|
||||
import type { JSONSchemaType } from "ajv";
|
||||
import { FormData } from "formdata-node";
|
||||
import fs from "fs";
|
||||
import got from "got";
|
||||
import { BaseTool, ToolMetadata } from "llamaindex";
|
||||
import path from "node:path";
|
||||
import { Readable } from "stream";
|
||||
|
||||
export type ImgGeneratorParameter = {
|
||||
prompt: string;
|
||||
};
|
||||
|
||||
export type ImgGeneratorToolParams = {
|
||||
metadata?: ToolMetadata<JSONSchemaType<ImgGeneratorParameter>>;
|
||||
};
|
||||
|
||||
export type ImgGeneratorToolOutput = {
|
||||
isSuccess: boolean;
|
||||
imageUrl?: string;
|
||||
errorMessage?: string;
|
||||
};
|
||||
|
||||
const DEFAULT_META_DATA: ToolMetadata<JSONSchemaType<ImgGeneratorParameter>> = {
|
||||
name: "image_generator",
|
||||
description: `Use this function to generate an image based on the prompt.`,
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
prompt: {
|
||||
type: "string",
|
||||
description: "The prompt to generate the image",
|
||||
},
|
||||
},
|
||||
required: ["prompt"],
|
||||
},
|
||||
};
|
||||
|
||||
export class ImgGeneratorTool implements BaseTool<ImgGeneratorParameter> {
|
||||
readonly IMG_OUTPUT_FORMAT = "webp";
|
||||
readonly IMG_OUTPUT_DIR = "tool-output";
|
||||
readonly IMG_GEN_API =
|
||||
"https://api.stability.ai/v2beta/stable-image/generate/core";
|
||||
|
||||
metadata: ToolMetadata<JSONSchemaType<ImgGeneratorParameter>>;
|
||||
|
||||
constructor(params?: ImgGeneratorToolParams) {
|
||||
this.checkRequiredEnvVars();
|
||||
this.metadata = params?.metadata || DEFAULT_META_DATA;
|
||||
}
|
||||
|
||||
async call(input: ImgGeneratorParameter): Promise<ImgGeneratorToolOutput> {
|
||||
return await this.generateImage(input.prompt);
|
||||
}
|
||||
|
||||
private generateImage = async (
|
||||
prompt: string,
|
||||
): Promise<ImgGeneratorToolOutput> => {
|
||||
try {
|
||||
const buffer = await this.promptToImgBuffer(prompt);
|
||||
const imageUrl = this.saveImage(buffer);
|
||||
return { isSuccess: true, imageUrl };
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
return {
|
||||
isSuccess: false,
|
||||
errorMessage: "Failed to generate image. Please try again.",
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
private promptToImgBuffer = async (prompt: string) => {
|
||||
const form = new FormData();
|
||||
form.append("prompt", prompt);
|
||||
form.append("output_format", this.IMG_OUTPUT_FORMAT);
|
||||
const buffer = await got
|
||||
.post(this.IMG_GEN_API, {
|
||||
// Not sure why it shows an type error when passing form to body
|
||||
// Although I follow document: https://github.com/sindresorhus/got/blob/main/documentation/2-options.md#body
|
||||
// Tt still works fine, so I make casting to unknown to avoid the typescript warning
|
||||
// Found a similar issue: https://github.com/sindresorhus/got/discussions/1877
|
||||
body: form as unknown as Buffer | Readable | string,
|
||||
headers: {
|
||||
Authorization: `Bearer ${process.env.STABILITY_API_KEY}`,
|
||||
Accept: "image/*",
|
||||
},
|
||||
})
|
||||
.buffer();
|
||||
return buffer;
|
||||
};
|
||||
|
||||
private saveImage = (buffer: Buffer) => {
|
||||
const filename = `${crypto.randomUUID()}.${this.IMG_OUTPUT_FORMAT}`;
|
||||
const outputPath = path.join(this.IMG_OUTPUT_DIR, filename);
|
||||
fs.writeFileSync(outputPath, buffer);
|
||||
const url = `${process.env.FILESERVER_URL_PREFIX}/${this.IMG_OUTPUT_DIR}/${filename}`;
|
||||
console.log(`Saved image to ${outputPath}.\nURL: ${url}`);
|
||||
return url;
|
||||
};
|
||||
|
||||
private checkRequiredEnvVars = () => {
|
||||
if (!process.env.STABILITY_API_KEY) {
|
||||
throw new Error(
|
||||
"STABILITY_API_KEY key is required to run image generator. Get it here: https://platform.stability.ai/account/keys",
|
||||
);
|
||||
}
|
||||
if (!process.env.FILESERVER_URL_PREFIX) {
|
||||
throw new Error(
|
||||
"FILESERVER_URL_PREFIX is required to display file output after generation",
|
||||
);
|
||||
}
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
import { BaseToolWithCall } from "llamaindex";
|
||||
import { ToolsFactory } from "llamaindex/tools/ToolsFactory";
|
||||
import { DuckDuckGoSearchTool, DuckDuckGoToolParams } from "./duckduckgo";
|
||||
import { ImgGeneratorTool, ImgGeneratorToolParams } from "./img-gen";
|
||||
import { InterpreterTool, InterpreterToolParams } from "./interpreter";
|
||||
import { OpenAPIActionTool } from "./openapi-action";
|
||||
import { WeatherTool, WeatherToolParams } from "./weather";
|
||||
|
||||
type ToolCreator = (config: unknown) => Promise<BaseToolWithCall[]>;
|
||||
|
||||
export async function createTools(toolConfig: {
|
||||
local: Record<string, unknown>;
|
||||
llamahub: any;
|
||||
}): Promise<BaseToolWithCall[]> {
|
||||
// add local tools from the 'tools' folder (if configured)
|
||||
const tools = await createLocalTools(toolConfig.local);
|
||||
// add tools from LlamaIndexTS (if configured)
|
||||
tools.push(...(await ToolsFactory.createTools(toolConfig.llamahub)));
|
||||
return tools;
|
||||
}
|
||||
|
||||
const toolFactory: Record<string, ToolCreator> = {
|
||||
weather: async (config: unknown) => {
|
||||
return [new WeatherTool(config as WeatherToolParams)];
|
||||
},
|
||||
interpreter: async (config: unknown) => {
|
||||
return [new InterpreterTool(config as InterpreterToolParams)];
|
||||
},
|
||||
"openapi_action.OpenAPIActionToolSpec": async (config: unknown) => {
|
||||
const { openapi_uri, domain_headers } = config as {
|
||||
openapi_uri: string;
|
||||
domain_headers: Record<string, Record<string, string>>;
|
||||
};
|
||||
const openAPIActionTool = new OpenAPIActionTool(
|
||||
openapi_uri,
|
||||
domain_headers,
|
||||
);
|
||||
return await openAPIActionTool.toToolFunctions();
|
||||
},
|
||||
duckduckgo: async (config: unknown) => {
|
||||
return [new DuckDuckGoSearchTool(config as DuckDuckGoToolParams)];
|
||||
},
|
||||
img_gen: async (config: unknown) => {
|
||||
return [new ImgGeneratorTool(config as ImgGeneratorToolParams)];
|
||||
},
|
||||
};
|
||||
|
||||
async function createLocalTools(
|
||||
localConfig: Record<string, unknown>,
|
||||
): Promise<BaseToolWithCall[]> {
|
||||
const tools: BaseToolWithCall[] = [];
|
||||
|
||||
for (const [key, toolConfig] of Object.entries(localConfig)) {
|
||||
if (key in toolFactory) {
|
||||
const newTools = await toolFactory[key](toolConfig);
|
||||
tools.push(...newTools);
|
||||
}
|
||||
}
|
||||
|
||||
return tools;
|
||||
}
|
||||
@@ -0,0 +1,189 @@
|
||||
import { CodeInterpreter, Logs, Result } from "@e2b/code-interpreter";
|
||||
import type { JSONSchemaType } from "ajv";
|
||||
import fs from "fs";
|
||||
import { BaseTool, ToolMetadata } from "llamaindex";
|
||||
import crypto from "node:crypto";
|
||||
import path from "node:path";
|
||||
|
||||
export type InterpreterParameter = {
|
||||
code: string;
|
||||
};
|
||||
|
||||
export type InterpreterToolParams = {
|
||||
metadata?: ToolMetadata<JSONSchemaType<InterpreterParameter>>;
|
||||
apiKey?: string;
|
||||
fileServerURLPrefix?: string;
|
||||
};
|
||||
|
||||
export type InterpreterToolOutput = {
|
||||
isError: boolean;
|
||||
logs: Logs;
|
||||
extraResult: InterpreterExtraResult[];
|
||||
};
|
||||
|
||||
type InterpreterExtraType =
|
||||
| "html"
|
||||
| "markdown"
|
||||
| "svg"
|
||||
| "png"
|
||||
| "jpeg"
|
||||
| "pdf"
|
||||
| "latex"
|
||||
| "json"
|
||||
| "javascript";
|
||||
|
||||
export type InterpreterExtraResult = {
|
||||
type: InterpreterExtraType;
|
||||
content?: string;
|
||||
filename?: string;
|
||||
url?: string;
|
||||
};
|
||||
|
||||
const DEFAULT_META_DATA: ToolMetadata<JSONSchemaType<InterpreterParameter>> = {
|
||||
name: "interpreter",
|
||||
description:
|
||||
"Execute python code in a Jupyter notebook cell and return any result, stdout, stderr, display_data, and error.",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
code: {
|
||||
type: "string",
|
||||
description: "The python code to execute in a single cell.",
|
||||
},
|
||||
},
|
||||
required: ["code"],
|
||||
},
|
||||
};
|
||||
|
||||
export class InterpreterTool implements BaseTool<InterpreterParameter> {
|
||||
private readonly outputDir = "tool-output";
|
||||
private apiKey?: string;
|
||||
private fileServerURLPrefix?: string;
|
||||
metadata: ToolMetadata<JSONSchemaType<InterpreterParameter>>;
|
||||
codeInterpreter?: CodeInterpreter;
|
||||
|
||||
constructor(params?: InterpreterToolParams) {
|
||||
this.metadata = params?.metadata || DEFAULT_META_DATA;
|
||||
this.apiKey = params?.apiKey || process.env.E2B_API_KEY;
|
||||
this.fileServerURLPrefix =
|
||||
params?.fileServerURLPrefix || process.env.FILESERVER_URL_PREFIX;
|
||||
|
||||
if (!this.apiKey) {
|
||||
throw new Error(
|
||||
"E2B_API_KEY key is required to run code interpreter. Get it here: https://e2b.dev/docs/getting-started/api-key",
|
||||
);
|
||||
}
|
||||
if (!this.fileServerURLPrefix) {
|
||||
throw new Error(
|
||||
"FILESERVER_URL_PREFIX is required to display file output from sandbox",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
public async initInterpreter() {
|
||||
if (!this.codeInterpreter) {
|
||||
this.codeInterpreter = await CodeInterpreter.create({
|
||||
apiKey: this.apiKey,
|
||||
});
|
||||
}
|
||||
return this.codeInterpreter;
|
||||
}
|
||||
|
||||
public async codeInterpret(code: string): Promise<InterpreterToolOutput> {
|
||||
console.log(
|
||||
`\n${"=".repeat(50)}\n> Running following AI-generated code:\n${code}\n${"=".repeat(50)}`,
|
||||
);
|
||||
const interpreter = await this.initInterpreter();
|
||||
const exec = await interpreter.notebook.execCell(code);
|
||||
if (exec.error) console.error("[Code Interpreter error]", exec.error);
|
||||
const extraResult = await this.getExtraResult(exec.results[0]);
|
||||
const result: InterpreterToolOutput = {
|
||||
isError: !!exec.error,
|
||||
logs: exec.logs,
|
||||
extraResult,
|
||||
};
|
||||
return result;
|
||||
}
|
||||
|
||||
async call(input: InterpreterParameter): Promise<InterpreterToolOutput> {
|
||||
const result = await this.codeInterpret(input.code);
|
||||
return result;
|
||||
}
|
||||
|
||||
async close() {
|
||||
await this.codeInterpreter?.close();
|
||||
}
|
||||
|
||||
private async getExtraResult(
|
||||
res?: Result,
|
||||
): Promise<InterpreterExtraResult[]> {
|
||||
if (!res) return [];
|
||||
const output: InterpreterExtraResult[] = [];
|
||||
|
||||
try {
|
||||
const formats = res.formats(); // formats available for the result. Eg: ['png', ...]
|
||||
const results = formats.map((f) => res[f as keyof Result]); // get base64 data for each format
|
||||
|
||||
// save base64 data to file and return the url
|
||||
for (let i = 0; i < formats.length; i++) {
|
||||
const ext = formats[i];
|
||||
const data = results[i];
|
||||
switch (ext) {
|
||||
case "png":
|
||||
case "jpeg":
|
||||
case "svg":
|
||||
case "pdf":
|
||||
const { filename } = this.saveToDisk(data, ext);
|
||||
output.push({
|
||||
type: ext as InterpreterExtraType,
|
||||
filename,
|
||||
url: this.getFileUrl(filename),
|
||||
});
|
||||
break;
|
||||
default:
|
||||
output.push({
|
||||
type: ext as InterpreterExtraType,
|
||||
content: data,
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error when parsing e2b response", error);
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
// Consider saving to cloud storage instead but it may cost more for you
|
||||
// See: https://e2b.dev/docs/sandbox/api/filesystem#write-to-file
|
||||
private saveToDisk(
|
||||
base64Data: string,
|
||||
ext: string,
|
||||
): {
|
||||
outputPath: string;
|
||||
filename: string;
|
||||
} {
|
||||
const filename = `${crypto.randomUUID()}.${ext}`; // generate a unique filename
|
||||
const buffer = Buffer.from(base64Data, "base64");
|
||||
const outputPath = this.getOutputPath(filename);
|
||||
fs.writeFileSync(outputPath, buffer);
|
||||
console.log(`Saved file to ${outputPath}`);
|
||||
return {
|
||||
outputPath,
|
||||
filename,
|
||||
};
|
||||
}
|
||||
|
||||
private getOutputPath(filename: string): string {
|
||||
// if outputDir doesn't exist, create it
|
||||
if (!fs.existsSync(this.outputDir)) {
|
||||
fs.mkdirSync(this.outputDir, { recursive: true });
|
||||
}
|
||||
return path.join(this.outputDir, filename);
|
||||
}
|
||||
|
||||
private getFileUrl(filename: string): string {
|
||||
return `${this.fileServerURLPrefix}/${this.outputDir}/${filename}`;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
import SwaggerParser from "@apidevtools/swagger-parser";
|
||||
import { JSONSchemaType } from "ajv";
|
||||
import got from "got";
|
||||
import { FunctionTool, JSONValue, ToolMetadata } from "llamaindex";
|
||||
|
||||
interface DomainHeaders {
|
||||
[key: string]: { [header: string]: string };
|
||||
}
|
||||
|
||||
type Input = {
|
||||
url: string;
|
||||
params: object;
|
||||
};
|
||||
|
||||
type APIInfo = {
|
||||
description: string;
|
||||
title: string;
|
||||
};
|
||||
|
||||
export class OpenAPIActionTool {
|
||||
// cache the loaded specs by URL
|
||||
private static specs: Record<string, any> = {};
|
||||
|
||||
private readonly INVALID_URL_PROMPT =
|
||||
"This url did not include a hostname or scheme. Please determine the complete URL and try again.";
|
||||
|
||||
private createLoadSpecMetaData = (info: APIInfo) => {
|
||||
return {
|
||||
name: "load_openapi_spec",
|
||||
description: `Use this to retrieve the OpenAPI spec for the API named ${info.title} with the following description: ${info.description}. Call it before making any requests to the API.`,
|
||||
};
|
||||
};
|
||||
|
||||
private readonly createMethodCallMetaData = (
|
||||
method: "POST" | "PATCH" | "GET",
|
||||
info: APIInfo,
|
||||
) => {
|
||||
return {
|
||||
name: `${method.toLowerCase()}_request`,
|
||||
description: `Use this to call the ${method} method on the API named ${info.title}`,
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
url: {
|
||||
type: "string",
|
||||
description: `The url to make the ${method} request against`,
|
||||
},
|
||||
params: {
|
||||
type: "object",
|
||||
description:
|
||||
method === "GET"
|
||||
? "the URL parameters to provide with the get request"
|
||||
: `the key-value pairs to provide with the ${method} request`,
|
||||
},
|
||||
},
|
||||
required: ["url"],
|
||||
},
|
||||
} as ToolMetadata<JSONSchemaType<Input>>;
|
||||
};
|
||||
|
||||
constructor(
|
||||
public openapi_uri: string,
|
||||
public domainHeaders: DomainHeaders = {},
|
||||
) {}
|
||||
|
||||
async loadOpenapiSpec(url: string): Promise<any> {
|
||||
const api = await SwaggerParser.validate(url);
|
||||
return {
|
||||
servers: "servers" in api ? api.servers : "",
|
||||
info: { description: api.info.description, title: api.info.title },
|
||||
endpoints: api.paths,
|
||||
};
|
||||
}
|
||||
|
||||
async getRequest(input: Input): Promise<JSONValue> {
|
||||
if (!this.validUrl(input.url)) {
|
||||
return this.INVALID_URL_PROMPT;
|
||||
}
|
||||
try {
|
||||
const data = await got
|
||||
.get(input.url, {
|
||||
headers: this.getHeadersForUrl(input.url),
|
||||
searchParams: input.params as URLSearchParams,
|
||||
})
|
||||
.json();
|
||||
return data as JSONValue;
|
||||
} catch (error) {
|
||||
return error as JSONValue;
|
||||
}
|
||||
}
|
||||
|
||||
async postRequest(input: Input): Promise<JSONValue> {
|
||||
if (!this.validUrl(input.url)) {
|
||||
return this.INVALID_URL_PROMPT;
|
||||
}
|
||||
try {
|
||||
const res = await got.post(input.url, {
|
||||
headers: this.getHeadersForUrl(input.url),
|
||||
json: input.params,
|
||||
});
|
||||
return res.body as JSONValue;
|
||||
} catch (error) {
|
||||
return error as JSONValue;
|
||||
}
|
||||
}
|
||||
|
||||
async patchRequest(input: Input): Promise<JSONValue> {
|
||||
if (!this.validUrl(input.url)) {
|
||||
return this.INVALID_URL_PROMPT;
|
||||
}
|
||||
try {
|
||||
const res = await got.patch(input.url, {
|
||||
headers: this.getHeadersForUrl(input.url),
|
||||
json: input.params,
|
||||
});
|
||||
return res.body as JSONValue;
|
||||
} catch (error) {
|
||||
return error as JSONValue;
|
||||
}
|
||||
}
|
||||
|
||||
public async toToolFunctions() {
|
||||
if (!OpenAPIActionTool.specs[this.openapi_uri]) {
|
||||
console.log(`Loading spec for URL: ${this.openapi_uri}`);
|
||||
const spec = await this.loadOpenapiSpec(this.openapi_uri);
|
||||
OpenAPIActionTool.specs[this.openapi_uri] = spec;
|
||||
}
|
||||
const spec = OpenAPIActionTool.specs[this.openapi_uri];
|
||||
// TODO: read endpoints with parameters from spec and create one tool for each endpoint
|
||||
// For now, we just create a tool for each HTTP method which does not work well for passing parameters
|
||||
return [
|
||||
FunctionTool.from(() => {
|
||||
return spec;
|
||||
}, this.createLoadSpecMetaData(spec.info)),
|
||||
FunctionTool.from(
|
||||
this.getRequest.bind(this),
|
||||
this.createMethodCallMetaData("GET", spec.info),
|
||||
),
|
||||
FunctionTool.from(
|
||||
this.postRequest.bind(this),
|
||||
this.createMethodCallMetaData("POST", spec.info),
|
||||
),
|
||||
FunctionTool.from(
|
||||
this.patchRequest.bind(this),
|
||||
this.createMethodCallMetaData("PATCH", spec.info),
|
||||
),
|
||||
];
|
||||
}
|
||||
|
||||
private validUrl(url: string): boolean {
|
||||
const parsed = new URL(url);
|
||||
return !!parsed.protocol && !!parsed.hostname;
|
||||
}
|
||||
|
||||
private getDomain(url: string): string {
|
||||
const parsed = new URL(url);
|
||||
return parsed.hostname;
|
||||
}
|
||||
|
||||
private getHeadersForUrl(url: string): { [header: string]: string } {
|
||||
const domain = this.getDomain(url);
|
||||
return this.domainHeaders[domain] || {};
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
import type { JSONSchemaType } from "ajv";
|
||||
import { BaseTool, ToolMetadata } from "llamaindex";
|
||||
|
||||
interface GeoLocation {
|
||||
id: string;
|
||||
name: string;
|
||||
latitude: number;
|
||||
longitude: number;
|
||||
}
|
||||
|
||||
export type WeatherParameter = {
|
||||
location: string;
|
||||
};
|
||||
|
||||
export type WeatherToolParams = {
|
||||
metadata?: ToolMetadata<JSONSchemaType<WeatherParameter>>;
|
||||
};
|
||||
|
||||
const DEFAULT_META_DATA: ToolMetadata<JSONSchemaType<WeatherParameter>> = {
|
||||
name: "get_weather_information",
|
||||
description: `
|
||||
Use this function to get the weather of any given location.
|
||||
Note that the weather code should follow WMO Weather interpretation codes (WW):
|
||||
0: Clear sky
|
||||
1, 2, 3: Mainly clear, partly cloudy, and overcast
|
||||
45, 48: Fog and depositing rime fog
|
||||
51, 53, 55: Drizzle: Light, moderate, and dense intensity
|
||||
56, 57: Freezing Drizzle: Light and dense intensity
|
||||
61, 63, 65: Rain: Slight, moderate and heavy intensity
|
||||
66, 67: Freezing Rain: Light and heavy intensity
|
||||
71, 73, 75: Snow fall: Slight, moderate, and heavy intensity
|
||||
77: Snow grains
|
||||
80, 81, 82: Rain showers: Slight, moderate, and violent
|
||||
85, 86: Snow showers slight and heavy
|
||||
95: Thunderstorm: Slight or moderate
|
||||
96, 99: Thunderstorm with slight and heavy hail
|
||||
`,
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
location: {
|
||||
type: "string",
|
||||
description: "The location to get the weather information",
|
||||
},
|
||||
},
|
||||
required: ["location"],
|
||||
},
|
||||
};
|
||||
|
||||
export class WeatherTool implements BaseTool<WeatherParameter> {
|
||||
metadata: ToolMetadata<JSONSchemaType<WeatherParameter>>;
|
||||
|
||||
private getGeoLocation = async (location: string): Promise<GeoLocation> => {
|
||||
const apiUrl = `https://geocoding-api.open-meteo.com/v1/search?name=${location}&count=10&language=en&format=json`;
|
||||
const response = await fetch(apiUrl);
|
||||
const data = await response.json();
|
||||
const { id, name, latitude, longitude } = data.results[0];
|
||||
return { id, name, latitude, longitude };
|
||||
};
|
||||
|
||||
private getWeatherByLocation = async (location: string) => {
|
||||
console.log(
|
||||
"Calling open-meteo api to get weather information of location:",
|
||||
location,
|
||||
);
|
||||
const { latitude, longitude } = await this.getGeoLocation(location);
|
||||
const timezone = Intl.DateTimeFormat().resolvedOptions().timeZone;
|
||||
const apiUrl = `https://api.open-meteo.com/v1/forecast?latitude=${latitude}&longitude=${longitude}¤t=temperature_2m,weather_code&hourly=temperature_2m,weather_code&daily=weather_code&timezone=${timezone}`;
|
||||
const response = await fetch(apiUrl);
|
||||
const data = await response.json();
|
||||
return data;
|
||||
};
|
||||
|
||||
constructor(params?: WeatherToolParams) {
|
||||
this.metadata = params?.metadata || DEFAULT_META_DATA;
|
||||
}
|
||||
|
||||
async call(input: WeatherParameter) {
|
||||
return await this.getWeatherByLocation(input.location);
|
||||
}
|
||||
}
|
||||
@@ -8,13 +8,14 @@ export async function createChatEngine() {
|
||||
`StorageContext is empty - call 'npm run generate' to generate the storage first`,
|
||||
);
|
||||
}
|
||||
const retriever = index.asRetriever();
|
||||
retriever.similarityTopK = process.env.TOP_K
|
||||
? parseInt(process.env.TOP_K)
|
||||
: 3;
|
||||
const retriever = index.asRetriever({
|
||||
similarityTopK: process.env.TOP_K ? parseInt(process.env.TOP_K) : 3,
|
||||
});
|
||||
|
||||
return new ContextChatEngine({
|
||||
chatModel: Settings.llm,
|
||||
retriever,
|
||||
// disable as a custom system prompt disables the generated context
|
||||
// systemPrompt: process.env.SYSTEM_PROMPT,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import os
|
||||
import logging
|
||||
from llama_parse import LlamaParse
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileLoaderConfig(BaseModel):
|
||||
data_dir: str = "data"
|
||||
@@ -20,15 +23,44 @@ def llama_parse_parser():
|
||||
"LLAMA_CLOUD_API_KEY environment variable is not set. "
|
||||
"Please set it in .env file or in your shell environment then run again!"
|
||||
)
|
||||
parser = LlamaParse(result_type="markdown", verbose=True, language="en")
|
||||
parser = LlamaParse(
|
||||
result_type="markdown",
|
||||
verbose=True,
|
||||
language="en",
|
||||
ignore_errors=False,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def get_file_documents(config: FileLoaderConfig):
|
||||
from llama_index.core.readers import SimpleDirectoryReader
|
||||
|
||||
reader = SimpleDirectoryReader(config.data_dir, recursive=True, filename_as_id=True)
|
||||
if config.use_llama_parse:
|
||||
parser = llama_parse_parser()
|
||||
reader.file_extractor = {".pdf": parser}
|
||||
return reader.load_data()
|
||||
try:
|
||||
reader = SimpleDirectoryReader(
|
||||
config.data_dir, recursive=True, filename_as_id=True, raise_on_error=True
|
||||
)
|
||||
if config.use_llama_parse:
|
||||
# LlamaParse is async first,
|
||||
# so we need to use nest_asyncio to run it in sync mode
|
||||
import nest_asyncio
|
||||
|
||||
nest_asyncio.apply()
|
||||
|
||||
parser = llama_parse_parser()
|
||||
reader.file_extractor = {".pdf": parser}
|
||||
return reader.load_data()
|
||||
except Exception as e:
|
||||
import sys, traceback
|
||||
|
||||
# Catch the error if the data dir is empty
|
||||
# and return as empty document list
|
||||
_, _, exc_traceback = sys.exc_info()
|
||||
function_name = traceback.extract_tb(exc_traceback)[-1].name
|
||||
if function_name == "_add_files":
|
||||
logger.warning(
|
||||
f"Failed to load file documents, error message: {e} . Return as empty document list."
|
||||
)
|
||||
return []
|
||||
else:
|
||||
# Raise the error if it is not the case of empty data dir
|
||||
raise e
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import { Message } from "./chat-messages";
|
||||
|
||||
export interface ChatInputProps {
|
||||
/** The current value of the input */
|
||||
input?: string;
|
||||
@@ -12,7 +14,8 @@ export interface ChatInputProps {
|
||||
/** Form submission handler to automatically reset input and append a user message */
|
||||
handleSubmit: (e: React.FormEvent<HTMLFormElement>) => void;
|
||||
isLoading: boolean;
|
||||
multiModal?: boolean;
|
||||
messages: Message[];
|
||||
setInput?: (input: string) => void;
|
||||
}
|
||||
|
||||
export default function ChatInput(props: ChatInputProps) {
|
||||
|
||||
@@ -19,8 +19,12 @@ export default function ChatMessages({
|
||||
isLoading?: boolean;
|
||||
stop?: () => void;
|
||||
reload?: () => void;
|
||||
append?: (
|
||||
message: Message | Omit<Message, "id">,
|
||||
) => Promise<string | null | undefined>;
|
||||
}) {
|
||||
const scrollableChatContainerRef = useRef<HTMLDivElement>(null);
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
|
||||
const scrollToBottom = () => {
|
||||
if (scrollableChatContainerRef.current) {
|
||||
@@ -31,14 +35,14 @@ export default function ChatMessages({
|
||||
|
||||
useEffect(() => {
|
||||
scrollToBottom();
|
||||
}, [messages.length]);
|
||||
}, [messages.length, lastMessage]);
|
||||
|
||||
return (
|
||||
<div className="w-full max-w-5xl p-4 bg-white rounded-xl shadow-xl">
|
||||
<div
|
||||
className="flex flex-col gap-5 divide-y h-[50vh] overflow-auto"
|
||||
ref={scrollableChatContainerRef}
|
||||
>
|
||||
<div
|
||||
className="flex-1 w-full max-w-5xl p-4 bg-white rounded-xl shadow-xl overflow-auto"
|
||||
ref={scrollableChatContainerRef}
|
||||
>
|
||||
<div className="flex flex-col gap-5 divide-y">
|
||||
{messages.map((m: Message) => (
|
||||
<ChatItem key={m.id} {...m} />
|
||||
))}
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
"use client";
|
||||
|
||||
import { useEffect, useMemo, useState } from "react";
|
||||
|
||||
export interface ChatConfig {
|
||||
chatAPI?: string;
|
||||
starterQuestions?: string[];
|
||||
}
|
||||
|
||||
export function useClientConfig() {
|
||||
const API_ROUTE = "/api/chat/config";
|
||||
const chatAPI = process.env.NEXT_PUBLIC_CHAT_API;
|
||||
const [config, setConfig] = useState<ChatConfig>({
|
||||
chatAPI,
|
||||
});
|
||||
|
||||
const configAPI = useMemo(() => {
|
||||
const backendOrigin = chatAPI ? new URL(chatAPI).origin : "";
|
||||
return `${backendOrigin}${API_ROUTE}`;
|
||||
}, [chatAPI]);
|
||||
|
||||
useEffect(() => {
|
||||
fetch(configAPI)
|
||||
.then((response) => response.json())
|
||||
.then((data) => setConfig({ ...data, chatAPI }))
|
||||
.catch((error) => console.error("Error fetching config", error));
|
||||
}, [chatAPI, configAPI]);
|
||||
|
||||
return config;
|
||||
}
|
||||
@@ -3,10 +3,18 @@ from llama_index.vector_stores.astra_db import AstraDBVectorStore
|
||||
|
||||
|
||||
def get_vector_store():
|
||||
endpoint = os.getenv("ASTRA_DB_ENDPOINT")
|
||||
token = os.getenv("ASTRA_DB_APPLICATION_TOKEN")
|
||||
collection = os.getenv("ASTRA_DB_COLLECTION")
|
||||
if not endpoint or not token or not collection:
|
||||
raise ValueError(
|
||||
"Please config ASTRA_DB_ENDPOINT, ASTRA_DB_APPLICATION_TOKEN and ASTRA_DB_COLLECTION"
|
||||
" to your environment variables or config them in the .env file"
|
||||
)
|
||||
store = AstraDBVectorStore(
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_ENDPOINT"],
|
||||
collection_name=os.environ["ASTRA_DB_COLLECTION"],
|
||||
embedding_dimension=int(os.environ["EMBEDDING_DIM"]),
|
||||
token=token,
|
||||
api_endpoint=endpoint,
|
||||
collection_name=collection,
|
||||
embedding_dimension=int(os.getenv("EMBEDDING_DIM")),
|
||||
)
|
||||
return store
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
import os
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
|
||||
|
||||
def get_vector_store():
|
||||
collection_name = os.getenv("CHROMA_COLLECTION", "default")
|
||||
chroma_path = os.getenv("CHROMA_PATH")
|
||||
# if CHROMA_PATH is set, use a local ChromaVectorStore from the path
|
||||
# otherwise, use a remote ChromaVectorStore (ChromaDB Cloud is not supported yet)
|
||||
if chroma_path:
|
||||
store = ChromaVectorStore.from_params(
|
||||
persist_dir=chroma_path, collection_name=collection_name
|
||||
)
|
||||
else:
|
||||
if not os.getenv("CHROMA_HOST") or not os.getenv("CHROMA_PORT"):
|
||||
raise ValueError(
|
||||
"Please provide either CHROMA_PATH or CHROMA_HOST and CHROMA_PORT"
|
||||
)
|
||||
store = ChromaVectorStore.from_params(
|
||||
host=os.getenv("CHROMA_HOST"),
|
||||
port=int(os.getenv("CHROMA_PORT")),
|
||||
collection_name=collection_name,
|
||||
)
|
||||
return store
|
||||
@@ -0,0 +1,45 @@
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
import os
|
||||
import logging
|
||||
from app.settings import init_settings
|
||||
from app.engine.loaders import get_documents
|
||||
from llama_index.indices.managed.llama_cloud import LlamaCloudIndex
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def generate_datasource():
|
||||
init_settings()
|
||||
logger.info("Generate index for the provided data")
|
||||
|
||||
name = os.getenv("LLAMA_CLOUD_INDEX_NAME")
|
||||
project_name = os.getenv("LLAMA_CLOUD_PROJECT_NAME")
|
||||
api_key = os.getenv("LLAMA_CLOUD_API_KEY")
|
||||
base_url = os.getenv("LLAMA_CLOUD_BASE_URL")
|
||||
|
||||
if name is None or project_name is None or api_key is None:
|
||||
raise ValueError(
|
||||
"Please set LLAMA_CLOUD_INDEX_NAME, LLAMA_CLOUD_PROJECT_NAME and LLAMA_CLOUD_API_KEY"
|
||||
" to your environment variables or config them in .env file"
|
||||
)
|
||||
|
||||
documents = get_documents()
|
||||
|
||||
LlamaCloudIndex.from_documents(
|
||||
documents=documents,
|
||||
name=name,
|
||||
project_name=project_name,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
)
|
||||
|
||||
logger.info("Finished generating the index")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_datasource()
|
||||
@@ -0,0 +1,28 @@
|
||||
import logging
|
||||
import os
|
||||
from llama_index.indices.managed.llama_cloud import LlamaCloudIndex
|
||||
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
def get_index():
|
||||
name = os.getenv("LLAMA_CLOUD_INDEX_NAME")
|
||||
project_name = os.getenv("LLAMA_CLOUD_PROJECT_NAME")
|
||||
api_key = os.getenv("LLAMA_CLOUD_API_KEY")
|
||||
base_url = os.getenv("LLAMA_CLOUD_BASE_URL")
|
||||
|
||||
if name is None or project_name is None or api_key is None:
|
||||
raise ValueError(
|
||||
"Please set LLAMA_CLOUD_INDEX_NAME, LLAMA_CLOUD_PROJECT_NAME and LLAMA_CLOUD_API_KEY"
|
||||
" to your environment variables or config them in .env file"
|
||||
)
|
||||
|
||||
index = LlamaCloudIndex(
|
||||
name=name,
|
||||
project_name=project_name,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
)
|
||||
|
||||
return index
|
||||
@@ -3,11 +3,18 @@ from llama_index.vector_stores.milvus import MilvusVectorStore
|
||||
|
||||
|
||||
def get_vector_store():
|
||||
address = os.getenv("MILVUS_ADDRESS")
|
||||
collection = os.getenv("MILVUS_COLLECTION")
|
||||
if not address or not collection:
|
||||
raise ValueError(
|
||||
"Please set MILVUS_ADDRESS and MILVUS_COLLECTION to your environment variables"
|
||||
" or config them in the .env file"
|
||||
)
|
||||
store = MilvusVectorStore(
|
||||
uri=os.environ["MILVUS_ADDRESS"],
|
||||
uri=address,
|
||||
user=os.getenv("MILVUS_USERNAME"),
|
||||
password=os.getenv("MILVUS_PASSWORD"),
|
||||
collection_name=os.getenv("MILVUS_COLLECTION"),
|
||||
collection_name=collection,
|
||||
dim=int(os.getenv("EMBEDDING_DIM")),
|
||||
)
|
||||
return store
|
||||
|
||||
@@ -3,9 +3,18 @@ from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch
|
||||
|
||||
|
||||
def get_vector_store():
|
||||
db_uri = os.getenv("MONGODB_URI")
|
||||
db_name = os.getenv("MONGODB_DATABASE")
|
||||
collection_name = os.getenv("MONGODB_VECTORS")
|
||||
index_name = os.getenv("MONGODB_VECTOR_INDEX")
|
||||
if not db_uri or not db_name or not collection_name or not index_name:
|
||||
raise ValueError(
|
||||
"Please set MONGODB_URI, MONGODB_DATABASE, MONGODB_VECTORS, and MONGODB_VECTOR_INDEX"
|
||||
" to your environment variables or config them in .env file"
|
||||
)
|
||||
store = MongoDBAtlasVectorSearch(
|
||||
db_name=os.environ["MONGODB_DATABASE"],
|
||||
collection_name=os.environ["MONGODB_VECTORS"],
|
||||
index_name=os.environ["MONGODB_VECTOR_INDEX"],
|
||||
db_name=db_name,
|
||||
collection_name=collection_name,
|
||||
index_name=index_name,
|
||||
)
|
||||
return store
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
import os
|
||||
import logging
|
||||
from llama_index.core.indices import (
|
||||
VectorStoreIndex,
|
||||
)
|
||||
from app.engine.loaders import get_documents
|
||||
from app.settings import init_settings
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def generate_datasource():
|
||||
init_settings()
|
||||
logger.info("Creating new index")
|
||||
storage_dir = os.environ.get("STORAGE_DIR", "storage")
|
||||
# load the documents and create the index
|
||||
documents = get_documents()
|
||||
index = VectorStoreIndex.from_documents(
|
||||
documents,
|
||||
)
|
||||
# store it for later
|
||||
index.storage_context.persist(storage_dir)
|
||||
logger.info(f"Finished creating new index. Stored in {storage_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_datasource()
|
||||
@@ -0,0 +1,30 @@
|
||||
import os
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
|
||||
from cachetools import cached, TTLCache
|
||||
from llama_index.core.storage import StorageContext
|
||||
from llama_index.core.indices import load_index_from_storage
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
@cached(
|
||||
TTLCache(maxsize=10, ttl=timedelta(minutes=5).total_seconds()),
|
||||
key=lambda *args, **kwargs: "global_storage_context",
|
||||
)
|
||||
def get_storage_context(persist_dir: str) -> StorageContext:
|
||||
return StorageContext.from_defaults(persist_dir=persist_dir)
|
||||
|
||||
|
||||
def get_index():
|
||||
storage_dir = os.getenv("STORAGE_DIR", "storage")
|
||||
# check if storage already exists
|
||||
if not os.path.exists(storage_dir):
|
||||
return None
|
||||
# load the existing index
|
||||
logger.info(f"Loading index from {storage_dir}...")
|
||||
storage_context = get_storage_context(storage_dir)
|
||||
index = load_index_from_storage(storage_context)
|
||||
logger.info(f"Finished loading index from {storage_dir}")
|
||||
return index
|
||||
@@ -1,13 +0,0 @@
|
||||
import os
|
||||
|
||||
from llama_index.core.vector_stores import SimpleVectorStore
|
||||
from app.constants import STORAGE_DIR
|
||||
|
||||
|
||||
def get_vector_store():
|
||||
if not os.path.exists(STORAGE_DIR):
|
||||
vector_store = SimpleVectorStore()
|
||||
else:
|
||||
vector_store = SimpleVectorStore.from_persist_dir(STORAGE_DIR)
|
||||
vector_store.stores_text = True
|
||||
return vector_store
|
||||
@@ -2,30 +2,36 @@ import os
|
||||
from llama_index.vector_stores.postgres import PGVectorStore
|
||||
from urllib.parse import urlparse
|
||||
|
||||
STORAGE_DIR = "storage"
|
||||
PGVECTOR_SCHEMA = "public"
|
||||
PGVECTOR_TABLE = "llamaindex_embedding"
|
||||
|
||||
vector_store: PGVectorStore = None
|
||||
|
||||
|
||||
def get_vector_store():
|
||||
original_conn_string = os.environ.get("PG_CONNECTION_STRING")
|
||||
if original_conn_string is None or original_conn_string == "":
|
||||
raise ValueError("PG_CONNECTION_STRING environment variable is not set.")
|
||||
global vector_store
|
||||
|
||||
# The PGVectorStore requires both two connection strings, one for psycopg2 and one for asyncpg
|
||||
# Update the configured scheme with the psycopg2 and asyncpg schemes
|
||||
original_scheme = urlparse(original_conn_string).scheme + "://"
|
||||
conn_string = original_conn_string.replace(
|
||||
original_scheme, "postgresql+psycopg2://"
|
||||
)
|
||||
async_conn_string = original_conn_string.replace(
|
||||
original_scheme, "postgresql+asyncpg://"
|
||||
)
|
||||
if vector_store is None:
|
||||
original_conn_string = os.environ.get("PG_CONNECTION_STRING")
|
||||
if original_conn_string is None or original_conn_string == "":
|
||||
raise ValueError("PG_CONNECTION_STRING environment variable is not set.")
|
||||
|
||||
return PGVectorStore(
|
||||
connection_string=conn_string,
|
||||
async_connection_string=async_conn_string,
|
||||
schema_name=PGVECTOR_SCHEMA,
|
||||
table_name=PGVECTOR_TABLE,
|
||||
embed_dim=int(os.environ.get("EMBEDDING_DIM", 768)),
|
||||
)
|
||||
# The PGVectorStore requires both two connection strings, one for psycopg2 and one for asyncpg
|
||||
# Update the configured scheme with the psycopg2 and asyncpg schemes
|
||||
original_scheme = urlparse(original_conn_string).scheme + "://"
|
||||
conn_string = original_conn_string.replace(
|
||||
original_scheme, "postgresql+psycopg2://"
|
||||
)
|
||||
async_conn_string = original_conn_string.replace(
|
||||
original_scheme, "postgresql+asyncpg://"
|
||||
)
|
||||
|
||||
vector_store = PGVectorStore(
|
||||
connection_string=conn_string,
|
||||
async_connection_string=async_conn_string,
|
||||
schema_name=PGVECTOR_SCHEMA,
|
||||
table_name=PGVECTOR_TABLE,
|
||||
embed_dim=int(os.environ.get("EMBEDDING_DIM", 1024)),
|
||||
)
|
||||
|
||||
return vector_store
|
||||
|
||||
@@ -3,9 +3,17 @@ from llama_index.vector_stores.pinecone import PineconeVectorStore
|
||||
|
||||
|
||||
def get_vector_store():
|
||||
api_key = os.getenv("PINECONE_API_KEY")
|
||||
index_name = os.getenv("PINECONE_INDEX_NAME")
|
||||
environment = os.getenv("PINECONE_ENVIRONMENT")
|
||||
if not api_key or not index_name or not environment:
|
||||
raise ValueError(
|
||||
"Please set PINECONE_API_KEY, PINECONE_INDEX_NAME, and PINECONE_ENVIRONMENT"
|
||||
" to your environment variables or config them in the .env file"
|
||||
)
|
||||
store = PineconeVectorStore(
|
||||
api_key=os.environ["PINECONE_API_KEY"],
|
||||
index_name=os.environ["PINECONE_INDEX_NAME"],
|
||||
environment=os.environ["PINECONE_ENVIRONMENT"],
|
||||
api_key=api_key,
|
||||
index_name=index_name,
|
||||
environment=environment,
|
||||
)
|
||||
return store
|
||||
|
||||
@@ -3,9 +3,17 @@ from llama_index.vector_stores.qdrant import QdrantVectorStore
|
||||
|
||||
|
||||
def get_vector_store():
|
||||
collection_name = os.getenv("QDRANT_COLLECTION")
|
||||
url = os.getenv("QDRANT_URL")
|
||||
api_key = os.getenv("QDRANT_API_KEY")
|
||||
if not collection_name or not url:
|
||||
raise ValueError(
|
||||
"Please set QDRANT_COLLECTION, QDRANT_URL"
|
||||
" to your environment variables or config them in the .env file"
|
||||
)
|
||||
store = QdrantVectorStore(
|
||||
collection_name=os.getenv("QDRANT_COLLECTION"),
|
||||
url=os.getenv("QDRANT_URL"),
|
||||
api_key=os.getenv("QDRANT_API_KEY"),
|
||||
collection_name=collection_name,
|
||||
url=url,
|
||||
api_key=api_key,
|
||||
)
|
||||
return store
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
/* eslint-disable turbo/no-undeclared-env-vars */
|
||||
import * as dotenv from "dotenv";
|
||||
import {
|
||||
AstraDBVectorStore,
|
||||
VectorStoreIndex,
|
||||
storageContextFromDefaults,
|
||||
} from "llamaindex";
|
||||
import { VectorStoreIndex, storageContextFromDefaults } from "llamaindex";
|
||||
import { AstraDBVectorStore } from "llamaindex/storage/vectorStore/AstraDBVectorStore";
|
||||
import { getDocuments } from "./loader";
|
||||
import { initSettings } from "./settings";
|
||||
import { checkRequiredEnvVars } from "./shared";
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
/* eslint-disable turbo/no-undeclared-env-vars */
|
||||
import { AstraDBVectorStore, VectorStoreIndex } from "llamaindex";
|
||||
import { VectorStoreIndex } from "llamaindex";
|
||||
import { AstraDBVectorStore } from "llamaindex/storage/vectorStore/AstraDBVectorStore";
|
||||
import { checkRequiredEnvVars } from "./shared";
|
||||
|
||||
export async function getDataSource() {
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
/* eslint-disable turbo/no-undeclared-env-vars */
|
||||
import * as dotenv from "dotenv";
|
||||
import { VectorStoreIndex, storageContextFromDefaults } from "llamaindex";
|
||||
import { ChromaVectorStore } from "llamaindex/storage/vectorStore/ChromaVectorStore";
|
||||
import { getDocuments } from "./loader";
|
||||
import { initSettings } from "./settings";
|
||||
import { checkRequiredEnvVars } from "./shared";
|
||||
|
||||
dotenv.config();
|
||||
|
||||
async function loadAndIndex() {
|
||||
// load objects from storage and convert them into LlamaIndex Document objects
|
||||
const documents = await getDocuments();
|
||||
|
||||
// create vector store
|
||||
const chromaUri = `http://${process.env.CHROMA_HOST}:${process.env.CHROMA_PORT}`;
|
||||
|
||||
const vectorStore = new ChromaVectorStore({
|
||||
collectionName: process.env.CHROMA_COLLECTION,
|
||||
chromaClientParams: { path: chromaUri },
|
||||
});
|
||||
|
||||
// create index from all the Documentss and store them in Pinecone
|
||||
console.log("Start creating embeddings...");
|
||||
const storageContext = await storageContextFromDefaults({ vectorStore });
|
||||
await VectorStoreIndex.fromDocuments(documents, { storageContext });
|
||||
console.log(
|
||||
"Successfully created embeddings and save to your ChromaDB index.",
|
||||
);
|
||||
}
|
||||
|
||||
(async () => {
|
||||
checkRequiredEnvVars();
|
||||
initSettings();
|
||||
await loadAndIndex();
|
||||
console.log("Finished generating storage.");
|
||||
})();
|
||||
@@ -0,0 +1,16 @@
|
||||
/* eslint-disable turbo/no-undeclared-env-vars */
|
||||
import { VectorStoreIndex } from "llamaindex";
|
||||
import { ChromaVectorStore } from "llamaindex/storage/vectorStore/ChromaVectorStore";
|
||||
import { checkRequiredEnvVars } from "./shared";
|
||||
|
||||
export async function getDataSource() {
|
||||
checkRequiredEnvVars();
|
||||
const chromaUri = `http://${process.env.CHROMA_HOST}:${process.env.CHROMA_PORT}`;
|
||||
|
||||
const store = new ChromaVectorStore({
|
||||
collectionName: process.env.CHROMA_COLLECTION,
|
||||
chromaClientParams: { path: chromaUri },
|
||||
});
|
||||
|
||||
return await VectorStoreIndex.fromVectorStore(store);
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
const REQUIRED_ENV_VARS = ["CHROMA_COLLECTION", "CHROMA_HOST", "CHROMA_PORT"];
|
||||
|
||||
export function checkRequiredEnvVars() {
|
||||
const missingEnvVars = REQUIRED_ENV_VARS.filter((envVar) => {
|
||||
return !process.env[envVar];
|
||||
});
|
||||
|
||||
if (missingEnvVars.length > 0) {
|
||||
console.log(
|
||||
`The following environment variables are required but missing: ${missingEnvVars.join(
|
||||
", ",
|
||||
)}`,
|
||||
);
|
||||
throw new Error(
|
||||
`Missing environment variables: ${missingEnvVars.join(", ")}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
import * as dotenv from "dotenv";
|
||||
import { LlamaCloudIndex } from "llamaindex";
|
||||
import { getDocuments } from "./loader";
|
||||
import { initSettings } from "./settings";
|
||||
import { checkRequiredEnvVars } from "./shared";
|
||||
|
||||
dotenv.config();
|
||||
|
||||
async function loadAndIndex() {
|
||||
const documents = await getDocuments();
|
||||
await LlamaCloudIndex.fromDocuments({
|
||||
documents,
|
||||
name: process.env.LLAMA_CLOUD_INDEX_NAME!,
|
||||
projectName: process.env.LLAMA_CLOUD_PROJECT_NAME!,
|
||||
apiKey: process.env.LLAMA_CLOUD_API_KEY,
|
||||
baseUrl: process.env.LLAMA_CLOUD_BASE_URL,
|
||||
});
|
||||
console.log(`Successfully created embeddings!`);
|
||||
}
|
||||
|
||||
(async () => {
|
||||
checkRequiredEnvVars();
|
||||
initSettings();
|
||||
await loadAndIndex();
|
||||
console.log("Finished generating storage.");
|
||||
})();
|
||||
@@ -0,0 +1,13 @@
|
||||
import { LlamaCloudIndex } from "llamaindex/cloud/LlamaCloudIndex";
|
||||
import { checkRequiredEnvVars } from "./shared";
|
||||
|
||||
export async function getDataSource() {
|
||||
checkRequiredEnvVars();
|
||||
const index = new LlamaCloudIndex({
|
||||
name: process.env.LLAMA_CLOUD_INDEX_NAME!,
|
||||
projectName: process.env.LLAMA_CLOUD_PROJECT_NAME!,
|
||||
apiKey: process.env.LLAMA_CLOUD_API_KEY,
|
||||
baseUrl: process.env.LLAMA_CLOUD_BASE_URL,
|
||||
});
|
||||
return index;
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
const REQUIRED_ENV_VARS = [
|
||||
"LLAMA_CLOUD_INDEX_NAME",
|
||||
"LLAMA_CLOUD_PROJECT_NAME",
|
||||
"LLAMA_CLOUD_API_KEY",
|
||||
];
|
||||
|
||||
export function checkRequiredEnvVars() {
|
||||
const missingEnvVars = REQUIRED_ENV_VARS.filter((envVar) => {
|
||||
return !process.env[envVar];
|
||||
});
|
||||
|
||||
if (missingEnvVars.length > 0) {
|
||||
console.log(
|
||||
`The following environment variables are required but missing: ${missingEnvVars.join(
|
||||
", ",
|
||||
)}`,
|
||||
);
|
||||
throw new Error(
|
||||
`Missing environment variables: ${missingEnvVars.join(", ")}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,7 @@
|
||||
/* eslint-disable turbo/no-undeclared-env-vars */
|
||||
import * as dotenv from "dotenv";
|
||||
import {
|
||||
MilvusVectorStore,
|
||||
VectorStoreIndex,
|
||||
storageContextFromDefaults,
|
||||
} from "llamaindex";
|
||||
import { VectorStoreIndex, storageContextFromDefaults } from "llamaindex";
|
||||
import { MilvusVectorStore } from "llamaindex/storage/vectorStore/MilvusVectorStore";
|
||||
import { getDocuments } from "./loader";
|
||||
import { initSettings } from "./settings";
|
||||
import { checkRequiredEnvVars, getMilvusClient } from "./shared";
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { MilvusVectorStore, VectorStoreIndex } from "llamaindex";
|
||||
import { VectorStoreIndex } from "llamaindex";
|
||||
import { MilvusVectorStore } from "llamaindex/storage/vectorStore/MilvusVectorStore";
|
||||
import { checkRequiredEnvVars, getMilvusClient } from "./shared";
|
||||
|
||||
export async function getDataSource() {
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
/* eslint-disable turbo/no-undeclared-env-vars */
|
||||
import * as dotenv from "dotenv";
|
||||
import {
|
||||
MongoDBAtlasVectorSearch,
|
||||
VectorStoreIndex,
|
||||
storageContextFromDefaults,
|
||||
} from "llamaindex";
|
||||
import { VectorStoreIndex, storageContextFromDefaults } from "llamaindex";
|
||||
import { MongoDBAtlasVectorSearch } from "llamaindex/storage/vectorStore/MongoDBAtlasVectorSearch";
|
||||
import { MongoClient } from "mongodb";
|
||||
import { getDocuments } from "./loader";
|
||||
import { initSettings } from "./settings";
|
||||
@@ -12,7 +9,7 @@ import { checkRequiredEnvVars } from "./shared";
|
||||
|
||||
dotenv.config();
|
||||
|
||||
const mongoUri = process.env.MONGO_URI!;
|
||||
const mongoUri = process.env.MONGODB_URI!;
|
||||
const databaseName = process.env.MONGODB_DATABASE!;
|
||||
const vectorCollectionName = process.env.MONGODB_VECTORS!;
|
||||
const indexName = process.env.MONGODB_VECTOR_INDEX;
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
/* eslint-disable turbo/no-undeclared-env-vars */
|
||||
import { MongoDBAtlasVectorSearch, VectorStoreIndex } from "llamaindex";
|
||||
import { VectorStoreIndex } from "llamaindex";
|
||||
import { MongoDBAtlasVectorSearch } from "llamaindex/storage/vectorStore/MongoDBAtlasVectorSearch";
|
||||
import { MongoClient } from "mongodb";
|
||||
import { checkRequiredEnvVars } from "./shared";
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
const REQUIRED_ENV_VARS = [
|
||||
"MONGO_URI",
|
||||
"MONGODB_URI",
|
||||
"MONGODB_DATABASE",
|
||||
"MONGODB_VECTORS",
|
||||
"MONGODB_VECTOR_INDEX",
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { VectorStoreIndex, storageContextFromDefaults } from "llamaindex";
|
||||
import { VectorStoreIndex } from "llamaindex";
|
||||
import { storageContextFromDefaults } from "llamaindex/storage/StorageContext";
|
||||
|
||||
import * as dotenv from "dotenv";
|
||||
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
import {
|
||||
SimpleDocumentStore,
|
||||
storageContextFromDefaults,
|
||||
VectorStoreIndex,
|
||||
} from "llamaindex";
|
||||
import { SimpleDocumentStore, VectorStoreIndex } from "llamaindex";
|
||||
import { storageContextFromDefaults } from "llamaindex/storage/StorageContext";
|
||||
import { STORAGE_CACHE_DIR } from "./shared";
|
||||
|
||||
export async function getDataSource() {
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
/* eslint-disable turbo/no-undeclared-env-vars */
|
||||
import * as dotenv from "dotenv";
|
||||
import {
|
||||
PGVectorStore,
|
||||
VectorStoreIndex,
|
||||
storageContextFromDefaults,
|
||||
} from "llamaindex";
|
||||
import { VectorStoreIndex, storageContextFromDefaults } from "llamaindex";
|
||||
import { PGVectorStore } from "llamaindex/storage/vectorStore/PGVectorStore";
|
||||
import { getDocuments } from "./loader";
|
||||
import { initSettings } from "./settings";
|
||||
import {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
/* eslint-disable turbo/no-undeclared-env-vars */
|
||||
import { PGVectorStore, VectorStoreIndex } from "llamaindex";
|
||||
import { VectorStoreIndex } from "llamaindex";
|
||||
import { PGVectorStore } from "llamaindex/storage/vectorStore/PGVectorStore";
|
||||
import {
|
||||
PGVECTOR_SCHEMA,
|
||||
PGVECTOR_TABLE,
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
/* eslint-disable turbo/no-undeclared-env-vars */
|
||||
import * as dotenv from "dotenv";
|
||||
import {
|
||||
PineconeVectorStore,
|
||||
VectorStoreIndex,
|
||||
storageContextFromDefaults,
|
||||
} from "llamaindex";
|
||||
import { VectorStoreIndex, storageContextFromDefaults } from "llamaindex";
|
||||
import { PineconeVectorStore } from "llamaindex/storage/vectorStore/PineconeVectorStore";
|
||||
import { getDocuments } from "./loader";
|
||||
import { initSettings } from "./settings";
|
||||
import { checkRequiredEnvVars } from "./shared";
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
/* eslint-disable turbo/no-undeclared-env-vars */
|
||||
import { PineconeVectorStore, VectorStoreIndex } from "llamaindex";
|
||||
import { VectorStoreIndex } from "llamaindex";
|
||||
import { PineconeVectorStore } from "llamaindex/storage/vectorStore/PineconeVectorStore";
|
||||
import { checkRequiredEnvVars } from "./shared";
|
||||
|
||||
export async function getDataSource() {
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
/* eslint-disable turbo/no-undeclared-env-vars */
|
||||
import * as dotenv from "dotenv";
|
||||
import {
|
||||
QdrantVectorStore,
|
||||
VectorStoreIndex,
|
||||
storageContextFromDefaults,
|
||||
} from "llamaindex";
|
||||
import { VectorStoreIndex, storageContextFromDefaults } from "llamaindex";
|
||||
import { QdrantVectorStore } from "llamaindex/storage/vectorStore/QdrantVectorStore";
|
||||
import { getDocuments } from "./loader";
|
||||
import { initSettings } from "./settings";
|
||||
import { checkRequiredEnvVars, getQdrantClient } from "./shared";
|
||||
@@ -18,7 +15,10 @@ async function loadAndIndex() {
|
||||
const documents = await getDocuments();
|
||||
|
||||
// Connect to Qdrant
|
||||
const vectorStore = new QdrantVectorStore(collectionName, getQdrantClient());
|
||||
const vectorStore = new QdrantVectorStore({
|
||||
collectionName,
|
||||
client: getQdrantClient(),
|
||||
});
|
||||
|
||||
const storageContext = await storageContextFromDefaults({ vectorStore });
|
||||
await VectorStoreIndex.fromDocuments(documents, {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import * as dotenv from "dotenv";
|
||||
import { QdrantVectorStore, VectorStoreIndex } from "llamaindex";
|
||||
import { VectorStoreIndex } from "llamaindex";
|
||||
import { QdrantVectorStore } from "llamaindex/storage/vectorStore/QdrantVectorStore";
|
||||
import { checkRequiredEnvVars, getQdrantClient } from "./shared";
|
||||
|
||||
dotenv.config();
|
||||
@@ -7,7 +8,10 @@ dotenv.config();
|
||||
export async function getDataSource() {
|
||||
checkRequiredEnvVars();
|
||||
const collectionName = process.env.QDRANT_COLLECTION;
|
||||
const store = new QdrantVectorStore(collectionName, getQdrantClient());
|
||||
const store = new QdrantVectorStore({
|
||||
collectionName,
|
||||
client: getQdrantClient(),
|
||||
});
|
||||
|
||||
return await VectorStoreIndex.fromVectorStore(store);
|
||||
}
|
||||
|
||||
@@ -3,5 +3,8 @@
|
||||
"rules": {
|
||||
"max-params": ["error", 4],
|
||||
"prefer-const": "error"
|
||||
},
|
||||
"parserOptions": {
|
||||
"sourceType": "module"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
# local env files
|
||||
.env
|
||||
node_modules/
|
||||
node_modules/
|
||||
|
||||
tool-output/
|
||||
@@ -31,6 +31,8 @@ if (isDevelopment) {
|
||||
console.warn("Production CORS origin not set, defaulting to no CORS.");
|
||||
}
|
||||
|
||||
app.use("/api/files/data", express.static("data"));
|
||||
app.use("/api/files/tool-output", express.static("tool-output"));
|
||||
app.use(express.text());
|
||||
|
||||
app.get("/", (req: Request, res: Response) => {
|
||||
|
||||
@@ -1,20 +1,32 @@
|
||||
{
|
||||
"name": "llama-index-express-streaming",
|
||||
"version": "1.0.0",
|
||||
"main": "dist/index.mjs",
|
||||
"exports": "./index.js",
|
||||
"types": "./index.d.ts",
|
||||
"type": "module",
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
},
|
||||
"scripts": {
|
||||
"format": "prettier --ignore-unknown --cache --check .",
|
||||
"format:write": "prettier --ignore-unknown --write .",
|
||||
"build": "tsup index.ts --format esm --dts",
|
||||
"start": "node dist/index.mjs",
|
||||
"dev": "concurrently \"tsup index.ts --format esm --dts --watch\" \"nodemon -q dist/index.mjs\""
|
||||
"start": "node dist/index.js",
|
||||
"dev": "concurrently \"tsup index.ts --format esm --dts --watch\" \"nodemon --watch dist/index.js\""
|
||||
},
|
||||
"dependencies": {
|
||||
"ai": "^3.0.21",
|
||||
"cors": "^2.8.5",
|
||||
"dotenv": "^16.3.1",
|
||||
"duck-duck-scrape": "^2.2.5",
|
||||
"express": "^4.18.2",
|
||||
"llamaindex": "0.2.10"
|
||||
"llamaindex": "0.4.6",
|
||||
"pdf2json": "3.0.5",
|
||||
"ajv": "^8.12.0",
|
||||
"@e2b/code-interpreter": "^0.0.5",
|
||||
"got": "^14.4.1",
|
||||
"@apidevtools/swagger-parser": "^10.1.0",
|
||||
"formdata-node": "^6.0.3"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/cors": "^2.8.16",
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
import { Request, Response } from "express";
|
||||
|
||||
export const chatConfig = async (_req: Request, res: Response) => {
|
||||
let starterQuestions = undefined;
|
||||
if (
|
||||
process.env.CONVERSATION_STARTERS &&
|
||||
process.env.CONVERSATION_STARTERS.trim()
|
||||
) {
|
||||
starterQuestions = process.env.CONVERSATION_STARTERS.trim().split("\n");
|
||||
}
|
||||
return res.status(200).json({
|
||||
starterQuestions,
|
||||
});
|
||||
};
|
||||
@@ -1,32 +1,16 @@
|
||||
import { Message, StreamData, streamToResponse } from "ai";
|
||||
import { Request, Response } from "express";
|
||||
import { ChatMessage, MessageContent, Settings } from "llamaindex";
|
||||
import { ChatMessage, Settings } from "llamaindex";
|
||||
import { createChatEngine } from "./engine/chat";
|
||||
import { LlamaIndexStream } from "./llamaindex-stream";
|
||||
import { appendEventData } from "./stream-helper";
|
||||
|
||||
const convertMessageContent = (
|
||||
textMessage: string,
|
||||
imageUrl: string | undefined,
|
||||
): MessageContent => {
|
||||
if (!imageUrl) return textMessage;
|
||||
return [
|
||||
{
|
||||
type: "text",
|
||||
text: textMessage,
|
||||
},
|
||||
{
|
||||
type: "image_url",
|
||||
image_url: {
|
||||
url: imageUrl,
|
||||
},
|
||||
},
|
||||
];
|
||||
};
|
||||
import { LlamaIndexStream, convertMessageContent } from "./llamaindex-stream";
|
||||
import { createCallbackManager, createStreamTimeout } from "./stream-helper";
|
||||
|
||||
export const chat = async (req: Request, res: Response) => {
|
||||
// Init Vercel AI StreamData and timeout
|
||||
const vercelStreamData = new StreamData();
|
||||
const streamTimeout = createStreamTimeout(vercelStreamData);
|
||||
try {
|
||||
const { messages, data }: { messages: Message[]; data: any } = req.body;
|
||||
const { messages }: { messages: Message[] } = req.body;
|
||||
const userMessage = messages.pop();
|
||||
if (!messages || !userMessage || userMessage.role !== "user") {
|
||||
return res.status(400).json({
|
||||
@@ -37,58 +21,47 @@ export const chat = async (req: Request, res: Response) => {
|
||||
|
||||
const chatEngine = await createChatEngine();
|
||||
|
||||
let annotations = userMessage.annotations;
|
||||
if (!annotations) {
|
||||
// the user didn't send any new annotations with the last message
|
||||
// so use the annotations from the last user message that has annotations
|
||||
// REASON: GPT4 doesn't consider MessageContentDetail from previous messages, only strings
|
||||
annotations = messages
|
||||
.slice()
|
||||
.reverse()
|
||||
.find(
|
||||
(message) => message.role === "user" && message.annotations,
|
||||
)?.annotations;
|
||||
}
|
||||
|
||||
// Convert message content from Vercel/AI format to LlamaIndex/OpenAI format
|
||||
const userMessageContent = convertMessageContent(
|
||||
userMessage.content,
|
||||
data?.imageUrl,
|
||||
annotations,
|
||||
);
|
||||
|
||||
// Init Vercel AI StreamData
|
||||
const vercelStreamData = new StreamData();
|
||||
appendEventData(
|
||||
vercelStreamData,
|
||||
`Retrieving context for query: '${userMessage.content}'`,
|
||||
);
|
||||
|
||||
// Setup callback for streaming data before chatting
|
||||
Settings.callbackManager.on("retrieve", (data) => {
|
||||
const { nodes } = data.detail;
|
||||
appendEventData(
|
||||
vercelStreamData,
|
||||
`Retrieved ${nodes.length} sources to use as context for the query`,
|
||||
);
|
||||
});
|
||||
// Setup callbacks
|
||||
const callbackManager = createCallbackManager(vercelStreamData);
|
||||
|
||||
// Calling LlamaIndex's ChatEngine to get a streamed response
|
||||
const response = await chatEngine.chat({
|
||||
message: userMessageContent,
|
||||
chatHistory: messages as ChatMessage[],
|
||||
stream: true,
|
||||
const response = await Settings.withCallbackManager(callbackManager, () => {
|
||||
return chatEngine.chat({
|
||||
message: userMessageContent,
|
||||
chatHistory: messages as ChatMessage[],
|
||||
stream: true,
|
||||
});
|
||||
});
|
||||
|
||||
// Return a stream, which can be consumed by the Vercel/AI client
|
||||
const { stream } = LlamaIndexStream(response, vercelStreamData, {
|
||||
parserOptions: {
|
||||
image_url: data?.imageUrl,
|
||||
},
|
||||
});
|
||||
const stream = LlamaIndexStream(response, vercelStreamData);
|
||||
|
||||
// Pipe LlamaIndexStream to response
|
||||
const processedStream = stream.pipeThrough(vercelStreamData.stream);
|
||||
return streamToResponse(processedStream, res, {
|
||||
headers: {
|
||||
// response MUST have the `X-Experimental-Stream-Data: 'true'` header
|
||||
// so that the client uses the correct parsing logic, see
|
||||
// https://sdk.vercel.ai/docs/api-reference/stream-data#on-the-server
|
||||
"X-Experimental-Stream-Data": "true",
|
||||
"Content-Type": "text/plain; charset=utf-8",
|
||||
"Access-Control-Expose-Headers": "X-Experimental-Stream-Data",
|
||||
},
|
||||
});
|
||||
return streamToResponse(stream, res, {}, vercelStreamData);
|
||||
} catch (error) {
|
||||
console.error("[LlamaIndex]", error);
|
||||
return res.status(500).json({
|
||||
detail: (error as Error).message,
|
||||
});
|
||||
} finally {
|
||||
clearTimeout(streamTimeout);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,10 +1,18 @@
|
||||
import {
|
||||
Ollama,
|
||||
OllamaEmbedding,
|
||||
Anthropic,
|
||||
GEMINI_EMBEDDING_MODEL,
|
||||
GEMINI_MODEL,
|
||||
Gemini,
|
||||
GeminiEmbedding,
|
||||
Groq,
|
||||
OpenAI,
|
||||
OpenAIEmbedding,
|
||||
Settings,
|
||||
} from "llamaindex";
|
||||
import { HuggingFaceEmbedding } from "llamaindex/embeddings/HuggingFaceEmbedding";
|
||||
import { OllamaEmbedding } from "llamaindex/embeddings/OllamaEmbedding";
|
||||
import { ALL_AVAILABLE_ANTHROPIC_MODELS } from "llamaindex/llm/anthropic";
|
||||
import { Ollama } from "llamaindex/llm/ollama";
|
||||
|
||||
const CHUNK_SIZE = 512;
|
||||
const CHUNK_OVERLAP = 20;
|
||||
@@ -12,10 +20,24 @@ const CHUNK_OVERLAP = 20;
|
||||
export const initSettings = async () => {
|
||||
// HINT: you can delete the initialization code for unused model providers
|
||||
console.log(`Using '${process.env.MODEL_PROVIDER}' model provider`);
|
||||
|
||||
if (!process.env.MODEL || !process.env.EMBEDDING_MODEL) {
|
||||
throw new Error("'MODEL' and 'EMBEDDING_MODEL' env variables must be set.");
|
||||
}
|
||||
|
||||
switch (process.env.MODEL_PROVIDER) {
|
||||
case "ollama":
|
||||
initOllama();
|
||||
break;
|
||||
case "groq":
|
||||
initGroq();
|
||||
break;
|
||||
case "anthropic":
|
||||
initAnthropic();
|
||||
break;
|
||||
case "gemini":
|
||||
initGemini();
|
||||
break;
|
||||
default:
|
||||
initOpenAI();
|
||||
break;
|
||||
@@ -27,7 +49,9 @@ export const initSettings = async () => {
|
||||
function initOpenAI() {
|
||||
Settings.llm = new OpenAI({
|
||||
model: process.env.MODEL ?? "gpt-3.5-turbo",
|
||||
maxTokens: 512,
|
||||
maxTokens: process.env.LLM_MAX_TOKENS
|
||||
? Number(process.env.LLM_MAX_TOKENS)
|
||||
: undefined,
|
||||
});
|
||||
Settings.embedModel = new OpenAIEmbedding({
|
||||
model: process.env.EMBEDDING_MODEL,
|
||||
@@ -38,15 +62,59 @@ function initOpenAI() {
|
||||
}
|
||||
|
||||
function initOllama() {
|
||||
if (!process.env.MODEL || !process.env.EMBEDDING_MODEL) {
|
||||
throw new Error(
|
||||
"Using Ollama as model provider, 'MODEL' and 'EMBEDDING_MODEL' env variables must be set.",
|
||||
);
|
||||
}
|
||||
const config = {
|
||||
host: process.env.OLLAMA_BASE_URL ?? "http://127.0.0.1:11434",
|
||||
};
|
||||
|
||||
Settings.llm = new Ollama({
|
||||
model: process.env.MODEL ?? "",
|
||||
config,
|
||||
});
|
||||
Settings.embedModel = new OllamaEmbedding({
|
||||
model: process.env.EMBEDDING_MODEL ?? "",
|
||||
config,
|
||||
});
|
||||
}
|
||||
|
||||
function initAnthropic() {
|
||||
const embedModelMap: Record<string, string> = {
|
||||
"all-MiniLM-L6-v2": "Xenova/all-MiniLM-L6-v2",
|
||||
"all-mpnet-base-v2": "Xenova/all-mpnet-base-v2",
|
||||
};
|
||||
Settings.llm = new Anthropic({
|
||||
model: process.env.MODEL as keyof typeof ALL_AVAILABLE_ANTHROPIC_MODELS,
|
||||
});
|
||||
Settings.embedModel = new HuggingFaceEmbedding({
|
||||
modelType: embedModelMap[process.env.EMBEDDING_MODEL!],
|
||||
});
|
||||
}
|
||||
|
||||
function initGroq() {
|
||||
const embedModelMap: Record<string, string> = {
|
||||
"all-MiniLM-L6-v2": "Xenova/all-MiniLM-L6-v2",
|
||||
"all-mpnet-base-v2": "Xenova/all-mpnet-base-v2",
|
||||
};
|
||||
|
||||
const modelMap: Record<string, string> = {
|
||||
"llama3-8b": "llama3-8b-8192",
|
||||
"llama3-70b": "llama3-70b-8192",
|
||||
"mixtral-8x7b": "mixtral-8x7b-32768",
|
||||
};
|
||||
|
||||
Settings.llm = new Groq({
|
||||
model: modelMap[process.env.MODEL!],
|
||||
});
|
||||
|
||||
Settings.embedModel = new HuggingFaceEmbedding({
|
||||
modelType: embedModelMap[process.env.EMBEDDING_MODEL!],
|
||||
});
|
||||
}
|
||||
|
||||
function initGemini() {
|
||||
Settings.llm = new Gemini({
|
||||
model: process.env.MODEL as GEMINI_MODEL,
|
||||
});
|
||||
Settings.embedModel = new GeminiEmbedding({
|
||||
model: process.env.EMBEDDING_MODEL as GEMINI_EMBEDDING_MODEL,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import {
|
||||
JSONValue,
|
||||
StreamData,
|
||||
createCallbacksTransformer,
|
||||
createStreamDataTransformer,
|
||||
@@ -6,44 +7,90 @@ import {
|
||||
type AIStreamCallbacksAndOptions,
|
||||
} from "ai";
|
||||
import {
|
||||
Metadata,
|
||||
NodeWithScore,
|
||||
Response,
|
||||
StreamingAgentChatResponse,
|
||||
EngineResponse,
|
||||
MessageContent,
|
||||
MessageContentDetail,
|
||||
} from "llamaindex";
|
||||
import { appendImageData, appendSourceData } from "./stream-helper";
|
||||
|
||||
type ParserOptions = {
|
||||
image_url?: string;
|
||||
import { CsvFile } from "./stream-helper";
|
||||
|
||||
export const convertMessageContent = (
|
||||
content: string,
|
||||
annotations?: JSONValue[],
|
||||
): MessageContent => {
|
||||
if (!annotations) return content;
|
||||
return [
|
||||
{
|
||||
type: "text",
|
||||
text: content,
|
||||
},
|
||||
...convertAnnotations(annotations),
|
||||
];
|
||||
};
|
||||
|
||||
function createParser(
|
||||
res: AsyncIterable<Response>,
|
||||
data: StreamData,
|
||||
opts?: ParserOptions,
|
||||
) {
|
||||
const convertAnnotations = (
|
||||
annotations: JSONValue[],
|
||||
): MessageContentDetail[] => {
|
||||
const content: MessageContentDetail[] = [];
|
||||
annotations.forEach((annotation: JSONValue) => {
|
||||
// first skip invalid annotation
|
||||
if (
|
||||
!(
|
||||
annotation &&
|
||||
typeof annotation === "object" &&
|
||||
"type" in annotation &&
|
||||
"data" in annotation &&
|
||||
annotation.data &&
|
||||
typeof annotation.data === "object"
|
||||
)
|
||||
) {
|
||||
console.log(
|
||||
"Client sent invalid annotation. Missing data and type",
|
||||
annotation,
|
||||
);
|
||||
return;
|
||||
}
|
||||
const { type, data } = annotation;
|
||||
// convert image
|
||||
if (type === "image" && "url" in data && typeof data.url === "string") {
|
||||
content.push({
|
||||
type: "image_url",
|
||||
image_url: {
|
||||
url: data.url,
|
||||
},
|
||||
});
|
||||
}
|
||||
// convert CSV files to text
|
||||
if (type === "csv" && "csvFiles" in data && Array.isArray(data.csvFiles)) {
|
||||
const rawContents = data.csvFiles.map((csv) => {
|
||||
return "```csv\n" + (csv as CsvFile).content + "\n```";
|
||||
});
|
||||
const csvContent =
|
||||
"Use data from following CSV raw contents:\n" +
|
||||
rawContents.join("\n\n");
|
||||
content.push({
|
||||
type: "text",
|
||||
text: csvContent,
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
return content;
|
||||
};
|
||||
|
||||
function createParser(res: AsyncIterable<EngineResponse>, data: StreamData) {
|
||||
const it = res[Symbol.asyncIterator]();
|
||||
const trimStartOfStream = trimStartOfStreamHelper();
|
||||
|
||||
let sourceNodes: NodeWithScore<Metadata>[] | undefined;
|
||||
return new ReadableStream<string>({
|
||||
start() {
|
||||
appendImageData(data, opts?.image_url);
|
||||
},
|
||||
async pull(controller): Promise<void> {
|
||||
const { value, done } = await it.next();
|
||||
if (done) {
|
||||
appendSourceData(data, sourceNodes);
|
||||
controller.close();
|
||||
data.close();
|
||||
return;
|
||||
}
|
||||
|
||||
if (!sourceNodes) {
|
||||
// get source nodes from the first response
|
||||
sourceNodes = value.sourceNodes;
|
||||
}
|
||||
const text = trimStartOfStream(value.response ?? "");
|
||||
const text = trimStartOfStream(value.delta ?? "");
|
||||
if (text) {
|
||||
controller.enqueue(text);
|
||||
}
|
||||
@@ -52,21 +99,13 @@ function createParser(
|
||||
}
|
||||
|
||||
export function LlamaIndexStream(
|
||||
response: StreamingAgentChatResponse | AsyncIterable<Response>,
|
||||
response: AsyncIterable<EngineResponse>,
|
||||
data: StreamData,
|
||||
opts?: {
|
||||
callbacks?: AIStreamCallbacksAndOptions;
|
||||
parserOptions?: ParserOptions;
|
||||
},
|
||||
): { stream: ReadableStream; data: StreamData } {
|
||||
const res =
|
||||
response instanceof StreamingAgentChatResponse
|
||||
? response.response
|
||||
: response;
|
||||
return {
|
||||
stream: createParser(res, data, opts?.parserOptions)
|
||||
.pipeThrough(createCallbacksTransformer(opts?.callbacks))
|
||||
.pipeThrough(createStreamDataTransformer()),
|
||||
data,
|
||||
};
|
||||
): ReadableStream<Uint8Array> {
|
||||
return createParser(response, data)
|
||||
.pipeThrough(createCallbacksTransformer(opts?.callbacks))
|
||||
.pipeThrough(createStreamDataTransformer());
|
||||
}
|
||||
|
||||
@@ -1,14 +1,26 @@
|
||||
import { StreamData } from "ai";
|
||||
import { Metadata, NodeWithScore } from "llamaindex";
|
||||
import {
|
||||
CallbackManager,
|
||||
Metadata,
|
||||
NodeWithScore,
|
||||
ToolCall,
|
||||
ToolOutput,
|
||||
} from "llamaindex";
|
||||
|
||||
export function appendImageData(data: StreamData, imageUrl?: string) {
|
||||
if (!imageUrl) return;
|
||||
data.appendMessageAnnotation({
|
||||
type: "image",
|
||||
data: {
|
||||
url: imageUrl,
|
||||
},
|
||||
});
|
||||
function getNodeUrl(metadata: Metadata) {
|
||||
const url = metadata["URL"];
|
||||
if (url) return url;
|
||||
const fileName = metadata["file_name"];
|
||||
if (!process.env.FILESERVER_URL_PREFIX) {
|
||||
console.warn(
|
||||
"FILESERVER_URL_PREFIX is not set. File URLs will not be generated.",
|
||||
);
|
||||
return undefined;
|
||||
}
|
||||
if (fileName) {
|
||||
return `${process.env.FILESERVER_URL_PREFIX}/data/${fileName}`;
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
export function appendSourceData(
|
||||
@@ -23,6 +35,7 @@ export function appendSourceData(
|
||||
...node.node.toMutableJSON(),
|
||||
id: node.node.id_,
|
||||
score: node.score ?? null,
|
||||
url: getNodeUrl(node.node.metadata),
|
||||
})),
|
||||
},
|
||||
});
|
||||
@@ -37,3 +50,72 @@ export function appendEventData(data: StreamData, title?: string) {
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
export function appendToolData(
|
||||
data: StreamData,
|
||||
toolCall: ToolCall,
|
||||
toolOutput: ToolOutput,
|
||||
) {
|
||||
data.appendMessageAnnotation({
|
||||
type: "tools",
|
||||
data: {
|
||||
toolCall: {
|
||||
id: toolCall.id,
|
||||
name: toolCall.name,
|
||||
input: toolCall.input,
|
||||
},
|
||||
toolOutput: {
|
||||
output: toolOutput.output,
|
||||
isError: toolOutput.isError,
|
||||
},
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
export function createStreamTimeout(stream: StreamData) {
|
||||
const timeout = Number(process.env.STREAM_TIMEOUT ?? 1000 * 60 * 5); // default to 5 minutes
|
||||
const t = setTimeout(() => {
|
||||
appendEventData(stream, `Stream timed out after ${timeout / 1000} seconds`);
|
||||
stream.close();
|
||||
}, timeout);
|
||||
return t;
|
||||
}
|
||||
|
||||
export function createCallbackManager(stream: StreamData) {
|
||||
const callbackManager = new CallbackManager();
|
||||
|
||||
callbackManager.on("retrieve-end", (data) => {
|
||||
const { nodes, query } = data.detail.payload;
|
||||
appendSourceData(stream, nodes);
|
||||
appendEventData(stream, `Retrieving context for query: '${query}'`);
|
||||
appendEventData(
|
||||
stream,
|
||||
`Retrieved ${nodes.length} sources to use as context for the query`,
|
||||
);
|
||||
});
|
||||
|
||||
callbackManager.on("llm-tool-call", (event) => {
|
||||
const { name, input } = event.detail.payload.toolCall;
|
||||
const inputString = Object.entries(input)
|
||||
.map(([key, value]) => `${key}: ${value}`)
|
||||
.join(", ");
|
||||
appendEventData(
|
||||
stream,
|
||||
`Using tool: '${name}' with inputs: '${inputString}'`,
|
||||
);
|
||||
});
|
||||
|
||||
callbackManager.on("llm-tool-result", (event) => {
|
||||
const { toolCall, toolResult } = event.detail.payload;
|
||||
appendToolData(stream, toolCall, toolResult);
|
||||
});
|
||||
|
||||
return callbackManager;
|
||||
}
|
||||
|
||||
export type CsvFile = {
|
||||
content: string;
|
||||
filename: string;
|
||||
filesize: number;
|
||||
id: string;
|
||||
};
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import express, { Router } from "express";
|
||||
import { chatConfig } from "../controllers/chat-config.controller";
|
||||
import { chatRequest } from "../controllers/chat-request.controller";
|
||||
import { chat } from "../controllers/chat.controller";
|
||||
import { initSettings } from "../controllers/engine/settings";
|
||||
@@ -8,5 +9,6 @@ const llmRouter: Router = express.Router();
|
||||
initSettings();
|
||||
llmRouter.route("/").post(chat);
|
||||
llmRouter.route("/request").post(chatRequest);
|
||||
llmRouter.route("/config").get(chatConfig);
|
||||
|
||||
export default llmRouter;
|
||||
|
||||
@@ -1,154 +1,114 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Any, Optional, Dict, Tuple
|
||||
import os
|
||||
import logging
|
||||
|
||||
from aiostream import stream
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from llama_index.core.chat_engine.types import (
|
||||
BaseChatEngine,
|
||||
StreamingAgentChatResponse,
|
||||
)
|
||||
from llama_index.core.schema import NodeWithScore
|
||||
from llama_index.core.llms import ChatMessage, MessageRole
|
||||
from llama_index.core.chat_engine.types import BaseChatEngine
|
||||
from llama_index.core.llms import MessageRole
|
||||
from app.engine import get_chat_engine
|
||||
from app.api.routers.vercel_response import VercelStreamResponse
|
||||
from app.api.routers.messaging import EventCallbackHandler
|
||||
from aiostream import stream
|
||||
from app.api.routers.events import EventCallbackHandler
|
||||
from app.api.routers.models import (
|
||||
ChatData,
|
||||
ChatConfig,
|
||||
SourceNodes,
|
||||
Result,
|
||||
Message,
|
||||
)
|
||||
|
||||
chat_router = r = APIRouter()
|
||||
|
||||
|
||||
class _Message(BaseModel):
|
||||
role: MessageRole
|
||||
content: str
|
||||
|
||||
|
||||
class _ChatData(BaseModel):
|
||||
messages: List[_Message]
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What standards for letters exist?",
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class _SourceNodes(BaseModel):
|
||||
id: str
|
||||
metadata: Dict[str, Any]
|
||||
score: Optional[float]
|
||||
text: str
|
||||
|
||||
@classmethod
|
||||
def from_source_node(cls, source_node: NodeWithScore):
|
||||
return cls(
|
||||
id=source_node.node.node_id,
|
||||
metadata=source_node.node.metadata,
|
||||
score=source_node.score,
|
||||
text=source_node.node.text, # type: ignore
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_source_nodes(cls, source_nodes: List[NodeWithScore]):
|
||||
return [cls.from_source_node(node) for node in source_nodes]
|
||||
|
||||
|
||||
class _Result(BaseModel):
|
||||
result: _Message
|
||||
nodes: List[_SourceNodes]
|
||||
|
||||
|
||||
async def parse_chat_data(data: _ChatData) -> Tuple[str, List[ChatMessage]]:
|
||||
# check preconditions and get last message
|
||||
if len(data.messages) == 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="No messages provided",
|
||||
)
|
||||
last_message = data.messages.pop()
|
||||
if last_message.role != MessageRole.USER:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Last message must be from user",
|
||||
)
|
||||
# convert messages coming from the request to type ChatMessage
|
||||
messages = [
|
||||
ChatMessage(
|
||||
role=m.role,
|
||||
content=m.content,
|
||||
)
|
||||
for m in data.messages
|
||||
]
|
||||
return last_message.content, messages
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
# streaming endpoint - delete if not needed
|
||||
@r.post("")
|
||||
async def chat(
|
||||
request: Request,
|
||||
data: _ChatData,
|
||||
data: ChatData,
|
||||
chat_engine: BaseChatEngine = Depends(get_chat_engine),
|
||||
):
|
||||
last_message_content, messages = await parse_chat_data(data)
|
||||
try:
|
||||
last_message_content = data.get_last_message_content()
|
||||
messages = data.get_history_messages()
|
||||
|
||||
event_handler = EventCallbackHandler()
|
||||
chat_engine.callback_manager.handlers.append(event_handler) # type: ignore
|
||||
response = await chat_engine.astream_chat(last_message_content, messages)
|
||||
event_handler = EventCallbackHandler()
|
||||
chat_engine.callback_manager.handlers.append(event_handler) # type: ignore
|
||||
|
||||
async def content_generator():
|
||||
# Yield the text response
|
||||
async def _text_generator():
|
||||
async for token in response.async_response_gen():
|
||||
yield VercelStreamResponse.convert_text(token)
|
||||
# the text_generator is the leading stream, once it's finished, also finish the event stream
|
||||
event_handler.is_done = True
|
||||
async def content_generator():
|
||||
# Yield the text response
|
||||
async def _chat_response_generator():
|
||||
response = await chat_engine.astream_chat(
|
||||
last_message_content, messages
|
||||
)
|
||||
async for token in response.async_response_gen():
|
||||
yield VercelStreamResponse.convert_text(token)
|
||||
# the text_generator is the leading stream, once it's finished, also finish the event stream
|
||||
event_handler.is_done = True
|
||||
|
||||
# Yield the events from the event handler
|
||||
async def _event_generator():
|
||||
async for event in event_handler.async_event_gen():
|
||||
# Yield the source nodes
|
||||
yield VercelStreamResponse.convert_data(
|
||||
{
|
||||
"type": "events",
|
||||
"data": {"title": event.get_title()},
|
||||
"type": "sources",
|
||||
"data": {
|
||||
"nodes": [
|
||||
SourceNodes.from_source_node(node).dict()
|
||||
for node in response.source_nodes
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
combine = stream.merge(_text_generator(), _event_generator())
|
||||
async with combine.stream() as streamer:
|
||||
async for item in streamer:
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
yield item
|
||||
# Yield the events from the event handler
|
||||
async def _event_generator():
|
||||
async for event in event_handler.async_event_gen():
|
||||
event_response = event.to_response()
|
||||
if event_response is not None:
|
||||
yield VercelStreamResponse.convert_data(event_response)
|
||||
|
||||
# Yield the source nodes
|
||||
yield VercelStreamResponse.convert_data(
|
||||
{
|
||||
"type": "sources",
|
||||
"data": {
|
||||
"nodes": [
|
||||
_SourceNodes.from_source_node(node).dict()
|
||||
for node in response.source_nodes
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
combine = stream.merge(_chat_response_generator(), _event_generator())
|
||||
is_stream_started = False
|
||||
async with combine.stream() as streamer:
|
||||
async for output in streamer:
|
||||
if not is_stream_started:
|
||||
is_stream_started = True
|
||||
# Stream a blank message to start the stream
|
||||
yield VercelStreamResponse.convert_text("")
|
||||
|
||||
return VercelStreamResponse(content=content_generator())
|
||||
yield output
|
||||
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
|
||||
return VercelStreamResponse(content=content_generator())
|
||||
except Exception as e:
|
||||
logger.exception("Error in chat engine", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error in chat engine: {e}",
|
||||
) from e
|
||||
|
||||
|
||||
# non-streaming endpoint - delete if not needed
|
||||
@r.post("/request")
|
||||
async def chat_request(
|
||||
data: _ChatData,
|
||||
data: ChatData,
|
||||
chat_engine: BaseChatEngine = Depends(get_chat_engine),
|
||||
) -> _Result:
|
||||
last_message_content, messages = await parse_chat_data(data)
|
||||
) -> Result:
|
||||
last_message_content = data.get_last_message_content()
|
||||
messages = data.get_history_messages()
|
||||
|
||||
response = await chat_engine.achat(last_message_content, messages)
|
||||
return _Result(
|
||||
result=_Message(role=MessageRole.ASSISTANT, content=response.response),
|
||||
nodes=_SourceNodes.from_source_nodes(response.source_nodes),
|
||||
return Result(
|
||||
result=Message(role=MessageRole.ASSISTANT, content=response.response),
|
||||
nodes=SourceNodes.from_source_nodes(response.source_nodes),
|
||||
)
|
||||
|
||||
|
||||
@r.get("/config")
|
||||
async def chat_config() -> ChatConfig:
|
||||
starter_questions = None
|
||||
conversation_starters = os.getenv("CONVERSATION_STARTERS")
|
||||
if conversation_starters and conversation_starters.strip():
|
||||
starter_questions = conversation_starters.strip().split("\n")
|
||||
return ChatConfig(starterQuestions=starter_questions)
|
||||
|
||||
@@ -0,0 +1,149 @@
|
||||
import json
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import AsyncGenerator, Dict, Any, List, Optional
|
||||
from llama_index.core.callbacks.base import BaseCallbackHandler
|
||||
from llama_index.core.callbacks.schema import CBEventType
|
||||
from llama_index.core.tools.types import ToolOutput
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CallbackEvent(BaseModel):
|
||||
event_type: CBEventType
|
||||
payload: Optional[Dict[str, Any]] = None
|
||||
event_id: str = ""
|
||||
|
||||
def get_retrieval_message(self) -> dict | None:
|
||||
if self.payload:
|
||||
nodes = self.payload.get("nodes")
|
||||
if nodes:
|
||||
msg = f"Retrieved {len(nodes)} sources to use as context for the query"
|
||||
else:
|
||||
msg = f"Retrieving context for query: '{self.payload.get('query_str')}'"
|
||||
return {
|
||||
"type": "events",
|
||||
"data": {"title": msg},
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_tool_message(self) -> dict | None:
|
||||
func_call_args = self.payload.get("function_call")
|
||||
if func_call_args is not None and "tool" in self.payload:
|
||||
tool = self.payload.get("tool")
|
||||
return {
|
||||
"type": "events",
|
||||
"data": {
|
||||
"title": f"Calling tool: {tool.name} with inputs: {func_call_args}",
|
||||
},
|
||||
}
|
||||
|
||||
def _is_output_serializable(self, output: Any) -> bool:
|
||||
try:
|
||||
json.dumps(output)
|
||||
return True
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
def get_agent_tool_response(self) -> dict | None:
|
||||
response = self.payload.get("response")
|
||||
if response is not None:
|
||||
sources = response.sources
|
||||
for source in sources:
|
||||
# Return the tool response here to include the toolCall information
|
||||
if isinstance(source, ToolOutput):
|
||||
if self._is_output_serializable(source.raw_output):
|
||||
output = source.raw_output
|
||||
else:
|
||||
output = source.content
|
||||
|
||||
return {
|
||||
"type": "tools",
|
||||
"data": {
|
||||
"toolOutput": {
|
||||
"output": output,
|
||||
"isError": source.is_error,
|
||||
},
|
||||
"toolCall": {
|
||||
"id": None, # There is no tool id in the ToolOutput
|
||||
"name": source.tool_name,
|
||||
"input": source.raw_input,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def to_response(self):
|
||||
try:
|
||||
match self.event_type:
|
||||
case "retrieve":
|
||||
return self.get_retrieval_message()
|
||||
case "function_call":
|
||||
return self.get_tool_message()
|
||||
case "agent_step":
|
||||
return self.get_agent_tool_response()
|
||||
case _:
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error in converting event to response: {e}")
|
||||
return None
|
||||
|
||||
|
||||
class EventCallbackHandler(BaseCallbackHandler):
|
||||
_aqueue: asyncio.Queue
|
||||
is_done: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
"""Initialize the base callback handler."""
|
||||
ignored_events = [
|
||||
CBEventType.CHUNKING,
|
||||
CBEventType.NODE_PARSING,
|
||||
CBEventType.EMBEDDING,
|
||||
CBEventType.LLM,
|
||||
CBEventType.TEMPLATING,
|
||||
]
|
||||
super().__init__(ignored_events, ignored_events)
|
||||
self._aqueue = asyncio.Queue()
|
||||
|
||||
def on_event_start(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
event_id: str = "",
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
event = CallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
|
||||
if event.to_response() is not None:
|
||||
self._aqueue.put_nowait(event)
|
||||
|
||||
def on_event_end(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
event_id: str = "",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
event = CallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
|
||||
if event.to_response() is not None:
|
||||
self._aqueue.put_nowait(event)
|
||||
|
||||
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
||||
"""No-op."""
|
||||
|
||||
def end_trace(
|
||||
self,
|
||||
trace_id: Optional[str] = None,
|
||||
trace_map: Optional[Dict[str, List[str]]] = None,
|
||||
) -> None:
|
||||
"""No-op."""
|
||||
|
||||
async def async_event_gen(self) -> AsyncGenerator[CallbackEvent, None]:
|
||||
while not self._aqueue.empty() or not self.is_done:
|
||||
try:
|
||||
yield await asyncio.wait_for(self._aqueue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
@@ -1,86 +0,0 @@
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, Dict, Any, List, Optional
|
||||
|
||||
from llama_index.core.callbacks.base import BaseCallbackHandler
|
||||
from llama_index.core.callbacks.schema import CBEventType
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CallbackEvent(BaseModel):
|
||||
event_type: CBEventType
|
||||
payload: Optional[Dict[str, Any]] = None
|
||||
event_id: str = ""
|
||||
|
||||
def get_title(self) -> str | None:
|
||||
# Return as None for the unhandled event types
|
||||
# to avoid showing them in the UI
|
||||
match self.event_type:
|
||||
case "retrieve":
|
||||
if self.payload:
|
||||
nodes = self.payload.get("nodes")
|
||||
if nodes:
|
||||
return f"Retrieved {len(nodes)} sources to use as context for the query"
|
||||
else:
|
||||
return f"Retrieving context for query: '{self.payload.get('query_str')}'"
|
||||
else:
|
||||
return None
|
||||
case _:
|
||||
return None
|
||||
|
||||
|
||||
class EventCallbackHandler(BaseCallbackHandler):
|
||||
_aqueue: asyncio.Queue
|
||||
is_done: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
"""Initialize the base callback handler."""
|
||||
ignored_events = [
|
||||
CBEventType.CHUNKING,
|
||||
CBEventType.NODE_PARSING,
|
||||
CBEventType.EMBEDDING,
|
||||
CBEventType.LLM,
|
||||
CBEventType.TEMPLATING,
|
||||
]
|
||||
super().__init__(ignored_events, ignored_events)
|
||||
self._aqueue = asyncio.Queue()
|
||||
|
||||
def on_event_start(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
event_id: str = "",
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
event = CallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
|
||||
if event.get_title() is not None:
|
||||
self._aqueue.put_nowait(event)
|
||||
|
||||
def on_event_end(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
event_id: str = "",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
event = CallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
|
||||
if event.get_title() is not None:
|
||||
self._aqueue.put_nowait(event)
|
||||
|
||||
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
||||
"""No-op."""
|
||||
|
||||
def end_trace(
|
||||
self,
|
||||
trace_id: Optional[str] = None,
|
||||
trace_map: Optional[Dict[str, List[str]]] = None,
|
||||
) -> None:
|
||||
"""No-op."""
|
||||
|
||||
async def async_event_gen(self) -> AsyncGenerator[CallbackEvent, None]:
|
||||
while not self._aqueue.empty() or not self.is_done:
|
||||
try:
|
||||
yield await asyncio.wait_for(self._aqueue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
@@ -0,0 +1,170 @@
|
||||
import os
|
||||
import logging
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from pydantic.alias_generators import to_camel
|
||||
from typing import List, Any, Optional, Dict
|
||||
from llama_index.core.schema import NodeWithScore
|
||||
from llama_index.core.llms import ChatMessage, MessageRole
|
||||
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
class CsvFile(BaseModel):
|
||||
content: str
|
||||
filename: str
|
||||
filesize: int
|
||||
id: str
|
||||
|
||||
|
||||
class AnnotationData(BaseModel):
|
||||
csv_files: List[CsvFile] | None = Field(
|
||||
default=None,
|
||||
description="List of CSV files",
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"csvFiles": [
|
||||
{
|
||||
"content": "Name, Age\nAlice, 25\nBob, 30",
|
||||
"filename": "example.csv",
|
||||
"filesize": 123,
|
||||
"id": "123",
|
||||
"type": "text/csv",
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
alias_generator = to_camel
|
||||
|
||||
|
||||
class Annotation(BaseModel):
|
||||
type: str
|
||||
data: AnnotationData
|
||||
|
||||
def to_content(self) -> str:
|
||||
if self.type == "csv":
|
||||
csv_files = self.data.csv_files
|
||||
if csv_files is not None and len(csv_files) > 0:
|
||||
return "Use data from following CSV raw contents\n" + "\n".join(
|
||||
[f"```csv\n{csv_file.content}\n```" for csv_file in csv_files]
|
||||
)
|
||||
raise ValueError(f"Unsupported annotation type: {self.type}")
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
role: MessageRole
|
||||
content: str
|
||||
annotations: List[Annotation] | None = None
|
||||
|
||||
|
||||
class ChatData(BaseModel):
|
||||
messages: List[Message]
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What standards for letters exist?",
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@validator("messages")
|
||||
def messages_must_not_be_empty(cls, v):
|
||||
if len(v) == 0:
|
||||
raise ValueError("Messages must not be empty")
|
||||
return v
|
||||
|
||||
def get_last_message_content(self) -> str:
|
||||
"""
|
||||
Get the content of the last message along with the data content if available. Fallback to use data content from previous messages
|
||||
"""
|
||||
if len(self.messages) == 0:
|
||||
raise ValueError("There is not any message in the chat")
|
||||
last_message = self.messages[-1]
|
||||
message_content = last_message.content
|
||||
for message in reversed(self.messages):
|
||||
if message.role == MessageRole.USER and message.annotations is not None:
|
||||
annotation_contents = (
|
||||
annotation.to_content() for annotation in message.annotations
|
||||
)
|
||||
annotation_text = "\n".join(annotation_contents)
|
||||
message_content = f"{message_content}\n{annotation_text}"
|
||||
break
|
||||
return message_content
|
||||
|
||||
def get_history_messages(self) -> List[Message]:
|
||||
"""
|
||||
Get the history messages
|
||||
"""
|
||||
return [
|
||||
ChatMessage(role=message.role, content=message.content)
|
||||
for message in self.messages[:-1]
|
||||
]
|
||||
|
||||
def is_last_message_from_user(self) -> bool:
|
||||
return self.messages[-1].role == MessageRole.USER
|
||||
|
||||
|
||||
class SourceNodes(BaseModel):
|
||||
id: str
|
||||
metadata: Dict[str, Any]
|
||||
score: Optional[float]
|
||||
text: str
|
||||
url: Optional[str]
|
||||
|
||||
@classmethod
|
||||
def from_source_node(cls, source_node: NodeWithScore):
|
||||
metadata = source_node.node.metadata
|
||||
url = metadata.get("URL")
|
||||
|
||||
if not url:
|
||||
file_name = metadata.get("file_name")
|
||||
url_prefix = os.getenv("FILESERVER_URL_PREFIX")
|
||||
if not url_prefix:
|
||||
logger.warning(
|
||||
"Warning: FILESERVER_URL_PREFIX not set in environment variables"
|
||||
)
|
||||
if file_name and url_prefix:
|
||||
url = f"{url_prefix}/data/{file_name}"
|
||||
|
||||
return cls(
|
||||
id=source_node.node.node_id,
|
||||
metadata=metadata,
|
||||
score=source_node.score,
|
||||
text=source_node.node.text, # type: ignore
|
||||
url=url,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_source_nodes(cls, source_nodes: List[NodeWithScore]):
|
||||
return [cls.from_source_node(node) for node in source_nodes]
|
||||
|
||||
|
||||
class Result(BaseModel):
|
||||
result: Message
|
||||
nodes: List[SourceNodes]
|
||||
|
||||
|
||||
class ChatConfig(BaseModel):
|
||||
starter_questions: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="List of starter questions",
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"starterQuestions": [
|
||||
"What standards for letters exist?",
|
||||
"What are the requirements for a letter to be considered a letter?",
|
||||
]
|
||||
}
|
||||
}
|
||||
alias_generator = to_camel
|
||||
@@ -1 +0,0 @@
|
||||
STORAGE_DIR = "storage" # directory to save the stores to (document store and if used, the `SimpleVectorStore`)
|
||||
@@ -7,9 +7,8 @@ import logging
|
||||
from llama_index.core.settings import Settings
|
||||
from llama_index.core.ingestion import IngestionPipeline
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core.vector_stores import SimpleVectorStore
|
||||
from llama_index.core.storage.docstore import SimpleDocumentStore
|
||||
from app.constants import STORAGE_DIR
|
||||
from llama_index.core.storage import StorageContext
|
||||
from app.settings import init_settings
|
||||
from app.engine.loaders import get_documents
|
||||
from app.engine.vectordb import get_vector_store
|
||||
@@ -18,26 +17,21 @@ from app.engine.vectordb import get_vector_store
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger()
|
||||
|
||||
STORAGE_DIR = os.getenv("STORAGE_DIR", "storage")
|
||||
|
||||
|
||||
def get_doc_store():
|
||||
if not os.path.exists(STORAGE_DIR):
|
||||
docstore = SimpleDocumentStore()
|
||||
return docstore
|
||||
else:
|
||||
|
||||
# If the storage directory is there, load the document store from it.
|
||||
# If not, set up an in-memory document store since we can't load from a directory that doesn't exist.
|
||||
if os.path.exists(STORAGE_DIR):
|
||||
return SimpleDocumentStore.from_persist_dir(STORAGE_DIR)
|
||||
else:
|
||||
return SimpleDocumentStore()
|
||||
|
||||
|
||||
def generate_datasource():
|
||||
init_settings()
|
||||
logger.info("Creating new index")
|
||||
|
||||
# load the documents and create the index
|
||||
documents = get_documents()
|
||||
docstore = get_doc_store()
|
||||
vector_store = get_vector_store()
|
||||
|
||||
# Create ingestion pipeline
|
||||
ingestion_pipeline = IngestionPipeline(
|
||||
def run_pipeline(docstore, vector_store, documents):
|
||||
pipeline = IngestionPipeline(
|
||||
transformations=[
|
||||
SentenceSplitter(
|
||||
chunk_size=Settings.chunk_size,
|
||||
@@ -47,23 +41,39 @@ def generate_datasource():
|
||||
],
|
||||
docstore=docstore,
|
||||
docstore_strategy="upserts_and_delete",
|
||||
vector_store=vector_store,
|
||||
)
|
||||
|
||||
# llama_index having an typing issue when passing vector_store to IngestionPipeline
|
||||
# so we need to set it manually after initialization
|
||||
ingestion_pipeline.vector_store = vector_store
|
||||
|
||||
# Run the ingestion pipeline and store the results
|
||||
ingestion_pipeline.run(show_progress=True, documents=documents)
|
||||
nodes = pipeline.run(show_progress=True, documents=documents)
|
||||
|
||||
# Default vector store only keeps data in memory, so we need to persist it
|
||||
# Can remove if using a different vector store
|
||||
if isinstance(vector_store, SimpleVectorStore):
|
||||
vector_store.persist(os.path.join(STORAGE_DIR, "vector_store.json"))
|
||||
# Persist the docstore to apply ingestion strategy
|
||||
docstore.persist(os.path.join(STORAGE_DIR, "docstore.json"))
|
||||
return nodes
|
||||
|
||||
logger.info("Finished creating new index.")
|
||||
|
||||
def persist_storage(docstore, vector_store):
|
||||
storage_context = StorageContext.from_defaults(
|
||||
docstore=docstore,
|
||||
vector_store=vector_store,
|
||||
)
|
||||
storage_context.persist(STORAGE_DIR)
|
||||
|
||||
|
||||
def generate_datasource():
|
||||
init_settings()
|
||||
logger.info("Generate index for the provided data")
|
||||
|
||||
# Get the stores and documents or create new ones
|
||||
documents = get_documents()
|
||||
docstore = get_doc_store()
|
||||
vector_store = get_vector_store()
|
||||
|
||||
# Run the ingestion pipeline
|
||||
_ = run_pipeline(docstore, vector_store, documents)
|
||||
|
||||
# Build the index and persist storage
|
||||
persist_storage(docstore, vector_store)
|
||||
|
||||
logger.info("Finished generating the index")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
import logging
|
||||
from llama_index.core.indices.vector_store import VectorStoreIndex
|
||||
from llama_index.core.indices import VectorStoreIndex
|
||||
from app.engine.vectordb import get_vector_store
|
||||
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
def get_index():
|
||||
logger.info("Loading the index...")
|
||||
logger.info("Connecting vector store...")
|
||||
store = get_vector_store()
|
||||
# Load the index from the vector store
|
||||
# If you are using a vector store that doesn't store text,
|
||||
# you must load the index from both the vector store and the document store
|
||||
index = VectorStoreIndex.from_vector_store(store)
|
||||
logger.info("Loaded index successfully.")
|
||||
logger.info("Finished load index from vector store.")
|
||||
return index
|
||||
|
||||
@@ -1,32 +1,51 @@
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from llama_index.core.settings import Settings
|
||||
|
||||
|
||||
def init_settings():
|
||||
model_provider = os.getenv("MODEL_PROVIDER")
|
||||
if model_provider == "openai":
|
||||
init_openai()
|
||||
elif model_provider == "ollama":
|
||||
init_ollama()
|
||||
else:
|
||||
raise ValueError(f"Invalid model provider: {model_provider}")
|
||||
match model_provider:
|
||||
case "openai":
|
||||
init_openai()
|
||||
case "groq":
|
||||
init_groq()
|
||||
case "ollama":
|
||||
init_ollama()
|
||||
case "anthropic":
|
||||
init_anthropic()
|
||||
case "gemini":
|
||||
init_gemini()
|
||||
case "azure-openai":
|
||||
init_azure_openai()
|
||||
case _:
|
||||
raise ValueError(f"Invalid model provider: {model_provider}")
|
||||
Settings.chunk_size = int(os.getenv("CHUNK_SIZE", "1024"))
|
||||
Settings.chunk_overlap = int(os.getenv("CHUNK_OVERLAP", "20"))
|
||||
|
||||
|
||||
def init_ollama():
|
||||
from llama_index.llms.ollama import Ollama
|
||||
from llama_index.embeddings.ollama import OllamaEmbedding
|
||||
from llama_index.llms.ollama.base import DEFAULT_REQUEST_TIMEOUT, Ollama
|
||||
|
||||
Settings.embed_model = OllamaEmbedding(model_name=os.getenv("EMBEDDING_MODEL"))
|
||||
Settings.llm = Ollama(model=os.getenv("MODEL"))
|
||||
base_url = os.getenv("OLLAMA_BASE_URL") or "http://127.0.0.1:11434"
|
||||
request_timeout = float(
|
||||
os.getenv("OLLAMA_REQUEST_TIMEOUT", DEFAULT_REQUEST_TIMEOUT)
|
||||
)
|
||||
Settings.embed_model = OllamaEmbedding(
|
||||
base_url=base_url,
|
||||
model_name=os.getenv("EMBEDDING_MODEL"),
|
||||
)
|
||||
Settings.llm = Ollama(
|
||||
base_url=base_url, model=os.getenv("MODEL"), request_timeout=request_timeout
|
||||
)
|
||||
|
||||
|
||||
def init_openai():
|
||||
from llama_index.llms.openai import OpenAI
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
from llama_index.llms.openai import OpenAI
|
||||
|
||||
max_tokens = os.getenv("LLM_MAX_TOKENS")
|
||||
config = {
|
||||
@@ -36,9 +55,92 @@ def init_openai():
|
||||
}
|
||||
Settings.llm = OpenAI(**config)
|
||||
|
||||
dimension = os.getenv("EMBEDDING_DIM")
|
||||
dimensions = os.getenv("EMBEDDING_DIM")
|
||||
config = {
|
||||
"model": os.getenv("EMBEDDING_MODEL"),
|
||||
"dimension": int(dimension) if dimension is not None else None,
|
||||
"dimensions": int(dimensions) if dimensions is not None else None,
|
||||
}
|
||||
Settings.embed_model = OpenAIEmbedding(**config)
|
||||
|
||||
|
||||
def init_azure_openai():
|
||||
from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
|
||||
from llama_index.llms.azure_openai import AzureOpenAI
|
||||
|
||||
llm_deployment = os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT")
|
||||
embedding_deployment = os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT")
|
||||
max_tokens = os.getenv("LLM_MAX_TOKENS")
|
||||
api_key = os.getenv("AZURE_OPENAI_API_KEY")
|
||||
llm_config = {
|
||||
"api_key": api_key,
|
||||
"deployment_name": llm_deployment,
|
||||
"model": os.getenv("MODEL"),
|
||||
"temperature": float(os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)),
|
||||
"max_tokens": int(max_tokens) if max_tokens is not None else None,
|
||||
}
|
||||
Settings.llm = AzureOpenAI(**llm_config)
|
||||
|
||||
dimensions = os.getenv("EMBEDDING_DIM")
|
||||
embedding_config = {
|
||||
"api_key": api_key,
|
||||
"deployment_name": embedding_deployment,
|
||||
"model": os.getenv("EMBEDDING_MODEL"),
|
||||
"dimensions": int(dimensions) if dimensions is not None else None,
|
||||
}
|
||||
Settings.embed_model = AzureOpenAIEmbedding(**embedding_config)
|
||||
|
||||
|
||||
def init_groq():
|
||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||
from llama_index.llms.groq import Groq
|
||||
|
||||
model_map: Dict[str, str] = {
|
||||
"llama3-8b": "llama3-8b-8192",
|
||||
"llama3-70b": "llama3-70b-8192",
|
||||
"mixtral-8x7b": "mixtral-8x7b-32768",
|
||||
}
|
||||
|
||||
embed_model_map: Dict[str, str] = {
|
||||
"all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
"all-mpnet-base-v2": "sentence-transformers/all-mpnet-base-v2",
|
||||
}
|
||||
|
||||
Settings.llm = Groq(model=model_map[os.getenv("MODEL")])
|
||||
Settings.embed_model = HuggingFaceEmbedding(
|
||||
model_name=embed_model_map[os.getenv("EMBEDDING_MODEL")]
|
||||
)
|
||||
|
||||
|
||||
def init_anthropic():
|
||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||
from llama_index.llms.anthropic import Anthropic
|
||||
|
||||
model_map: Dict[str, str] = {
|
||||
"claude-3-opus": "claude-3-opus-20240229",
|
||||
"claude-3-sonnet": "claude-3-sonnet-20240229",
|
||||
"claude-3-haiku": "claude-3-haiku-20240307",
|
||||
"claude-2.1": "claude-2.1",
|
||||
"claude-instant-1.2": "claude-instant-1.2",
|
||||
}
|
||||
|
||||
embed_model_map: Dict[str, str] = {
|
||||
"all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
"all-mpnet-base-v2": "sentence-transformers/all-mpnet-base-v2",
|
||||
}
|
||||
|
||||
Settings.llm = Anthropic(model=model_map[os.getenv("MODEL")])
|
||||
Settings.embed_model = HuggingFaceEmbedding(
|
||||
model_name=embed_model_map[os.getenv("EMBEDDING_MODEL")]
|
||||
)
|
||||
|
||||
|
||||
def init_gemini():
|
||||
from llama_index.embeddings.gemini import GeminiEmbedding
|
||||
from llama_index.llms.gemini import Gemini
|
||||
|
||||
model_name = f"models/{os.getenv('MODEL')}"
|
||||
embed_model_name = f"models/{os.getenv('EMBEDDING_MODEL')}"
|
||||
|
||||
Settings.llm = Gemini(model=model_name)
|
||||
Settings.embed_model = GeminiEmbedding(model_name=embed_model_name)
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
__pycache__
|
||||
storage
|
||||
.env
|
||||
.env
|
||||
@@ -11,6 +11,7 @@ from fastapi.responses import RedirectResponse
|
||||
from app.api.routers.chat import chat_router
|
||||
from app.settings import init_settings
|
||||
from app.observability import init_observability
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
@@ -20,7 +21,6 @@ init_observability()
|
||||
|
||||
environment = os.getenv("ENVIRONMENT", "dev") # Default to 'development' if not set
|
||||
|
||||
|
||||
if environment == "dev":
|
||||
logger = logging.getLogger("uvicorn")
|
||||
logger.warning("Running in development mode - allowing CORS for all origins")
|
||||
@@ -38,6 +38,16 @@ if environment == "dev":
|
||||
return RedirectResponse(url="/docs")
|
||||
|
||||
|
||||
def mount_static_files(directory, path):
|
||||
if os.path.exists(directory):
|
||||
app.mount(path, StaticFiles(directory=directory), name=f"{directory}-static")
|
||||
|
||||
|
||||
# Mount the data files to serve the file viewer
|
||||
mount_static_files("data", "/api/files/data")
|
||||
# Mount the output files from tools
|
||||
mount_static_files("tool-output", "/api/files/tool-output")
|
||||
|
||||
app.include_router(chat_router, prefix="/api/chat")
|
||||
|
||||
|
||||
|
||||
@@ -14,8 +14,9 @@ fastapi = "^0.109.1"
|
||||
uvicorn = { extras = ["standard"], version = "^0.23.2" }
|
||||
python-dotenv = "^1.0.0"
|
||||
aiostream = "^0.5.2"
|
||||
llama-index = "0.10.28"
|
||||
llama-index-core = "0.10.28"
|
||||
llama-index = "0.10.50"
|
||||
llama-index-core = "0.10.50"
|
||||
cachetools = "^5.3.3"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
import { NextResponse } from "next/server";
|
||||
|
||||
/**
|
||||
* This API is to get config from the backend envs and expose them to the frontend
|
||||
*/
|
||||
export async function GET() {
|
||||
const config = {
|
||||
starterQuestions: process.env.CONVERSATION_STARTERS?.trim().split("\n"),
|
||||
};
|
||||
return NextResponse.json(config, { status: 200 });
|
||||
}
|
||||
@@ -1,10 +1,18 @@
|
||||
import {
|
||||
Ollama,
|
||||
OllamaEmbedding,
|
||||
Anthropic,
|
||||
GEMINI_EMBEDDING_MODEL,
|
||||
GEMINI_MODEL,
|
||||
Gemini,
|
||||
GeminiEmbedding,
|
||||
Groq,
|
||||
OpenAI,
|
||||
OpenAIEmbedding,
|
||||
Settings,
|
||||
} from "llamaindex";
|
||||
import { HuggingFaceEmbedding } from "llamaindex/embeddings/HuggingFaceEmbedding";
|
||||
import { OllamaEmbedding } from "llamaindex/embeddings/OllamaEmbedding";
|
||||
import { ALL_AVAILABLE_ANTHROPIC_MODELS } from "llamaindex/llm/anthropic";
|
||||
import { Ollama } from "llamaindex/llm/ollama";
|
||||
|
||||
const CHUNK_SIZE = 512;
|
||||
const CHUNK_OVERLAP = 20;
|
||||
@@ -12,10 +20,24 @@ const CHUNK_OVERLAP = 20;
|
||||
export const initSettings = async () => {
|
||||
// HINT: you can delete the initialization code for unused model providers
|
||||
console.log(`Using '${process.env.MODEL_PROVIDER}' model provider`);
|
||||
|
||||
if (!process.env.MODEL || !process.env.EMBEDDING_MODEL) {
|
||||
throw new Error("'MODEL' and 'EMBEDDING_MODEL' env variables must be set.");
|
||||
}
|
||||
|
||||
switch (process.env.MODEL_PROVIDER) {
|
||||
case "ollama":
|
||||
initOllama();
|
||||
break;
|
||||
case "groq":
|
||||
initGroq();
|
||||
break;
|
||||
case "anthropic":
|
||||
initAnthropic();
|
||||
break;
|
||||
case "gemini":
|
||||
initGemini();
|
||||
break;
|
||||
default:
|
||||
initOpenAI();
|
||||
break;
|
||||
@@ -27,7 +49,9 @@ export const initSettings = async () => {
|
||||
function initOpenAI() {
|
||||
Settings.llm = new OpenAI({
|
||||
model: process.env.MODEL ?? "gpt-3.5-turbo",
|
||||
maxTokens: 512,
|
||||
maxTokens: process.env.LLM_MAX_TOKENS
|
||||
? Number(process.env.LLM_MAX_TOKENS)
|
||||
: undefined,
|
||||
});
|
||||
Settings.embedModel = new OpenAIEmbedding({
|
||||
model: process.env.EMBEDDING_MODEL,
|
||||
@@ -38,15 +62,58 @@ function initOpenAI() {
|
||||
}
|
||||
|
||||
function initOllama() {
|
||||
if (!process.env.MODEL || !process.env.EMBEDDING_MODEL) {
|
||||
throw new Error(
|
||||
"Using Ollama as model provider, 'MODEL' and 'EMBEDDING_MODEL' env variables must be set.",
|
||||
);
|
||||
}
|
||||
const config = {
|
||||
host: process.env.OLLAMA_BASE_URL ?? "http://127.0.0.1:11434",
|
||||
};
|
||||
Settings.llm = new Ollama({
|
||||
model: process.env.MODEL ?? "",
|
||||
config,
|
||||
});
|
||||
Settings.embedModel = new OllamaEmbedding({
|
||||
model: process.env.EMBEDDING_MODEL ?? "",
|
||||
config,
|
||||
});
|
||||
}
|
||||
|
||||
function initGroq() {
|
||||
const embedModelMap: Record<string, string> = {
|
||||
"all-MiniLM-L6-v2": "Xenova/all-MiniLM-L6-v2",
|
||||
"all-mpnet-base-v2": "Xenova/all-mpnet-base-v2",
|
||||
};
|
||||
|
||||
const modelMap: Record<string, string> = {
|
||||
"llama3-8b": "llama3-8b-8192",
|
||||
"llama3-70b": "llama3-70b-8192",
|
||||
"mixtral-8x7b": "mixtral-8x7b-32768",
|
||||
};
|
||||
|
||||
Settings.llm = new Groq({
|
||||
model: modelMap[process.env.MODEL!],
|
||||
});
|
||||
|
||||
Settings.embedModel = new HuggingFaceEmbedding({
|
||||
modelType: embedModelMap[process.env.EMBEDDING_MODEL!],
|
||||
});
|
||||
}
|
||||
|
||||
function initAnthropic() {
|
||||
const embedModelMap: Record<string, string> = {
|
||||
"all-MiniLM-L6-v2": "Xenova/all-MiniLM-L6-v2",
|
||||
"all-mpnet-base-v2": "Xenova/all-mpnet-base-v2",
|
||||
};
|
||||
Settings.llm = new Anthropic({
|
||||
model: process.env.MODEL as keyof typeof ALL_AVAILABLE_ANTHROPIC_MODELS,
|
||||
});
|
||||
Settings.embedModel = new HuggingFaceEmbedding({
|
||||
modelType: embedModelMap[process.env.EMBEDDING_MODEL!],
|
||||
});
|
||||
}
|
||||
|
||||
function initGemini() {
|
||||
Settings.llm = new Gemini({
|
||||
model: process.env.MODEL as GEMINI_MODEL,
|
||||
});
|
||||
Settings.embedModel = new GeminiEmbedding({
|
||||
model: process.env.EMBEDDING_MODEL as GEMINI_EMBEDDING_MODEL,
|
||||
});
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user