Compare commits

..

7 Commits

Author SHA1 Message Date
leehuwuj c3215ccc7b better log 2024-05-02 15:23:06 +07:00
leehuwuj 18ca18123f split code to run_ingestion_pipeline and persist_storage 2024-05-02 15:18:40 +07:00
leehuwuj 5ecb0c9fb7 update comments and remove stores_index 2024-05-02 14:15:56 +07:00
leehuwuj 7e45f604e6 Fix dimensions typo in settings.py 2024-05-02 10:45:58 +07:00
leehuwuj bbacf0f199 refactor code and comments 2024-05-02 10:43:54 +07:00
leehuwuj c0c6df80c7 fix redundant stashed code 2024-05-02 09:25:05 +07:00
leehuwuj 3b39a12ad6 Refactor code to persist the docstore and index in the SimpleVectorStore case 2024-05-02 08:50:09 +07:00
112 changed files with 970 additions and 4106 deletions
+5
View File
@@ -0,0 +1,5 @@
---
"create-llama": patch
---
Use ingestion pipeline for Python
+5
View File
@@ -0,0 +1,5 @@
---
"create-llama": patch
---
Display events (e.g. retrieving nodes) per chat message
+1 -1
View File
@@ -26,7 +26,7 @@ jobs:
- uses: actions/checkout@v4
- name: Set up python ${{ matrix.python-version }}
uses: actions/setup-python@v5
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
-64
View File
@@ -1,69 +1,5 @@
# create-llama
## 0.1.9
### Patch Changes
- a42fa53: Add CSV upload
- 563b51d: Fix Vercel streaming (python) to stream data events instantly
- d60b3c5: Add E2B code interpreter tool for FastAPI
- 956538e: Add OpenAPI action tool for FastAPI
## 0.1.8
### Patch Changes
- cd50a33: Add interpreter tool for TS using e2b.dev
## 0.1.7
### Patch Changes
- 260d37a: Add system prompt env variable for TS
- bbd5b8d: Fix postgres connection leaking issue
- bb53425: Support HTTP proxies by setting the GLOBAL_AGENT_HTTP_PROXY env variable
- 69c2e16: Fix streaming for Express
- 7873bfb: Update Ollama provider to run with the base URL from the environment variable
## 0.1.6
### Patch Changes
- 56537a1: Display PDF files in source nodes
## 0.1.5
### Patch Changes
- 84db798: feat: support display latex in chat markdown
## 0.1.4
### Patch Changes
- 0bc8e75: Use ingestion pipeline for dedicated vector stores (Python only)
- cb1001d: Add ChromaDB vector store
## 0.1.3
### Patch Changes
- 416073d: Directly import vector stores to work with NextJS
## 0.1.2
### Patch Changes
- 056e376: Add support for displaying tool outputs (including weather widget as example)
## 0.1.1
### Patch Changes
- 7bd3ed5: Support Anthropic and Gemini as model providers
- 7bd3ed5: Support new agents from LITS 0.3
- cfb5257: Display events (e.g. retrieving nodes) per chat message
## 0.1.0
### Minor Changes
+33 -147
View File
@@ -1,6 +1,5 @@
import fs from "fs/promises";
import path from "path";
import { TOOL_SYSTEM_PROMPT_ENV_VAR, Tool } from "./tools";
import {
ModelConfig,
TemplateDataSource,
@@ -8,7 +7,7 @@ import {
TemplateVectorDB,
} from "./types";
export type EnvVar = {
type EnvVar = {
name?: string;
description?: string;
value?: string;
@@ -30,20 +29,17 @@ const renderEnvVar = (envVars: EnvVar[]): string => {
);
};
const getVectorDBEnvs = (
vectorDb?: TemplateVectorDB,
framework?: TemplateFramework,
): EnvVar[] => {
if (!vectorDb || !framework) {
const getVectorDBEnvs = (vectorDb?: TemplateVectorDB): EnvVar[] => {
if (!vectorDb) {
return [];
}
switch (vectorDb) {
case "mongo":
return [
{
name: "MONGODB_URI",
name: "MONGO_URI",
description:
"For generating a connection URI, see https://www.mongodb.com/docs/manual/reference/connection-string/ \nThe MongoDB connection URI.",
"For generating a connection URI, see https://docs.timescale.com/use-timescale/latest/services/create-a-service\nThe MongoDB connection URI.",
},
{
name: "MONGODB_DATABASE",
@@ -133,31 +129,6 @@ const getVectorDBEnvs = (
"Optional API key for authenticating requests to Qdrant.",
},
];
case "chroma":
const envs = [
{
name: "CHROMA_COLLECTION",
description: "The name of the collection in your Chroma database",
},
{
name: "CHROMA_HOST",
description: "The API endpoint for your Chroma database",
},
{
name: "CHROMA_PORT",
description: "The port for your Chroma database",
},
];
// TS Version doesn't support config local storage path
if (framework === "fastapi") {
envs.push({
name: "CHROMA_PATH",
description: `The local path to the Chroma database.
Specify this if you are using a local Chroma database.
Otherwise, use CHROMA_HOST and CHROMA_PORT config above`,
});
}
return envs;
default:
return [];
}
@@ -185,10 +156,6 @@ const getModelEnvs = (modelConfig: ModelConfig): EnvVar[] => {
description: "Dimension of the embedding model to use.",
value: modelConfig.dimensions.toString(),
},
{
name: "CONVERSATION_STARTERS",
description: "The questions to help users get started (multi-line).",
},
...(modelConfig.provider === "openai"
? [
{
@@ -206,70 +173,41 @@ const getModelEnvs = (modelConfig: ModelConfig): EnvVar[] => {
},
]
: []),
...(modelConfig.provider === "anthropic"
? [
{
name: "ANTHROPIC_API_KEY",
description: "The Anthropic API key to use.",
value: modelConfig.apiKey,
},
]
: []),
...(modelConfig.provider === "gemini"
? [
{
name: "GOOGLE_API_KEY",
description: "The Google API key to use.",
value: modelConfig.apiKey,
},
]
: []),
...(modelConfig.provider === "ollama"
? [
{
name: "OLLAMA_BASE_URL",
description:
"The base URL for the Ollama API. Eg: http://127.0.0.1:11434",
},
]
: []),
];
};
const getFrameworkEnvs = (
framework: TemplateFramework,
framework?: TemplateFramework,
port?: number,
): EnvVar[] => {
const sPort = port?.toString() || "8000";
const result: EnvVar[] = [
if (framework !== "fastapi") {
return [];
}
return [
{
name: "FILESERVER_URL_PREFIX",
description:
"FILESERVER_URL_PREFIX is the URL prefix of the server storing the images generated by the interpreter.",
value:
framework === "nextjs"
? // FIXME: if we are using nextjs, port should be 3000
"http://localhost:3000/api/files"
: `http://localhost:${sPort}/api/files`,
name: "APP_HOST",
description: "The address to start the backend app.",
value: "0.0.0.0",
},
{
name: "APP_PORT",
description: "The port to start the backend app.",
value: port?.toString() || "8000",
},
// TODO: Once LlamaIndexTS supports string templates, move this to `getEngineEnvs`
{
name: "SYSTEM_PROMPT",
description: `Custom system prompt.
Example:
SYSTEM_PROMPT="
We have provided context information below.
---------------------
{context_str}
---------------------
Given this information, please answer the question: {query_str}
"`,
},
];
if (framework === "fastapi") {
result.push(
...[
{
name: "APP_HOST",
description: "The address to start the backend app.",
value: "0.0.0.0",
},
{
name: "APP_PORT",
description: "The port to start the backend app.",
value: sPort,
},
],
);
}
return result;
};
const getEngineEnvs = (): EnvVar[] => {
@@ -280,68 +218,18 @@ const getEngineEnvs = (): EnvVar[] => {
"The number of similar embeddings to return when retrieving documents.",
value: "3",
},
{
name: "STREAM_TIMEOUT",
description:
"The time in milliseconds to wait for the stream to return a response.",
value: "60000",
},
];
};
const getToolEnvs = (tools?: Tool[]): EnvVar[] => {
if (!tools?.length) return [];
const toolEnvs: EnvVar[] = [];
tools.forEach((tool) => {
if (tool.envVars?.length) {
toolEnvs.push(
// Don't include the system prompt env var here
// It should be handled separately by merging with the default system prompt
...tool.envVars.filter(
(env) => env.name !== TOOL_SYSTEM_PROMPT_ENV_VAR,
),
);
}
});
return toolEnvs;
};
const getSystemPromptEnv = (tools?: Tool[]): EnvVar => {
const defaultSystemPrompt =
"You are a helpful assistant who helps users with their questions.";
// build tool system prompt by merging all tool system prompts
let toolSystemPrompt = "";
tools?.forEach((tool) => {
const toolSystemPromptEnv = tool.envVars?.find(
(env) => env.name === TOOL_SYSTEM_PROMPT_ENV_VAR,
);
if (toolSystemPromptEnv) {
toolSystemPrompt += toolSystemPromptEnv.value + "\n";
}
});
const systemPrompt = toolSystemPrompt
? `\"${toolSystemPrompt}\"`
: defaultSystemPrompt;
return {
name: "SYSTEM_PROMPT",
description: "The system prompt for the AI model.",
value: systemPrompt,
};
};
export const createBackendEnvFile = async (
root: string,
opts: {
llamaCloudKey?: string;
vectorDb?: TemplateVectorDB;
modelConfig: ModelConfig;
framework: TemplateFramework;
framework?: TemplateFramework;
dataSources?: TemplateDataSource[];
port?: number;
tools?: Tool[];
},
) => {
// Init env values
@@ -357,10 +245,8 @@ export const createBackendEnvFile = async (
// Add engine environment variables
...getEngineEnvs(),
// Add vector database environment variables
...getVectorDBEnvs(opts.vectorDb, opts.framework),
...getVectorDBEnvs(opts.vectorDb),
...getFrameworkEnvs(opts.framework, opts.port),
...getToolEnvs(opts.tools),
getSystemPromptEnv(opts.tools),
];
// Render and write env file
const content = renderEnvVar(envVars);
+2 -7
View File
@@ -9,6 +9,7 @@ import { createBackendEnvFile, createFrontendEnvFile } from "./env-variables";
import { PackageManager } from "./get-pkg-manager";
import { installLlamapackProject } from "./llama-pack";
import { isHavingPoetryLockFile, tryPoetryRun } from "./poetry";
import { isModelConfigured } from "./providers";
import { installPythonTemplate } from "./python";
import { downloadAndExtractRepo } from "./repo";
import { ConfigFileType, writeToolsConfig } from "./tools";
@@ -37,7 +38,7 @@ async function generateContextData(
? "poetry run generate"
: `${packageManager} run generate`,
)}`;
const modelConfigured = modelConfig.isConfigured();
const modelConfigured = isModelConfigured(modelConfig);
const llamaCloudKeyConfigured = useLlamaParse
? llamaCloudKey || process.env["LLAMA_CLOUD_API_KEY"]
: true;
@@ -148,7 +149,6 @@ export const installTemplate = async (
framework: props.framework,
dataSources: props.dataSources,
port: props.externalPort,
tools: props.tools,
});
if (props.dataSources.length > 0) {
@@ -171,11 +171,6 @@ export const installTemplate = async (
);
}
}
// Create tool-output directory
if (props.tools && props.tools.length > 0) {
await fsExtra.mkdir(path.join(props.root, "tool-output"));
}
} else {
// this is a frontend for a full-stack app, create .env file with model information
await createFrontendEnvFile(props.root, {
-106
View File
@@ -1,106 +0,0 @@
import ciInfo from "ci-info";
import prompts from "prompts";
import { ModelConfigParams } from ".";
import { questionHandlers, toChoice } from "../../questions";
const MODELS = [
"claude-3-opus",
"claude-3-sonnet",
"claude-3-haiku",
"claude-2.1",
"claude-instant-1.2",
];
const DEFAULT_MODEL = MODELS[0];
// TODO: get embedding vector dimensions from the anthropic sdk (currently not supported)
// Use huggingface embedding models for now
enum HuggingFaceEmbeddingModelType {
XENOVA_ALL_MINILM_L6_V2 = "all-MiniLM-L6-v2",
XENOVA_ALL_MPNET_BASE_V2 = "all-mpnet-base-v2",
}
type ModelData = {
dimensions: number;
};
const EMBEDDING_MODELS: Record<HuggingFaceEmbeddingModelType, ModelData> = {
[HuggingFaceEmbeddingModelType.XENOVA_ALL_MINILM_L6_V2]: {
dimensions: 384,
},
[HuggingFaceEmbeddingModelType.XENOVA_ALL_MPNET_BASE_V2]: {
dimensions: 768,
},
};
const DEFAULT_EMBEDDING_MODEL = Object.keys(EMBEDDING_MODELS)[0];
const DEFAULT_DIMENSIONS = Object.values(EMBEDDING_MODELS)[0].dimensions;
type AnthropicQuestionsParams = {
apiKey?: string;
askModels: boolean;
};
export async function askAnthropicQuestions({
askModels,
apiKey,
}: AnthropicQuestionsParams): Promise<ModelConfigParams> {
const config: ModelConfigParams = {
apiKey,
model: DEFAULT_MODEL,
embeddingModel: DEFAULT_EMBEDDING_MODEL,
dimensions: DEFAULT_DIMENSIONS,
isConfigured(): boolean {
if (config.apiKey) {
return true;
}
if (process.env["ANTHROPIC_API_KEY"]) {
return true;
}
return false;
},
};
if (!config.apiKey) {
const { key } = await prompts(
{
type: "text",
name: "key",
message:
"Please provide your Anthropic API key (or leave blank to use ANTHROPIC_API_KEY env variable):",
},
questionHandlers,
);
config.apiKey = key || process.env.ANTHROPIC_API_KEY;
}
// use default model values in CI or if user should not be asked
const useDefaults = ciInfo.isCI || !askModels;
if (!useDefaults) {
const { model } = await prompts(
{
type: "select",
name: "model",
message: "Which LLM model would you like to use?",
choices: MODELS.map(toChoice),
initial: 0,
},
questionHandlers,
);
config.model = model;
const { embeddingModel } = await prompts(
{
type: "select",
name: "embeddingModel",
message: "Which embedding model would you like to use?",
choices: Object.keys(EMBEDDING_MODELS).map(toChoice),
initial: 0,
},
questionHandlers,
);
config.embeddingModel = embeddingModel;
config.dimensions =
EMBEDDING_MODELS[
embeddingModel as HuggingFaceEmbeddingModelType
].dimensions;
}
return config;
}
-87
View File
@@ -1,87 +0,0 @@
import ciInfo from "ci-info";
import prompts from "prompts";
import { ModelConfigParams } from ".";
import { questionHandlers, toChoice } from "../../questions";
const MODELS = ["gemini-1.5-pro-latest", "gemini-pro", "gemini-pro-vision"];
type ModelData = {
dimensions: number;
};
const EMBEDDING_MODELS: Record<string, ModelData> = {
"embedding-001": { dimensions: 768 },
"text-embedding-004": { dimensions: 768 },
};
const DEFAULT_MODEL = MODELS[0];
const DEFAULT_EMBEDDING_MODEL = Object.keys(EMBEDDING_MODELS)[0];
const DEFAULT_DIMENSIONS = Object.values(EMBEDDING_MODELS)[0].dimensions;
type GeminiQuestionsParams = {
apiKey?: string;
askModels: boolean;
};
export async function askGeminiQuestions({
askModels,
apiKey,
}: GeminiQuestionsParams): Promise<ModelConfigParams> {
const config: ModelConfigParams = {
apiKey,
model: DEFAULT_MODEL,
embeddingModel: DEFAULT_EMBEDDING_MODEL,
dimensions: DEFAULT_DIMENSIONS,
isConfigured(): boolean {
if (config.apiKey) {
return true;
}
if (process.env["GOOGLE_API_KEY"]) {
return true;
}
return false;
},
};
if (!config.apiKey) {
const { key } = await prompts(
{
type: "text",
name: "key",
message:
"Please provide your Google API key (or leave blank to use GOOGLE_API_KEY env variable):",
},
questionHandlers,
);
config.apiKey = key || process.env.GOOGLE_API_KEY;
}
// use default model values in CI or if user should not be asked
const useDefaults = ciInfo.isCI || !askModels;
if (!useDefaults) {
const { model } = await prompts(
{
type: "select",
name: "model",
message: "Which LLM model would you like to use?",
choices: MODELS.map(toChoice),
initial: 0,
},
questionHandlers,
);
config.model = model;
const { embeddingModel } = await prompts(
{
type: "select",
name: "embeddingModel",
message: "Which embedding model would you like to use?",
choices: Object.keys(EMBEDDING_MODELS).map(toChoice),
initial: 0,
},
questionHandlers,
);
config.embeddingModel = embeddingModel;
config.dimensions = EMBEDDING_MODELS[embeddingModel].dimensions;
}
return config;
}
+10 -11
View File
@@ -2,10 +2,8 @@ import ciInfo from "ci-info";
import prompts from "prompts";
import { questionHandlers } from "../../questions";
import { ModelConfig, ModelProvider } from "../types";
import { askAnthropicQuestions } from "./anthropic";
import { askGeminiQuestions } from "./gemini";
import { askOllamaQuestions } from "./ollama";
import { askOpenAIQuestions } from "./openai";
import { askOpenAIQuestions, isOpenAIConfigured } from "./openai";
const DEFAULT_MODEL_PROVIDER = "openai";
@@ -33,8 +31,6 @@ export async function askModelConfig({
value: "openai",
},
{ title: "Ollama", value: "ollama" },
{ title: "Anthropic", value: "anthropic" },
{ title: "Gemini", value: "gemini" },
],
initial: 0,
},
@@ -48,12 +44,6 @@ export async function askModelConfig({
case "ollama":
modelConfig = await askOllamaQuestions({ askModels });
break;
case "anthropic":
modelConfig = await askAnthropicQuestions({ askModels });
break;
case "gemini":
modelConfig = await askGeminiQuestions({ askModels });
break;
default:
modelConfig = await askOpenAIQuestions({
openAiKey,
@@ -65,3 +55,12 @@ export async function askModelConfig({
provider: modelProvider,
};
}
export function isModelConfigured(modelConfig: ModelConfig): boolean {
switch (modelConfig.provider) {
case "openai":
return isOpenAIConfigured(modelConfig);
default:
return true;
}
}
-3
View File
@@ -29,9 +29,6 @@ export async function askOllamaQuestions({
model: DEFAULT_MODEL,
embeddingModel: DEFAULT_EMBEDDING_MODEL,
dimensions: EMBEDDING_MODELS[DEFAULT_EMBEDDING_MODEL].dimensions,
isConfigured(): boolean {
return true;
},
};
// use default model values in CI or if user should not be asked
+12 -10
View File
@@ -8,7 +8,7 @@ import { questionHandlers } from "../../questions";
const OPENAI_API_URL = "https://api.openai.com/v1";
const DEFAULT_MODEL = "gpt-3.5-turbo";
const DEFAULT_MODEL = "gpt-4-turbo";
const DEFAULT_EMBEDDING_MODEL = "text-embedding-3-large";
export async function askOpenAIQuestions({
@@ -20,15 +20,6 @@ export async function askOpenAIQuestions({
model: DEFAULT_MODEL,
embeddingModel: DEFAULT_EMBEDDING_MODEL,
dimensions: getDimensions(DEFAULT_EMBEDDING_MODEL),
isConfigured(): boolean {
if (config.apiKey) {
return true;
}
if (process.env["OPENAI_API_KEY"]) {
return true;
}
return false;
},
};
if (!config.apiKey) {
@@ -40,6 +31,7 @@ export async function askOpenAIQuestions({
? "Please provide your OpenAI API key (or leave blank to use OPENAI_API_KEY env variable):"
: "Please provide your OpenAI API key (leave blank to skip):",
validate: (value: string) => {
console.log(value);
if (askModels && !value) {
if (process.env.OPENAI_API_KEY) {
return true;
@@ -86,6 +78,16 @@ export async function askOpenAIQuestions({
return config;
}
export function isOpenAIConfigured(params: ModelConfigParams): boolean {
if (params.apiKey) {
return true;
}
if (process.env["OPENAI_API_KEY"]) {
return true;
}
return false;
}
async function getAvailableModelChoices(
selectEmbedding: boolean,
apiKey?: string,
-8
View File
@@ -1,8 +0,0 @@
/* Function to conditionally load the global-agent/bootstrap module */
export async function initializeGlobalAgent() {
if (process.env.GLOBAL_AGENT_HTTP_PROXY) {
/* Dynamically import global-agent/bootstrap */
await import("global-agent/bootstrap");
console.log("Proxy enabled via global-agent.");
}
}
+34 -79
View File
@@ -24,7 +24,7 @@ interface Dependency {
const getAdditionalDependencies = (
modelConfig: ModelConfig,
vectorDb?: TemplateVectorDB,
dataSources?: TemplateDataSource[],
dataSource?: TemplateDataSource,
tools?: Tool[],
) => {
const dependencies: Dependency[] = [];
@@ -43,7 +43,6 @@ const getAdditionalDependencies = (
name: "llama-index-vector-stores-postgres",
version: "^0.1.1",
});
break;
}
case "pinecone": {
dependencies.push({
@@ -70,60 +69,41 @@ const getAdditionalDependencies = (
});
break;
}
case "qdrant": {
dependencies.push({
name: "llama-index-vector-stores-qdrant",
version: "^0.2.8",
});
break;
}
case "chroma": {
dependencies.push({
name: "llama-index-vector-stores-chroma",
version: "^0.1.8",
});
break;
}
}
// Add data source dependencies
if (dataSources) {
for (const ds of dataSources) {
const dsType = ds?.type;
switch (dsType) {
case "file":
dependencies.push({
name: "docx2txt",
version: "^0.8",
});
break;
case "web":
dependencies.push({
name: "llama-index-readers-web",
version: "^0.1.6",
});
break;
case "db":
dependencies.push({
name: "llama-index-readers-database",
version: "^0.1.3",
});
dependencies.push({
name: "pymysql",
version: "^1.1.0",
extras: ["rsa"],
});
dependencies.push({
name: "psycopg2",
version: "^2.9.9",
});
break;
}
}
const dataSourceType = dataSource?.type;
switch (dataSourceType) {
case "file":
dependencies.push({
name: "docx2txt",
version: "^0.8",
});
break;
case "web":
dependencies.push({
name: "llama-index-readers-web",
version: "^0.1.6",
});
break;
case "db":
dependencies.push({
name: "llama-index-readers-database",
version: "^0.1.3",
});
dependencies.push({
name: "pymysql",
version: "^1.1.0",
extras: ["rsa"],
});
dependencies.push({
name: "psycopg2",
version: "^2.9.9",
});
break;
}
// Add tools dependencies
console.log("Adding tools dependencies");
tools?.forEach((tool) => {
tool.dependencies?.forEach((dep) => {
dependencies.push(dep);
@@ -144,27 +124,7 @@ const getAdditionalDependencies = (
case "openai":
dependencies.push({
name: "llama-index-agent-openai",
version: "0.2.6",
});
break;
case "anthropic":
dependencies.push({
name: "llama-index-llms-anthropic",
version: "0.1.10",
});
dependencies.push({
name: "llama-index-embeddings-huggingface",
version: "0.2.0",
});
break;
case "gemini":
dependencies.push({
name: "llama-index-llms-gemini",
version: "0.1.10",
});
dependencies.push({
name: "llama-index-embeddings-gemini",
version: "0.1.6",
version: "0.2.2",
});
break;
}
@@ -318,14 +278,9 @@ export const installPythonTemplate = async ({
cwd: path.join(compPath, "engines", "python", engine),
});
console.log("Adding additional dependencies");
const addOnDependencies = getAdditionalDependencies(
modelConfig,
vectorDb,
dataSources,
tools,
);
const addOnDependencies = dataSources
.map((ds) => getAdditionalDependencies(modelConfig, vectorDb, ds, tools))
.flat();
if (observability === "opentelemetry") {
addOnDependencies.push({
+5 -121
View File
@@ -2,25 +2,15 @@ import fs from "fs/promises";
import path from "path";
import { red } from "picocolors";
import yaml from "yaml";
import { EnvVar } from "./env-variables";
import { makeDir } from "./make-dir";
import { TemplateFramework } from "./types";
export const TOOL_SYSTEM_PROMPT_ENV_VAR = "TOOL_SYSTEM_PROMPT";
export enum ToolType {
LLAMAHUB = "llamahub",
LOCAL = "local",
}
export type Tool = {
display: string;
name: string;
config?: Record<string, any>;
dependencies?: ToolDependencies[];
supportedFrameworks?: Array<TemplateFramework>;
type: ToolType;
envVars?: EnvVar[];
};
export type ToolDependencies = {
@@ -30,7 +20,7 @@ export type ToolDependencies = {
export const supportedTools: Tool[] = [
{
display: "Google Search",
display: "Google Search (configuration required after installation)",
name: "google.GoogleSearchToolSpec",
config: {
engine:
@@ -45,14 +35,6 @@ export const supportedTools: Tool[] = [
},
],
supportedFrameworks: ["fastapi"],
type: ToolType.LLAMAHUB,
envVars: [
{
name: TOOL_SYSTEM_PROMPT_ENV_VAR,
description: "System prompt for google search tool.",
value: `You are a Google search agent. You help users to get information from Google search.`,
},
],
},
{
display: "Wikipedia",
@@ -64,89 +46,6 @@ export const supportedTools: Tool[] = [
},
],
supportedFrameworks: ["fastapi", "express", "nextjs"],
type: ToolType.LLAMAHUB,
envVars: [
{
name: TOOL_SYSTEM_PROMPT_ENV_VAR,
description: "System prompt for wiki tool.",
value: `You are a Wikipedia agent. You help users to get information from Wikipedia.`,
},
],
},
{
display: "Weather",
name: "weather",
dependencies: [],
supportedFrameworks: ["fastapi", "express", "nextjs"],
type: ToolType.LOCAL,
envVars: [
{
name: TOOL_SYSTEM_PROMPT_ENV_VAR,
description: "System prompt for weather tool.",
value: `You are a weather forecast agent. You help users to get the weather forecast for a given location.`,
},
],
},
{
display: "Code Interpreter",
name: "interpreter",
dependencies: [
{
name: "e2b_code_interpreter",
version: "0.0.7",
},
],
supportedFrameworks: ["fastapi", "express", "nextjs"],
type: ToolType.LOCAL,
envVars: [
{
name: "E2B_API_KEY",
description:
"E2B_API_KEY key is required to run code interpreter tool. Get it here: https://e2b.dev/docs/getting-started/api-key",
},
{
name: TOOL_SYSTEM_PROMPT_ENV_VAR,
description: "System prompt for code interpreter tool.",
value: `You are a Python interpreter.
- You are given tasks to complete and you run python code to solve them.
- The python code runs in a Jupyter notebook. Every time you call \`interpreter\` tool, the python code is executed in a separate cell. It's okay to make multiple calls to \`interpreter\`.
- Display visualizations using matplotlib or any other visualization library directly in the notebook. Shouldn't save the visualizations to a file, just return the base64 encoded data.
- You can install any pip package (if it exists) if you need to but the usual packages for data analysis are already preinstalled.
- You can run any python code you want in a secure environment.
- Use absolute url from result to display images or any other media.`,
},
],
},
{
display: "OpenAPI action",
name: "openapi_action.OpenAPIActionToolSpec",
dependencies: [
{
name: "llama-index-tools-openapi",
version: "0.1.3",
},
{
name: "jsonschema",
version: "^4.22.0",
},
{
name: "llama-index-tools-requests",
version: "0.1.3",
},
],
config: {
openapi_uri: "The URL or file path of the OpenAPI schema",
},
supportedFrameworks: ["fastapi"],
type: ToolType.LOCAL,
envVars: [
{
name: TOOL_SYSTEM_PROMPT_ENV_VAR,
description: "System prompt for openapi action tool.",
value:
"You are an OpenAPI action agent. You help users to make requests to the provided OpenAPI schema.",
},
],
},
];
@@ -173,15 +72,9 @@ export const getTools = (toolsName: string[]): Tool[] => {
return tools;
};
export const toolRequiresConfig = (tool: Tool): boolean => {
const hasConfig = Object.keys(tool.config || {}).length > 0;
const hasEmptyEnvVar = tool.envVars?.some((envVar) => !envVar.value) ?? false;
return hasConfig || hasEmptyEnvVar;
};
export const toolsRequireConfig = (tools?: Tool[]): boolean => {
if (tools) {
return tools?.some(toolRequiresConfig);
return tools?.some((tool) => Object.keys(tool.config || {}).length > 0);
}
return false;
};
@@ -196,19 +89,10 @@ export const writeToolsConfig = async (
tools: Tool[] = [],
type: ConfigFileType = ConfigFileType.YAML,
) => {
const configContent: {
[key in ToolType]: Record<string, any>;
} = {
local: {},
llamahub: {},
};
if (tools.length === 0) return; // no tools selected, no config need
const configContent: Record<string, any> = {};
tools.forEach((tool) => {
if (tool.type === ToolType.LLAMAHUB) {
configContent.llamahub[tool.name] = tool.config ?? {};
}
if (tool.type === ToolType.LOCAL) {
configContent.local[tool.name] = tool.config ?? {};
}
configContent[tool.name] = tool.config ?? {};
});
const configPath = path.join(root, "config");
await makeDir(configPath);
+2 -4
View File
@@ -1,14 +1,13 @@
import { PackageManager } from "../helpers/get-pkg-manager";
import { Tool } from "./tools";
export type ModelProvider = "openai" | "ollama" | "anthropic" | "gemini";
export type ModelProvider = "openai" | "ollama";
export type ModelConfig = {
provider: ModelProvider;
apiKey?: string;
model: string;
embeddingModel: string;
dimensions: number;
isConfigured(): boolean;
};
export type TemplateType = "streaming" | "community" | "llamapack";
export type TemplateFramework = "nextjs" | "express" | "fastapi";
@@ -20,8 +19,7 @@ export type TemplateVectorDB =
| "pinecone"
| "milvus"
| "astra"
| "qdrant"
| "chroma";
| "qdrant";
export type TemplatePostInstallAction =
| "none"
| "VSCode"
+1 -1
View File
@@ -105,7 +105,7 @@ export const installTSTemplate = async ({
const enginePath = path.join(root, relativeEngineDestPath, "engine");
// copy vector db component
console.log("\nUsing vector DB:", vectorDb ?? "none", "\n");
console.log("\nUsing vector DB:", vectorDb, "\n");
await copy("**", enginePath, {
parents: true,
cwd: path.join(compPath, "vectordbs", "typescript", vectorDb ?? "none"),
-4
View File
@@ -12,16 +12,12 @@ import { createApp } from "./create-app";
import { getDataSources } from "./helpers/datasources";
import { getPkgManager } from "./helpers/get-pkg-manager";
import { isFolderEmpty } from "./helpers/is-folder-empty";
import { initializeGlobalAgent } from "./helpers/proxy";
import { runApp } from "./helpers/run-app";
import { getTools } from "./helpers/tools";
import { validateNpmName } from "./helpers/validate-pkg";
import packageJson from "./package.json";
import { QuestionArgs, askQuestions, onPromptState } from "./questions";
// Run the initialization function
initializeGlobalAgent();
let projectPath: string = "";
const handleSigTerm = () => process.exit(0);
+1 -2
View File
@@ -1,6 +1,6 @@
{
"name": "create-llama",
"version": "0.1.9",
"version": "0.1.0",
"description": "Create LlamaIndex-powered apps with one command",
"keywords": [
"rag",
@@ -52,7 +52,6 @@
"cross-spawn": "7.0.3",
"fast-glob": "3.3.1",
"fs-extra": "11.2.0",
"global-agent": "^3.0.0",
"got": "10.7.0",
"ollama": "^0.5.0",
"ora": "^8.0.1",
+147 -267
View File
File diff suppressed because it is too large Load Diff
+7 -11
View File
@@ -14,13 +14,9 @@ import { COMMUNITY_OWNER, COMMUNITY_REPO } from "./helpers/constant";
import { EXAMPLE_FILE } from "./helpers/datasources";
import { templatesDir } from "./helpers/dir";
import { getAvailableLlamapackOptions } from "./helpers/llama-pack";
import { askModelConfig } from "./helpers/providers";
import { askModelConfig, isModelConfigured } from "./helpers/providers";
import { getProjectOptions } from "./helpers/repo";
import {
supportedTools,
toolRequiresConfig,
toolsRequireConfig,
} from "./helpers/tools";
import { supportedTools, toolsRequireConfig } from "./helpers/tools";
export type QuestionArgs = Omit<
InstallAppArgs,
@@ -101,7 +97,6 @@ const getVectorDbChoices = (framework: TemplateFramework) => {
{ title: "Milvus", value: "milvus" },
{ title: "Astra", value: "astra" },
{ title: "Qdrant", value: "qdrant" },
{ title: "ChromaDB", value: "chroma" },
];
const vectordbLang = framework === "fastapi" ? "python" : "typescript";
@@ -262,8 +257,7 @@ export const askQuestions = async (
},
];
const modelConfigured =
!program.llamapack && program.modelConfig.isConfigured();
const modelConfigured = isModelConfigured(program.modelConfig);
// If using LlamaParse, require LlamaCloud API key
const llamaCloudKeyConfigured = program.useLlamaParse
? program.llamaCloudKey || process.env["LLAMA_CLOUD_API_KEY"]
@@ -274,7 +268,8 @@ export const askQuestions = async (
!hasVectorDb &&
modelConfigured &&
llamaCloudKeyConfigured &&
!toolsRequireConfig(program.tools)
!toolsRequireConfig(program.tools) &&
!program.llamapack
) {
actionChoices.push({
title:
@@ -403,6 +398,7 @@ export const askQuestions = async (
if (program.framework === "express" || program.framework === "fastapi") {
// if a backend-only framework is selected, ask whether we should create a frontend
// (only for streaming backends)
if (program.frontend === undefined) {
if (ciInfo.isCI) {
program.frontend = getPrefOrDefault("frontend");
@@ -656,7 +652,7 @@ export const askQuestions = async (
t.supportedFrameworks?.includes(program.framework),
);
const toolChoices = options.map((tool) => ({
title: `${tool.display}${toolRequiresConfig(tool) ? "" : " (no config needed)"}`,
title: tool.display,
value: tool.name,
}));
const { toolsName } = await prompts({
@@ -0,0 +1,35 @@
import os
import yaml
import importlib
from llama_index.core.tools.tool_spec.base import BaseToolSpec
from llama_index.core.tools.function_tool import FunctionTool
class ToolFactory:
@staticmethod
def create_tool(tool_name: str, **kwargs) -> list[FunctionTool]:
try:
tool_package, tool_cls_name = tool_name.split(".")
module_name = f"llama_index.tools.{tool_package}"
module = importlib.import_module(module_name)
tool_class = getattr(module, tool_cls_name)
tool_spec: BaseToolSpec = tool_class(**kwargs)
return tool_spec.to_tool_list()
except (ImportError, AttributeError) as e:
raise ValueError(f"Unsupported tool: {tool_name}") from e
except TypeError as e:
raise ValueError(
f"Could not create tool: {tool_name}. With config: {kwargs}"
) from e
@staticmethod
def from_env() -> list[FunctionTool]:
tools = []
if os.path.exists("config/tools.yaml"):
with open("config/tools.yaml", "r") as f:
tool_configs = yaml.safe_load(f)
for name, config in tool_configs.items():
tools += ToolFactory.create_tool(name, **config)
return tools
@@ -1,65 +0,0 @@
import os
import yaml
import json
import importlib
from cachetools import cached, LRUCache
from llama_index.core.tools.tool_spec.base import BaseToolSpec
from llama_index.core.tools.function_tool import FunctionTool
class ToolType:
LLAMAHUB = "llamahub"
LOCAL = "local"
class ToolFactory:
TOOL_SOURCE_PACKAGE_MAP = {
ToolType.LLAMAHUB: "llama_index.tools",
ToolType.LOCAL: "app.engine.tools",
}
@staticmethod
@cached(
LRUCache(maxsize=100),
key=lambda tool_type, tool_name, config: (
tool_type,
tool_name,
json.dumps(config, sort_keys=True),
),
)
def load_tools(tool_type: str, tool_name: str, config: dict) -> list[FunctionTool]:
source_package = ToolFactory.TOOL_SOURCE_PACKAGE_MAP[tool_type]
try:
if "ToolSpec" in tool_name:
tool_package, tool_cls_name = tool_name.split(".")
module_name = f"{source_package}.{tool_package}"
module = importlib.import_module(module_name)
tool_class = getattr(module, tool_cls_name)
tool_spec: BaseToolSpec = tool_class(**config)
return tool_spec.to_tool_list()
else:
module = importlib.import_module(f"{source_package}.{tool_name}")
tools = getattr(module, "tools")
if not all(isinstance(tool, FunctionTool) for tool in tools):
raise ValueError(
f"The module {module} does not contain valid tools"
)
return tools
except ImportError as e:
raise ValueError(f"Failed to import tool {tool_name}: {e}")
except AttributeError as e:
raise ValueError(f"Failed to load tool {tool_name}: {e}")
@staticmethod
def from_env() -> list[FunctionTool]:
tools = []
if os.path.exists("config/tools.yaml"):
with open("config/tools.yaml", "r") as f:
tool_configs = yaml.safe_load(f)
for tool_type, config_entries in tool_configs.items():
for tool_name, config in config_entries.items():
tools.extend(
ToolFactory.load_tools(tool_type, tool_name, config)
)
return tools
@@ -1,150 +0,0 @@
import os
import logging
import base64
import uuid
from pydantic import BaseModel
from typing import List, Tuple, Dict, Optional
from llama_index.core.tools import FunctionTool
from e2b_code_interpreter import CodeInterpreter
from e2b_code_interpreter.models import Logs
logger = logging.getLogger(__name__)
class InterpreterExtraResult(BaseModel):
type: str
content: Optional[str] = None
filename: Optional[str] = None
url: Optional[str] = None
class E2BToolOutput(BaseModel):
is_error: bool
logs: Logs
results: List[InterpreterExtraResult] = []
class E2BCodeInterpreter:
output_dir = "tool-output"
def __init__(self, api_key: str, filesever_url_prefix: str):
self.api_key = api_key
self.filesever_url_prefix = filesever_url_prefix
def get_output_path(self, filename: str) -> str:
# if output directory doesn't exist, create it
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir, exist_ok=True)
return os.path.join(self.output_dir, filename)
def save_to_disk(self, base64_data: str, ext: str) -> Dict:
filename = f"{uuid.uuid4()}.{ext}" # generate a unique filename
buffer = base64.b64decode(base64_data)
output_path = self.get_output_path(filename)
try:
with open(output_path, "wb") as file:
file.write(buffer)
except IOError as e:
logger.error(f"Failed to write to file {output_path}: {str(e)}")
raise e
logger.info(f"Saved file to {output_path}")
return {
"outputPath": output_path,
"filename": filename,
}
def get_file_url(self, filename: str) -> str:
return f"{self.filesever_url_prefix}/{self.output_dir}/{filename}"
def parse_result(self, result) -> List[InterpreterExtraResult]:
"""
The result could include multiple formats (e.g. png, svg, etc.) but encoded in base64
We save each result to disk and return saved file metadata (extension, filename, url)
"""
if not result:
return []
output = []
try:
formats = result.formats()
results = [result[format] for format in formats]
for ext, data in zip(formats, results):
match ext:
case "png" | "svg" | "jpeg" | "pdf":
result = self.save_to_disk(data, ext)
filename = result["filename"]
output.append(
InterpreterExtraResult(
type=ext,
filename=filename,
url=self.get_file_url(filename),
)
)
case _:
output.append(
InterpreterExtraResult(
type=ext,
content=data,
)
)
except Exception as error:
logger.exception(error, exc_info=True)
logger.error("Error when parsing output from E2b interpreter tool", error)
return output
def interpret(self, code: str) -> E2BToolOutput:
with CodeInterpreter(api_key=self.api_key) as interpreter:
logger.info(
f"\n{'='*50}\n> Running following AI-generated code:\n{code}\n{'='*50}"
)
exec = interpreter.notebook.exec_cell(code)
if exec.error:
logger.error("Error when executing code", exec.error)
output = E2BToolOutput(is_error=True, logs=exec.logs, results=[])
else:
if len(exec.results) == 0:
output = E2BToolOutput(is_error=False, logs=exec.logs, results=[])
else:
results = self.parse_result(exec.results[0])
output = E2BToolOutput(
is_error=False, logs=exec.logs, results=results
)
return output
def code_interpret(code: str) -> Dict:
"""
Execute python code in a Jupyter notebook cell and return any result, stdout, stderr, display_data, and error.
Parameters:
code (str): The python code to be executed in a single cell.
"""
api_key = os.getenv("E2B_API_KEY")
filesever_url_prefix = os.getenv("FILESERVER_URL_PREFIX")
if not api_key:
raise ValueError(
"E2B_API_KEY key is required to run code interpreter. Get it here: https://e2b.dev/docs/getting-started/api-key"
)
if not filesever_url_prefix:
raise ValueError(
"FILESERVER_URL_PREFIX is required to display file output from sandbox"
)
interpreter = E2BCodeInterpreter(
api_key=api_key, filesever_url_prefix=filesever_url_prefix
)
output = interpreter.interpret(code)
return output.dict()
# Specify as functions tools to be loaded by the ToolFactory
tools = [FunctionTool.from_defaults(code_interpret)]
@@ -1,71 +0,0 @@
from typing import Dict, List, Tuple
from llama_index.tools.openapi import OpenAPIToolSpec
from llama_index.tools.requests import RequestsToolSpec
class OpenAPIActionToolSpec(OpenAPIToolSpec, RequestsToolSpec):
"""
A combination of OpenAPI and Requests tool specs that can parse OpenAPI specs and make requests.
openapi_uri: str: The file path or URL to the OpenAPI spec.
domain_headers: dict: Whitelist domains and the headers to use.
"""
spec_functions = OpenAPIToolSpec.spec_functions + RequestsToolSpec.spec_functions
def __init__(self, openapi_uri: str, domain_headers: dict = {}, **kwargs):
# Load the OpenAPI spec
openapi_spec, servers = self.load_openapi_spec(openapi_uri)
# Add the servers to the domain headers if they are not already present
for server in servers:
if server not in domain_headers:
domain_headers[server] = {}
OpenAPIToolSpec.__init__(self, spec=openapi_spec)
RequestsToolSpec.__init__(self, domain_headers)
@staticmethod
def load_openapi_spec(uri: str) -> Tuple[Dict, List[str]]:
"""
Load an OpenAPI spec from a URI.
Args:
uri (str): A file path or URL to the OpenAPI spec.
Returns:
List[Document]: A list of Document objects.
"""
import yaml
from urllib.parse import urlparse
if uri.startswith("http"):
import requests
response = requests.get(uri)
if response.status_code != 200:
raise ValueError(
"Could not initialize OpenAPIActionToolSpec: "
f"Failed to load OpenAPI spec from {uri}, status code: {response.status_code}"
)
spec = yaml.safe_load(response.text)
elif uri.startswith("file"):
filepath = urlparse(uri).path
with open(filepath, "r") as file:
spec = yaml.safe_load(file)
else:
raise ValueError(
"Could not initialize OpenAPIActionToolSpec: Invalid OpenAPI URI provided. "
"Only HTTP and file path are supported."
)
# Add the servers to the whitelist
try:
servers = [
urlparse(server["url"]).netloc for server in spec.get("servers", [])
]
except KeyError as e:
raise ValueError(
"Could not initialize OpenAPIActionToolSpec: Invalid OpenAPI spec provided. "
"Could not get `servers` from the spec."
) from e
return spec, servers
@@ -1,72 +0,0 @@
"""Open Meteo weather map tool spec."""
import logging
import requests
import pytz
from llama_index.core.tools import FunctionTool
logger = logging.getLogger(__name__)
class OpenMeteoWeather:
geo_api = "https://geocoding-api.open-meteo.com/v1"
weather_api = "https://api.open-meteo.com/v1"
@classmethod
def _get_geo_location(cls, location: str) -> dict:
"""Get geo location from location name."""
params = {"name": location, "count": 10, "language": "en", "format": "json"}
response = requests.get(f"{cls.geo_api}/search", params=params)
if response.status_code != 200:
raise Exception(f"Failed to fetch geo location: {response.status_code}")
else:
data = response.json()
result = data["results"][0]
geo_location = {
"id": result["id"],
"name": result["name"],
"latitude": result["latitude"],
"longitude": result["longitude"],
}
return geo_location
@classmethod
def get_weather_information(cls, location: str) -> dict:
"""Use this function to get the weather of any given location.
Note that the weather code should follow WMO Weather interpretation codes (WW):
0: Clear sky
1, 2, 3: Mainly clear, partly cloudy, and overcast
45, 48: Fog and depositing rime fog
51, 53, 55: Drizzle: Light, moderate, and dense intensity
56, 57: Freezing Drizzle: Light and dense intensity
61, 63, 65: Rain: Slight, moderate and heavy intensity
66, 67: Freezing Rain: Light and heavy intensity
71, 73, 75: Snow fall: Slight, moderate, and heavy intensity
77: Snow grains
80, 81, 82: Rain showers: Slight, moderate, and violent
85, 86: Snow showers slight and heavy
95: Thunderstorm: Slight or moderate
96, 99: Thunderstorm with slight and heavy hail
"""
logger.info(
f"Calling open-meteo api to get weather information of location: {location}"
)
geo_location = cls._get_geo_location(location)
timezone = pytz.timezone("UTC").zone
params = {
"latitude": geo_location["latitude"],
"longitude": geo_location["longitude"],
"current": "temperature_2m,weather_code",
"hourly": "temperature_2m,weather_code",
"daily": "weather_code",
"timezone": timezone,
}
response = requests.get(f"{cls.weather_api}/forecast", params=params)
if response.status_code != 200:
raise Exception(
f"Failed to fetch weather information: {response.status_code}"
)
return response.json()
tools = [FunctionTool.from_defaults(OpenMeteoWeather.get_weather_information)]
@@ -1,12 +1,12 @@
import { BaseToolWithCall, OpenAIAgent, QueryEngineTool } from "llamaindex";
import { BaseTool, OpenAIAgent, QueryEngineTool } from "llamaindex";
import { ToolsFactory } from "llamaindex/tools/ToolsFactory";
import fs from "node:fs/promises";
import path from "node:path";
import { getDataSource } from "./index";
import { STORAGE_CACHE_DIR } from "./shared";
import { createTools } from "./tools";
export async function createChatEngine() {
const tools: BaseToolWithCall[] = [];
let tools: BaseTool[] = [];
// Add a query engine tool if we have a data source
// Delete this code if you don't have a data source
@@ -23,20 +23,15 @@ export async function createChatEngine() {
);
}
const configFile = path.join("config", "tools.json");
let toolConfig: any;
try {
// add tools from config file if it exists
toolConfig = JSON.parse(await fs.readFile(configFile, "utf8"));
} catch (e) {
console.info(`Could not read ${configFile} file. Using no tools.`);
}
if (toolConfig) {
tools.push(...(await createTools(toolConfig)));
}
const config = JSON.parse(
await fs.readFile(path.join("config", "tools.json"), "utf8"),
);
tools = tools.concat(await ToolsFactory.createTools(config));
} catch {}
return new OpenAIAgent({
tools,
systemPrompt: process.env.SYSTEM_PROMPT,
});
}
@@ -1,42 +0,0 @@
import { BaseToolWithCall } from "llamaindex";
import { ToolsFactory } from "llamaindex/tools/ToolsFactory";
import { InterpreterTool, InterpreterToolParams } from "./interpreter";
import { WeatherTool, WeatherToolParams } from "./weather";
type ToolCreator = (config: unknown) => BaseToolWithCall;
export async function createTools(toolConfig: {
local: Record<string, unknown>;
llamahub: any;
}): Promise<BaseToolWithCall[]> {
// add local tools from the 'tools' folder (if configured)
const tools = createLocalTools(toolConfig.local);
// add tools from LlamaIndexTS (if configured)
tools.push(...(await ToolsFactory.createTools(toolConfig.llamahub)));
return tools;
}
const toolFactory: Record<string, ToolCreator> = {
weather: (config: unknown) => {
return new WeatherTool(config as WeatherToolParams);
},
interpreter: (config: unknown) => {
return new InterpreterTool(config as InterpreterToolParams);
},
};
function createLocalTools(
localConfig: Record<string, unknown>,
): BaseToolWithCall[] {
const tools: BaseToolWithCall[] = [];
Object.keys(localConfig).forEach((key) => {
if (key in toolFactory) {
const toolConfig = localConfig[key];
const tool = toolFactory[key](toolConfig);
tools.push(tool);
}
});
return tools;
}
@@ -1,189 +0,0 @@
import { CodeInterpreter, Logs, Result } from "@e2b/code-interpreter";
import type { JSONSchemaType } from "ajv";
import fs from "fs";
import { BaseTool, ToolMetadata } from "llamaindex";
import crypto from "node:crypto";
import path from "node:path";
export type InterpreterParameter = {
code: string;
};
export type InterpreterToolParams = {
metadata?: ToolMetadata<JSONSchemaType<InterpreterParameter>>;
apiKey?: string;
fileServerURLPrefix?: string;
};
export type InterpreterToolOutput = {
isError: boolean;
logs: Logs;
extraResult: InterpreterExtraResult[];
};
type InterpreterExtraType =
| "html"
| "markdown"
| "svg"
| "png"
| "jpeg"
| "pdf"
| "latex"
| "json"
| "javascript";
export type InterpreterExtraResult = {
type: InterpreterExtraType;
content?: string;
filename?: string;
url?: string;
};
const DEFAULT_META_DATA: ToolMetadata<JSONSchemaType<InterpreterParameter>> = {
name: "interpreter",
description:
"Execute python code in a Jupyter notebook cell and return any result, stdout, stderr, display_data, and error.",
parameters: {
type: "object",
properties: {
code: {
type: "string",
description: "The python code to execute in a single cell.",
},
},
required: ["code"],
},
};
export class InterpreterTool implements BaseTool<InterpreterParameter> {
private readonly outputDir = "tool-output";
private apiKey?: string;
private fileServerURLPrefix?: string;
metadata: ToolMetadata<JSONSchemaType<InterpreterParameter>>;
codeInterpreter?: CodeInterpreter;
constructor(params?: InterpreterToolParams) {
this.metadata = params?.metadata || DEFAULT_META_DATA;
this.apiKey = params?.apiKey || process.env.E2B_API_KEY;
this.fileServerURLPrefix =
params?.fileServerURLPrefix || process.env.FILESERVER_URL_PREFIX;
if (!this.apiKey) {
throw new Error(
"E2B_API_KEY key is required to run code interpreter. Get it here: https://e2b.dev/docs/getting-started/api-key",
);
}
if (!this.fileServerURLPrefix) {
throw new Error(
"FILESERVER_URL_PREFIX is required to display file output from sandbox",
);
}
}
public async initInterpreter() {
if (!this.codeInterpreter) {
this.codeInterpreter = await CodeInterpreter.create({
apiKey: this.apiKey,
});
}
return this.codeInterpreter;
}
public async codeInterpret(code: string): Promise<InterpreterToolOutput> {
console.log(
`\n${"=".repeat(50)}\n> Running following AI-generated code:\n${code}\n${"=".repeat(50)}`,
);
const interpreter = await this.initInterpreter();
const exec = await interpreter.notebook.execCell(code);
if (exec.error) console.error("[Code Interpreter error]", exec.error);
const extraResult = await this.getExtraResult(exec.results[0]);
const result: InterpreterToolOutput = {
isError: !!exec.error,
logs: exec.logs,
extraResult,
};
return result;
}
async call(input: InterpreterParameter): Promise<InterpreterToolOutput> {
const result = await this.codeInterpret(input.code);
return result;
}
async close() {
await this.codeInterpreter?.close();
}
private async getExtraResult(
res?: Result,
): Promise<InterpreterExtraResult[]> {
if (!res) return [];
const output: InterpreterExtraResult[] = [];
try {
const formats = res.formats(); // formats available for the result. Eg: ['png', ...]
const results = formats.map((f) => res[f as keyof Result]); // get base64 data for each format
// save base64 data to file and return the url
for (let i = 0; i < formats.length; i++) {
const ext = formats[i];
const data = results[i];
switch (ext) {
case "png":
case "jpeg":
case "svg":
case "pdf":
const { filename } = this.saveToDisk(data, ext);
output.push({
type: ext as InterpreterExtraType,
filename,
url: this.getFileUrl(filename),
});
break;
default:
output.push({
type: ext as InterpreterExtraType,
content: data,
});
break;
}
}
} catch (error) {
console.error("Error when parsing e2b response", error);
}
return output;
}
// Consider saving to cloud storage instead but it may cost more for you
// See: https://e2b.dev/docs/sandbox/api/filesystem#write-to-file
private saveToDisk(
base64Data: string,
ext: string,
): {
outputPath: string;
filename: string;
} {
const filename = `${crypto.randomUUID()}.${ext}`; // generate a unique filename
const buffer = Buffer.from(base64Data, "base64");
const outputPath = this.getOutputPath(filename);
fs.writeFileSync(outputPath, buffer);
console.log(`Saved file to ${outputPath}`);
return {
outputPath,
filename,
};
}
private getOutputPath(filename: string): string {
// if outputDir doesn't exist, create it
if (!fs.existsSync(this.outputDir)) {
fs.mkdirSync(this.outputDir, { recursive: true });
}
return path.join(this.outputDir, filename);
}
private getFileUrl(filename: string): string {
return `${this.fileServerURLPrefix}/${this.outputDir}/${filename}`;
}
}
@@ -1,81 +0,0 @@
import type { JSONSchemaType } from "ajv";
import { BaseTool, ToolMetadata } from "llamaindex";
interface GeoLocation {
id: string;
name: string;
latitude: number;
longitude: number;
}
export type WeatherParameter = {
location: string;
};
export type WeatherToolParams = {
metadata?: ToolMetadata<JSONSchemaType<WeatherParameter>>;
};
const DEFAULT_META_DATA: ToolMetadata<JSONSchemaType<WeatherParameter>> = {
name: "get_weather_information",
description: `
Use this function to get the weather of any given location.
Note that the weather code should follow WMO Weather interpretation codes (WW):
0: Clear sky
1, 2, 3: Mainly clear, partly cloudy, and overcast
45, 48: Fog and depositing rime fog
51, 53, 55: Drizzle: Light, moderate, and dense intensity
56, 57: Freezing Drizzle: Light and dense intensity
61, 63, 65: Rain: Slight, moderate and heavy intensity
66, 67: Freezing Rain: Light and heavy intensity
71, 73, 75: Snow fall: Slight, moderate, and heavy intensity
77: Snow grains
80, 81, 82: Rain showers: Slight, moderate, and violent
85, 86: Snow showers slight and heavy
95: Thunderstorm: Slight or moderate
96, 99: Thunderstorm with slight and heavy hail
`,
parameters: {
type: "object",
properties: {
location: {
type: "string",
description: "The location to get the weather information",
},
},
required: ["location"],
},
};
export class WeatherTool implements BaseTool<WeatherParameter> {
metadata: ToolMetadata<JSONSchemaType<WeatherParameter>>;
private getGeoLocation = async (location: string): Promise<GeoLocation> => {
const apiUrl = `https://geocoding-api.open-meteo.com/v1/search?name=${location}&count=10&language=en&format=json`;
const response = await fetch(apiUrl);
const data = await response.json();
const { id, name, latitude, longitude } = data.results[0];
return { id, name, latitude, longitude };
};
private getWeatherByLocation = async (location: string) => {
console.log(
"Calling open-meteo api to get weather information of location:",
location,
);
const { latitude, longitude } = await this.getGeoLocation(location);
const timezone = Intl.DateTimeFormat().resolvedOptions().timeZone;
const apiUrl = `https://api.open-meteo.com/v1/forecast?latitude=${latitude}&longitude=${longitude}&current=temperature_2m,weather_code&hourly=temperature_2m,weather_code&daily=weather_code&timezone=${timezone}`;
const response = await fetch(apiUrl);
const data = await response.json();
return data;
};
constructor(params?: WeatherToolParams) {
this.metadata = params?.metadata || DEFAULT_META_DATA;
}
async call(input: WeatherParameter) {
return await this.getWeatherByLocation(input.location);
}
}
@@ -16,6 +16,5 @@ export async function createChatEngine() {
return new ContextChatEngine({
chatModel: Settings.llm,
retriever,
systemPrompt: process.env.SYSTEM_PROMPT,
});
}
+5 -28
View File
@@ -1,10 +1,7 @@
import os
import logging
from llama_parse import LlamaParse
from pydantic import BaseModel, validator
logger = logging.getLogger(__name__)
class FileLoaderConfig(BaseModel):
data_dir: str = "data"
@@ -30,28 +27,8 @@ def llama_parse_parser():
def get_file_documents(config: FileLoaderConfig):
from llama_index.core.readers import SimpleDirectoryReader
try:
reader = SimpleDirectoryReader(
config.data_dir,
recursive=True,
filename_as_id=True,
)
if config.use_llama_parse:
parser = llama_parse_parser()
reader.file_extractor = {".pdf": parser}
return reader.load_data()
except ValueError as e:
import sys, traceback
# Catch the error if the data dir is empty
# and return as empty document list
_, _, exc_traceback = sys.exc_info()
function_name = traceback.extract_tb(exc_traceback)[-1].name
if function_name == "_add_files":
logger.warning(
f"Failed to load file documents, error message: {e} . Return as empty document list."
)
return []
else:
# Raise the error if it is not the case of empty data dir
raise e
reader = SimpleDirectoryReader(config.data_dir, recursive=True, filename_as_id=True)
if config.use_llama_parse:
parser = llama_parse_parser()
reader.file_extractor = {".pdf": parser}
return reader.load_data()
@@ -1,7 +1,5 @@
"use client";
import { Message } from "./chat-messages";
export interface ChatInputProps {
/** The current value of the input */
input?: string;
@@ -14,8 +12,7 @@ export interface ChatInputProps {
/** Form submission handler to automatically reset input and append a user message */
handleSubmit: (e: React.FormEvent<HTMLFormElement>) => void;
isLoading: boolean;
messages: Message[];
setInput?: (input: string) => void;
multiModal?: boolean;
}
export default function ChatInput(props: ChatInputProps) {
@@ -19,9 +19,6 @@ export default function ChatMessages({
isLoading?: boolean;
stop?: () => void;
reload?: () => void;
append?: (
message: Message | Omit<Message, "id">,
) => Promise<string | null | undefined>;
}) {
const scrollableChatContainerRef = useRef<HTMLDivElement>(null);
@@ -1,30 +0,0 @@
"use client";
import { useEffect, useMemo, useState } from "react";
export interface ChatConfig {
chatAPI?: string;
starterQuestions?: string[];
}
export function useClientConfig() {
const API_ROUTE = "/api/chat/config";
const chatAPI = process.env.NEXT_PUBLIC_CHAT_API;
const [config, setConfig] = useState<ChatConfig>({
chatAPI,
});
const configAPI = useMemo(() => {
const backendOrigin = chatAPI ? new URL(chatAPI).origin : "";
return `${backendOrigin}${API_ROUTE}`;
}, [chatAPI]);
useEffect(() => {
fetch(configAPI)
.then((response) => response.json())
.then((data) => setConfig({ ...data, chatAPI }))
.catch((error) => console.error("Error fetching config", error));
}, [chatAPI, configAPI]);
return config;
}
@@ -3,18 +3,10 @@ from llama_index.vector_stores.astra_db import AstraDBVectorStore
def get_vector_store():
endpoint = os.getenv("ASTRA_DB_ENDPOINT")
token = os.getenv("ASTRA_DB_APPLICATION_TOKEN")
collection = os.getenv("ASTRA_DB_COLLECTION")
if not endpoint or not token or not collection:
raise ValueError(
"Please config ASTRA_DB_ENDPOINT, ASTRA_DB_APPLICATION_TOKEN and ASTRA_DB_COLLECTION"
" to your environment variables or config them in the .env file"
)
store = AstraDBVectorStore(
token=token,
api_endpoint=endpoint,
collection_name=collection,
embedding_dimension=int(os.getenv("EMBEDDING_DIM")),
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_ENDPOINT"],
collection_name=os.environ["ASTRA_DB_COLLECTION"],
embedding_dimension=int(os.environ["EMBEDDING_DIM"]),
)
return store
@@ -1,24 +0,0 @@
import os
from llama_index.vector_stores.chroma import ChromaVectorStore
def get_vector_store():
collection_name = os.getenv("CHROMA_COLLECTION", "default")
chroma_path = os.getenv("CHROMA_PATH")
# if CHROMA_PATH is set, use a local ChromaVectorStore from the path
# otherwise, use a remote ChromaVectorStore (ChromaDB Cloud is not supported yet)
if chroma_path:
store = ChromaVectorStore.from_params(
persist_dir=chroma_path, collection_name=collection_name
)
else:
if not os.getenv("CHROMA_HOST") or not os.getenv("CHROMA_PORT"):
raise ValueError(
"Please provide either CHROMA_PATH or CHROMA_HOST and CHROMA_PORT"
)
store = ChromaVectorStore.from_params(
host=os.getenv("CHROMA_HOST"),
port=int(os.getenv("CHROMA_PORT")),
collection_name=collection_name,
)
return store
@@ -3,18 +3,11 @@ from llama_index.vector_stores.milvus import MilvusVectorStore
def get_vector_store():
address = os.getenv("MILVUS_ADDRESS")
collection = os.getenv("MILVUS_COLLECTION")
if not address or not collection:
raise ValueError(
"Please set MILVUS_ADDRESS and MILVUS_COLLECTION to your environment variables"
" or config them in the .env file"
)
store = MilvusVectorStore(
uri=address,
uri=os.environ["MILVUS_ADDRESS"],
user=os.getenv("MILVUS_USERNAME"),
password=os.getenv("MILVUS_PASSWORD"),
collection_name=collection,
collection_name=os.getenv("MILVUS_COLLECTION"),
dim=int(os.getenv("EMBEDDING_DIM")),
)
return store
@@ -3,18 +3,9 @@ from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch
def get_vector_store():
db_uri = os.getenv("MONGODB_URI")
db_name = os.getenv("MONGODB_DATABASE")
collection_name = os.getenv("MONGODB_VECTORS")
index_name = os.getenv("MONGODB_VECTOR_INDEX")
if not db_uri or not db_name or not collection_name or not index_name:
raise ValueError(
"Please set MONGODB_URI, MONGODB_DATABASE, MONGODB_VECTORS, and MONGODB_VECTOR_INDEX"
" to your environment variables or config them in .env file"
)
store = MongoDBAtlasVectorSearch(
db_name=db_name,
collection_name=collection_name,
index_name=index_name,
db_name=os.environ["MONGODB_DATABASE"],
collection_name=os.environ["MONGODB_VECTORS"],
index_name=os.environ["MONGODB_VECTOR_INDEX"],
)
return store
@@ -1,33 +0,0 @@
from dotenv import load_dotenv
load_dotenv()
import os
import logging
from llama_index.core.indices import (
VectorStoreIndex,
)
from app.engine.loaders import get_documents
from app.settings import init_settings
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()
def generate_datasource():
init_settings()
logger.info("Creating new index")
storage_dir = os.environ.get("STORAGE_DIR", "storage")
# load the documents and create the index
documents = get_documents()
index = VectorStoreIndex.from_documents(
documents,
)
# store it for later
index.storage_context.persist(storage_dir)
logger.info(f"Finished creating new index. Stored in {storage_dir}")
if __name__ == "__main__":
generate_datasource()
@@ -1,30 +0,0 @@
import os
import logging
from datetime import timedelta
from cachetools import cached, TTLCache
from llama_index.core.storage import StorageContext
from llama_index.core.indices import load_index_from_storage
logger = logging.getLogger("uvicorn")
@cached(
TTLCache(maxsize=10, ttl=timedelta(minutes=5).total_seconds()),
key=lambda *args, **kwargs: "global_storage_context",
)
def get_storage_context(persist_dir: str) -> StorageContext:
return StorageContext.from_defaults(persist_dir=persist_dir)
def get_index():
storage_dir = os.getenv("STORAGE_DIR", "storage")
# check if storage already exists
if not os.path.exists(storage_dir):
return None
# load the existing index
logger.info(f"Loading index from {storage_dir}...")
storage_context = get_storage_context(storage_dir)
index = load_index_from_storage(storage_context)
logger.info(f"Finished loading index from {storage_dir}")
return index
@@ -0,0 +1,16 @@
import os
from llama_index.core.vector_stores import SimpleVectorStore
from app.constants import STORAGE_DIR
def get_vector_store():
if not os.path.exists(STORAGE_DIR):
# Vector store hasn't been persisted before, create a new one
vector_store = SimpleVectorStore()
else:
# Vector store has already been persisted before at STORAGE_DIR - load it
vector_store = SimpleVectorStore.from_persist_dir(
STORAGE_DIR, namespace="default"
)
return vector_store
@@ -2,36 +2,30 @@ import os
from llama_index.vector_stores.postgres import PGVectorStore
from urllib.parse import urlparse
STORAGE_DIR = "storage"
PGVECTOR_SCHEMA = "public"
PGVECTOR_TABLE = "llamaindex_embedding"
vector_store: PGVectorStore = None
def get_vector_store():
global vector_store
original_conn_string = os.environ.get("PG_CONNECTION_STRING")
if original_conn_string is None or original_conn_string == "":
raise ValueError("PG_CONNECTION_STRING environment variable is not set.")
if vector_store is None:
original_conn_string = os.environ.get("PG_CONNECTION_STRING")
if original_conn_string is None or original_conn_string == "":
raise ValueError("PG_CONNECTION_STRING environment variable is not set.")
# The PGVectorStore requires both two connection strings, one for psycopg2 and one for asyncpg
# Update the configured scheme with the psycopg2 and asyncpg schemes
original_scheme = urlparse(original_conn_string).scheme + "://"
conn_string = original_conn_string.replace(
original_scheme, "postgresql+psycopg2://"
)
async_conn_string = original_conn_string.replace(
original_scheme, "postgresql+asyncpg://"
)
# The PGVectorStore requires both two connection strings, one for psycopg2 and one for asyncpg
# Update the configured scheme with the psycopg2 and asyncpg schemes
original_scheme = urlparse(original_conn_string).scheme + "://"
conn_string = original_conn_string.replace(
original_scheme, "postgresql+psycopg2://"
)
async_conn_string = original_conn_string.replace(
original_scheme, "postgresql+asyncpg://"
)
vector_store = PGVectorStore(
connection_string=conn_string,
async_connection_string=async_conn_string,
schema_name=PGVECTOR_SCHEMA,
table_name=PGVECTOR_TABLE,
embed_dim=int(os.environ.get("EMBEDDING_DIM", 1024)),
)
return vector_store
return PGVectorStore(
connection_string=conn_string,
async_connection_string=async_conn_string,
schema_name=PGVECTOR_SCHEMA,
table_name=PGVECTOR_TABLE,
embed_dim=int(os.environ.get("EMBEDDING_DIM", 768)),
)
@@ -3,17 +3,9 @@ from llama_index.vector_stores.pinecone import PineconeVectorStore
def get_vector_store():
api_key = os.getenv("PINECONE_API_KEY")
index_name = os.getenv("PINECONE_INDEX_NAME")
environment = os.getenv("PINECONE_ENVIRONMENT")
if not api_key or not index_name or not environment:
raise ValueError(
"Please set PINECONE_API_KEY, PINECONE_INDEX_NAME, and PINECONE_ENVIRONMENT"
" to your environment variables or config them in the .env file"
)
store = PineconeVectorStore(
api_key=api_key,
index_name=index_name,
environment=environment,
api_key=os.environ["PINECONE_API_KEY"],
index_name=os.environ["PINECONE_INDEX_NAME"],
environment=os.environ["PINECONE_ENVIRONMENT"],
)
return store
@@ -3,17 +3,9 @@ from llama_index.vector_stores.qdrant import QdrantVectorStore
def get_vector_store():
collection_name = os.getenv("QDRANT_COLLECTION")
url = os.getenv("QDRANT_URL")
api_key = os.getenv("QDRANT_API_KEY")
if not collection_name or not url:
raise ValueError(
"Please set QDRANT_COLLECTION, QDRANT_URL"
" to your environment variables or config them in the .env file"
)
store = QdrantVectorStore(
collection_name=collection_name,
url=url,
api_key=api_key,
collection_name=os.getenv("QDRANT_COLLECTION"),
url=os.getenv("QDRANT_URL"),
api_key=os.getenv("QDRANT_API_KEY"),
)
return store
@@ -1,7 +1,10 @@
/* eslint-disable turbo/no-undeclared-env-vars */
import * as dotenv from "dotenv";
import { VectorStoreIndex, storageContextFromDefaults } from "llamaindex";
import { AstraDBVectorStore } from "llamaindex/storage/vectorStore/AstraDBVectorStore";
import {
AstraDBVectorStore,
VectorStoreIndex,
storageContextFromDefaults,
} from "llamaindex";
import { getDocuments } from "./loader";
import { initSettings } from "./settings";
import { checkRequiredEnvVars } from "./shared";
@@ -1,6 +1,5 @@
/* eslint-disable turbo/no-undeclared-env-vars */
import { VectorStoreIndex } from "llamaindex";
import { AstraDBVectorStore } from "llamaindex/storage/vectorStore/AstraDBVectorStore";
import { AstraDBVectorStore, VectorStoreIndex } from "llamaindex";
import { checkRequiredEnvVars } from "./shared";
export async function getDataSource() {
@@ -1,37 +0,0 @@
/* eslint-disable turbo/no-undeclared-env-vars */
import * as dotenv from "dotenv";
import { VectorStoreIndex, storageContextFromDefaults } from "llamaindex";
import { ChromaVectorStore } from "llamaindex/storage/vectorStore/ChromaVectorStore";
import { getDocuments } from "./loader";
import { initSettings } from "./settings";
import { checkRequiredEnvVars } from "./shared";
dotenv.config();
async function loadAndIndex() {
// load objects from storage and convert them into LlamaIndex Document objects
const documents = await getDocuments();
// create vector store
const chromaUri = `http://${process.env.CHROMA_HOST}:${process.env.CHROMA_PORT}`;
const vectorStore = new ChromaVectorStore({
collectionName: process.env.CHROMA_COLLECTION,
chromaClientParams: { path: chromaUri },
});
// create index from all the Documentss and store them in Pinecone
console.log("Start creating embeddings...");
const storageContext = await storageContextFromDefaults({ vectorStore });
await VectorStoreIndex.fromDocuments(documents, { storageContext });
console.log(
"Successfully created embeddings and save to your ChromaDB index.",
);
}
(async () => {
checkRequiredEnvVars();
initSettings();
await loadAndIndex();
console.log("Finished generating storage.");
})();
@@ -1,16 +0,0 @@
/* eslint-disable turbo/no-undeclared-env-vars */
import { VectorStoreIndex } from "llamaindex";
import { ChromaVectorStore } from "llamaindex/storage/vectorStore/ChromaVectorStore";
import { checkRequiredEnvVars } from "./shared";
export async function getDataSource() {
checkRequiredEnvVars();
const chromaUri = `http://${process.env.CHROMA_HOST}:${process.env.CHROMA_PORT}`;
const store = new ChromaVectorStore({
collectionName: process.env.CHROMA_COLLECTION,
chromaClientParams: { path: chromaUri },
});
return await VectorStoreIndex.fromVectorStore(store);
}
@@ -1,18 +0,0 @@
const REQUIRED_ENV_VARS = ["CHROMA_COLLECTION", "CHROMA_HOST", "CHROMA_PORT"];
export function checkRequiredEnvVars() {
const missingEnvVars = REQUIRED_ENV_VARS.filter((envVar) => {
return !process.env[envVar];
});
if (missingEnvVars.length > 0) {
console.log(
`The following environment variables are required but missing: ${missingEnvVars.join(
", ",
)}`,
);
throw new Error(
`Missing environment variables: ${missingEnvVars.join(", ")}`,
);
}
}
@@ -1,7 +1,10 @@
/* eslint-disable turbo/no-undeclared-env-vars */
import * as dotenv from "dotenv";
import { VectorStoreIndex, storageContextFromDefaults } from "llamaindex";
import { MilvusVectorStore } from "llamaindex/storage/vectorStore/MilvusVectorStore";
import {
MilvusVectorStore,
VectorStoreIndex,
storageContextFromDefaults,
} from "llamaindex";
import { getDocuments } from "./loader";
import { initSettings } from "./settings";
import { checkRequiredEnvVars, getMilvusClient } from "./shared";
@@ -1,5 +1,4 @@
import { VectorStoreIndex } from "llamaindex";
import { MilvusVectorStore } from "llamaindex/storage/vectorStore/MilvusVectorStore";
import { MilvusVectorStore, VectorStoreIndex } from "llamaindex";
import { checkRequiredEnvVars, getMilvusClient } from "./shared";
export async function getDataSource() {
@@ -1,7 +1,10 @@
/* eslint-disable turbo/no-undeclared-env-vars */
import * as dotenv from "dotenv";
import { VectorStoreIndex, storageContextFromDefaults } from "llamaindex";
import { MongoDBAtlasVectorSearch } from "llamaindex/storage/vectorStore/MongoDBAtlasVectorSearch";
import {
MongoDBAtlasVectorSearch,
VectorStoreIndex,
storageContextFromDefaults,
} from "llamaindex";
import { MongoClient } from "mongodb";
import { getDocuments } from "./loader";
import { initSettings } from "./settings";
@@ -9,7 +12,7 @@ import { checkRequiredEnvVars } from "./shared";
dotenv.config();
const mongoUri = process.env.MONGODB_URI!;
const mongoUri = process.env.MONGO_URI!;
const databaseName = process.env.MONGODB_DATABASE!;
const vectorCollectionName = process.env.MONGODB_VECTORS!;
const indexName = process.env.MONGODB_VECTOR_INDEX;
@@ -1,6 +1,5 @@
/* eslint-disable turbo/no-undeclared-env-vars */
import { VectorStoreIndex } from "llamaindex";
import { MongoDBAtlasVectorSearch } from "llamaindex/storage/vectorStore/MongoDBAtlasVectorSearch";
import { MongoDBAtlasVectorSearch, VectorStoreIndex } from "llamaindex";
import { MongoClient } from "mongodb";
import { checkRequiredEnvVars } from "./shared";
@@ -1,5 +1,5 @@
const REQUIRED_ENV_VARS = [
"MONGODB_URI",
"MONGO_URI",
"MONGODB_DATABASE",
"MONGODB_VECTORS",
"MONGODB_VECTOR_INDEX",
@@ -1,5 +1,4 @@
import { VectorStoreIndex } from "llamaindex";
import { storageContextFromDefaults } from "llamaindex/storage/StorageContext";
import { VectorStoreIndex, storageContextFromDefaults } from "llamaindex";
import * as dotenv from "dotenv";
@@ -1,5 +1,8 @@
import { SimpleDocumentStore, VectorStoreIndex } from "llamaindex";
import { storageContextFromDefaults } from "llamaindex/storage/StorageContext";
import {
SimpleDocumentStore,
storageContextFromDefaults,
VectorStoreIndex,
} from "llamaindex";
import { STORAGE_CACHE_DIR } from "./shared";
export async function getDataSource() {
@@ -1,7 +1,10 @@
/* eslint-disable turbo/no-undeclared-env-vars */
import * as dotenv from "dotenv";
import { VectorStoreIndex, storageContextFromDefaults } from "llamaindex";
import { PGVectorStore } from "llamaindex/storage/vectorStore/PGVectorStore";
import {
PGVectorStore,
VectorStoreIndex,
storageContextFromDefaults,
} from "llamaindex";
import { getDocuments } from "./loader";
import { initSettings } from "./settings";
import {
@@ -1,6 +1,5 @@
/* eslint-disable turbo/no-undeclared-env-vars */
import { VectorStoreIndex } from "llamaindex";
import { PGVectorStore } from "llamaindex/storage/vectorStore/PGVectorStore";
import { PGVectorStore, VectorStoreIndex } from "llamaindex";
import {
PGVECTOR_SCHEMA,
PGVECTOR_TABLE,
@@ -1,7 +1,10 @@
/* eslint-disable turbo/no-undeclared-env-vars */
import * as dotenv from "dotenv";
import { VectorStoreIndex, storageContextFromDefaults } from "llamaindex";
import { PineconeVectorStore } from "llamaindex/storage/vectorStore/PineconeVectorStore";
import {
PineconeVectorStore,
VectorStoreIndex,
storageContextFromDefaults,
} from "llamaindex";
import { getDocuments } from "./loader";
import { initSettings } from "./settings";
import { checkRequiredEnvVars } from "./shared";
@@ -1,6 +1,5 @@
/* eslint-disable turbo/no-undeclared-env-vars */
import { VectorStoreIndex } from "llamaindex";
import { PineconeVectorStore } from "llamaindex/storage/vectorStore/PineconeVectorStore";
import { PineconeVectorStore, VectorStoreIndex } from "llamaindex";
import { checkRequiredEnvVars } from "./shared";
export async function getDataSource() {
@@ -1,7 +1,10 @@
/* eslint-disable turbo/no-undeclared-env-vars */
import * as dotenv from "dotenv";
import { VectorStoreIndex, storageContextFromDefaults } from "llamaindex";
import { QdrantVectorStore } from "llamaindex/storage/vectorStore/QdrantVectorStore";
import {
QdrantVectorStore,
VectorStoreIndex,
storageContextFromDefaults,
} from "llamaindex";
import { getDocuments } from "./loader";
import { initSettings } from "./settings";
import { checkRequiredEnvVars, getQdrantClient } from "./shared";
@@ -15,10 +18,7 @@ async function loadAndIndex() {
const documents = await getDocuments();
// Connect to Qdrant
const vectorStore = new QdrantVectorStore({
collectionName,
client: getQdrantClient(),
});
const vectorStore = new QdrantVectorStore(collectionName, getQdrantClient());
const storageContext = await storageContextFromDefaults({ vectorStore });
await VectorStoreIndex.fromDocuments(documents, {
@@ -1,6 +1,5 @@
import * as dotenv from "dotenv";
import { VectorStoreIndex } from "llamaindex";
import { QdrantVectorStore } from "llamaindex/storage/vectorStore/QdrantVectorStore";
import { QdrantVectorStore, VectorStoreIndex } from "llamaindex";
import { checkRequiredEnvVars, getQdrantClient } from "./shared";
dotenv.config();
@@ -8,10 +7,7 @@ dotenv.config();
export async function getDataSource() {
checkRequiredEnvVars();
const collectionName = process.env.QDRANT_COLLECTION;
const store = new QdrantVectorStore({
collectionName,
client: getQdrantClient(),
});
const store = new QdrantVectorStore(collectionName, getQdrantClient());
return await VectorStoreIndex.fromVectorStore(store);
}
+1 -3
View File
@@ -1,5 +1,3 @@
# local env files
.env
node_modules/
tool-output/
node_modules/
@@ -31,8 +31,6 @@ if (isDevelopment) {
console.warn("Production CORS origin not set, defaulting to no CORS.");
}
app.use("/api/files/data", express.static("data"));
app.use("/api/files/tool-output", express.static("tool-output"));
app.use(express.text());
app.get("/", (req: Request, res: Response) => {
@@ -1,23 +1,20 @@
{
"name": "llama-index-express-streaming",
"version": "1.0.0",
"main": "dist/index.js",
"main": "dist/index.mjs",
"scripts": {
"format": "prettier --ignore-unknown --cache --check .",
"format:write": "prettier --ignore-unknown --write .",
"build": "tsup index.ts --format cjs --dts",
"start": "node dist/index.js",
"dev": "concurrently \"tsup index.ts --format cjs --dts --watch\" \"nodemon -q dist/index.js\""
"build": "tsup index.ts --format esm --dts",
"start": "node dist/index.mjs",
"dev": "concurrently \"tsup index.ts --format esm --dts --watch\" \"nodemon -q dist/index.mjs\""
},
"dependencies": {
"ai": "^3.0.21",
"cors": "^2.8.5",
"dotenv": "^16.3.1",
"express": "^4.18.2",
"llamaindex": "0.3.16",
"pdf2json": "3.0.5",
"ajv": "^8.12.0",
"@e2b/code-interpreter": "^0.0.5"
"llamaindex": "0.2.10"
},
"devDependencies": {
"@types/cors": "^2.8.16",
@@ -1,14 +0,0 @@
import { Request, Response } from "express";
export const chatConfig = async (_req: Request, res: Response) => {
let starterQuestions = undefined;
if (
process.env.CONVERSATION_STARTERS &&
process.env.CONVERSATION_STARTERS.trim()
) {
starterQuestions = process.env.CONVERSATION_STARTERS.trim().split("\n");
}
return res.status(200).json({
starterQuestions,
});
};
@@ -1,16 +1,32 @@
import { Message, StreamData, streamToResponse } from "ai";
import { Request, Response } from "express";
import { ChatMessage, Settings } from "llamaindex";
import { ChatMessage, MessageContent, Settings } from "llamaindex";
import { createChatEngine } from "./engine/chat";
import { LlamaIndexStream, convertMessageContent } from "./llamaindex-stream";
import { createCallbackManager, createStreamTimeout } from "./stream-helper";
import { LlamaIndexStream } from "./llamaindex-stream";
import { appendEventData } from "./stream-helper";
const convertMessageContent = (
textMessage: string,
imageUrl: string | undefined,
): MessageContent => {
if (!imageUrl) return textMessage;
return [
{
type: "text",
text: textMessage,
},
{
type: "image_url",
image_url: {
url: imageUrl,
},
},
];
};
export const chat = async (req: Request, res: Response) => {
// Init Vercel AI StreamData and timeout
const vercelStreamData = new StreamData();
const streamTimeout = createStreamTimeout(vercelStreamData);
try {
const { messages }: { messages: Message[] } = req.body;
const { messages, data }: { messages: Message[]; data: any } = req.body;
const userMessage = messages.pop();
if (!messages || !userMessage || userMessage.role !== "user") {
return res.status(400).json({
@@ -21,47 +37,58 @@ export const chat = async (req: Request, res: Response) => {
const chatEngine = await createChatEngine();
let annotations = userMessage.annotations;
if (!annotations) {
// the user didn't send any new annotations with the last message
// so use the annotations from the last user message that has annotations
// REASON: GPT4 doesn't consider MessageContentDetail from previous messages, only strings
annotations = messages
.slice()
.reverse()
.find(
(message) => message.role === "user" && message.annotations,
)?.annotations;
}
// Convert message content from Vercel/AI format to LlamaIndex/OpenAI format
const userMessageContent = convertMessageContent(
userMessage.content,
annotations,
data?.imageUrl,
);
// Setup callbacks
const callbackManager = createCallbackManager(vercelStreamData);
// Init Vercel AI StreamData
const vercelStreamData = new StreamData();
appendEventData(
vercelStreamData,
`Retrieving context for query: '${userMessage.content}'`,
);
// Setup callback for streaming data before chatting
Settings.callbackManager.on("retrieve", (data) => {
const { nodes } = data.detail;
appendEventData(
vercelStreamData,
`Retrieved ${nodes.length} sources to use as context for the query`,
);
});
// Calling LlamaIndex's ChatEngine to get a streamed response
const response = await Settings.withCallbackManager(callbackManager, () => {
return chatEngine.chat({
message: userMessageContent,
chatHistory: messages as ChatMessage[],
stream: true,
});
const response = await chatEngine.chat({
message: userMessageContent,
chatHistory: messages as ChatMessage[],
stream: true,
});
// Return a stream, which can be consumed by the Vercel/AI client
const stream = LlamaIndexStream(response, vercelStreamData);
const { stream } = LlamaIndexStream(response, vercelStreamData, {
parserOptions: {
image_url: data?.imageUrl,
},
});
return streamToResponse(stream, res, {}, vercelStreamData);
// Pipe LlamaIndexStream to response
const processedStream = stream.pipeThrough(vercelStreamData.stream);
return streamToResponse(processedStream, res, {
headers: {
// response MUST have the `X-Experimental-Stream-Data: 'true'` header
// so that the client uses the correct parsing logic, see
// https://sdk.vercel.ai/docs/api-reference/stream-data#on-the-server
"X-Experimental-Stream-Data": "true",
"Content-Type": "text/plain; charset=utf-8",
"Access-Control-Expose-Headers": "X-Experimental-Stream-Data",
},
});
} catch (error) {
console.error("[LlamaIndex]", error);
return res.status(500).json({
detail: (error as Error).message,
});
} finally {
clearTimeout(streamTimeout);
}
};
@@ -1,17 +1,10 @@
import {
Anthropic,
GEMINI_EMBEDDING_MODEL,
GEMINI_MODEL,
Gemini,
GeminiEmbedding,
Ollama,
OllamaEmbedding,
OpenAI,
OpenAIEmbedding,
Settings,
} from "llamaindex";
import { HuggingFaceEmbedding } from "llamaindex/embeddings/HuggingFaceEmbedding";
import { OllamaEmbedding } from "llamaindex/embeddings/OllamaEmbedding";
import { ALL_AVAILABLE_ANTHROPIC_MODELS } from "llamaindex/llm/anthropic";
import { Ollama } from "llamaindex/llm/ollama";
const CHUNK_SIZE = 512;
const CHUNK_OVERLAP = 20;
@@ -19,21 +12,10 @@ const CHUNK_OVERLAP = 20;
export const initSettings = async () => {
// HINT: you can delete the initialization code for unused model providers
console.log(`Using '${process.env.MODEL_PROVIDER}' model provider`);
if (!process.env.MODEL || !process.env.EMBEDDING_MODEL) {
throw new Error("'MODEL' and 'EMBEDDING_MODEL' env variables must be set.");
}
switch (process.env.MODEL_PROVIDER) {
case "ollama":
initOllama();
break;
case "anthropic":
initAnthropic();
break;
case "gemini":
initGemini();
break;
default:
initOpenAI();
break;
@@ -45,9 +27,7 @@ export const initSettings = async () => {
function initOpenAI() {
Settings.llm = new OpenAI({
model: process.env.MODEL ?? "gpt-3.5-turbo",
maxTokens: process.env.LLM_MAX_TOKENS
? Number(process.env.LLM_MAX_TOKENS)
: undefined,
maxTokens: 512,
});
Settings.embedModel = new OpenAIEmbedding({
model: process.env.EMBEDDING_MODEL,
@@ -58,38 +38,15 @@ function initOpenAI() {
}
function initOllama() {
const config = {
host: process.env.OLLAMA_BASE_URL ?? "http://127.0.0.1:11434",
};
if (!process.env.MODEL || !process.env.EMBEDDING_MODEL) {
throw new Error(
"Using Ollama as model provider, 'MODEL' and 'EMBEDDING_MODEL' env variables must be set.",
);
}
Settings.llm = new Ollama({
model: process.env.MODEL ?? "",
config,
});
Settings.embedModel = new OllamaEmbedding({
model: process.env.EMBEDDING_MODEL ?? "",
config,
});
}
function initAnthropic() {
const embedModelMap: Record<string, string> = {
"all-MiniLM-L6-v2": "Xenova/all-MiniLM-L6-v2",
"all-mpnet-base-v2": "Xenova/all-mpnet-base-v2",
};
Settings.llm = new Anthropic({
model: process.env.MODEL as keyof typeof ALL_AVAILABLE_ANTHROPIC_MODELS,
});
Settings.embedModel = new HuggingFaceEmbedding({
modelType: embedModelMap[process.env.EMBEDDING_MODEL!],
});
}
function initGemini() {
Settings.llm = new Gemini({
model: process.env.MODEL as GEMINI_MODEL,
});
Settings.embedModel = new GeminiEmbedding({
model: process.env.EMBEDDING_MODEL as GEMINI_EMBEDDING_MODEL,
});
}
@@ -1,5 +1,4 @@
import {
JSONValue,
StreamData,
createCallbacksTransformer,
createStreamDataTransformer,
@@ -7,118 +6,44 @@ import {
type AIStreamCallbacksAndOptions,
} from "ai";
import {
MessageContent,
MessageContentDetail,
Metadata,
NodeWithScore,
Response,
ToolCallLLMMessageOptions,
StreamingAgentChatResponse,
} from "llamaindex";
import { appendImageData, appendSourceData } from "./stream-helper";
import { AgentStreamChatResponse } from "llamaindex/agent/base";
import { CsvFile, appendSourceData } from "./stream-helper";
type LlamaIndexResponse =
| AgentStreamChatResponse<ToolCallLLMMessageOptions>
| Response;
export const convertMessageContent = (
content: string,
annotations?: JSONValue[],
): MessageContent => {
if (!annotations) return content;
return [
{
type: "text",
text: content,
},
...convertAnnotations(annotations),
];
};
const convertAnnotations = (
annotations: JSONValue[],
): MessageContentDetail[] => {
const content: MessageContentDetail[] = [];
annotations.forEach((annotation: JSONValue) => {
// first skip invalid annotation
if (
!(
annotation &&
typeof annotation === "object" &&
"type" in annotation &&
"data" in annotation &&
annotation.data &&
typeof annotation.data === "object"
)
) {
console.log(
"Client sent invalid annotation. Missing data and type",
annotation,
);
return;
}
const { type, data } = annotation;
// convert image
if (type === "image" && "url" in data && typeof data.url === "string") {
content.push({
type: "image_url",
image_url: {
url: data.url,
},
});
}
// convert CSV files to text
if (type === "csv" && "csvFiles" in data && Array.isArray(data.csvFiles)) {
const rawContents = data.csvFiles.map((csv) => {
return "```csv\n" + (csv as CsvFile).content + "\n```";
});
const csvContent =
"Use data from following CSV raw contents:\n" +
rawContents.join("\n\n");
content.push({
type: "text",
text: csvContent,
});
}
});
return content;
type ParserOptions = {
image_url?: string;
};
function createParser(
res: AsyncIterable<LlamaIndexResponse>,
res: AsyncIterable<Response>,
data: StreamData,
opts?: ParserOptions,
) {
const it = res[Symbol.asyncIterator]();
const trimStartOfStream = trimStartOfStreamHelper();
let sourceNodes: NodeWithScore<Metadata>[] | undefined;
return new ReadableStream<string>({
start() {
appendImageData(data, opts?.image_url);
},
async pull(controller): Promise<void> {
const { value, done } = await it.next();
if (done) {
if (sourceNodes) {
appendSourceData(data, sourceNodes);
}
appendSourceData(data, sourceNodes);
controller.close();
data.close();
return;
}
let delta;
if (value instanceof Response) {
// handle Response type
if (value.sourceNodes) {
// get source nodes from the first response
sourceNodes = value.sourceNodes;
}
delta = value.response ?? "";
} else {
// handle other types
delta = value.response.delta;
if (!sourceNodes) {
// get source nodes from the first response
sourceNodes = value.sourceNodes;
}
const text = trimStartOfStream(delta ?? "");
const text = trimStartOfStream(value.response ?? "");
if (text) {
controller.enqueue(text);
}
@@ -127,13 +52,21 @@ function createParser(
}
export function LlamaIndexStream(
response: AsyncIterable<LlamaIndexResponse>,
response: StreamingAgentChatResponse | AsyncIterable<Response>,
data: StreamData,
opts?: {
callbacks?: AIStreamCallbacksAndOptions;
parserOptions?: ParserOptions;
},
): ReadableStream<Uint8Array> {
return createParser(response, data)
.pipeThrough(createCallbacksTransformer(opts?.callbacks))
.pipeThrough(createStreamDataTransformer());
): { stream: ReadableStream; data: StreamData } {
const res =
response instanceof StreamingAgentChatResponse
? response.response
: response;
return {
stream: createParser(res, data, opts?.parserOptions)
.pipeThrough(createCallbacksTransformer(opts?.callbacks))
.pipeThrough(createStreamDataTransformer()),
data,
};
}
@@ -1,26 +1,14 @@
import { StreamData } from "ai";
import {
CallbackManager,
Metadata,
NodeWithScore,
ToolCall,
ToolOutput,
} from "llamaindex";
import { Metadata, NodeWithScore } from "llamaindex";
function getNodeUrl(metadata: Metadata) {
const url = metadata["URL"];
if (url) return url;
const fileName = metadata["file_name"];
if (!process.env.FILESERVER_URL_PREFIX) {
console.warn(
"FILESERVER_URL_PREFIX is not set. File URLs will not be generated.",
);
return undefined;
}
if (fileName) {
return `${process.env.FILESERVER_URL_PREFIX}/data/${fileName}`;
}
return undefined;
export function appendImageData(data: StreamData, imageUrl?: string) {
if (!imageUrl) return;
data.appendMessageAnnotation({
type: "image",
data: {
url: imageUrl,
},
});
}
export function appendSourceData(
@@ -35,7 +23,6 @@ export function appendSourceData(
...node.node.toMutableJSON(),
id: node.node.id_,
score: node.score ?? null,
url: getNodeUrl(node.node.metadata),
})),
},
});
@@ -50,71 +37,3 @@ export function appendEventData(data: StreamData, title?: string) {
},
});
}
export function appendToolData(
data: StreamData,
toolCall: ToolCall,
toolOutput: ToolOutput,
) {
data.appendMessageAnnotation({
type: "tools",
data: {
toolCall: {
id: toolCall.id,
name: toolCall.name,
input: toolCall.input,
},
toolOutput: {
output: toolOutput.output,
isError: toolOutput.isError,
},
},
});
}
export function createStreamTimeout(stream: StreamData) {
const timeout = Number(process.env.STREAM_TIMEOUT ?? 1000 * 60 * 5); // default to 5 minutes
const t = setTimeout(() => {
appendEventData(stream, `Stream timed out after ${timeout / 1000} seconds`);
stream.close();
}, timeout);
return t;
}
export function createCallbackManager(stream: StreamData) {
const callbackManager = new CallbackManager();
callbackManager.on("retrieve", (data) => {
const { nodes, query } = data.detail;
appendEventData(stream, `Retrieving context for query: '${query}'`);
appendEventData(
stream,
`Retrieved ${nodes.length} sources to use as context for the query`,
);
});
callbackManager.on("llm-tool-call", (event) => {
const { name, input } = event.detail.payload.toolCall;
const inputString = Object.entries(input)
.map(([key, value]) => `${key}: ${value}`)
.join(", ");
appendEventData(
stream,
`Using tool: '${name}' with inputs: '${inputString}'`,
);
});
callbackManager.on("llm-tool-result", (event) => {
const { toolCall, toolResult } = event.detail.payload;
appendToolData(stream, toolCall, toolResult);
});
return callbackManager;
}
export type CsvFile = {
content: string;
filename: string;
filesize: number;
id: string;
};
@@ -1,5 +1,4 @@
import express, { Router } from "express";
import { chatConfig } from "../controllers/chat-config.controller";
import { chatRequest } from "../controllers/chat-request.controller";
import { chat } from "../controllers/chat.controller";
import { initSettings } from "../controllers/engine/settings";
@@ -9,6 +8,5 @@ const llmRouter: Router = express.Router();
initSettings();
llmRouter.route("/").post(chat);
llmRouter.route("/request").post(chatRequest);
llmRouter.route("/config").get(chatConfig);
export default llmRouter;
@@ -1,114 +1,154 @@
import os
import logging
from aiostream import stream
from pydantic import BaseModel
from typing import List, Any, Optional, Dict, Tuple
from fastapi import APIRouter, Depends, HTTPException, Request, status
from llama_index.core.chat_engine.types import BaseChatEngine
from llama_index.core.llms import MessageRole
from llama_index.core.chat_engine.types import (
BaseChatEngine,
StreamingAgentChatResponse,
)
from llama_index.core.schema import NodeWithScore
from llama_index.core.llms import ChatMessage, MessageRole
from app.engine import get_chat_engine
from app.api.routers.vercel_response import VercelStreamResponse
from app.api.routers.events import EventCallbackHandler
from app.api.routers.models import (
ChatData,
ChatConfig,
SourceNodes,
Result,
Message,
)
from app.api.routers.messaging import EventCallbackHandler
from aiostream import stream
chat_router = r = APIRouter()
logger = logging.getLogger("uvicorn")
class _Message(BaseModel):
role: MessageRole
content: str
class _ChatData(BaseModel):
messages: List[_Message]
class Config:
json_schema_extra = {
"example": {
"messages": [
{
"role": "user",
"content": "What standards for letters exist?",
}
]
}
}
class _SourceNodes(BaseModel):
id: str
metadata: Dict[str, Any]
score: Optional[float]
text: str
@classmethod
def from_source_node(cls, source_node: NodeWithScore):
return cls(
id=source_node.node.node_id,
metadata=source_node.node.metadata,
score=source_node.score,
text=source_node.node.text, # type: ignore
)
@classmethod
def from_source_nodes(cls, source_nodes: List[NodeWithScore]):
return [cls.from_source_node(node) for node in source_nodes]
class _Result(BaseModel):
result: _Message
nodes: List[_SourceNodes]
async def parse_chat_data(data: _ChatData) -> Tuple[str, List[ChatMessage]]:
# check preconditions and get last message
if len(data.messages) == 0:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="No messages provided",
)
last_message = data.messages.pop()
if last_message.role != MessageRole.USER:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Last message must be from user",
)
# convert messages coming from the request to type ChatMessage
messages = [
ChatMessage(
role=m.role,
content=m.content,
)
for m in data.messages
]
return last_message.content, messages
# streaming endpoint - delete if not needed
@r.post("")
async def chat(
request: Request,
data: ChatData,
data: _ChatData,
chat_engine: BaseChatEngine = Depends(get_chat_engine),
):
try:
last_message_content = data.get_last_message_content()
messages = data.get_history_messages()
last_message_content, messages = await parse_chat_data(data)
event_handler = EventCallbackHandler()
chat_engine.callback_manager.handlers.append(event_handler) # type: ignore
event_handler = EventCallbackHandler()
chat_engine.callback_manager.handlers.append(event_handler) # type: ignore
response = await chat_engine.astream_chat(last_message_content, messages)
async def content_generator():
# Yield the text response
async def _chat_response_generator():
response = await chat_engine.astream_chat(
last_message_content, messages
)
async for token in response.async_response_gen():
yield VercelStreamResponse.convert_text(token)
# the text_generator is the leading stream, once it's finished, also finish the event stream
event_handler.is_done = True
async def content_generator():
# Yield the text response
async def _text_generator():
async for token in response.async_response_gen():
yield VercelStreamResponse.convert_text(token)
# the text_generator is the leading stream, once it's finished, also finish the event stream
event_handler.is_done = True
# Yield the source nodes
# Yield the events from the event handler
async def _event_generator():
async for event in event_handler.async_event_gen():
yield VercelStreamResponse.convert_data(
{
"type": "sources",
"data": {
"nodes": [
SourceNodes.from_source_node(node).dict()
for node in response.source_nodes
]
},
"type": "events",
"data": {"title": event.get_title()},
}
)
# Yield the events from the event handler
async def _event_generator():
async for event in event_handler.async_event_gen():
event_response = event.to_response()
if event_response is not None:
yield VercelStreamResponse.convert_data(event_response)
combine = stream.merge(_text_generator(), _event_generator())
async with combine.stream() as streamer:
async for item in streamer:
if await request.is_disconnected():
break
yield item
combine = stream.merge(_chat_response_generator(), _event_generator())
is_stream_started = False
async with combine.stream() as streamer:
async for output in streamer:
if not is_stream_started:
is_stream_started = True
# Stream a blank message to start the stream
yield VercelStreamResponse.convert_text("")
# Yield the source nodes
yield VercelStreamResponse.convert_data(
{
"type": "sources",
"data": {
"nodes": [
_SourceNodes.from_source_node(node).dict()
for node in response.source_nodes
]
},
}
)
yield output
if await request.is_disconnected():
break
return VercelStreamResponse(content=content_generator())
except Exception as e:
logger.exception("Error in chat engine", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error in chat engine: {e}",
) from e
return VercelStreamResponse(content=content_generator())
# non-streaming endpoint - delete if not needed
@r.post("/request")
async def chat_request(
data: ChatData,
data: _ChatData,
chat_engine: BaseChatEngine = Depends(get_chat_engine),
) -> Result:
last_message_content = data.get_last_message_content()
messages = data.get_history_messages()
) -> _Result:
last_message_content, messages = await parse_chat_data(data)
response = await chat_engine.achat(last_message_content, messages)
return Result(
result=Message(role=MessageRole.ASSISTANT, content=response.response),
nodes=SourceNodes.from_source_nodes(response.source_nodes),
return _Result(
result=_Message(role=MessageRole.ASSISTANT, content=response.response),
nodes=_SourceNodes.from_source_nodes(response.source_nodes),
)
@r.get("/config")
async def chat_config() -> ChatConfig:
starter_questions = None
conversation_starters = os.getenv("CONVERSATION_STARTERS")
if conversation_starters and conversation_starters.strip():
starter_questions = conversation_starters.strip().split("\n")
return ChatConfig(starterQuestions=starter_questions)
@@ -1,149 +0,0 @@
import json
import asyncio
import logging
from typing import AsyncGenerator, Dict, Any, List, Optional
from llama_index.core.callbacks.base import BaseCallbackHandler
from llama_index.core.callbacks.schema import CBEventType
from llama_index.core.tools.types import ToolOutput
from pydantic import BaseModel
logger = logging.getLogger(__name__)
class CallbackEvent(BaseModel):
event_type: CBEventType
payload: Optional[Dict[str, Any]] = None
event_id: str = ""
def get_retrieval_message(self) -> dict | None:
if self.payload:
nodes = self.payload.get("nodes")
if nodes:
msg = f"Retrieved {len(nodes)} sources to use as context for the query"
else:
msg = f"Retrieving context for query: '{self.payload.get('query_str')}'"
return {
"type": "events",
"data": {"title": msg},
}
else:
return None
def get_tool_message(self) -> dict | None:
func_call_args = self.payload.get("function_call")
if func_call_args is not None and "tool" in self.payload:
tool = self.payload.get("tool")
return {
"type": "events",
"data": {
"title": f"Calling tool: {tool.name} with inputs: {func_call_args}",
},
}
def _is_output_serializable(self, output: Any) -> bool:
try:
json.dumps(output)
return True
except TypeError:
return False
def get_agent_tool_response(self) -> dict | None:
response = self.payload.get("response")
if response is not None:
sources = response.sources
for source in sources:
# Return the tool response here to include the toolCall information
if isinstance(source, ToolOutput):
if self._is_output_serializable(source.raw_output):
output = source.raw_output
else:
output = source.content
return {
"type": "tools",
"data": {
"toolOutput": {
"output": output,
"isError": source.is_error,
},
"toolCall": {
"id": None, # There is no tool id in the ToolOutput
"name": source.tool_name,
"input": source.raw_input,
},
},
}
def to_response(self):
try:
match self.event_type:
case "retrieve":
return self.get_retrieval_message()
case "function_call":
return self.get_tool_message()
case "agent_step":
return self.get_agent_tool_response()
case _:
return None
except Exception as e:
logger.error(f"Error in converting event to response: {e}")
return None
class EventCallbackHandler(BaseCallbackHandler):
_aqueue: asyncio.Queue
is_done: bool = False
def __init__(
self,
):
"""Initialize the base callback handler."""
ignored_events = [
CBEventType.CHUNKING,
CBEventType.NODE_PARSING,
CBEventType.EMBEDDING,
CBEventType.LLM,
CBEventType.TEMPLATING,
]
super().__init__(ignored_events, ignored_events)
self._aqueue = asyncio.Queue()
def on_event_start(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> str:
event = CallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
if event.to_response() is not None:
self._aqueue.put_nowait(event)
def on_event_end(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> None:
event = CallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
if event.to_response() is not None:
self._aqueue.put_nowait(event)
def start_trace(self, trace_id: Optional[str] = None) -> None:
"""No-op."""
def end_trace(
self,
trace_id: Optional[str] = None,
trace_map: Optional[Dict[str, List[str]]] = None,
) -> None:
"""No-op."""
async def async_event_gen(self) -> AsyncGenerator[CallbackEvent, None]:
while not self._aqueue.empty() or not self.is_done:
try:
yield await asyncio.wait_for(self._aqueue.get(), timeout=0.1)
except asyncio.TimeoutError:
pass
@@ -0,0 +1,86 @@
import asyncio
from typing import AsyncGenerator, Dict, Any, List, Optional
from llama_index.core.callbacks.base import BaseCallbackHandler
from llama_index.core.callbacks.schema import CBEventType
from pydantic import BaseModel
class CallbackEvent(BaseModel):
event_type: CBEventType
payload: Optional[Dict[str, Any]] = None
event_id: str = ""
def get_title(self) -> str | None:
# Return as None for the unhandled event types
# to avoid showing them in the UI
match self.event_type:
case "retrieve":
if self.payload:
nodes = self.payload.get("nodes")
if nodes:
return f"Retrieved {len(nodes)} sources to use as context for the query"
else:
return f"Retrieving context for query: '{self.payload.get('query_str')}'"
else:
return None
case _:
return None
class EventCallbackHandler(BaseCallbackHandler):
_aqueue: asyncio.Queue
is_done: bool = False
def __init__(
self,
):
"""Initialize the base callback handler."""
ignored_events = [
CBEventType.CHUNKING,
CBEventType.NODE_PARSING,
CBEventType.EMBEDDING,
CBEventType.LLM,
CBEventType.TEMPLATING,
]
super().__init__(ignored_events, ignored_events)
self._aqueue = asyncio.Queue()
def on_event_start(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> str:
event = CallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
if event.get_title() is not None:
self._aqueue.put_nowait(event)
def on_event_end(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> None:
event = CallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
if event.get_title() is not None:
self._aqueue.put_nowait(event)
def start_trace(self, trace_id: Optional[str] = None) -> None:
"""No-op."""
def end_trace(
self,
trace_id: Optional[str] = None,
trace_map: Optional[Dict[str, List[str]]] = None,
) -> None:
"""No-op."""
async def async_event_gen(self) -> AsyncGenerator[CallbackEvent, None]:
while not self._aqueue.empty() or not self.is_done:
try:
yield await asyncio.wait_for(self._aqueue.get(), timeout=0.1)
except asyncio.TimeoutError:
pass
@@ -1,170 +0,0 @@
import os
import logging
from pydantic import BaseModel, Field, validator
from pydantic.alias_generators import to_camel
from typing import List, Any, Optional, Dict
from llama_index.core.schema import NodeWithScore
from llama_index.core.llms import ChatMessage, MessageRole
logger = logging.getLogger("uvicorn")
class CsvFile(BaseModel):
content: str
filename: str
filesize: int
id: str
class AnnotationData(BaseModel):
csv_files: List[CsvFile] | None = Field(
default=None,
description="List of CSV files",
)
class Config:
json_schema_extra = {
"example": {
"csvFiles": [
{
"content": "Name, Age\nAlice, 25\nBob, 30",
"filename": "example.csv",
"filesize": 123,
"id": "123",
"type": "text/csv",
}
]
}
}
alias_generator = to_camel
class Annotation(BaseModel):
type: str
data: AnnotationData
def to_content(self) -> str:
if self.type == "csv":
csv_files = self.data.csv_files
if csv_files is not None and len(csv_files) > 0:
return "Use data from following CSV raw contents\n" + "\n".join(
[f"```csv\n{csv_file.content}\n```" for csv_file in csv_files]
)
raise ValueError(f"Unsupported annotation type: {self.type}")
class Message(BaseModel):
role: MessageRole
content: str
annotations: List[Annotation] | None = None
class ChatData(BaseModel):
messages: List[Message]
class Config:
json_schema_extra = {
"example": {
"messages": [
{
"role": "user",
"content": "What standards for letters exist?",
}
]
}
}
@validator("messages")
def messages_must_not_be_empty(cls, v):
if len(v) == 0:
raise ValueError("Messages must not be empty")
return v
def get_last_message_content(self) -> str:
"""
Get the content of the last message along with the data content if available. Fallback to use data content from previous messages
"""
if len(self.messages) == 0:
raise ValueError("There is not any message in the chat")
last_message = self.messages[-1]
message_content = last_message.content
for message in reversed(self.messages):
if message.role == MessageRole.USER and message.annotations is not None:
annotation_contents = (
annotation.to_content() for annotation in message.annotations
)
annotation_text = "\n".join(annotation_contents)
message_content = f"{message_content}\n{annotation_text}"
break
return message_content
def get_history_messages(self) -> List[Message]:
"""
Get the history messages
"""
return [
ChatMessage(role=message.role, content=message.content)
for message in self.messages[:-1]
]
def is_last_message_from_user(self) -> bool:
return self.messages[-1].role == MessageRole.USER
class SourceNodes(BaseModel):
id: str
metadata: Dict[str, Any]
score: Optional[float]
text: str
url: Optional[str]
@classmethod
def from_source_node(cls, source_node: NodeWithScore):
metadata = source_node.node.metadata
url = metadata.get("URL")
if not url:
file_name = metadata.get("file_name")
url_prefix = os.getenv("FILESERVER_URL_PREFIX")
if not url_prefix:
logger.warning(
"Warning: FILESERVER_URL_PREFIX not set in environment variables"
)
if file_name and url_prefix:
url = f"{url_prefix}/data/{file_name}"
return cls(
id=source_node.node.node_id,
metadata=metadata,
score=source_node.score,
text=source_node.node.text, # type: ignore
url=url,
)
@classmethod
def from_source_nodes(cls, source_nodes: List[NodeWithScore]):
return [cls.from_source_node(node) for node in source_nodes]
class Result(BaseModel):
result: Message
nodes: List[SourceNodes]
class ChatConfig(BaseModel):
starter_questions: Optional[List[str]] = Field(
default=None,
description="List of starter questions",
)
class Config:
json_schema_extra = {
"example": {
"starterQuestions": [
"What standards for letters exist?",
"What are the requirements for a letter to be considered a letter?",
]
}
}
alias_generator = to_camel
@@ -0,0 +1 @@
STORAGE_DIR = "storage" # directory to save the stores to (document store and if used, the `SimpleVectorStore`)
@@ -7,8 +7,11 @@ import logging
from llama_index.core.settings import Settings
from llama_index.core.ingestion import IngestionPipeline
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.vector_stores import SimpleVectorStore
from llama_index.core.storage.docstore import SimpleDocumentStore
from llama_index.core.storage import StorageContext
from llama_index.core import VectorStoreIndex
from app.constants import STORAGE_DIR
from app.settings import init_settings
from app.engine.loaders import get_documents
from app.engine.vectordb import get_vector_store
@@ -17,21 +20,18 @@ from app.engine.vectordb import get_vector_store
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()
STORAGE_DIR = os.getenv("STORAGE_DIR", "storage")
def get_doc_store():
# If the storage directory is there, load the document store from it.
# If not, set up an in-memory document store since we can't load from a directory that doesn't exist.
if os.path.exists(STORAGE_DIR):
return SimpleDocumentStore.from_persist_dir(STORAGE_DIR)
if not os.path.exists(STORAGE_DIR):
docstore = SimpleDocumentStore()
return docstore
else:
return SimpleDocumentStore()
return SimpleDocumentStore.from_persist_dir(STORAGE_DIR)
def run_pipeline(docstore, vector_store, documents):
pipeline = IngestionPipeline(
def run_ingestion_pipeline(docstore, vector_store, documents):
# Create ingestion pipeline
ingestion_pipeline = IngestionPipeline(
transformations=[
SentenceSplitter(
chunk_size=Settings.chunk_size,
@@ -41,20 +41,32 @@ def run_pipeline(docstore, vector_store, documents):
],
docstore=docstore,
docstore_strategy="upserts_and_delete",
vector_store=vector_store,
)
# llama_index having an typing issue when passing vector_store to IngestionPipeline
# so we need to set it manually after initialization
ingestion_pipeline.vector_store = vector_store
# Run the ingestion pipeline and store the results
nodes = pipeline.run(show_progress=True, documents=documents)
nodes = ingestion_pipeline.run(show_progress=True, documents=documents)
return nodes
def persist_storage(docstore, vector_store):
def persist_storage(docstore, vector_store, nodes):
storage_context = StorageContext.from_defaults(
docstore=docstore,
vector_store=vector_store,
)
# SimpleVectorStore does not include index by default
# so we need to create the index manually
# can be removed if using other vector store
if isinstance(vector_store, SimpleVectorStore):
VectorStoreIndex(
nodes=nodes,
storage_context=storage_context,
store_nodes_override=True, # Need enable this to store the nodes and index's id
)
storage_context.persist(STORAGE_DIR)
@@ -68,10 +80,14 @@ def generate_datasource():
vector_store = get_vector_store()
# Run the ingestion pipeline
_ = run_pipeline(docstore, vector_store, documents)
nodes = run_ingestion_pipeline(
docstore=docstore,
vector_store=vector_store,
documents=documents,
)
# Build the index and persist storage
persist_storage(docstore, vector_store)
persist_storage(docstore, vector_store, nodes)
logger.info("Finished generating the index")
@@ -1,17 +1,27 @@
import logging
from llama_index.core.indices import VectorStoreIndex
from llama_index.core import load_index_from_storage
from llama_index.core.storage import StorageContext
from llama_index.core.indices.vector_store import VectorStoreIndex
from llama_index.core.vector_stores.simple import SimpleVectorStore
from app.constants import STORAGE_DIR
from app.engine.vectordb import get_vector_store
logger = logging.getLogger("uvicorn")
def get_index():
logger.info("Connecting vector store...")
logger.info("Loading the index...")
store = get_vector_store()
# Load the index from the vector store
# If you are using a vector store that doesn't store text,
# you must load the index from both the vector store and the document store
index = VectorStoreIndex.from_vector_store(store)
logger.info("Finished load index from vector store.")
# If the store is a SimpleVectorStore, we need to load the index from the storage
if isinstance(store, SimpleVectorStore):
index = load_index_from_storage(
StorageContext.from_defaults(
vector_store=store,
persist_dir=STORAGE_DIR,
)
)
else:
index = VectorStoreIndex.from_vector_store(store)
logger.info("Loaded index successfully.")
return index
@@ -5,19 +5,12 @@ from llama_index.core.settings import Settings
def init_settings():
model_provider = os.getenv("MODEL_PROVIDER")
match model_provider:
case "openai":
init_openai()
case "ollama":
init_ollama()
case "anthropic":
init_anthropic()
case "gemini":
init_gemini()
case "azure-openai":
init_azure_openai()
case _:
raise ValueError(f"Invalid model provider: {model_provider}")
if model_provider == "openai":
init_openai()
elif model_provider == "ollama":
init_ollama()
else:
raise ValueError(f"Invalid model provider: {model_provider}")
Settings.chunk_size = int(os.getenv("CHUNK_SIZE", "1024"))
Settings.chunk_overlap = int(os.getenv("CHUNK_OVERLAP", "20"))
@@ -26,12 +19,8 @@ def init_ollama():
from llama_index.llms.ollama import Ollama
from llama_index.embeddings.ollama import OllamaEmbedding
base_url = os.getenv("OLLAMA_BASE_URL") or "http://127.0.0.1:11434"
Settings.embed_model = OllamaEmbedding(
base_url=base_url,
model_name=os.getenv("EMBEDDING_MODEL"),
)
Settings.llm = Ollama(base_url=base_url, model=os.getenv("MODEL"))
Settings.embed_model = OllamaEmbedding(model_name=os.getenv("EMBEDDING_MODEL"))
Settings.llm = Ollama(model=os.getenv("MODEL"))
def init_openai():
@@ -53,75 +42,3 @@ def init_openai():
"dimensions": int(dimensions) if dimensions is not None else None,
}
Settings.embed_model = OpenAIEmbedding(**config)
def init_azure_openai():
from llama_index.llms.azure_openai import AzureOpenAI
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
from llama_index.core.constants import DEFAULT_TEMPERATURE
llm_deployment = os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT")
embedding_deployment = os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT")
max_tokens = os.getenv("LLM_MAX_TOKENS")
api_key = os.getenv("AZURE_OPENAI_API_KEY")
llm_config = {
"api_key": api_key,
"deployment_name": llm_deployment,
"model": os.getenv("MODEL"),
"temperature": float(os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)),
"max_tokens": int(max_tokens) if max_tokens is not None else None,
}
Settings.llm = AzureOpenAI(**llm_config)
dimensions = os.getenv("EMBEDDING_DIM")
embedding_config = {
"api_key": api_key,
"deployment_name": embedding_deployment,
"model": os.getenv("EMBEDDING_MODEL"),
"dimensions": int(dimensions) if dimensions is not None else None,
}
Settings.embed_model = AzureOpenAIEmbedding(**embedding_config)
def init_anthropic():
from llama_index.llms.anthropic import Anthropic
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
model_map: Dict[str, str] = {
"claude-3-opus": "claude-3-opus-20240229",
"claude-3-sonnet": "claude-3-sonnet-20240229",
"claude-3-haiku": "claude-3-haiku-20240307",
"claude-2.1": "claude-2.1",
"claude-instant-1.2": "claude-instant-1.2",
}
embed_model_map: Dict[str, str] = {
"all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2",
"all-mpnet-base-v2": "sentence-transformers/all-mpnet-base-v2",
}
Settings.llm = Anthropic(model=model_map[os.getenv("MODEL")])
Settings.embed_model = HuggingFaceEmbedding(
model_name=embed_model_map[os.getenv("EMBEDDING_MODEL")]
)
def init_gemini():
from llama_index.llms.gemini import Gemini
from llama_index.embeddings.gemini import GeminiEmbedding
model_map: Dict[str, str] = {
"gemini-1.5-pro-latest": "models/gemini-1.5-pro-latest",
"gemini-pro": "models/gemini-pro",
"gemini-pro-vision": "models/gemini-pro-vision",
}
embed_model_map: Dict[str, str] = {
"embedding-001": "models/embedding-001",
"text-embedding-004": "models/text-embedding-004",
}
Settings.llm = Gemini(model=model_map[os.getenv("MODEL")])
Settings.embed_model = GeminiEmbedding(
model_name=embed_model_map[os.getenv("EMBEDDING_MODEL")]
)
+1 -1
View File
@@ -1,3 +1,3 @@
__pycache__
storage
.env
.env
+1 -11
View File
@@ -11,7 +11,6 @@ from fastapi.responses import RedirectResponse
from app.api.routers.chat import chat_router
from app.settings import init_settings
from app.observability import init_observability
from fastapi.staticfiles import StaticFiles
app = FastAPI()
@@ -21,6 +20,7 @@ init_observability()
environment = os.getenv("ENVIRONMENT", "dev") # Default to 'development' if not set
if environment == "dev":
logger = logging.getLogger("uvicorn")
logger.warning("Running in development mode - allowing CORS for all origins")
@@ -38,16 +38,6 @@ if environment == "dev":
return RedirectResponse(url="/docs")
def mount_static_files(directory, path):
if os.path.exists(directory):
app.mount(path, StaticFiles(directory=directory), name=f"{directory}-static")
# Mount the data files to serve the file viewer
mount_static_files("data", "/api/files/data")
# Mount the output files from tools
mount_static_files("tool-output", "/api/files/tool-output")
app.include_router(chat_router, prefix="/api/chat")
@@ -14,9 +14,8 @@ fastapi = "^0.109.1"
uvicorn = { extras = ["standard"], version = "^0.23.2" }
python-dotenv = "^1.0.0"
aiostream = "^0.5.2"
llama-index = "0.10.41"
llama-index-core = "0.10.41"
cachetools = "^5.3.3"
llama-index = "0.10.28"
llama-index-core = "0.10.28"
[build-system]
requires = ["poetry-core"]
@@ -1,11 +0,0 @@
import { NextResponse } from "next/server";
/**
* This API is to get config from the backend envs and expose them to the frontend
*/
export async function GET() {
const config = {
starterQuestions: process.env.CONVERSATION_STARTERS?.trim().split("\n"),
};
return NextResponse.json(config, { status: 200 });
}
@@ -1,17 +1,10 @@
import {
Anthropic,
GEMINI_EMBEDDING_MODEL,
GEMINI_MODEL,
Gemini,
GeminiEmbedding,
Ollama,
OllamaEmbedding,
OpenAI,
OpenAIEmbedding,
Settings,
} from "llamaindex";
import { HuggingFaceEmbedding } from "llamaindex/embeddings/HuggingFaceEmbedding";
import { OllamaEmbedding } from "llamaindex/embeddings/OllamaEmbedding";
import { ALL_AVAILABLE_ANTHROPIC_MODELS } from "llamaindex/llm/anthropic";
import { Ollama } from "llamaindex/llm/ollama";
const CHUNK_SIZE = 512;
const CHUNK_OVERLAP = 20;
@@ -19,21 +12,10 @@ const CHUNK_OVERLAP = 20;
export const initSettings = async () => {
// HINT: you can delete the initialization code for unused model providers
console.log(`Using '${process.env.MODEL_PROVIDER}' model provider`);
if (!process.env.MODEL || !process.env.EMBEDDING_MODEL) {
throw new Error("'MODEL' and 'EMBEDDING_MODEL' env variables must be set.");
}
switch (process.env.MODEL_PROVIDER) {
case "ollama":
initOllama();
break;
case "anthropic":
initAnthropic();
break;
case "gemini":
initGemini();
break;
default:
initOpenAI();
break;
@@ -45,9 +27,7 @@ export const initSettings = async () => {
function initOpenAI() {
Settings.llm = new OpenAI({
model: process.env.MODEL ?? "gpt-3.5-turbo",
maxTokens: process.env.LLM_MAX_TOKENS
? Number(process.env.LLM_MAX_TOKENS)
: undefined,
maxTokens: 512,
});
Settings.embedModel = new OpenAIEmbedding({
model: process.env.EMBEDDING_MODEL,
@@ -58,37 +38,15 @@ function initOpenAI() {
}
function initOllama() {
const config = {
host: process.env.OLLAMA_BASE_URL ?? "http://127.0.0.1:11434",
};
if (!process.env.MODEL || !process.env.EMBEDDING_MODEL) {
throw new Error(
"Using Ollama as model provider, 'MODEL' and 'EMBEDDING_MODEL' env variables must be set.",
);
}
Settings.llm = new Ollama({
model: process.env.MODEL ?? "",
config,
});
Settings.embedModel = new OllamaEmbedding({
model: process.env.EMBEDDING_MODEL ?? "",
config,
});
}
function initAnthropic() {
const embedModelMap: Record<string, string> = {
"all-MiniLM-L6-v2": "Xenova/all-MiniLM-L6-v2",
"all-mpnet-base-v2": "Xenova/all-mpnet-base-v2",
};
Settings.llm = new Anthropic({
model: process.env.MODEL as keyof typeof ALL_AVAILABLE_ANTHROPIC_MODELS,
});
Settings.embedModel = new HuggingFaceEmbedding({
modelType: embedModelMap[process.env.EMBEDDING_MODEL!],
});
}
function initGemini() {
Settings.llm = new Gemini({
model: process.env.MODEL as GEMINI_MODEL,
});
Settings.embedModel = new GeminiEmbedding({
model: process.env.EMBEDDING_MODEL as GEMINI_EMBEDDING_MODEL,
});
}
@@ -1,5 +1,4 @@
import {
JSONValue,
StreamData,
createCallbacksTransformer,
createStreamDataTransformer,
@@ -7,118 +6,44 @@ import {
type AIStreamCallbacksAndOptions,
} from "ai";
import {
MessageContent,
MessageContentDetail,
Metadata,
NodeWithScore,
Response,
ToolCallLLMMessageOptions,
StreamingAgentChatResponse,
} from "llamaindex";
import { appendImageData, appendSourceData } from "./stream-helper";
import { AgentStreamChatResponse } from "llamaindex/agent/base";
import { CsvFile, appendSourceData } from "./stream-helper";
type LlamaIndexResponse =
| AgentStreamChatResponse<ToolCallLLMMessageOptions>
| Response;
export const convertMessageContent = (
content: string,
annotations?: JSONValue[],
): MessageContent => {
if (!annotations) return content;
return [
{
type: "text",
text: content,
},
...convertAnnotations(annotations),
];
};
const convertAnnotations = (
annotations: JSONValue[],
): MessageContentDetail[] => {
const content: MessageContentDetail[] = [];
annotations.forEach((annotation: JSONValue) => {
// first skip invalid annotation
if (
!(
annotation &&
typeof annotation === "object" &&
"type" in annotation &&
"data" in annotation &&
annotation.data &&
typeof annotation.data === "object"
)
) {
console.log(
"Client sent invalid annotation. Missing data and type",
annotation,
);
return;
}
const { type, data } = annotation;
// convert image
if (type === "image" && "url" in data && typeof data.url === "string") {
content.push({
type: "image_url",
image_url: {
url: data.url,
},
});
}
// convert CSV files to text
if (type === "csv" && "csvFiles" in data && Array.isArray(data.csvFiles)) {
const rawContents = data.csvFiles.map((csv) => {
return "```csv\n" + (csv as CsvFile).content + "\n```";
});
const csvContent =
"Use data from following CSV raw contents:\n" +
rawContents.join("\n\n");
content.push({
type: "text",
text: csvContent,
});
}
});
return content;
type ParserOptions = {
image_url?: string;
};
function createParser(
res: AsyncIterable<LlamaIndexResponse>,
res: AsyncIterable<Response>,
data: StreamData,
opts?: ParserOptions,
) {
const it = res[Symbol.asyncIterator]();
const trimStartOfStream = trimStartOfStreamHelper();
let sourceNodes: NodeWithScore<Metadata>[] | undefined;
return new ReadableStream<string>({
start() {
appendImageData(data, opts?.image_url);
},
async pull(controller): Promise<void> {
const { value, done } = await it.next();
if (done) {
if (sourceNodes) {
appendSourceData(data, sourceNodes);
}
appendSourceData(data, sourceNodes);
controller.close();
data.close();
return;
}
let delta;
if (value instanceof Response) {
// handle Response type
if (value.sourceNodes) {
// get source nodes from the first response
sourceNodes = value.sourceNodes;
}
delta = value.response ?? "";
} else {
// handle other types
delta = value.response.delta;
if (!sourceNodes) {
// get source nodes from the first response
sourceNodes = value.sourceNodes;
}
const text = trimStartOfStream(delta ?? "");
const text = trimStartOfStream(value.response ?? "");
if (text) {
controller.enqueue(text);
}
@@ -127,13 +52,21 @@ function createParser(
}
export function LlamaIndexStream(
response: AsyncIterable<LlamaIndexResponse>,
response: StreamingAgentChatResponse | AsyncIterable<Response>,
data: StreamData,
opts?: {
callbacks?: AIStreamCallbacksAndOptions;
parserOptions?: ParserOptions;
},
): ReadableStream<Uint8Array> {
return createParser(response, data)
.pipeThrough(createCallbacksTransformer(opts?.callbacks))
.pipeThrough(createStreamDataTransformer());
): { stream: ReadableStream; data: StreamData } {
const res =
response instanceof StreamingAgentChatResponse
? response.response
: response;
return {
stream: createParser(res, data, opts?.parserOptions)
.pipeThrough(createCallbacksTransformer(opts?.callbacks))
.pipeThrough(createStreamDataTransformer()),
data,
};
}
@@ -1,11 +1,11 @@
import { initObservability } from "@/app/observability";
import { Message, StreamData, StreamingTextResponse } from "ai";
import { ChatMessage, Settings } from "llamaindex";
import { ChatMessage, MessageContent, Settings } from "llamaindex";
import { NextRequest, NextResponse } from "next/server";
import { createChatEngine } from "./engine/chat";
import { initSettings } from "./engine/settings";
import { LlamaIndexStream, convertMessageContent } from "./llamaindex-stream";
import { createCallbackManager, createStreamTimeout } from "./stream-helper";
import { LlamaIndexStream } from "./llamaindex-stream";
import { appendEventData } from "./stream-helper";
initObservability();
initSettings();
@@ -13,14 +13,29 @@ initSettings();
export const runtime = "nodejs";
export const dynamic = "force-dynamic";
export async function POST(request: NextRequest) {
// Init Vercel AI StreamData and timeout
const vercelStreamData = new StreamData();
const streamTimeout = createStreamTimeout(vercelStreamData);
const convertMessageContent = (
textMessage: string,
imageUrl: string | undefined,
): MessageContent => {
if (!imageUrl) return textMessage;
return [
{
type: "text",
text: textMessage,
},
{
type: "image_url",
image_url: {
url: imageUrl,
},
},
];
};
export async function POST(request: NextRequest) {
try {
const body = await request.json();
const { messages }: { messages: Message[] } = body;
const { messages, data }: { messages: Message[]; data: any } = body;
const userMessage = messages.pop();
if (!messages || !userMessage || userMessage.role !== "user") {
return NextResponse.json(
@@ -34,39 +49,41 @@ export async function POST(request: NextRequest) {
const chatEngine = await createChatEngine();
let annotations = userMessage.annotations;
if (!annotations) {
// the user didn't send any new annotations with the last message
// so use the annotations from the last user message that has annotations
// REASON: GPT4 doesn't consider MessageContentDetail from previous messages, only strings
annotations = messages
.slice()
.reverse()
.find(
(message) => message.role === "user" && message.annotations,
)?.annotations;
}
// Convert message content from Vercel/AI format to LlamaIndex/OpenAI format
const userMessageContent = convertMessageContent(
userMessage.content,
annotations,
data?.imageUrl,
);
// Setup callbacks
const callbackManager = createCallbackManager(vercelStreamData);
// Init Vercel AI StreamData
const vercelStreamData = new StreamData();
appendEventData(
vercelStreamData,
`Retrieving context for query: '${userMessage.content}'`,
);
// Setup callback for streaming data before chatting
Settings.callbackManager.on("retrieve", (data) => {
const { nodes } = data.detail;
appendEventData(
vercelStreamData,
`Retrieved ${nodes.length} sources to use as context for the query`,
);
});
// Calling LlamaIndex's ChatEngine to get a streamed response
const response = await Settings.withCallbackManager(callbackManager, () => {
return chatEngine.chat({
message: userMessageContent,
chatHistory: messages as ChatMessage[],
stream: true,
});
const response = await chatEngine.chat({
message: userMessageContent,
chatHistory: messages as ChatMessage[],
stream: true,
});
// Transform LlamaIndex stream to Vercel/AI format
const stream = LlamaIndexStream(response, vercelStreamData);
const { stream } = LlamaIndexStream(response, vercelStreamData, {
parserOptions: {
image_url: data?.imageUrl,
},
});
// Return a StreamingTextResponse, which can be consumed by the Vercel/AI client
return new StreamingTextResponse(stream, {}, vercelStreamData);
@@ -80,7 +97,5 @@ export async function POST(request: NextRequest) {
status: 500,
},
);
} finally {
clearTimeout(streamTimeout);
}
}
@@ -1,26 +1,14 @@
import { StreamData } from "ai";
import {
CallbackManager,
Metadata,
NodeWithScore,
ToolCall,
ToolOutput,
} from "llamaindex";
import { Metadata, NodeWithScore } from "llamaindex";
function getNodeUrl(metadata: Metadata) {
const url = metadata["URL"];
if (url) return url;
const fileName = metadata["file_name"];
if (!process.env.FILESERVER_URL_PREFIX) {
console.warn(
"FILESERVER_URL_PREFIX is not set. File URLs will not be generated.",
);
return undefined;
}
if (fileName) {
return `${process.env.FILESERVER_URL_PREFIX}/data/${fileName}`;
}
return undefined;
export function appendImageData(data: StreamData, imageUrl?: string) {
if (!imageUrl) return;
data.appendMessageAnnotation({
type: "image",
data: {
url: imageUrl,
},
});
}
export function appendSourceData(
@@ -35,7 +23,6 @@ export function appendSourceData(
...node.node.toMutableJSON(),
id: node.node.id_,
score: node.score ?? null,
url: getNodeUrl(node.node.metadata),
})),
},
});
@@ -50,71 +37,3 @@ export function appendEventData(data: StreamData, title?: string) {
},
});
}
export function appendToolData(
data: StreamData,
toolCall: ToolCall,
toolOutput: ToolOutput,
) {
data.appendMessageAnnotation({
type: "tools",
data: {
toolCall: {
id: toolCall.id,
name: toolCall.name,
input: toolCall.input,
},
toolOutput: {
output: toolOutput.output,
isError: toolOutput.isError,
},
},
});
}
export function createStreamTimeout(stream: StreamData) {
const timeout = Number(process.env.STREAM_TIMEOUT ?? 1000 * 60 * 5); // default to 5 minutes
const t = setTimeout(() => {
appendEventData(stream, `Stream timed out after ${timeout / 1000} seconds`);
stream.close();
}, timeout);
return t;
}
export function createCallbackManager(stream: StreamData) {
const callbackManager = new CallbackManager();
callbackManager.on("retrieve", (data) => {
const { nodes, query } = data.detail;
appendEventData(stream, `Retrieving context for query: '${query}'`);
appendEventData(
stream,
`Retrieved ${nodes.length} sources to use as context for the query`,
);
});
callbackManager.on("llm-tool-call", (event) => {
const { name, input } = event.detail.payload.toolCall;
const inputString = Object.entries(input)
.map(([key, value]) => `${key}: ${value}`)
.join(", ");
appendEventData(
stream,
`Using tool: '${name}' with inputs: '${inputString}'`,
);
});
callbackManager.on("llm-tool-result", (event) => {
const { toolCall, toolResult } = event.detail.payload;
appendToolData(stream, toolCall, toolResult);
});
return callbackManager;
}
export type CsvFile = {
content: string;
filename: string;
filesize: number;
id: string;
};
@@ -1,45 +0,0 @@
import { readFile } from "fs/promises";
import { NextRequest, NextResponse } from "next/server";
import path from "path";
/**
* This API is to get file data from allowed folders
* It receives path slug and response file data like serve static file
*/
export async function GET(
_request: NextRequest,
{ params }: { params: { slug: string[] } },
) {
const slug = params.slug;
if (!slug) {
return NextResponse.json({ detail: "Missing file slug" }, { status: 400 });
}
if (slug.includes("..") || path.isAbsolute(path.join(...slug))) {
return NextResponse.json({ detail: "Invalid file path" }, { status: 400 });
}
const [folder, ...pathTofile] = params.slug; // data, file.pdf
const allowedFolders = ["data", "tool-output"];
if (!allowedFolders.includes(folder)) {
return NextResponse.json({ detail: "No permission" }, { status: 400 });
}
try {
const filePath = path.join(process.cwd(), folder, path.join(...pathTofile));
const blob = await readFile(filePath);
return new NextResponse(blob, {
status: 200,
statusText: "OK",
headers: {
"Content-Length": blob.byteLength.toString(),
},
});
} catch (error) {
console.error(error);
return NextResponse.json({ detail: "File not found" }, { status: 404 });
}
}
@@ -2,10 +2,8 @@
import { useChat } from "ai/react";
import { ChatInput, ChatMessages } from "./ui/chat";
import { useClientConfig } from "./ui/chat/use-config";
export default function ChatSection() {
const { chatAPI } = useClientConfig();
const {
messages,
input,
@@ -14,15 +12,12 @@ export default function ChatSection() {
handleInputChange,
reload,
stop,
append,
setInput,
} = useChat({
api: chatAPI,
api: process.env.NEXT_PUBLIC_CHAT_API,
headers: {
"Content-Type": "application/json", // using JSON because of vercel/ai 2.2.26
},
onError: (error: unknown) => {
if (!(error instanceof Error)) throw error;
onError: (error) => {
const message = JSON.parse(error.message);
alert(message.detail);
},
@@ -35,16 +30,13 @@ export default function ChatSection() {
isLoading={isLoading}
reload={reload}
stop={stop}
append={append}
/>
<ChatInput
input={input}
handleSubmit={handleSubmit}
handleInputChange={handleInputChange}
isLoading={isLoading}
messages={messages}
append={append}
setInput={setInput}
multiModal={true}
/>
</div>
);
@@ -7,7 +7,7 @@ export default function Header() {
Get started by editing&nbsp;
<code className="font-mono font-bold">app/page.tsx</code>
</p>
<div className="fixed bottom-0 left-0 mb-4 flex h-auto w-full items-end justify-center bg-gradient-to-t from-white via-white dark:from-black dark:via-black lg:static lg:w-auto lg:bg-none lg:mb-0">
<div className="fixed bottom-0 left-0 flex h-48 w-full items-end justify-center bg-gradient-to-t from-white via-white dark:from-black dark:via-black lg:static lg:h-auto lg:w-auto lg:bg-none">
<a
href="https://www.llamaindex.ai/"
className="flex items-center justify-center font-nunito text-lg font-bold gap-2"
@@ -38,9 +38,7 @@ export function ChatEvents({
<CollapsibleContent asChild>
<div className="mt-4 text-sm space-y-2">
{data.map((eventItem, index) => (
<div className="whitespace-break-spaces" key={index}>
{eventItem.title}
</div>
<div key={index}>{eventItem.title}</div>
))}
</div>
</CollapsibleContent>
@@ -1,14 +1,9 @@
import { JSONValue } from "ai";
import { useState } from "react";
import { v4 as uuidv4 } from "uuid";
import { MessageAnnotation, MessageAnnotationType } from ".";
import { Button } from "../button";
import FileUploader from "../file-uploader";
import { Input } from "../input";
import UploadCsvPreview from "../upload-csv-preview";
import UploadImagePreview from "../upload-image-preview";
import { ChatHandler } from "./chat.interface";
import { useCsv } from "./use-csv";
export default function ChatInput(
props: Pick<
@@ -19,61 +14,18 @@ export default function ChatInput(
| "onFileError"
| "handleSubmit"
| "handleInputChange"
| "messages"
| "setInput"
| "append"
>,
> & {
multiModal?: boolean;
},
) {
const [imageUrl, setImageUrl] = useState<string | null>(null);
const { files: csvFiles, upload, remove, reset } = useCsv();
const getAnnotations = () => {
if (!imageUrl && csvFiles.length === 0) return undefined;
const annotations: MessageAnnotation[] = [];
if (imageUrl) {
annotations.push({
type: MessageAnnotationType.IMAGE,
data: { url: imageUrl },
});
}
if (csvFiles.length > 0) {
annotations.push({
type: MessageAnnotationType.CSV,
data: {
csvFiles: csvFiles.map((file) => ({
id: file.id,
content: file.content,
filename: file.filename,
filesize: file.filesize,
})),
},
});
}
return annotations as JSONValue[];
};
// default submit function does not handle including annotations in the message
// so we need to use append function to submit new message with annotations
const handleSubmitWithAnnotations = (
e: React.FormEvent<HTMLFormElement>,
annotations: JSONValue[] | undefined,
) => {
e.preventDefault();
props.append!({
content: props.input,
role: "user",
createdAt: new Date(),
annotations,
});
props.setInput!("");
};
const onSubmit = (e: React.FormEvent<HTMLFormElement>) => {
const annotations = getAnnotations();
if (annotations) {
handleSubmitWithAnnotations(e, annotations);
imageUrl && setImageUrl(null);
csvFiles.length && reset();
if (imageUrl) {
props.handleSubmit(e, {
data: { imageUrl: imageUrl },
});
setImageUrl(null);
return;
}
props.handleSubmit(e);
@@ -91,36 +43,11 @@ export default function ChatInput(
setImageUrl(base64);
};
const handleUploadCsvFile = async (file: File) => {
const content = await new Promise<string>((resolve, reject) => {
const reader = new FileReader();
reader.readAsText(file);
reader.onload = () => resolve(reader.result as string);
reader.onerror = (error) => reject(error);
});
const isSuccess = upload({
id: uuidv4(),
content,
filename: file.name,
filesize: file.size,
});
if (!isSuccess) {
alert("File already exists in the list.");
}
};
const handleUploadFile = async (file: File) => {
try {
if (file.type.startsWith("image/")) {
if (props.multiModal && file.type.startsWith("image/")) {
return await handleUploadImageFile(file);
}
if (file.type === "text/csv") {
if (csvFiles.length > 0) {
alert("You can only upload one csv file at a time.");
return;
}
return await handleUploadCsvFile(file);
}
props.onFileUpload?.(file);
} catch (error: any) {
props.onFileError?.(error.message);
@@ -135,19 +62,6 @@ export default function ChatInput(
{imageUrl && (
<UploadImagePreview url={imageUrl} onRemove={onRemovePreviewImage} />
)}
{csvFiles.length > 0 && (
<div className="flex gap-4 w-full overflow-auto py-2">
{csvFiles.map((csv) => {
return (
<UploadCsvPreview
key={csv.id}
csv={csv}
onRemove={() => remove(csv)}
/>
);
})}
</div>
)}
<div className="flex w-full items-start justify-between gap-4 ">
<Input
autoFocus
@@ -161,7 +75,7 @@ export default function ChatInput(
onFileUpload={handleUploadFile}
onFileError={props.onFileError}
/>
<Button type="submit" disabled={props.isLoading || !props.input.trim()}>
<Button type="submit" disabled={props.isLoading}>
Send message
</Button>
</div>
@@ -7,17 +7,13 @@ import ChatAvatar from "./chat-avatar";
import { ChatEvents } from "./chat-events";
import { ChatImage } from "./chat-image";
import { ChatSources } from "./chat-sources";
import ChatTools from "./chat-tools";
import CsvContent from "./csv-content";
import {
CsvData,
AnnotationData,
EventData,
ImageData,
MessageAnnotation,
MessageAnnotationType,
SourceData,
ToolData,
getAnnotationData,
} from "./index";
import Markdown from "./markdown";
import { useCopyToClipboard } from "./use-copy-to-clipboard";
@@ -27,6 +23,13 @@ type ContentDisplayConfig = {
component: JSX.Element | null;
};
function getAnnotationData<T extends AnnotationData>(
annotations: MessageAnnotation[],
type: MessageAnnotationType,
): T[] {
return annotations.filter((a) => a.type === type).map((a) => a.data as T);
}
function ChatMessageContent({
message,
isLoading,
@@ -41,10 +44,6 @@ function ChatMessageContent({
annotations,
MessageAnnotationType.IMAGE,
);
const csvData = getAnnotationData<CsvData>(
annotations,
MessageAnnotationType.CSV,
);
const eventData = getAnnotationData<EventData>(
annotations,
MessageAnnotationType.EVENTS,
@@ -53,37 +52,25 @@ function ChatMessageContent({
annotations,
MessageAnnotationType.SOURCES,
);
const toolData = getAnnotationData<ToolData>(
annotations,
MessageAnnotationType.TOOLS,
);
const contents: ContentDisplayConfig[] = [
{
order: 1,
order: -2,
component: imageData[0] ? <ChatImage data={imageData[0]} /> : null,
},
{
order: -3,
order: -1,
component:
eventData.length > 0 ? (
<ChatEvents isLoading={isLoading} data={eventData} />
) : null,
},
{
order: 2,
component: csvData[0] ? <CsvContent data={csvData[0]} /> : null,
},
{
order: -1,
component: toolData[0] ? <ChatTools data={toolData[0]} /> : null,
},
{
order: 0,
component: <Markdown content={message.content} />,
},
{
order: 3,
order: 1,
component: sourceData[0] ? <ChatSources data={sourceData[0]} /> : null,
},
];
@@ -1,19 +1,13 @@
import { Loader2 } from "lucide-react";
import { useEffect, useRef } from "react";
import { Button } from "../button";
import ChatActions from "./chat-actions";
import ChatMessage from "./chat-message";
import { ChatHandler } from "./chat.interface";
import { useClientConfig } from "./use-config";
export default function ChatMessages(
props: Pick<
ChatHandler,
"messages" | "isLoading" | "reload" | "stop" | "append"
>,
props: Pick<ChatHandler, "messages" | "isLoading" | "reload" | "stop">,
) {
const { starterQuestions } = useClientConfig();
const scrollableChatContainerRef = useRef<HTMLDivElement>(null);
const messageLength = props.messages.length;
const lastMessage = props.messages[messageLength - 1];
@@ -41,21 +35,14 @@ export default function ChatMessages(
}, [messageLength, lastMessage]);
return (
<div className="w-full rounded-xl bg-white p-4 shadow-xl pb-0 relative">
<div className="w-full rounded-xl bg-white p-4 shadow-xl pb-0">
<div
className="flex h-[50vh] flex-col gap-5 divide-y overflow-y-auto pb-4"
ref={scrollableChatContainerRef}
>
{props.messages.map((m, i) => {
const isLoadingMessage = i === messageLength - 1 && props.isLoading;
return (
<ChatMessage
key={m.id}
chatMessage={m}
isLoading={isLoadingMessage}
/>
);
})}
{props.messages.map((m) => (
<ChatMessage key={m.id} chatMessage={m} isLoading={props.isLoading} />
))}
{isPending && (
<div className="flex justify-center items-center pt-10">
<Loader2 className="h-4 w-4 animate-spin" />
@@ -70,23 +57,6 @@ export default function ChatMessages(
showStop={showStop}
/>
</div>
{!messageLength && starterQuestions?.length && props.append && (
<div className="absolute bottom-6 left-0 w-full">
<div className="grid grid-cols-2 gap-2 mx-20">
{starterQuestions.map((question, i) => (
<Button
variant="outline"
key={i}
onClick={() =>
props.append!({ role: "user", content: question })
}
>
{question}
</Button>
))}
</div>
</div>
)}
</div>
);
}
@@ -1,46 +1,20 @@
import { Check, Copy } from "lucide-react";
import { ArrowUpRightSquare, Check, Copy } from "lucide-react";
import { useMemo } from "react";
import { Button } from "../button";
import { HoverCard, HoverCardContent, HoverCardTrigger } from "../hover-card";
import { SourceData } from "./index";
import { SourceData, SourceNode } from "./index";
import { useCopyToClipboard } from "./use-copy-to-clipboard";
import PdfDialog from "./widgets/PdfDialog";
const SCORE_THRESHOLD = 0.3;
function SourceNumberButton({ index }: { index: number }) {
return (
<div className="text-xs w-5 h-5 rounded-full bg-gray-100 mb-2 flex items-center justify-center hover:text-white hover:bg-primary hover:cursor-pointer">
{index + 1}
</div>
);
}
type NodeInfo = {
id: string;
url?: string;
};
const SCORE_THRESHOLD = 0.5;
export function ChatSources({ data }: { data: SourceData }) {
const sources: NodeInfo[] = useMemo(() => {
// aggregate nodes by url or file_path (get the highest one by score)
const nodesByPath: { [path: string]: NodeInfo } = {};
data.nodes
.filter((node) => (node.score ?? 1) > SCORE_THRESHOLD)
.sort((a, b) => (b.score ?? 1) - (a.score ?? 1))
.forEach((node) => {
const nodeInfo = {
id: node.id,
url: node.url,
};
const key = nodeInfo.url ?? nodeInfo.id; // use id as key for UNKNOWN type
if (!nodesByPath[key]) {
nodesByPath[key] = nodeInfo;
}
});
return Object.values(nodesByPath);
const sources = useMemo(() => {
return (
data.nodes
?.filter((node) => Object.keys(node.metadata).length > 0)
?.filter((node) => (node.score ?? 1) > SCORE_THRESHOLD)
.sort((a, b) => (b.score ?? 1) - (a.score ?? 1)) || []
);
}, [data.nodes]);
if (sources.length === 0) return null;
@@ -49,51 +23,55 @@ export function ChatSources({ data }: { data: SourceData }) {
<div className="space-x-2 text-sm">
<span className="font-semibold">Sources:</span>
<div className="inline-flex gap-1 items-center">
{sources.map((nodeInfo: NodeInfo, index: number) => {
if (nodeInfo.url?.endsWith(".pdf")) {
return (
<PdfDialog
key={nodeInfo.id}
documentId={nodeInfo.id}
url={nodeInfo.url!}
trigger={<SourceNumberButton index={index} />}
/>
);
}
return (
<div key={nodeInfo.id}>
<HoverCard>
<HoverCardTrigger>
<SourceNumberButton index={index} />
</HoverCardTrigger>
<HoverCardContent className="w-[320px]">
<NodeInfo nodeInfo={nodeInfo} />
</HoverCardContent>
</HoverCard>
</div>
);
})}
{sources.map((node: SourceNode, index: number) => (
<div key={node.id}>
<HoverCard>
<HoverCardTrigger>
<div className="text-xs w-5 h-5 rounded-full bg-gray-100 mb-2 flex items-center justify-center hover:text-white hover:bg-primary hover:cursor-pointer">
{index + 1}
</div>
</HoverCardTrigger>
<HoverCardContent>
<NodeInfo node={node} />
</HoverCardContent>
</HoverCard>
</div>
))}
</div>
</div>
);
}
function NodeInfo({ nodeInfo }: { nodeInfo: NodeInfo }) {
function NodeInfo({ node }: { node: SourceNode }) {
const { isCopied, copyToClipboard } = useCopyToClipboard({ timeout: 1000 });
if (nodeInfo.url) {
// this is a node generated by the web loader or file loader,
// add a link to view its URL and a button to copy the URL to the clipboard
if (typeof node.metadata["URL"] === "string") {
// this is a node generated by the web loader, it contains an external URL
// add a link to view this URL
return (
<div className="flex items-center my-2">
<a className="hover:text-blue-900" href={nodeInfo.url} target="_blank">
<span>{nodeInfo.url}</span>
</a>
<a
className="space-x-2 flex items-center my-2 hover:text-blue-900"
href={node.metadata["URL"]}
target="_blank"
>
<span>{node.metadata["URL"]}</span>
<ArrowUpRightSquare className="w-4 h-4" />
</a>
);
}
if (typeof node.metadata["file_path"] === "string") {
// this is a node generated by the file loader, it contains file path
// add a button to copy the path to the clipboard
const filePath = node.metadata["file_path"];
return (
<div className="flex items-center px-2 py-1 justify-between my-2">
<span>{filePath}</span>
<Button
onClick={() => copyToClipboard(nodeInfo.url!)}
onClick={() => copyToClipboard(filePath)}
size="icon"
variant="ghost"
className="h-12 w-12 shrink-0"
className="h-12 w-12"
>
{isCopied ? (
<Check className="h-4 w-4" />
@@ -106,6 +84,7 @@ function NodeInfo({ nodeInfo }: { nodeInfo: NodeInfo }) {
}
// node generated by unknown loader, implement renderer by analyzing logged out metadata
console.log("Node metadata", node.metadata);
return (
<p>
Sorry, unknown node type. Please add a new renderer in the NodeInfo
@@ -1,26 +0,0 @@
import { ToolData } from "./index";
import { WeatherCard, WeatherData } from "./widgets/WeatherCard";
// TODO: If needed, add displaying more tool outputs here
export default function ChatTools({ data }: { data: ToolData }) {
if (!data) return null;
const { toolCall, toolOutput } = data;
if (toolOutput.isError) {
return (
<div className="border-l-2 border-red-400 pl-2">
There was an error when calling the tool {toolCall.name} with input:{" "}
<br />
{JSON.stringify(toolCall.input)}
</div>
);
}
switch (toolCall.name) {
case "get_weather_information":
const weatherData = toolOutput.output as unknown as WeatherData;
return <WeatherCard data={weatherData} />;
default:
return null;
}
}
@@ -15,11 +15,4 @@ export interface ChatHandler {
stop?: () => void;
onFileUpload?: (file: File) => Promise<void>;
onFileError?: (errMsg: string) => void;
setInput?: (input: string) => void;
append?: (
message: Message | Omit<Message, "id">,
ops?: {
data: any;
},
) => Promise<string | null | undefined>;
}
@@ -1,13 +0,0 @@
import { CsvData } from ".";
import UploadCsvPreview from "../upload-csv-preview";
export default function CsvContent({ data }: { data: CsvData }) {
if (!data.csvFiles.length) return null;
return (
<div className="flex gap-2 items-center">
{data.csvFiles.map((csv, index) => (
<UploadCsvPreview key={index} csv={csv} />
))}
</div>
);
}
@@ -1,4 +1,3 @@
import { JSONValue } from "ai";
import ChatInput from "./chat-input";
import ChatMessages from "./chat-messages";
@@ -6,34 +5,20 @@ export { type ChatHandler } from "./chat.interface";
export { ChatInput, ChatMessages };
export enum MessageAnnotationType {
CSV = "csv",
IMAGE = "image",
SOURCES = "sources",
EVENTS = "events",
TOOLS = "tools",
}
export type ImageData = {
url: string;
};
export type CsvFile = {
content: string;
filename: string;
filesize: number;
id: string;
};
export type CsvData = {
csvFiles: CsvFile[];
};
export type SourceNode = {
id: string;
metadata: Record<string, unknown>;
score?: number;
text: string;
url?: string;
};
export type SourceData = {
@@ -45,35 +30,9 @@ export type EventData = {
isCollapsed: boolean;
};
export type ToolData = {
toolCall: {
id: string;
name: string;
input: {
[key: string]: JSONValue;
};
};
toolOutput: {
output: JSONValue;
isError: boolean;
};
};
export type AnnotationData =
| ImageData
| CsvData
| SourceData
| EventData
| ToolData;
export type AnnotationData = ImageData | SourceData | EventData;
export type MessageAnnotation = {
type: MessageAnnotationType;
data: AnnotationData;
};
export function getAnnotationData<T extends AnnotationData>(
annotations: MessageAnnotation[],
type: MessageAnnotationType,
): T[] {
return annotations.filter((a) => a.type === type).map((a) => a.data as T);
}

Some files were not shown because too many files have changed in this diff Show More