Compare commits

..

31 Commits

Author SHA1 Message Date
github-actions[bot] ed114856d9 Release 0.1.7 (#93)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2024-05-22 18:30:49 +07:00
Marcus Schiesser 69c2e16c82 fix: streaming for express 2024-05-22 13:04:35 +02:00
Marcus Schiesser f5da6623cf fix: update llamaindex, use 127.0.0.1 for ollama as default 2024-05-22 12:42:34 +02:00
Marcus Schiesser 0950cb90f2 fix: global-agent types 2024-05-22 11:50:34 +02:00
Mohammad Amir bb53425b4b Proxy support added via global agent (#76) 2024-05-22 16:35:03 +07:00
Huu Le (Lee) bbd5b8ddd6 fix: Reuse PG vector store to avoid recreating sqlalchemy engine (#95) 2024-05-22 16:12:44 +07:00
Thuc Pham 260d37a3f1 feat(ts): add system prompt for chat engine (#92) 2024-05-20 16:12:19 +07:00
Huu Le (Lee) 7873bfb030 chore: Add Ollama API base URL environment variable (#91) 2024-05-17 17:01:06 +07:00
github-actions[bot] 0c7c41ee3b Release 0.1.6 (#90)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2024-05-16 19:08:40 +07:00
Thuc Pham 56537a1473 feat: host local files and add viewer for PDFs (#85) 2024-05-16 18:06:26 +07:00
github-actions[bot] d8dfc29edd Release 0.1.5 (#89)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2024-05-16 16:12:40 +07:00
Thuc Pham 84db798353 feat: support display latex in chat markdown (#88) 2024-05-16 15:25:53 +07:00
github-actions[bot] 67a062af14 Release 0.1.4 (#86)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2024-05-14 20:08:48 +07:00
Marcus Schiesser 0bc8e75c64 docs: add changeset for ingestion pipeline 2024-05-14 15:07:40 +02:00
Huu Le (Lee) 6bd5e7b77a using ingestion pipeline for chromadb (#87) 2024-05-14 20:02:47 +07:00
Huu Le (Lee) 38bc1d1350 Use ingestion pipeline for dedicated vector stores (#74) 2024-05-14 18:58:07 +07:00
Huu Le (Lee) cb1001de95 feat: add support for ChromaDB vector store (#82) 2024-05-14 15:42:01 +07:00
github-actions[bot] 78776ac51e Release 0.1.3 (#84)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2024-05-13 20:27:42 +07:00
Marcus Schiesser 416073db1d fix: use CJS for express (otherwise qdrant doesn't work) and upgrade to 0.3.9 2024-05-13 15:18:45 +02:00
Huu Le (Lee) 84929de8b2 chore: Update vector store imports in vectordbs components (#83) 2024-05-13 19:55:23 +07:00
Huu Le (Lee) 6fe240b854 Merge pull request #81 from sagech/fix/store-qdrant-init
fix: qdrant store init parameters
2024-05-13 16:52:53 +07:00
Sam Cheng Hung 8bb1024d0f fix: qdrant store init parameters 2024-05-12 04:10:47 +08:00
github-actions[bot] 988bfc2a60 Release 0.1.2 (#79)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2024-05-10 14:12:31 +07:00
Thuc Pham 056e376ee0 feat: add weather widget and weather tool (#72)
---------
Co-authored-by: leehuwuj <leehuwuj@gmail.com>
Co-authored-by: Marcus Schiesser <mail@marcusschiesser.de>
2024-05-10 14:00:16 +07:00
Thuc Pham 819cccb11a feat: use 3.5 as default model (#77) 2024-05-09 15:48:25 +07:00
Huu Le (Lee) 8a5ece10c2 chores: update wrong example system prompt and fix missing switch breaking (#75) 2024-05-08 10:14:34 +07:00
github-actions[bot] 63bb0505d6 Release 0.1.1 (#60)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2024-05-03 10:38:01 +07:00
Huu Le (Lee) 2e80ef47ee Fix typo in settings.py (#73) 2024-05-03 10:36:12 +07:00
Marcus Schiesser a1feb524e9 Revert "Use ingestion pipeline in Python code (#61)"
This reverts commit c094b0c6bf.
2024-05-03 11:06:02 +08:00
Marcus Schiesser 06823da849 fix: stream type 2024-05-02 17:25:49 +08:00
Thuc Pham 7bd3ed551f feat: support anthropic and gemini model providers and update to LITS 0.3.3 (#63)
Co-authored-by: Marcus Schiesser <mail@marcusschiesser.de>
2024-05-02 16:13:31 +07:00
89 changed files with 2369 additions and 626 deletions
-5
View File
@@ -1,5 +0,0 @@
---
"create-llama": patch
---
Use ingestion pipeline for Python
-5
View File
@@ -1,5 +0,0 @@
---
"create-llama": patch
---
Display events (e.g. retrieving nodes) per chat message
+49
View File
@@ -1,5 +1,54 @@
# create-llama
## 0.1.7
### Patch Changes
- 260d37a: Add system prompt env variable for TS
- bbd5b8d: Fix postgres connection leaking issue
- bb53425: Support HTTP proxies by setting the GLOBAL_AGENT_HTTP_PROXY env variable
- 69c2e16: Fix streaming for Express
- 7873bfb: Update Ollama provider to run with the base URL from the environment variable
## 0.1.6
### Patch Changes
- 56537a1: Display PDF files in source nodes
## 0.1.5
### Patch Changes
- 84db798: feat: support display latex in chat markdown
## 0.1.4
### Patch Changes
- 0bc8e75: Use ingestion pipeline for dedicated vector stores (Python only)
- cb1001d: Add ChromaDB vector store
## 0.1.3
### Patch Changes
- 416073d: Directly import vector stores to work with NextJS
## 0.1.2
### Patch Changes
- 056e376: Add support for displaying tool outputs (including weather widget as example)
## 0.1.1
### Patch Changes
- 7bd3ed5: Support Anthropic and Gemini as model providers
- 7bd3ed5: Support new agents from LITS 0.3
- cfb5257: Display events (e.g. retrieving nodes) per chat message
## 0.1.0
### Minor Changes
+66 -18
View File
@@ -29,17 +29,20 @@ const renderEnvVar = (envVars: EnvVar[]): string => {
);
};
const getVectorDBEnvs = (vectorDb?: TemplateVectorDB): EnvVar[] => {
if (!vectorDb) {
const getVectorDBEnvs = (
vectorDb?: TemplateVectorDB,
framework?: TemplateFramework,
): EnvVar[] => {
if (!vectorDb || !framework) {
return [];
}
switch (vectorDb) {
case "mongo":
return [
{
name: "MONGO_URI",
name: "MONGODB_URI",
description:
"For generating a connection URI, see https://docs.timescale.com/use-timescale/latest/services/create-a-service\nThe MongoDB connection URI.",
"For generating a connection URI, see https://www.mongodb.com/docs/manual/reference/connection-string/ \nThe MongoDB connection URI.",
},
{
name: "MONGODB_DATABASE",
@@ -129,6 +132,31 @@ const getVectorDBEnvs = (vectorDb?: TemplateVectorDB): EnvVar[] => {
"Optional API key for authenticating requests to Qdrant.",
},
];
case "chroma":
const envs = [
{
name: "CHROMA_COLLECTION",
description: "The name of the collection in your Chroma database",
},
{
name: "CHROMA_HOST",
description: "The API endpoint for your Chroma database",
},
{
name: "CHROMA_PORT",
description: "The port for your Chroma database",
},
];
// TS Version doesn't support config local storage path
if (framework === "fastapi") {
envs.push({
name: "CHROMA_PATH",
description: `The local path to the Chroma database.
Specify this if you are using a local Chroma database.
Otherwise, use CHROMA_HOST and CHROMA_PORT config above`,
});
}
return envs;
default:
return [];
}
@@ -173,6 +201,33 @@ const getModelEnvs = (modelConfig: ModelConfig): EnvVar[] => {
},
]
: []),
...(modelConfig.provider === "anthropic"
? [
{
name: "ANTHROPIC_API_KEY",
description: "The Anthropic API key to use.",
value: modelConfig.apiKey,
},
]
: []),
...(modelConfig.provider === "gemini"
? [
{
name: "GOOGLE_API_KEY",
description: "The Google API key to use.",
value: modelConfig.apiKey,
},
]
: []),
...(modelConfig.provider === "ollama"
? [
{
name: "OLLAMA_BASE_URL",
description:
"The base URL for the Ollama API. Eg: http://127.0.0.1:11434",
},
]
: []),
];
};
@@ -194,19 +249,6 @@ const getFrameworkEnvs = (
description: "The port to start the backend app.",
value: port?.toString() || "8000",
},
// TODO: Once LlamaIndexTS supports string templates, move this to `getEngineEnvs`
{
name: "SYSTEM_PROMPT",
description: `Custom system prompt.
Example:
SYSTEM_PROMPT="
We have provided context information below.
---------------------
{context_str}
---------------------
Given this information, please answer the question: {query_str}
"`,
},
];
};
@@ -218,6 +260,12 @@ const getEngineEnvs = (): EnvVar[] => {
"The number of similar embeddings to return when retrieving documents.",
value: "3",
},
{
name: "SYSTEM_PROMPT",
description: `Custom system prompt.
Example:
SYSTEM_PROMPT="You are a helpful assistant who helps users with their questions."`,
},
];
};
@@ -245,7 +293,7 @@ export const createBackendEnvFile = async (
// Add engine environment variables
...getEngineEnvs(),
// Add vector database environment variables
...getVectorDBEnvs(opts.vectorDb),
...getVectorDBEnvs(opts.vectorDb, opts.framework),
...getFrameworkEnvs(opts.framework, opts.port),
];
// Render and write env file
+1 -2
View File
@@ -9,7 +9,6 @@ import { createBackendEnvFile, createFrontendEnvFile } from "./env-variables";
import { PackageManager } from "./get-pkg-manager";
import { installLlamapackProject } from "./llama-pack";
import { isHavingPoetryLockFile, tryPoetryRun } from "./poetry";
import { isModelConfigured } from "./providers";
import { installPythonTemplate } from "./python";
import { downloadAndExtractRepo } from "./repo";
import { ConfigFileType, writeToolsConfig } from "./tools";
@@ -38,7 +37,7 @@ async function generateContextData(
? "poetry run generate"
: `${packageManager} run generate`,
)}`;
const modelConfigured = isModelConfigured(modelConfig);
const modelConfigured = modelConfig.isConfigured();
const llamaCloudKeyConfigured = useLlamaParse
? llamaCloudKey || process.env["LLAMA_CLOUD_API_KEY"]
: true;
+106
View File
@@ -0,0 +1,106 @@
import ciInfo from "ci-info";
import prompts from "prompts";
import { ModelConfigParams } from ".";
import { questionHandlers, toChoice } from "../../questions";
const MODELS = [
"claude-3-opus",
"claude-3-sonnet",
"claude-3-haiku",
"claude-2.1",
"claude-instant-1.2",
];
const DEFAULT_MODEL = MODELS[0];
// TODO: get embedding vector dimensions from the anthropic sdk (currently not supported)
// Use huggingface embedding models for now
enum HuggingFaceEmbeddingModelType {
XENOVA_ALL_MINILM_L6_V2 = "all-MiniLM-L6-v2",
XENOVA_ALL_MPNET_BASE_V2 = "all-mpnet-base-v2",
}
type ModelData = {
dimensions: number;
};
const EMBEDDING_MODELS: Record<HuggingFaceEmbeddingModelType, ModelData> = {
[HuggingFaceEmbeddingModelType.XENOVA_ALL_MINILM_L6_V2]: {
dimensions: 384,
},
[HuggingFaceEmbeddingModelType.XENOVA_ALL_MPNET_BASE_V2]: {
dimensions: 768,
},
};
const DEFAULT_EMBEDDING_MODEL = Object.keys(EMBEDDING_MODELS)[0];
const DEFAULT_DIMENSIONS = Object.values(EMBEDDING_MODELS)[0].dimensions;
type AnthropicQuestionsParams = {
apiKey?: string;
askModels: boolean;
};
export async function askAnthropicQuestions({
askModels,
apiKey,
}: AnthropicQuestionsParams): Promise<ModelConfigParams> {
const config: ModelConfigParams = {
apiKey,
model: DEFAULT_MODEL,
embeddingModel: DEFAULT_EMBEDDING_MODEL,
dimensions: DEFAULT_DIMENSIONS,
isConfigured(): boolean {
if (config.apiKey) {
return true;
}
if (process.env["ANTHROPIC_API_KEY"]) {
return true;
}
return false;
},
};
if (!config.apiKey) {
const { key } = await prompts(
{
type: "text",
name: "key",
message:
"Please provide your Anthropic API key (or leave blank to use ANTHROPIC_API_KEY env variable):",
},
questionHandlers,
);
config.apiKey = key || process.env.ANTHROPIC_API_KEY;
}
// use default model values in CI or if user should not be asked
const useDefaults = ciInfo.isCI || !askModels;
if (!useDefaults) {
const { model } = await prompts(
{
type: "select",
name: "model",
message: "Which LLM model would you like to use?",
choices: MODELS.map(toChoice),
initial: 0,
},
questionHandlers,
);
config.model = model;
const { embeddingModel } = await prompts(
{
type: "select",
name: "embeddingModel",
message: "Which embedding model would you like to use?",
choices: Object.keys(EMBEDDING_MODELS).map(toChoice),
initial: 0,
},
questionHandlers,
);
config.embeddingModel = embeddingModel;
config.dimensions =
EMBEDDING_MODELS[
embeddingModel as HuggingFaceEmbeddingModelType
].dimensions;
}
return config;
}
+87
View File
@@ -0,0 +1,87 @@
import ciInfo from "ci-info";
import prompts from "prompts";
import { ModelConfigParams } from ".";
import { questionHandlers, toChoice } from "../../questions";
const MODELS = ["gemini-1.5-pro-latest", "gemini-pro", "gemini-pro-vision"];
type ModelData = {
dimensions: number;
};
const EMBEDDING_MODELS: Record<string, ModelData> = {
"embedding-001": { dimensions: 768 },
"text-embedding-004": { dimensions: 768 },
};
const DEFAULT_MODEL = MODELS[0];
const DEFAULT_EMBEDDING_MODEL = Object.keys(EMBEDDING_MODELS)[0];
const DEFAULT_DIMENSIONS = Object.values(EMBEDDING_MODELS)[0].dimensions;
type GeminiQuestionsParams = {
apiKey?: string;
askModels: boolean;
};
export async function askGeminiQuestions({
askModels,
apiKey,
}: GeminiQuestionsParams): Promise<ModelConfigParams> {
const config: ModelConfigParams = {
apiKey,
model: DEFAULT_MODEL,
embeddingModel: DEFAULT_EMBEDDING_MODEL,
dimensions: DEFAULT_DIMENSIONS,
isConfigured(): boolean {
if (config.apiKey) {
return true;
}
if (process.env["GOOGLE_API_KEY"]) {
return true;
}
return false;
},
};
if (!config.apiKey) {
const { key } = await prompts(
{
type: "text",
name: "key",
message:
"Please provide your Google API key (or leave blank to use GOOGLE_API_KEY env variable):",
},
questionHandlers,
);
config.apiKey = key || process.env.GOOGLE_API_KEY;
}
// use default model values in CI or if user should not be asked
const useDefaults = ciInfo.isCI || !askModels;
if (!useDefaults) {
const { model } = await prompts(
{
type: "select",
name: "model",
message: "Which LLM model would you like to use?",
choices: MODELS.map(toChoice),
initial: 0,
},
questionHandlers,
);
config.model = model;
const { embeddingModel } = await prompts(
{
type: "select",
name: "embeddingModel",
message: "Which embedding model would you like to use?",
choices: Object.keys(EMBEDDING_MODELS).map(toChoice),
initial: 0,
},
questionHandlers,
);
config.embeddingModel = embeddingModel;
config.dimensions = EMBEDDING_MODELS[embeddingModel].dimensions;
}
return config;
}
+11 -10
View File
@@ -2,8 +2,10 @@ import ciInfo from "ci-info";
import prompts from "prompts";
import { questionHandlers } from "../../questions";
import { ModelConfig, ModelProvider } from "../types";
import { askAnthropicQuestions } from "./anthropic";
import { askGeminiQuestions } from "./gemini";
import { askOllamaQuestions } from "./ollama";
import { askOpenAIQuestions, isOpenAIConfigured } from "./openai";
import { askOpenAIQuestions } from "./openai";
const DEFAULT_MODEL_PROVIDER = "openai";
@@ -31,6 +33,8 @@ export async function askModelConfig({
value: "openai",
},
{ title: "Ollama", value: "ollama" },
{ title: "Anthropic", value: "anthropic" },
{ title: "Gemini", value: "gemini" },
],
initial: 0,
},
@@ -44,6 +48,12 @@ export async function askModelConfig({
case "ollama":
modelConfig = await askOllamaQuestions({ askModels });
break;
case "anthropic":
modelConfig = await askAnthropicQuestions({ askModels });
break;
case "gemini":
modelConfig = await askGeminiQuestions({ askModels });
break;
default:
modelConfig = await askOpenAIQuestions({
openAiKey,
@@ -55,12 +65,3 @@ export async function askModelConfig({
provider: modelProvider,
};
}
export function isModelConfigured(modelConfig: ModelConfig): boolean {
switch (modelConfig.provider) {
case "openai":
return isOpenAIConfigured(modelConfig);
default:
return true;
}
}
+3
View File
@@ -29,6 +29,9 @@ export async function askOllamaQuestions({
model: DEFAULT_MODEL,
embeddingModel: DEFAULT_EMBEDDING_MODEL,
dimensions: EMBEDDING_MODELS[DEFAULT_EMBEDDING_MODEL].dimensions,
isConfigured(): boolean {
return true;
},
};
// use default model values in CI or if user should not be asked
+10 -12
View File
@@ -8,7 +8,7 @@ import { questionHandlers } from "../../questions";
const OPENAI_API_URL = "https://api.openai.com/v1";
const DEFAULT_MODEL = "gpt-4-turbo";
const DEFAULT_MODEL = "gpt-3.5-turbo";
const DEFAULT_EMBEDDING_MODEL = "text-embedding-3-large";
export async function askOpenAIQuestions({
@@ -20,6 +20,15 @@ export async function askOpenAIQuestions({
model: DEFAULT_MODEL,
embeddingModel: DEFAULT_EMBEDDING_MODEL,
dimensions: getDimensions(DEFAULT_EMBEDDING_MODEL),
isConfigured(): boolean {
if (config.apiKey) {
return true;
}
if (process.env["OPENAI_API_KEY"]) {
return true;
}
return false;
},
};
if (!config.apiKey) {
@@ -31,7 +40,6 @@ export async function askOpenAIQuestions({
? "Please provide your OpenAI API key (or leave blank to use OPENAI_API_KEY env variable):"
: "Please provide your OpenAI API key (leave blank to skip):",
validate: (value: string) => {
console.log(value);
if (askModels && !value) {
if (process.env.OPENAI_API_KEY) {
return true;
@@ -78,16 +86,6 @@ export async function askOpenAIQuestions({
return config;
}
export function isOpenAIConfigured(params: ModelConfigParams): boolean {
if (params.apiKey) {
return true;
}
if (process.env["OPENAI_API_KEY"]) {
return true;
}
return false;
}
async function getAvailableModelChoices(
selectEmbedding: boolean,
apiKey?: string,
+8
View File
@@ -0,0 +1,8 @@
/* Function to conditionally load the global-agent/bootstrap module */
export async function initializeGlobalAgent() {
if (process.env.GLOBAL_AGENT_HTTP_PROXY) {
/* Dynamically import global-agent/bootstrap */
await import("global-agent/bootstrap");
console.log("Proxy enabled via global-agent.");
}
}
+78 -33
View File
@@ -24,7 +24,7 @@ interface Dependency {
const getAdditionalDependencies = (
modelConfig: ModelConfig,
vectorDb?: TemplateVectorDB,
dataSource?: TemplateDataSource,
dataSources?: TemplateDataSource[],
tools?: Tool[],
) => {
const dependencies: Dependency[] = [];
@@ -43,6 +43,7 @@ const getAdditionalDependencies = (
name: "llama-index-vector-stores-postgres",
version: "^0.1.1",
});
break;
}
case "pinecone": {
dependencies.push({
@@ -69,41 +70,60 @@ const getAdditionalDependencies = (
});
break;
}
case "qdrant": {
dependencies.push({
name: "llama-index-vector-stores-qdrant",
version: "^0.2.8",
});
break;
}
case "chroma": {
dependencies.push({
name: "llama-index-vector-stores-chroma",
version: "^0.1.8",
});
break;
}
}
// Add data source dependencies
const dataSourceType = dataSource?.type;
switch (dataSourceType) {
case "file":
dependencies.push({
name: "docx2txt",
version: "^0.8",
});
break;
case "web":
dependencies.push({
name: "llama-index-readers-web",
version: "^0.1.6",
});
break;
case "db":
dependencies.push({
name: "llama-index-readers-database",
version: "^0.1.3",
});
dependencies.push({
name: "pymysql",
version: "^1.1.0",
extras: ["rsa"],
});
dependencies.push({
name: "psycopg2",
version: "^2.9.9",
});
break;
if (dataSources) {
for (const ds of dataSources) {
const dsType = ds?.type;
switch (dsType) {
case "file":
dependencies.push({
name: "docx2txt",
version: "^0.8",
});
break;
case "web":
dependencies.push({
name: "llama-index-readers-web",
version: "^0.1.6",
});
break;
case "db":
dependencies.push({
name: "llama-index-readers-database",
version: "^0.1.3",
});
dependencies.push({
name: "pymysql",
version: "^1.1.0",
extras: ["rsa"],
});
dependencies.push({
name: "psycopg2",
version: "^2.9.9",
});
break;
}
}
}
// Add tools dependencies
console.log("Adding tools dependencies");
tools?.forEach((tool) => {
tool.dependencies?.forEach((dep) => {
dependencies.push(dep);
@@ -127,6 +147,26 @@ const getAdditionalDependencies = (
version: "0.2.2",
});
break;
case "anthropic":
dependencies.push({
name: "llama-index-llms-anthropic",
version: "0.1.10",
});
dependencies.push({
name: "llama-index-embeddings-huggingface",
version: "0.2.0",
});
break;
case "gemini":
dependencies.push({
name: "llama-index-llms-gemini",
version: "0.1.7",
});
dependencies.push({
name: "llama-index-embeddings-gemini",
version: "0.1.6",
});
break;
}
return dependencies;
@@ -278,9 +318,14 @@ export const installPythonTemplate = async ({
cwd: path.join(compPath, "engines", "python", engine),
});
const addOnDependencies = dataSources
.map((ds) => getAdditionalDependencies(modelConfig, vectorDb, ds, tools))
.flat();
console.log("Adding additional dependencies");
const addOnDependencies = getAdditionalDependencies(
modelConfig,
vectorDb,
dataSources,
tools,
);
if (observability === "opentelemetry") {
addOnDependencies.push({
+27 -2
View File
@@ -5,12 +5,18 @@ import yaml from "yaml";
import { makeDir } from "./make-dir";
import { TemplateFramework } from "./types";
export enum ToolType {
LLAMAHUB = "llamahub",
LOCAL = "local",
}
export type Tool = {
display: string;
name: string;
config?: Record<string, any>;
dependencies?: ToolDependencies[];
supportedFrameworks?: Array<TemplateFramework>;
type: ToolType;
};
export type ToolDependencies = {
@@ -35,6 +41,7 @@ export const supportedTools: Tool[] = [
},
],
supportedFrameworks: ["fastapi"],
type: ToolType.LLAMAHUB,
},
{
display: "Wikipedia",
@@ -46,6 +53,14 @@ export const supportedTools: Tool[] = [
},
],
supportedFrameworks: ["fastapi", "express", "nextjs"],
type: ToolType.LLAMAHUB,
},
{
display: "Weather",
name: "weather",
dependencies: [],
supportedFrameworks: ["fastapi", "express", "nextjs"],
type: ToolType.LOCAL,
},
];
@@ -90,9 +105,19 @@ export const writeToolsConfig = async (
type: ConfigFileType = ConfigFileType.YAML,
) => {
if (tools.length === 0) return; // no tools selected, no config need
const configContent: Record<string, any> = {};
const configContent: {
[key in ToolType]: Record<string, any>;
} = {
local: {},
llamahub: {},
};
tools.forEach((tool) => {
configContent[tool.name] = tool.config ?? {};
if (tool.type === ToolType.LLAMAHUB) {
configContent.llamahub[tool.name] = tool.config ?? {};
}
if (tool.type === ToolType.LOCAL) {
configContent.local[tool.name] = tool.config ?? {};
}
});
const configPath = path.join(root, "config");
await makeDir(configPath);
+4 -2
View File
@@ -1,13 +1,14 @@
import { PackageManager } from "../helpers/get-pkg-manager";
import { Tool } from "./tools";
export type ModelProvider = "openai" | "ollama";
export type ModelProvider = "openai" | "ollama" | "anthropic" | "gemini";
export type ModelConfig = {
provider: ModelProvider;
apiKey?: string;
model: string;
embeddingModel: string;
dimensions: number;
isConfigured(): boolean;
};
export type TemplateType = "streaming" | "community" | "llamapack";
export type TemplateFramework = "nextjs" | "express" | "fastapi";
@@ -19,7 +20,8 @@ export type TemplateVectorDB =
| "pinecone"
| "milvus"
| "astra"
| "qdrant";
| "qdrant"
| "chroma";
export type TemplatePostInstallAction =
| "none"
| "VSCode"
+1 -1
View File
@@ -105,7 +105,7 @@ export const installTSTemplate = async ({
const enginePath = path.join(root, relativeEngineDestPath, "engine");
// copy vector db component
console.log("\nUsing vector DB:", vectorDb, "\n");
console.log("\nUsing vector DB:", vectorDb ?? "none", "\n");
await copy("**", enginePath, {
parents: true,
cwd: path.join(compPath, "vectordbs", "typescript", vectorDb ?? "none"),
+4
View File
@@ -12,12 +12,16 @@ import { createApp } from "./create-app";
import { getDataSources } from "./helpers/datasources";
import { getPkgManager } from "./helpers/get-pkg-manager";
import { isFolderEmpty } from "./helpers/is-folder-empty";
import { initializeGlobalAgent } from "./helpers/proxy";
import { runApp } from "./helpers/run-app";
import { getTools } from "./helpers/tools";
import { validateNpmName } from "./helpers/validate-pkg";
import packageJson from "./package.json";
import { QuestionArgs, askQuestions, onPromptState } from "./questions";
// Run the initialization function
initializeGlobalAgent();
let projectPath: string = "";
const handleSigTerm = () => process.exit(0);
+2 -1
View File
@@ -1,6 +1,6 @@
{
"name": "create-llama",
"version": "0.1.0",
"version": "0.1.7",
"description": "Create LlamaIndex-powered apps with one command",
"keywords": [
"rag",
@@ -52,6 +52,7 @@
"cross-spawn": "7.0.3",
"fast-glob": "3.3.1",
"fs-extra": "11.2.0",
"global-agent": "^3.0.0",
"got": "10.7.0",
"ollama": "^0.5.0",
"ora": "^8.0.1",
+265 -145
View File
File diff suppressed because it is too large Load Diff
+5 -5
View File
@@ -14,7 +14,7 @@ import { COMMUNITY_OWNER, COMMUNITY_REPO } from "./helpers/constant";
import { EXAMPLE_FILE } from "./helpers/datasources";
import { templatesDir } from "./helpers/dir";
import { getAvailableLlamapackOptions } from "./helpers/llama-pack";
import { askModelConfig, isModelConfigured } from "./helpers/providers";
import { askModelConfig } from "./helpers/providers";
import { getProjectOptions } from "./helpers/repo";
import { supportedTools, toolsRequireConfig } from "./helpers/tools";
@@ -97,6 +97,7 @@ const getVectorDbChoices = (framework: TemplateFramework) => {
{ title: "Milvus", value: "milvus" },
{ title: "Astra", value: "astra" },
{ title: "Qdrant", value: "qdrant" },
{ title: "ChromaDB", value: "chroma" },
];
const vectordbLang = framework === "fastapi" ? "python" : "typescript";
@@ -257,7 +258,8 @@ export const askQuestions = async (
},
];
const modelConfigured = isModelConfigured(program.modelConfig);
const modelConfigured =
!program.llamapack && program.modelConfig.isConfigured();
// If using LlamaParse, require LlamaCloud API key
const llamaCloudKeyConfigured = program.useLlamaParse
? program.llamaCloudKey || process.env["LLAMA_CLOUD_API_KEY"]
@@ -268,8 +270,7 @@ export const askQuestions = async (
!hasVectorDb &&
modelConfigured &&
llamaCloudKeyConfigured &&
!toolsRequireConfig(program.tools) &&
!program.llamapack
!toolsRequireConfig(program.tools)
) {
actionChoices.push({
title:
@@ -398,7 +399,6 @@ export const askQuestions = async (
if (program.framework === "express" || program.framework === "fastapi") {
// if a backend-only framework is selected, ask whether we should create a frontend
// (only for streaming backends)
if (program.frontend === undefined) {
if (ciInfo.isCI) {
program.frontend = getPrefOrDefault("frontend");
@@ -1,35 +0,0 @@
import os
import yaml
import importlib
from llama_index.core.tools.tool_spec.base import BaseToolSpec
from llama_index.core.tools.function_tool import FunctionTool
class ToolFactory:
@staticmethod
def create_tool(tool_name: str, **kwargs) -> list[FunctionTool]:
try:
tool_package, tool_cls_name = tool_name.split(".")
module_name = f"llama_index.tools.{tool_package}"
module = importlib.import_module(module_name)
tool_class = getattr(module, tool_cls_name)
tool_spec: BaseToolSpec = tool_class(**kwargs)
return tool_spec.to_tool_list()
except (ImportError, AttributeError) as e:
raise ValueError(f"Unsupported tool: {tool_name}") from e
except TypeError as e:
raise ValueError(
f"Could not create tool: {tool_name}. With config: {kwargs}"
) from e
@staticmethod
def from_env() -> list[FunctionTool]:
tools = []
if os.path.exists("config/tools.yaml"):
with open("config/tools.yaml", "r") as f:
tool_configs = yaml.safe_load(f)
for name, config in tool_configs.items():
tools += ToolFactory.create_tool(name, **config)
return tools
@@ -0,0 +1,56 @@
import os
import yaml
import importlib
from llama_index.core.tools.tool_spec.base import BaseToolSpec
from llama_index.core.tools.function_tool import FunctionTool
class ToolType:
LLAMAHUB = "llamahub"
LOCAL = "local"
class ToolFactory:
TOOL_SOURCE_PACKAGE_MAP = {
ToolType.LLAMAHUB: "llama_index.tools",
ToolType.LOCAL: "app.engine.tools",
}
@staticmethod
def load_tools(tool_type: str, tool_name: str, config: dict) -> list[FunctionTool]:
source_package = ToolFactory.TOOL_SOURCE_PACKAGE_MAP[tool_type]
try:
if "ToolSpec" in tool_name:
tool_package, tool_cls_name = tool_name.split(".")
module_name = f"{source_package}.{tool_package}"
module = importlib.import_module(module_name)
tool_class = getattr(module, tool_cls_name)
tool_spec: BaseToolSpec = tool_class(**config)
return tool_spec.to_tool_list()
else:
module = importlib.import_module(f"{source_package}.{tool_name}")
tools = getattr(module, "tools")
if not all(isinstance(tool, FunctionTool) for tool in tools):
raise ValueError(
f"The module {module} does not contain valid tools"
)
return tools
except ImportError as e:
raise ValueError(f"Failed to import tool {tool_name}: {e}")
except AttributeError as e:
raise ValueError(f"Failed to load tool {tool_name}: {e}")
@staticmethod
def from_env() -> list[FunctionTool]:
tools = []
if os.path.exists("config/tools.yaml"):
with open("config/tools.yaml", "r") as f:
tool_configs = yaml.safe_load(f)
for tool_type, config_entries in tool_configs.items():
for tool_name, config in config_entries.items():
tools.extend(
ToolFactory.load_tools(tool_type, tool_name, config)
)
return tools
@@ -0,0 +1,72 @@
"""Open Meteo weather map tool spec."""
import logging
import requests
import pytz
from llama_index.core.tools import FunctionTool
logger = logging.getLogger(__name__)
class OpenMeteoWeather:
geo_api = "https://geocoding-api.open-meteo.com/v1"
weather_api = "https://api.open-meteo.com/v1"
@classmethod
def _get_geo_location(cls, location: str) -> dict:
"""Get geo location from location name."""
params = {"name": location, "count": 10, "language": "en", "format": "json"}
response = requests.get(f"{cls.geo_api}/search", params=params)
if response.status_code != 200:
raise Exception(f"Failed to fetch geo location: {response.status_code}")
else:
data = response.json()
result = data["results"][0]
geo_location = {
"id": result["id"],
"name": result["name"],
"latitude": result["latitude"],
"longitude": result["longitude"],
}
return geo_location
@classmethod
def get_weather_information(cls, location: str) -> dict:
"""Use this function to get the weather of any given location.
Note that the weather code should follow WMO Weather interpretation codes (WW):
0: Clear sky
1, 2, 3: Mainly clear, partly cloudy, and overcast
45, 48: Fog and depositing rime fog
51, 53, 55: Drizzle: Light, moderate, and dense intensity
56, 57: Freezing Drizzle: Light and dense intensity
61, 63, 65: Rain: Slight, moderate and heavy intensity
66, 67: Freezing Rain: Light and heavy intensity
71, 73, 75: Snow fall: Slight, moderate, and heavy intensity
77: Snow grains
80, 81, 82: Rain showers: Slight, moderate, and violent
85, 86: Snow showers slight and heavy
95: Thunderstorm: Slight or moderate
96, 99: Thunderstorm with slight and heavy hail
"""
logger.info(
f"Calling open-meteo api to get weather information of location: {location}"
)
geo_location = cls._get_geo_location(location)
timezone = pytz.timezone("UTC").zone
params = {
"latitude": geo_location["latitude"],
"longitude": geo_location["longitude"],
"current": "temperature_2m,weather_code",
"hourly": "temperature_2m,weather_code",
"daily": "weather_code",
"timezone": timezone,
}
response = requests.get(f"{cls.weather_api}/forecast", params=params)
if response.status_code != 200:
raise Exception(
f"Failed to fetch weather information: {response.status_code}"
)
return response.json()
tools = [FunctionTool.from_defaults(OpenMeteoWeather.get_weather_information)]
@@ -1,12 +1,13 @@
import { BaseTool, OpenAIAgent, QueryEngineTool } from "llamaindex";
import { BaseToolWithCall, OpenAIAgent, QueryEngineTool } from "llamaindex";
import { ToolsFactory } from "llamaindex/tools/ToolsFactory";
import fs from "node:fs/promises";
import path from "node:path";
import { getDataSource } from "./index";
import { STORAGE_CACHE_DIR } from "./shared";
import { createLocalTools } from "./tools";
export async function createChatEngine() {
let tools: BaseTool[] = [];
const tools: BaseToolWithCall[] = [];
// Add a query engine tool if we have a data source
// Delete this code if you don't have a data source
@@ -28,10 +29,18 @@ export async function createChatEngine() {
const config = JSON.parse(
await fs.readFile(path.join("config", "tools.json"), "utf8"),
);
tools = tools.concat(await ToolsFactory.createTools(config));
// add local tools from the 'tools' folder (if configured)
const localTools = createLocalTools(config.local);
tools.push(...localTools);
// add tools from LlamaIndexTS (if configured)
const llamaTools = await ToolsFactory.createTools(config.llamahub);
tools.push(...llamaTools);
} catch {}
return new OpenAIAgent({
tools,
systemPrompt: process.env.SYSTEM_PROMPT,
});
}
@@ -0,0 +1,26 @@
import { BaseToolWithCall } from "llamaindex";
import { WeatherTool, WeatherToolParams } from "./weather";
type ToolCreator = (config: unknown) => BaseToolWithCall;
const toolFactory: Record<string, ToolCreator> = {
weather: (config: unknown) => {
return new WeatherTool(config as WeatherToolParams);
},
};
export function createLocalTools(
localConfig: Record<string, unknown>,
): BaseToolWithCall[] {
const tools: BaseToolWithCall[] = [];
Object.keys(localConfig).forEach((key) => {
if (key in toolFactory) {
const toolConfig = localConfig[key];
const tool = toolFactory[key](toolConfig);
tools.push(tool);
}
});
return tools;
}
@@ -0,0 +1,81 @@
import type { JSONSchemaType } from "ajv";
import { BaseTool, ToolMetadata } from "llamaindex";
interface GeoLocation {
id: string;
name: string;
latitude: number;
longitude: number;
}
export type WeatherParameter = {
location: string;
};
export type WeatherToolParams = {
metadata?: ToolMetadata<JSONSchemaType<WeatherParameter>>;
};
const DEFAULT_META_DATA: ToolMetadata<JSONSchemaType<WeatherParameter>> = {
name: "get_weather_information",
description: `
Use this function to get the weather of any given location.
Note that the weather code should follow WMO Weather interpretation codes (WW):
0: Clear sky
1, 2, 3: Mainly clear, partly cloudy, and overcast
45, 48: Fog and depositing rime fog
51, 53, 55: Drizzle: Light, moderate, and dense intensity
56, 57: Freezing Drizzle: Light and dense intensity
61, 63, 65: Rain: Slight, moderate and heavy intensity
66, 67: Freezing Rain: Light and heavy intensity
71, 73, 75: Snow fall: Slight, moderate, and heavy intensity
77: Snow grains
80, 81, 82: Rain showers: Slight, moderate, and violent
85, 86: Snow showers slight and heavy
95: Thunderstorm: Slight or moderate
96, 99: Thunderstorm with slight and heavy hail
`,
parameters: {
type: "object",
properties: {
location: {
type: "string",
description: "The location to get the weather information",
},
},
required: ["location"],
},
};
export class WeatherTool implements BaseTool<WeatherParameter> {
metadata: ToolMetadata<JSONSchemaType<WeatherParameter>>;
private getGeoLocation = async (location: string): Promise<GeoLocation> => {
const apiUrl = `https://geocoding-api.open-meteo.com/v1/search?name=${location}&count=10&language=en&format=json`;
const response = await fetch(apiUrl);
const data = await response.json();
const { id, name, latitude, longitude } = data.results[0];
return { id, name, latitude, longitude };
};
private getWeatherByLocation = async (location: string) => {
console.log(
"Calling open-meteo api to get weather information of location:",
location,
);
const { latitude, longitude } = await this.getGeoLocation(location);
const timezone = Intl.DateTimeFormat().resolvedOptions().timeZone;
const apiUrl = `https://api.open-meteo.com/v1/forecast?latitude=${latitude}&longitude=${longitude}&current=temperature_2m,weather_code&hourly=temperature_2m,weather_code&daily=weather_code&timezone=${timezone}`;
const response = await fetch(apiUrl);
const data = await response.json();
return data;
};
constructor(params?: WeatherToolParams) {
this.metadata = params?.metadata || DEFAULT_META_DATA;
}
async call(input: WeatherParameter) {
return await this.getWeatherByLocation(input.location);
}
}
@@ -16,5 +16,6 @@ export async function createChatEngine() {
return new ContextChatEngine({
chatModel: Settings.llm,
retriever,
systemPrompt: process.env.SYSTEM_PROMPT,
});
}
+28 -5
View File
@@ -1,7 +1,10 @@
import os
import logging
from llama_parse import LlamaParse
from pydantic import BaseModel, validator
logger = logging.getLogger(__name__)
class FileLoaderConfig(BaseModel):
data_dir: str = "data"
@@ -27,8 +30,28 @@ def llama_parse_parser():
def get_file_documents(config: FileLoaderConfig):
from llama_index.core.readers import SimpleDirectoryReader
reader = SimpleDirectoryReader(config.data_dir, recursive=True, filename_as_id=True)
if config.use_llama_parse:
parser = llama_parse_parser()
reader.file_extractor = {".pdf": parser}
return reader.load_data()
try:
reader = SimpleDirectoryReader(
config.data_dir,
recursive=True,
filename_as_id=True,
)
if config.use_llama_parse:
parser = llama_parse_parser()
reader.file_extractor = {".pdf": parser}
return reader.load_data()
except ValueError as e:
import sys, traceback
# Catch the error if the data dir is empty
# and return as empty document list
_, _, exc_traceback = sys.exc_info()
function_name = traceback.extract_tb(exc_traceback)[-1].name
if function_name == "_add_files":
logger.warning(
f"Failed to load file documents, error message: {e} . Return as empty document list."
)
return []
else:
# Raise the error if it is not the case of empty data dir
raise e
@@ -3,10 +3,18 @@ from llama_index.vector_stores.astra_db import AstraDBVectorStore
def get_vector_store():
endpoint = os.getenv("ASTRA_DB_ENDPOINT")
token = os.getenv("ASTRA_DB_APPLICATION_TOKEN")
collection = os.getenv("ASTRA_DB_COLLECTION")
if not endpoint or not token or not collection:
raise ValueError(
"Please config ASTRA_DB_ENDPOINT, ASTRA_DB_APPLICATION_TOKEN and ASTRA_DB_COLLECTION"
" to your environment variables or config them in the .env file"
)
store = AstraDBVectorStore(
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_ENDPOINT"],
collection_name=os.environ["ASTRA_DB_COLLECTION"],
embedding_dimension=int(os.environ["EMBEDDING_DIM"]),
token=token,
api_endpoint=endpoint,
collection_name=collection,
embedding_dimension=int(os.getenv("EMBEDDING_DIM")),
)
return store
@@ -0,0 +1,24 @@
import os
from llama_index.vector_stores.chroma import ChromaVectorStore
def get_vector_store():
collection_name = os.getenv("CHROMA_COLLECTION", "default")
chroma_path = os.getenv("CHROMA_PATH")
# if CHROMA_PATH is set, use a local ChromaVectorStore from the path
# otherwise, use a remote ChromaVectorStore (ChromaDB Cloud is not supported yet)
if chroma_path:
store = ChromaVectorStore.from_params(
persist_dir=chroma_path, collection_name=collection_name
)
else:
if not os.getenv("CHROMA_HOST") or not os.getenv("CHROMA_PORT"):
raise ValueError(
"Please provide either CHROMA_PATH or CHROMA_HOST and CHROMA_PORT"
)
store = ChromaVectorStore.from_params(
host=os.getenv("CHROMA_HOST"),
port=int(os.getenv("CHROMA_PORT")),
collection_name=collection_name,
)
return store
@@ -3,11 +3,18 @@ from llama_index.vector_stores.milvus import MilvusVectorStore
def get_vector_store():
address = os.getenv("MILVUS_ADDRESS")
collection = os.getenv("MILVUS_COLLECTION")
if not address or not collection:
raise ValueError(
"Please set MILVUS_ADDRESS and MILVUS_COLLECTION to your environment variables"
" or config them in the .env file"
)
store = MilvusVectorStore(
uri=os.environ["MILVUS_ADDRESS"],
uri=address,
user=os.getenv("MILVUS_USERNAME"),
password=os.getenv("MILVUS_PASSWORD"),
collection_name=os.getenv("MILVUS_COLLECTION"),
collection_name=collection,
dim=int(os.getenv("EMBEDDING_DIM")),
)
return store
@@ -3,9 +3,18 @@ from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch
def get_vector_store():
db_uri = os.getenv("MONGODB_URI")
db_name = os.getenv("MONGODB_DATABASE")
collection_name = os.getenv("MONGODB_VECTORS")
index_name = os.getenv("MONGODB_VECTOR_INDEX")
if not db_uri or not db_name or not collection_name or not index_name:
raise ValueError(
"Please set MONGODB_URI, MONGODB_DATABASE, MONGODB_VECTORS, and MONGODB_VECTOR_INDEX"
" to your environment variables or config them in .env file"
)
store = MongoDBAtlasVectorSearch(
db_name=os.environ["MONGODB_DATABASE"],
collection_name=os.environ["MONGODB_VECTORS"],
index_name=os.environ["MONGODB_VECTOR_INDEX"],
db_name=db_name,
collection_name=collection_name,
index_name=index_name,
)
return store
@@ -0,0 +1,33 @@
from dotenv import load_dotenv
load_dotenv()
import os
import logging
from llama_index.core.indices import (
VectorStoreIndex,
)
from app.engine.loaders import get_documents
from app.settings import init_settings
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()
def generate_datasource():
init_settings()
logger.info("Creating new index")
storage_dir = os.environ.get("STORAGE_DIR", "storage")
# load the documents and create the index
documents = get_documents()
index = VectorStoreIndex.from_documents(
documents,
)
# store it for later
index.storage_context.persist(storage_dir)
logger.info(f"Finished creating new index. Stored in {storage_dir}")
if __name__ == "__main__":
generate_datasource()
@@ -0,0 +1,30 @@
import os
import logging
from datetime import timedelta
from cachetools import cached, TTLCache
from llama_index.core.storage import StorageContext
from llama_index.core.indices import load_index_from_storage
logger = logging.getLogger("uvicorn")
@cached(
TTLCache(maxsize=10, ttl=timedelta(minutes=5).total_seconds()),
key=lambda *args, **kwargs: "global_storage_context",
)
def get_storage_context(persist_dir: str) -> StorageContext:
return StorageContext.from_defaults(persist_dir=persist_dir)
def get_index():
storage_dir = os.getenv("STORAGE_DIR", "storage")
# check if storage already exists
if not os.path.exists(storage_dir):
return None
# load the existing index
logger.info(f"Loading index from {storage_dir}...")
storage_context = get_storage_context(storage_dir)
index = load_index_from_storage(storage_context)
logger.info(f"Finished loading index from {storage_dir}")
return index
@@ -1,16 +0,0 @@
import os
from llama_index.core.vector_stores import SimpleVectorStore
from app.constants import STORAGE_DIR
def get_vector_store():
if not os.path.exists(STORAGE_DIR):
# Vector store hasn't been persisted before, create a new one
vector_store = SimpleVectorStore()
else:
# Vector store has already been persisted before at STORAGE_DIR - load it
vector_store = SimpleVectorStore.from_persist_dir(
STORAGE_DIR, namespace="default"
)
return vector_store
@@ -2,30 +2,36 @@ import os
from llama_index.vector_stores.postgres import PGVectorStore
from urllib.parse import urlparse
STORAGE_DIR = "storage"
PGVECTOR_SCHEMA = "public"
PGVECTOR_TABLE = "llamaindex_embedding"
vector_store: PGVectorStore = None
def get_vector_store():
original_conn_string = os.environ.get("PG_CONNECTION_STRING")
if original_conn_string is None or original_conn_string == "":
raise ValueError("PG_CONNECTION_STRING environment variable is not set.")
global vector_store
# The PGVectorStore requires both two connection strings, one for psycopg2 and one for asyncpg
# Update the configured scheme with the psycopg2 and asyncpg schemes
original_scheme = urlparse(original_conn_string).scheme + "://"
conn_string = original_conn_string.replace(
original_scheme, "postgresql+psycopg2://"
)
async_conn_string = original_conn_string.replace(
original_scheme, "postgresql+asyncpg://"
)
if vector_store is None:
original_conn_string = os.environ.get("PG_CONNECTION_STRING")
if original_conn_string is None or original_conn_string == "":
raise ValueError("PG_CONNECTION_STRING environment variable is not set.")
return PGVectorStore(
connection_string=conn_string,
async_connection_string=async_conn_string,
schema_name=PGVECTOR_SCHEMA,
table_name=PGVECTOR_TABLE,
embed_dim=int(os.environ.get("EMBEDDING_DIM", 768)),
)
# The PGVectorStore requires both two connection strings, one for psycopg2 and one for asyncpg
# Update the configured scheme with the psycopg2 and asyncpg schemes
original_scheme = urlparse(original_conn_string).scheme + "://"
conn_string = original_conn_string.replace(
original_scheme, "postgresql+psycopg2://"
)
async_conn_string = original_conn_string.replace(
original_scheme, "postgresql+asyncpg://"
)
vector_store = PGVectorStore(
connection_string=conn_string,
async_connection_string=async_conn_string,
schema_name=PGVECTOR_SCHEMA,
table_name=PGVECTOR_TABLE,
embed_dim=int(os.environ.get("EMBEDDING_DIM", 1024)),
)
return vector_store
@@ -3,9 +3,17 @@ from llama_index.vector_stores.pinecone import PineconeVectorStore
def get_vector_store():
api_key = os.getenv("PINECONE_API_KEY")
index_name = os.getenv("PINECONE_INDEX_NAME")
environment = os.getenv("PINECONE_ENVIRONMENT")
if not api_key or not index_name or not environment:
raise ValueError(
"Please set PINECONE_API_KEY, PINECONE_INDEX_NAME, and PINECONE_ENVIRONMENT"
" to your environment variables or config them in the .env file"
)
store = PineconeVectorStore(
api_key=os.environ["PINECONE_API_KEY"],
index_name=os.environ["PINECONE_INDEX_NAME"],
environment=os.environ["PINECONE_ENVIRONMENT"],
api_key=api_key,
index_name=index_name,
environment=environment,
)
return store
@@ -3,9 +3,17 @@ from llama_index.vector_stores.qdrant import QdrantVectorStore
def get_vector_store():
collection_name = os.getenv("QDRANT_COLLECTION")
url = os.getenv("QDRANT_URL")
api_key = os.getenv("QDRANT_API_KEY")
if not collection_name or not url:
raise ValueError(
"Please set QDRANT_COLLECTION, QDRANT_URL"
" to your environment variables or config them in the .env file"
)
store = QdrantVectorStore(
collection_name=os.getenv("QDRANT_COLLECTION"),
url=os.getenv("QDRANT_URL"),
api_key=os.getenv("QDRANT_API_KEY"),
collection_name=collection_name,
url=url,
api_key=api_key,
)
return store
@@ -1,10 +1,7 @@
/* eslint-disable turbo/no-undeclared-env-vars */
import * as dotenv from "dotenv";
import {
AstraDBVectorStore,
VectorStoreIndex,
storageContextFromDefaults,
} from "llamaindex";
import { VectorStoreIndex, storageContextFromDefaults } from "llamaindex";
import { AstraDBVectorStore } from "llamaindex/storage/vectorStore/AstraDBVectorStore";
import { getDocuments } from "./loader";
import { initSettings } from "./settings";
import { checkRequiredEnvVars } from "./shared";
@@ -1,5 +1,6 @@
/* eslint-disable turbo/no-undeclared-env-vars */
import { AstraDBVectorStore, VectorStoreIndex } from "llamaindex";
import { VectorStoreIndex } from "llamaindex";
import { AstraDBVectorStore } from "llamaindex/storage/vectorStore/AstraDBVectorStore";
import { checkRequiredEnvVars } from "./shared";
export async function getDataSource() {
@@ -0,0 +1,37 @@
/* eslint-disable turbo/no-undeclared-env-vars */
import * as dotenv from "dotenv";
import { VectorStoreIndex, storageContextFromDefaults } from "llamaindex";
import { ChromaVectorStore } from "llamaindex/storage/vectorStore/ChromaVectorStore";
import { getDocuments } from "./loader";
import { initSettings } from "./settings";
import { checkRequiredEnvVars } from "./shared";
dotenv.config();
async function loadAndIndex() {
// load objects from storage and convert them into LlamaIndex Document objects
const documents = await getDocuments();
// create vector store
const chromaUri = `http://${process.env.CHROMA_HOST}:${process.env.CHROMA_PORT}`;
const vectorStore = new ChromaVectorStore({
collectionName: process.env.CHROMA_COLLECTION,
chromaClientParams: { path: chromaUri },
});
// create index from all the Documentss and store them in Pinecone
console.log("Start creating embeddings...");
const storageContext = await storageContextFromDefaults({ vectorStore });
await VectorStoreIndex.fromDocuments(documents, { storageContext });
console.log(
"Successfully created embeddings and save to your ChromaDB index.",
);
}
(async () => {
checkRequiredEnvVars();
initSettings();
await loadAndIndex();
console.log("Finished generating storage.");
})();
@@ -0,0 +1,16 @@
/* eslint-disable turbo/no-undeclared-env-vars */
import { VectorStoreIndex } from "llamaindex";
import { ChromaVectorStore } from "llamaindex/storage/vectorStore/ChromaVectorStore";
import { checkRequiredEnvVars } from "./shared";
export async function getDataSource() {
checkRequiredEnvVars();
const chromaUri = `http://${process.env.CHROMA_HOST}:${process.env.CHROMA_PORT}`;
const store = new ChromaVectorStore({
collectionName: process.env.CHROMA_COLLECTION,
chromaClientParams: { path: chromaUri },
});
return await VectorStoreIndex.fromVectorStore(store);
}
@@ -0,0 +1,18 @@
const REQUIRED_ENV_VARS = ["CHROMA_COLLECTION", "CHROMA_HOST", "CHROMA_PORT"];
export function checkRequiredEnvVars() {
const missingEnvVars = REQUIRED_ENV_VARS.filter((envVar) => {
return !process.env[envVar];
});
if (missingEnvVars.length > 0) {
console.log(
`The following environment variables are required but missing: ${missingEnvVars.join(
", ",
)}`,
);
throw new Error(
`Missing environment variables: ${missingEnvVars.join(", ")}`,
);
}
}
@@ -1,10 +1,7 @@
/* eslint-disable turbo/no-undeclared-env-vars */
import * as dotenv from "dotenv";
import {
MilvusVectorStore,
VectorStoreIndex,
storageContextFromDefaults,
} from "llamaindex";
import { VectorStoreIndex, storageContextFromDefaults } from "llamaindex";
import { MilvusVectorStore } from "llamaindex/storage/vectorStore/MilvusVectorStore";
import { getDocuments } from "./loader";
import { initSettings } from "./settings";
import { checkRequiredEnvVars, getMilvusClient } from "./shared";
@@ -1,4 +1,5 @@
import { MilvusVectorStore, VectorStoreIndex } from "llamaindex";
import { VectorStoreIndex } from "llamaindex";
import { MilvusVectorStore } from "llamaindex/storage/vectorStore/MilvusVectorStore";
import { checkRequiredEnvVars, getMilvusClient } from "./shared";
export async function getDataSource() {
@@ -1,10 +1,7 @@
/* eslint-disable turbo/no-undeclared-env-vars */
import * as dotenv from "dotenv";
import {
MongoDBAtlasVectorSearch,
VectorStoreIndex,
storageContextFromDefaults,
} from "llamaindex";
import { VectorStoreIndex, storageContextFromDefaults } from "llamaindex";
import { MongoDBAtlasVectorSearch } from "llamaindex/storage/vectorStore/MongoDBAtlasVectorSearch";
import { MongoClient } from "mongodb";
import { getDocuments } from "./loader";
import { initSettings } from "./settings";
@@ -12,7 +9,7 @@ import { checkRequiredEnvVars } from "./shared";
dotenv.config();
const mongoUri = process.env.MONGO_URI!;
const mongoUri = process.env.MONGODB_URI!;
const databaseName = process.env.MONGODB_DATABASE!;
const vectorCollectionName = process.env.MONGODB_VECTORS!;
const indexName = process.env.MONGODB_VECTOR_INDEX;
@@ -1,5 +1,6 @@
/* eslint-disable turbo/no-undeclared-env-vars */
import { MongoDBAtlasVectorSearch, VectorStoreIndex } from "llamaindex";
import { VectorStoreIndex } from "llamaindex";
import { MongoDBAtlasVectorSearch } from "llamaindex/storage/vectorStore/MongoDBAtlasVectorSearch";
import { MongoClient } from "mongodb";
import { checkRequiredEnvVars } from "./shared";
@@ -1,5 +1,5 @@
const REQUIRED_ENV_VARS = [
"MONGO_URI",
"MONGODB_URI",
"MONGODB_DATABASE",
"MONGODB_VECTORS",
"MONGODB_VECTOR_INDEX",
@@ -1,4 +1,5 @@
import { VectorStoreIndex, storageContextFromDefaults } from "llamaindex";
import { VectorStoreIndex } from "llamaindex";
import { storageContextFromDefaults } from "llamaindex/storage/StorageContext";
import * as dotenv from "dotenv";
@@ -1,8 +1,5 @@
import {
SimpleDocumentStore,
storageContextFromDefaults,
VectorStoreIndex,
} from "llamaindex";
import { SimpleDocumentStore, VectorStoreIndex } from "llamaindex";
import { storageContextFromDefaults } from "llamaindex/storage/StorageContext";
import { STORAGE_CACHE_DIR } from "./shared";
export async function getDataSource() {
@@ -1,10 +1,7 @@
/* eslint-disable turbo/no-undeclared-env-vars */
import * as dotenv from "dotenv";
import {
PGVectorStore,
VectorStoreIndex,
storageContextFromDefaults,
} from "llamaindex";
import { VectorStoreIndex, storageContextFromDefaults } from "llamaindex";
import { PGVectorStore } from "llamaindex/storage/vectorStore/PGVectorStore";
import { getDocuments } from "./loader";
import { initSettings } from "./settings";
import {
@@ -1,5 +1,6 @@
/* eslint-disable turbo/no-undeclared-env-vars */
import { PGVectorStore, VectorStoreIndex } from "llamaindex";
import { VectorStoreIndex } from "llamaindex";
import { PGVectorStore } from "llamaindex/storage/vectorStore/PGVectorStore";
import {
PGVECTOR_SCHEMA,
PGVECTOR_TABLE,
@@ -1,10 +1,7 @@
/* eslint-disable turbo/no-undeclared-env-vars */
import * as dotenv from "dotenv";
import {
PineconeVectorStore,
VectorStoreIndex,
storageContextFromDefaults,
} from "llamaindex";
import { VectorStoreIndex, storageContextFromDefaults } from "llamaindex";
import { PineconeVectorStore } from "llamaindex/storage/vectorStore/PineconeVectorStore";
import { getDocuments } from "./loader";
import { initSettings } from "./settings";
import { checkRequiredEnvVars } from "./shared";
@@ -1,5 +1,6 @@
/* eslint-disable turbo/no-undeclared-env-vars */
import { PineconeVectorStore, VectorStoreIndex } from "llamaindex";
import { VectorStoreIndex } from "llamaindex";
import { PineconeVectorStore } from "llamaindex/storage/vectorStore/PineconeVectorStore";
import { checkRequiredEnvVars } from "./shared";
export async function getDataSource() {
@@ -1,10 +1,7 @@
/* eslint-disable turbo/no-undeclared-env-vars */
import * as dotenv from "dotenv";
import {
QdrantVectorStore,
VectorStoreIndex,
storageContextFromDefaults,
} from "llamaindex";
import { VectorStoreIndex, storageContextFromDefaults } from "llamaindex";
import { QdrantVectorStore } from "llamaindex/storage/vectorStore/QdrantVectorStore";
import { getDocuments } from "./loader";
import { initSettings } from "./settings";
import { checkRequiredEnvVars, getQdrantClient } from "./shared";
@@ -18,7 +15,10 @@ async function loadAndIndex() {
const documents = await getDocuments();
// Connect to Qdrant
const vectorStore = new QdrantVectorStore(collectionName, getQdrantClient());
const vectorStore = new QdrantVectorStore({
collectionName,
client: getQdrantClient(),
});
const storageContext = await storageContextFromDefaults({ vectorStore });
await VectorStoreIndex.fromDocuments(documents, {
@@ -1,5 +1,6 @@
import * as dotenv from "dotenv";
import { QdrantVectorStore, VectorStoreIndex } from "llamaindex";
import { VectorStoreIndex } from "llamaindex";
import { QdrantVectorStore } from "llamaindex/storage/vectorStore/QdrantVectorStore";
import { checkRequiredEnvVars, getQdrantClient } from "./shared";
dotenv.config();
@@ -7,7 +8,10 @@ dotenv.config();
export async function getDataSource() {
checkRequiredEnvVars();
const collectionName = process.env.QDRANT_COLLECTION;
const store = new QdrantVectorStore(collectionName, getQdrantClient());
const store = new QdrantVectorStore({
collectionName,
client: getQdrantClient(),
});
return await VectorStoreIndex.fromVectorStore(store);
}
@@ -31,6 +31,7 @@ if (isDevelopment) {
console.warn("Production CORS origin not set, defaulting to no CORS.");
}
app.use("/api/data", express.static("data"));
app.use(express.text());
app.get("/", (req: Request, res: Response) => {
@@ -1,20 +1,22 @@
{
"name": "llama-index-express-streaming",
"version": "1.0.0",
"main": "dist/index.mjs",
"main": "dist/index.js",
"scripts": {
"format": "prettier --ignore-unknown --cache --check .",
"format:write": "prettier --ignore-unknown --write .",
"build": "tsup index.ts --format esm --dts",
"start": "node dist/index.mjs",
"dev": "concurrently \"tsup index.ts --format esm --dts --watch\" \"nodemon -q dist/index.mjs\""
"build": "tsup index.ts --format cjs --dts",
"start": "node dist/index.js",
"dev": "concurrently \"tsup index.ts --format cjs --dts --watch\" \"nodemon -q dist/index.js\""
},
"dependencies": {
"ai": "^3.0.21",
"cors": "^2.8.5",
"dotenv": "^16.3.1",
"express": "^4.18.2",
"llamaindex": "0.2.10"
"llamaindex": "0.3.13",
"pdf2json": "3.0.5",
"ajv": "^8.12.0"
},
"devDependencies": {
"@types/cors": "^2.8.16",
@@ -3,7 +3,7 @@ import { Request, Response } from "express";
import { ChatMessage, MessageContent, Settings } from "llamaindex";
import { createChatEngine } from "./engine/chat";
import { LlamaIndexStream } from "./llamaindex-stream";
import { appendEventData } from "./stream-helper";
import { createCallbackManager } from "./stream-helper";
const convertMessageContent = (
textMessage: string,
@@ -45,46 +45,27 @@ export const chat = async (req: Request, res: Response) => {
// Init Vercel AI StreamData
const vercelStreamData = new StreamData();
appendEventData(
vercelStreamData,
`Retrieving context for query: '${userMessage.content}'`,
);
// Setup callback for streaming data before chatting
Settings.callbackManager.on("retrieve", (data) => {
const { nodes } = data.detail;
appendEventData(
vercelStreamData,
`Retrieved ${nodes.length} sources to use as context for the query`,
);
});
// Setup callbacks
const callbackManager = createCallbackManager(vercelStreamData);
// Calling LlamaIndex's ChatEngine to get a streamed response
const response = await chatEngine.chat({
message: userMessageContent,
chatHistory: messages as ChatMessage[],
stream: true,
const response = await Settings.withCallbackManager(callbackManager, () => {
return chatEngine.chat({
message: userMessageContent,
chatHistory: messages as ChatMessage[],
stream: true,
});
});
// Return a stream, which can be consumed by the Vercel/AI client
const { stream } = LlamaIndexStream(response, vercelStreamData, {
const stream = LlamaIndexStream(response, vercelStreamData, {
parserOptions: {
image_url: data?.imageUrl,
},
});
// Pipe LlamaIndexStream to response
const processedStream = stream.pipeThrough(vercelStreamData.stream);
return streamToResponse(processedStream, res, {
headers: {
// response MUST have the `X-Experimental-Stream-Data: 'true'` header
// so that the client uses the correct parsing logic, see
// https://sdk.vercel.ai/docs/api-reference/stream-data#on-the-server
"X-Experimental-Stream-Data": "true",
"Content-Type": "text/plain; charset=utf-8",
"Access-Control-Expose-Headers": "X-Experimental-Stream-Data",
},
});
return streamToResponse(stream, res, {}, vercelStreamData);
} catch (error) {
console.error("[LlamaIndex]", error);
return res.status(500).json({
@@ -1,10 +1,17 @@
import {
Ollama,
OllamaEmbedding,
Anthropic,
GEMINI_EMBEDDING_MODEL,
GEMINI_MODEL,
Gemini,
GeminiEmbedding,
OpenAI,
OpenAIEmbedding,
Settings,
} from "llamaindex";
import { HuggingFaceEmbedding } from "llamaindex/embeddings/HuggingFaceEmbedding";
import { OllamaEmbedding } from "llamaindex/embeddings/OllamaEmbedding";
import { ALL_AVAILABLE_ANTHROPIC_MODELS } from "llamaindex/llm/anthropic";
import { Ollama } from "llamaindex/llm/ollama";
const CHUNK_SIZE = 512;
const CHUNK_OVERLAP = 20;
@@ -12,10 +19,21 @@ const CHUNK_OVERLAP = 20;
export const initSettings = async () => {
// HINT: you can delete the initialization code for unused model providers
console.log(`Using '${process.env.MODEL_PROVIDER}' model provider`);
if (!process.env.MODEL || !process.env.EMBEDDING_MODEL) {
throw new Error("'MODEL' and 'EMBEDDING_MODEL' env variables must be set.");
}
switch (process.env.MODEL_PROVIDER) {
case "ollama":
initOllama();
break;
case "anthropic":
initAnthropic();
break;
case "gemini":
initGemini();
break;
default:
initOpenAI();
break;
@@ -38,15 +56,38 @@ function initOpenAI() {
}
function initOllama() {
if (!process.env.MODEL || !process.env.EMBEDDING_MODEL) {
throw new Error(
"Using Ollama as model provider, 'MODEL' and 'EMBEDDING_MODEL' env variables must be set.",
);
}
const config = {
host: process.env.OLLAMA_BASE_URL ?? "http://127.0.0.1:11434",
};
Settings.llm = new Ollama({
model: process.env.MODEL ?? "",
config,
});
Settings.embedModel = new OllamaEmbedding({
model: process.env.EMBEDDING_MODEL ?? "",
config,
});
}
function initAnthropic() {
const embedModelMap: Record<string, string> = {
"all-MiniLM-L6-v2": "Xenova/all-MiniLM-L6-v2",
"all-mpnet-base-v2": "Xenova/all-mpnet-base-v2",
};
Settings.llm = new Anthropic({
model: process.env.MODEL as keyof typeof ALL_AVAILABLE_ANTHROPIC_MODELS,
});
Settings.embedModel = new HuggingFaceEmbedding({
modelType: embedModelMap[process.env.EMBEDDING_MODEL!],
});
}
function initGemini() {
Settings.llm = new Gemini({
model: process.env.MODEL as GEMINI_MODEL,
});
Settings.embedModel = new GeminiEmbedding({
model: process.env.EMBEDDING_MODEL as GEMINI_EMBEDDING_MODEL,
});
}
@@ -9,16 +9,22 @@ import {
Metadata,
NodeWithScore,
Response,
StreamingAgentChatResponse,
ToolCallLLMMessageOptions,
} from "llamaindex";
import { AgentStreamChatResponse } from "llamaindex/agent/base";
import { appendImageData, appendSourceData } from "./stream-helper";
type LlamaIndexResponse =
| AgentStreamChatResponse<ToolCallLLMMessageOptions>
| Response;
type ParserOptions = {
image_url?: string;
};
function createParser(
res: AsyncIterable<Response>,
res: AsyncIterable<LlamaIndexResponse>,
data: StreamData,
opts?: ParserOptions,
) {
@@ -33,17 +39,27 @@ function createParser(
async pull(controller): Promise<void> {
const { value, done } = await it.next();
if (done) {
appendSourceData(data, sourceNodes);
if (sourceNodes) {
appendSourceData(data, sourceNodes);
}
controller.close();
data.close();
return;
}
if (!sourceNodes) {
// get source nodes from the first response
sourceNodes = value.sourceNodes;
let delta;
if (value instanceof Response) {
// handle Response type
if (value.sourceNodes) {
// get source nodes from the first response
sourceNodes = value.sourceNodes;
}
delta = value.response ?? "";
} else {
// handle other types
delta = value.response.delta;
}
const text = trimStartOfStream(value.response ?? "");
const text = trimStartOfStream(delta ?? "");
if (text) {
controller.enqueue(text);
}
@@ -52,21 +68,14 @@ function createParser(
}
export function LlamaIndexStream(
response: StreamingAgentChatResponse | AsyncIterable<Response>,
response: AsyncIterable<LlamaIndexResponse>,
data: StreamData,
opts?: {
callbacks?: AIStreamCallbacksAndOptions;
parserOptions?: ParserOptions;
},
): { stream: ReadableStream; data: StreamData } {
const res =
response instanceof StreamingAgentChatResponse
? response.response
: response;
return {
stream: createParser(res, data, opts?.parserOptions)
.pipeThrough(createCallbacksTransformer(opts?.callbacks))
.pipeThrough(createStreamDataTransformer()),
data,
};
): ReadableStream<Uint8Array> {
return createParser(response, data, opts?.parserOptions)
.pipeThrough(createCallbacksTransformer(opts?.callbacks))
.pipeThrough(createStreamDataTransformer());
}
@@ -1,5 +1,11 @@
import { StreamData } from "ai";
import { Metadata, NodeWithScore } from "llamaindex";
import {
CallbackManager,
Metadata,
NodeWithScore,
ToolCall,
ToolOutput,
} from "llamaindex";
export function appendImageData(data: StreamData, imageUrl?: string) {
if (!imageUrl) return;
@@ -37,3 +43,55 @@ export function appendEventData(data: StreamData, title?: string) {
},
});
}
export function appendToolData(
data: StreamData,
toolCall: ToolCall,
toolOutput: ToolOutput,
) {
data.appendMessageAnnotation({
type: "tools",
data: {
toolCall: {
id: toolCall.id,
name: toolCall.name,
input: toolCall.input,
},
toolOutput: {
output: toolOutput.output,
isError: toolOutput.isError,
},
},
});
}
export function createCallbackManager(stream: StreamData) {
const callbackManager = new CallbackManager();
callbackManager.on("retrieve", (data) => {
const { nodes, query } = data.detail;
appendEventData(stream, `Retrieving context for query: '${query}'`);
appendEventData(
stream,
`Retrieved ${nodes.length} sources to use as context for the query`,
);
});
callbackManager.on("llm-tool-call", (event) => {
const { name, input } = event.detail.payload.toolCall;
const inputString = Object.entries(input)
.map(([key, value]) => `${key}: ${value}`)
.join(", ");
appendEventData(
stream,
`Using tool: '${name}' with inputs: '${inputString}'`,
);
});
callbackManager.on("llm-tool-result", (event) => {
const { toolCall, toolResult } = event.detail.payload;
appendToolData(stream, toolCall, toolResult);
});
return callbackManager;
}
@@ -1,10 +1,7 @@
from pydantic import BaseModel
from typing import List, Any, Optional, Dict, Tuple
from fastapi import APIRouter, Depends, HTTPException, Request, status
from llama_index.core.chat_engine.types import (
BaseChatEngine,
StreamingAgentChatResponse,
)
from llama_index.core.chat_engine.types import BaseChatEngine
from llama_index.core.schema import NodeWithScore
from llama_index.core.llms import ChatMessage, MessageRole
from app.engine import get_chat_engine
@@ -109,12 +106,9 @@ async def chat(
# Yield the events from the event handler
async def _event_generator():
async for event in event_handler.async_event_gen():
yield VercelStreamResponse.convert_data(
{
"type": "events",
"data": {"title": event.get_title()},
}
)
event_response = event.to_response()
if event_response is not None:
yield VercelStreamResponse.convert_data(event_response)
combine = stream.merge(_text_generator(), _event_generator())
async with combine.stream() as streamer:
@@ -1,8 +1,9 @@
import json
import asyncio
from typing import AsyncGenerator, Dict, Any, List, Optional
from llama_index.core.callbacks.base import BaseCallbackHandler
from llama_index.core.callbacks.schema import CBEventType
from llama_index.core.tools.types import ToolOutput
from pydantic import BaseModel
@@ -11,19 +12,73 @@ class CallbackEvent(BaseModel):
payload: Optional[Dict[str, Any]] = None
event_id: str = ""
def get_title(self) -> str | None:
# Return as None for the unhandled event types
# to avoid showing them in the UI
def get_retrieval_message(self) -> dict | None:
if self.payload:
nodes = self.payload.get("nodes")
if nodes:
msg = f"Retrieved {len(nodes)} sources to use as context for the query"
else:
msg = f"Retrieving context for query: '{self.payload.get('query_str')}'"
return {
"type": "events",
"data": {"title": msg},
}
else:
return None
def get_tool_message(self) -> dict | None:
func_call_args = self.payload.get("function_call")
if func_call_args is not None and "tool" in self.payload:
tool = self.payload.get("tool")
return {
"type": "events",
"data": {
"title": f"Calling tool: {tool.name} with inputs: {func_call_args}",
},
}
def _is_output_serializable(self, output: Any) -> bool:
try:
json.dumps(output)
return True
except TypeError:
return False
def get_agent_tool_response(self) -> dict | None:
response = self.payload.get("response")
if response is not None:
sources = response.sources
for source in sources:
# Return the tool response here to include the toolCall information
if isinstance(source, ToolOutput):
if self._is_output_serializable(source.raw_output):
output = source.raw_output
else:
output = source.content
return {
"type": "tools",
"data": {
"toolOutput": {
"output": output,
"isError": source.is_error,
},
"toolCall": {
"id": None, # There is no tool id in the ToolOutput
"name": source.tool_name,
"input": source.raw_input,
},
},
}
def to_response(self):
match self.event_type:
case "retrieve":
if self.payload:
nodes = self.payload.get("nodes")
if nodes:
return f"Retrieved {len(nodes)} sources to use as context for the query"
else:
return f"Retrieving context for query: '{self.payload.get('query_str')}'"
else:
return None
return self.get_retrieval_message()
case "function_call":
return self.get_tool_message()
case "agent_step":
return self.get_agent_tool_response()
case _:
return None
@@ -54,7 +109,7 @@ class EventCallbackHandler(BaseCallbackHandler):
**kwargs: Any,
) -> str:
event = CallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
if event.get_title() is not None:
if event.to_response() is not None:
self._aqueue.put_nowait(event)
def on_event_end(
@@ -65,7 +120,7 @@ class EventCallbackHandler(BaseCallbackHandler):
**kwargs: Any,
) -> None:
event = CallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
if event.get_title() is not None:
if event.to_response() is not None:
self._aqueue.put_nowait(event)
def start_trace(self, trace_id: Optional[str] = None) -> None:
@@ -1 +0,0 @@
STORAGE_DIR = "storage" # directory to save the stores to (document store and if used, the `SimpleVectorStore`)
@@ -7,11 +7,8 @@ import logging
from llama_index.core.settings import Settings
from llama_index.core.ingestion import IngestionPipeline
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.vector_stores import SimpleVectorStore
from llama_index.core.storage.docstore import SimpleDocumentStore
from llama_index.core.storage import StorageContext
from llama_index.core import VectorStoreIndex
from app.constants import STORAGE_DIR
from app.settings import init_settings
from app.engine.loaders import get_documents
from app.engine.vectordb import get_vector_store
@@ -20,18 +17,21 @@ from app.engine.vectordb import get_vector_store
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()
STORAGE_DIR = os.getenv("STORAGE_DIR", "storage")
def get_doc_store():
if not os.path.exists(STORAGE_DIR):
docstore = SimpleDocumentStore()
return docstore
else:
# If the storage directory is there, load the document store from it.
# If not, set up an in-memory document store since we can't load from a directory that doesn't exist.
if os.path.exists(STORAGE_DIR):
return SimpleDocumentStore.from_persist_dir(STORAGE_DIR)
else:
return SimpleDocumentStore()
def run_ingestion_pipeline(docstore, vector_store, documents):
# Create ingestion pipeline
ingestion_pipeline = IngestionPipeline(
def run_pipeline(docstore, vector_store, documents):
pipeline = IngestionPipeline(
transformations=[
SentenceSplitter(
chunk_size=Settings.chunk_size,
@@ -41,32 +41,20 @@ def run_ingestion_pipeline(docstore, vector_store, documents):
],
docstore=docstore,
docstore_strategy="upserts_and_delete",
vector_store=vector_store,
)
# llama_index having an typing issue when passing vector_store to IngestionPipeline
# so we need to set it manually after initialization
ingestion_pipeline.vector_store = vector_store
# Run the ingestion pipeline and store the results
nodes = ingestion_pipeline.run(show_progress=True, documents=documents)
nodes = pipeline.run(show_progress=True, documents=documents)
return nodes
def persist_storage(docstore, vector_store, nodes):
def persist_storage(docstore, vector_store):
storage_context = StorageContext.from_defaults(
docstore=docstore,
vector_store=vector_store,
)
# SimpleVectorStore does not include index by default
# so we need to create the index manually
# can be removed if using other vector store
if isinstance(vector_store, SimpleVectorStore):
VectorStoreIndex(
nodes=nodes,
storage_context=storage_context,
store_nodes_override=True, # Need enable this to store the nodes and index's id
)
storage_context.persist(STORAGE_DIR)
@@ -80,14 +68,10 @@ def generate_datasource():
vector_store = get_vector_store()
# Run the ingestion pipeline
nodes = run_ingestion_pipeline(
docstore=docstore,
vector_store=vector_store,
documents=documents,
)
_ = run_pipeline(docstore, vector_store, documents)
# Build the index and persist storage
persist_storage(docstore, vector_store, nodes)
persist_storage(docstore, vector_store)
logger.info("Finished generating the index")
@@ -1,27 +1,17 @@
import logging
from llama_index.core import load_index_from_storage
from llama_index.core.storage import StorageContext
from llama_index.core.indices.vector_store import VectorStoreIndex
from llama_index.core.vector_stores.simple import SimpleVectorStore
from app.constants import STORAGE_DIR
from llama_index.core.indices import VectorStoreIndex
from app.engine.vectordb import get_vector_store
logger = logging.getLogger("uvicorn")
def get_index():
logger.info("Loading the index...")
logger.info("Connecting vector store...")
store = get_vector_store()
# If the store is a SimpleVectorStore, we need to load the index from the storage
if isinstance(store, SimpleVectorStore):
index = load_index_from_storage(
StorageContext.from_defaults(
vector_store=store,
persist_dir=STORAGE_DIR,
)
)
else:
index = VectorStoreIndex.from_vector_store(store)
logger.info("Loaded index successfully.")
# Load the index from the vector store
# If you are using a vector store that doesn't store text,
# you must load the index from both the vector store and the document store
index = VectorStoreIndex.from_vector_store(store)
logger.info("Finished load index from vector store.")
return index
@@ -9,6 +9,10 @@ def init_settings():
init_openai()
elif model_provider == "ollama":
init_ollama()
elif model_provider == "anthropic":
init_anthropic()
elif model_provider == "gemini":
init_gemini()
else:
raise ValueError(f"Invalid model provider: {model_provider}")
Settings.chunk_size = int(os.getenv("CHUNK_SIZE", "1024"))
@@ -19,8 +23,12 @@ def init_ollama():
from llama_index.llms.ollama import Ollama
from llama_index.embeddings.ollama import OllamaEmbedding
Settings.embed_model = OllamaEmbedding(model_name=os.getenv("EMBEDDING_MODEL"))
Settings.llm = Ollama(model=os.getenv("MODEL"))
base_url = os.getenv("OLLAMA_BASE_URL") or "http://127.0.0.1:11434"
Settings.embed_model = OllamaEmbedding(
base_url=base_url,
model_name=os.getenv("EMBEDDING_MODEL"),
)
Settings.llm = Ollama(base_url=base_url, model=os.getenv("MODEL"))
def init_openai():
@@ -42,3 +50,47 @@ def init_openai():
"dimensions": int(dimensions) if dimensions is not None else None,
}
Settings.embed_model = OpenAIEmbedding(**config)
def init_anthropic():
from llama_index.llms.anthropic import Anthropic
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
model_map: Dict[str, str] = {
"claude-3-opus": "claude-3-opus-20240229",
"claude-3-sonnet": "claude-3-sonnet-20240229",
"claude-3-haiku": "claude-3-haiku-20240307",
"claude-2.1": "claude-2.1",
"claude-instant-1.2": "claude-instant-1.2",
}
embed_model_map: Dict[str, str] = {
"all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2",
"all-mpnet-base-v2": "sentence-transformers/all-mpnet-base-v2",
}
Settings.llm = Anthropic(model=model_map[os.getenv("MODEL")])
Settings.embed_model = HuggingFaceEmbedding(
model_name=embed_model_map[os.getenv("EMBEDDING_MODEL")]
)
def init_gemini():
from llama_index.llms.gemini import Gemini
from llama_index.embeddings.gemini import GeminiEmbedding
model_map: Dict[str, str] = {
"gemini-1.5-pro-latest": "models/gemini-1.5-pro-latest",
"gemini-pro": "models/gemini-pro",
"gemini-pro-vision": "models/gemini-pro-vision",
}
embed_model_map: Dict[str, str] = {
"embedding-001": "models/embedding-001",
"text-embedding-004": "models/text-embedding-004",
}
Settings.llm = Gemini(model=model_map[os.getenv("MODEL")])
Settings.embed_model = GeminiEmbedding(
model_name=embed_model_map[os.getenv("EMBEDDING_MODEL")]
)
+3 -1
View File
@@ -11,6 +11,7 @@ from fastapi.responses import RedirectResponse
from app.api.routers.chat import chat_router
from app.settings import init_settings
from app.observability import init_observability
from fastapi.staticfiles import StaticFiles
app = FastAPI()
@@ -20,7 +21,6 @@ init_observability()
environment = os.getenv("ENVIRONMENT", "dev") # Default to 'development' if not set
if environment == "dev":
logger = logging.getLogger("uvicorn")
logger.warning("Running in development mode - allowing CORS for all origins")
@@ -38,6 +38,8 @@ if environment == "dev":
return RedirectResponse(url="/docs")
if os.path.exists("data"):
app.mount("/api/data", StaticFiles(directory="data"), name="static")
app.include_router(chat_router, prefix="/api/chat")
@@ -16,6 +16,7 @@ python-dotenv = "^1.0.0"
aiostream = "^0.5.2"
llama-index = "0.10.28"
llama-index-core = "0.10.28"
cachetools = "^5.3.3"
[build-system]
requires = ["poetry-core"]
@@ -1,10 +1,17 @@
import {
Ollama,
OllamaEmbedding,
Anthropic,
GEMINI_EMBEDDING_MODEL,
GEMINI_MODEL,
Gemini,
GeminiEmbedding,
OpenAI,
OpenAIEmbedding,
Settings,
} from "llamaindex";
import { HuggingFaceEmbedding } from "llamaindex/embeddings/HuggingFaceEmbedding";
import { OllamaEmbedding } from "llamaindex/embeddings/OllamaEmbedding";
import { ALL_AVAILABLE_ANTHROPIC_MODELS } from "llamaindex/llm/anthropic";
import { Ollama } from "llamaindex/llm/ollama";
const CHUNK_SIZE = 512;
const CHUNK_OVERLAP = 20;
@@ -12,10 +19,21 @@ const CHUNK_OVERLAP = 20;
export const initSettings = async () => {
// HINT: you can delete the initialization code for unused model providers
console.log(`Using '${process.env.MODEL_PROVIDER}' model provider`);
if (!process.env.MODEL || !process.env.EMBEDDING_MODEL) {
throw new Error("'MODEL' and 'EMBEDDING_MODEL' env variables must be set.");
}
switch (process.env.MODEL_PROVIDER) {
case "ollama":
initOllama();
break;
case "anthropic":
initAnthropic();
break;
case "gemini":
initGemini();
break;
default:
initOpenAI();
break;
@@ -38,15 +56,37 @@ function initOpenAI() {
}
function initOllama() {
if (!process.env.MODEL || !process.env.EMBEDDING_MODEL) {
throw new Error(
"Using Ollama as model provider, 'MODEL' and 'EMBEDDING_MODEL' env variables must be set.",
);
}
const config = {
host: process.env.OLLAMA_BASE_URL ?? "http://127.0.0.1:11434",
};
Settings.llm = new Ollama({
model: process.env.MODEL ?? "",
config,
});
Settings.embedModel = new OllamaEmbedding({
model: process.env.EMBEDDING_MODEL ?? "",
config,
});
}
function initAnthropic() {
const embedModelMap: Record<string, string> = {
"all-MiniLM-L6-v2": "Xenova/all-MiniLM-L6-v2",
"all-mpnet-base-v2": "Xenova/all-mpnet-base-v2",
};
Settings.llm = new Anthropic({
model: process.env.MODEL as keyof typeof ALL_AVAILABLE_ANTHROPIC_MODELS,
});
Settings.embedModel = new HuggingFaceEmbedding({
modelType: embedModelMap[process.env.EMBEDDING_MODEL!],
});
}
function initGemini() {
Settings.llm = new Gemini({
model: process.env.MODEL as GEMINI_MODEL,
});
Settings.embedModel = new GeminiEmbedding({
model: process.env.EMBEDDING_MODEL as GEMINI_EMBEDDING_MODEL,
});
}
@@ -9,16 +9,22 @@ import {
Metadata,
NodeWithScore,
Response,
StreamingAgentChatResponse,
ToolCallLLMMessageOptions,
} from "llamaindex";
import { AgentStreamChatResponse } from "llamaindex/agent/base";
import { appendImageData, appendSourceData } from "./stream-helper";
type LlamaIndexResponse =
| AgentStreamChatResponse<ToolCallLLMMessageOptions>
| Response;
type ParserOptions = {
image_url?: string;
};
function createParser(
res: AsyncIterable<Response>,
res: AsyncIterable<LlamaIndexResponse>,
data: StreamData,
opts?: ParserOptions,
) {
@@ -33,17 +39,27 @@ function createParser(
async pull(controller): Promise<void> {
const { value, done } = await it.next();
if (done) {
appendSourceData(data, sourceNodes);
if (sourceNodes) {
appendSourceData(data, sourceNodes);
}
controller.close();
data.close();
return;
}
if (!sourceNodes) {
// get source nodes from the first response
sourceNodes = value.sourceNodes;
let delta;
if (value instanceof Response) {
// handle Response type
if (value.sourceNodes) {
// get source nodes from the first response
sourceNodes = value.sourceNodes;
}
delta = value.response ?? "";
} else {
// handle other types
delta = value.response.delta;
}
const text = trimStartOfStream(value.response ?? "");
const text = trimStartOfStream(delta ?? "");
if (text) {
controller.enqueue(text);
}
@@ -52,21 +68,14 @@ function createParser(
}
export function LlamaIndexStream(
response: StreamingAgentChatResponse | AsyncIterable<Response>,
response: AsyncIterable<LlamaIndexResponse>,
data: StreamData,
opts?: {
callbacks?: AIStreamCallbacksAndOptions;
parserOptions?: ParserOptions;
},
): { stream: ReadableStream; data: StreamData } {
const res =
response instanceof StreamingAgentChatResponse
? response.response
: response;
return {
stream: createParser(res, data, opts?.parserOptions)
.pipeThrough(createCallbacksTransformer(opts?.callbacks))
.pipeThrough(createStreamDataTransformer()),
data,
};
): ReadableStream<Uint8Array> {
return createParser(response, data, opts?.parserOptions)
.pipeThrough(createCallbacksTransformer(opts?.callbacks))
.pipeThrough(createStreamDataTransformer());
}
@@ -5,7 +5,7 @@ import { NextRequest, NextResponse } from "next/server";
import { createChatEngine } from "./engine/chat";
import { initSettings } from "./engine/settings";
import { LlamaIndexStream } from "./llamaindex-stream";
import { appendEventData } from "./stream-helper";
import { createCallbackManager } from "./stream-helper";
initObservability();
initSettings();
@@ -57,29 +57,21 @@ export async function POST(request: NextRequest) {
// Init Vercel AI StreamData
const vercelStreamData = new StreamData();
appendEventData(
vercelStreamData,
`Retrieving context for query: '${userMessage.content}'`,
);
// Setup callback for streaming data before chatting
Settings.callbackManager.on("retrieve", (data) => {
const { nodes } = data.detail;
appendEventData(
vercelStreamData,
`Retrieved ${nodes.length} sources to use as context for the query`,
);
});
// Setup callbacks
const callbackManager = createCallbackManager(vercelStreamData);
// Calling LlamaIndex's ChatEngine to get a streamed response
const response = await chatEngine.chat({
message: userMessageContent,
chatHistory: messages as ChatMessage[],
stream: true,
const response = await Settings.withCallbackManager(callbackManager, () => {
return chatEngine.chat({
message: userMessageContent,
chatHistory: messages as ChatMessage[],
stream: true,
});
});
// Transform LlamaIndex stream to Vercel/AI format
const { stream } = LlamaIndexStream(response, vercelStreamData, {
const stream = LlamaIndexStream(response, vercelStreamData, {
parserOptions: {
image_url: data?.imageUrl,
},
@@ -1,5 +1,11 @@
import { StreamData } from "ai";
import { Metadata, NodeWithScore } from "llamaindex";
import {
CallbackManager,
Metadata,
NodeWithScore,
ToolCall,
ToolOutput,
} from "llamaindex";
export function appendImageData(data: StreamData, imageUrl?: string) {
if (!imageUrl) return;
@@ -37,3 +43,55 @@ export function appendEventData(data: StreamData, title?: string) {
},
});
}
export function appendToolData(
data: StreamData,
toolCall: ToolCall,
toolOutput: ToolOutput,
) {
data.appendMessageAnnotation({
type: "tools",
data: {
toolCall: {
id: toolCall.id,
name: toolCall.name,
input: toolCall.input,
},
toolOutput: {
output: toolOutput.output,
isError: toolOutput.isError,
},
},
});
}
export function createCallbackManager(stream: StreamData) {
const callbackManager = new CallbackManager();
callbackManager.on("retrieve", (data) => {
const { nodes, query } = data.detail;
appendEventData(stream, `Retrieving context for query: '${query}'`);
appendEventData(
stream,
`Retrieved ${nodes.length} sources to use as context for the query`,
);
});
callbackManager.on("llm-tool-call", (event) => {
const { name, input } = event.detail.payload.toolCall;
const inputString = Object.entries(input)
.map(([key, value]) => `${key}: ${value}`)
.join(", ");
appendEventData(
stream,
`Using tool: '${name}' with inputs: '${inputString}'`,
);
});
callbackManager.on("llm-tool-result", (event) => {
const { toolCall, toolResult } = event.detail.payload;
appendToolData(stream, toolCall, toolResult);
});
return callbackManager;
}
@@ -0,0 +1,38 @@
import { readFile } from "fs/promises";
import { NextRequest, NextResponse } from "next/server";
import path from "path";
/**
* This API is to get file data from ./data folder
* It receives path slug and response file data like serve static file
*/
export async function GET(
_request: NextRequest,
{ params }: { params: { path: string } },
) {
const slug = params.path;
if (!slug) {
return NextResponse.json({ detail: "Missing file slug" }, { status: 400 });
}
if (slug.includes("..") || path.isAbsolute(slug)) {
return NextResponse.json({ detail: "Invalid file path" }, { status: 400 });
}
try {
const filePath = path.join(process.cwd(), "data", slug);
const blob = await readFile(filePath);
return new NextResponse(blob, {
status: 200,
statusText: "OK",
headers: {
"Content-Length": blob.byteLength.toString(),
},
});
} catch (error) {
console.error(error);
return NextResponse.json({ detail: "File not found" }, { status: 404 });
}
}
@@ -17,7 +17,8 @@ export default function ChatSection() {
headers: {
"Content-Type": "application/json", // using JSON because of vercel/ai 2.2.26
},
onError: (error) => {
onError: (error: unknown) => {
if (!(error instanceof Error)) throw error;
const message = JSON.parse(error.message);
alert(message.detail);
},
@@ -7,6 +7,7 @@ import ChatAvatar from "./chat-avatar";
import { ChatEvents } from "./chat-events";
import { ChatImage } from "./chat-image";
import { ChatSources } from "./chat-sources";
import ChatTools from "./chat-tools";
import {
AnnotationData,
EventData,
@@ -14,6 +15,7 @@ import {
MessageAnnotation,
MessageAnnotationType,
SourceData,
ToolData,
} from "./index";
import Markdown from "./markdown";
import { useCopyToClipboard } from "./use-copy-to-clipboard";
@@ -52,19 +54,27 @@ function ChatMessageContent({
annotations,
MessageAnnotationType.SOURCES,
);
const toolData = getAnnotationData<ToolData>(
annotations,
MessageAnnotationType.TOOLS,
);
const contents: ContentDisplayConfig[] = [
{
order: -2,
order: -3,
component: imageData[0] ? <ChatImage data={imageData[0]} /> : null,
},
{
order: -1,
order: -2,
component:
eventData.length > 0 ? (
<ChatEvents isLoading={isLoading} data={eventData} />
) : null,
},
{
order: -1,
component: toolData[0] ? <ChatTools data={toolData[0]} /> : null,
},
{
order: 0,
component: <Markdown content={message.content} />,
@@ -40,9 +40,16 @@ export default function ChatMessages(
className="flex h-[50vh] flex-col gap-5 divide-y overflow-y-auto pb-4"
ref={scrollableChatContainerRef}
>
{props.messages.map((m) => (
<ChatMessage key={m.id} chatMessage={m} isLoading={props.isLoading} />
))}
{props.messages.map((m, i) => {
const isLoadingMessage = i === messageLength - 1 && props.isLoading;
return (
<ChatMessage
key={m.id}
chatMessage={m}
isLoading={isLoadingMessage}
/>
);
})}
{isPending && (
<div className="flex justify-center items-center pt-10">
<Loader2 className="h-4 w-4 animate-spin" />
@@ -1,20 +1,78 @@
import { ArrowUpRightSquare, Check, Copy } from "lucide-react";
import { Check, Copy } from "lucide-react";
import { useMemo } from "react";
import { Button } from "../button";
import { HoverCard, HoverCardContent, HoverCardTrigger } from "../hover-card";
import { getStaticFileDataUrl } from "../lib/url";
import { SourceData, SourceNode } from "./index";
import { useCopyToClipboard } from "./use-copy-to-clipboard";
import PdfDialog from "./widgets/PdfDialog";
const SCORE_THRESHOLD = 0.5;
const SCORE_THRESHOLD = 0.3;
function SourceNumberButton({ index }: { index: number }) {
return (
<div className="text-xs w-5 h-5 rounded-full bg-gray-100 mb-2 flex items-center justify-center hover:text-white hover:bg-primary hover:cursor-pointer">
{index + 1}
</div>
);
}
enum NODE_TYPE {
URL,
FILE,
UNKNOWN,
}
type NodeInfo = {
id: string;
type: NODE_TYPE;
path?: string;
url?: string;
};
function getNodeInfo(node: SourceNode): NodeInfo {
if (typeof node.metadata["URL"] === "string") {
const url = node.metadata["URL"];
return {
id: node.id,
type: NODE_TYPE.URL,
path: url,
url,
};
}
if (typeof node.metadata["file_path"] === "string") {
const fileName = node.metadata["file_name"] as string;
return {
id: node.id,
type: NODE_TYPE.FILE,
path: node.metadata["file_path"],
url: getStaticFileDataUrl(fileName),
};
}
return {
id: node.id,
type: NODE_TYPE.UNKNOWN,
};
}
export function ChatSources({ data }: { data: SourceData }) {
const sources = useMemo(() => {
return (
data.nodes
?.filter((node) => Object.keys(node.metadata).length > 0)
?.filter((node) => (node.score ?? 1) > SCORE_THRESHOLD)
.sort((a, b) => (b.score ?? 1) - (a.score ?? 1)) || []
);
const sources: NodeInfo[] = useMemo(() => {
// aggregate nodes by url or file_path (get the highest one by score)
const nodesByPath: { [path: string]: NodeInfo } = {};
data.nodes
.filter((node) => (node.score ?? 1) > SCORE_THRESHOLD)
.sort((a, b) => (b.score ?? 1) - (a.score ?? 1))
.forEach((node) => {
const nodeInfo = getNodeInfo(node);
const key = nodeInfo.path ?? nodeInfo.id; // use id as key for UNKNOWN type
if (!nodesByPath[key]) {
nodesByPath[key] = nodeInfo;
}
});
return Object.values(nodesByPath);
}, [data.nodes]);
if (sources.length === 0) return null;
@@ -23,55 +81,52 @@ export function ChatSources({ data }: { data: SourceData }) {
<div className="space-x-2 text-sm">
<span className="font-semibold">Sources:</span>
<div className="inline-flex gap-1 items-center">
{sources.map((node: SourceNode, index: number) => (
<div key={node.id}>
<HoverCard>
<HoverCardTrigger>
<div className="text-xs w-5 h-5 rounded-full bg-gray-100 mb-2 flex items-center justify-center hover:text-white hover:bg-primary hover:cursor-pointer">
{index + 1}
</div>
</HoverCardTrigger>
<HoverCardContent>
<NodeInfo node={node} />
</HoverCardContent>
</HoverCard>
</div>
))}
{sources.map((nodeInfo: NodeInfo, index: number) => {
if (nodeInfo.path?.endsWith(".pdf")) {
return (
<PdfDialog
key={nodeInfo.id}
documentId={nodeInfo.id}
url={nodeInfo.url!}
path={nodeInfo.path}
trigger={<SourceNumberButton index={index} />}
/>
);
}
return (
<div key={nodeInfo.id}>
<HoverCard>
<HoverCardTrigger>
<SourceNumberButton index={index} />
</HoverCardTrigger>
<HoverCardContent className="w-[320px]">
<NodeInfo nodeInfo={nodeInfo} />
</HoverCardContent>
</HoverCard>
</div>
);
})}
</div>
</div>
);
}
function NodeInfo({ node }: { node: SourceNode }) {
function NodeInfo({ nodeInfo }: { nodeInfo: NodeInfo }) {
const { isCopied, copyToClipboard } = useCopyToClipboard({ timeout: 1000 });
if (typeof node.metadata["URL"] === "string") {
// this is a node generated by the web loader, it contains an external URL
// add a link to view this URL
if (nodeInfo.type !== NODE_TYPE.UNKNOWN) {
// this is a node generated by the web loader or file loader,
// add a link to view its URL and a button to copy the URL to the clipboard
return (
<a
className="space-x-2 flex items-center my-2 hover:text-blue-900"
href={node.metadata["URL"]}
target="_blank"
>
<span>{node.metadata["URL"]}</span>
<ArrowUpRightSquare className="w-4 h-4" />
</a>
);
}
if (typeof node.metadata["file_path"] === "string") {
// this is a node generated by the file loader, it contains file path
// add a button to copy the path to the clipboard
const filePath = node.metadata["file_path"];
return (
<div className="flex items-center px-2 py-1 justify-between my-2">
<span>{filePath}</span>
<div className="flex items-center my-2">
<a className="hover:text-blue-900" href={nodeInfo.url} target="_blank">
<span>{nodeInfo.path}</span>
</a>
<Button
onClick={() => copyToClipboard(filePath)}
onClick={() => copyToClipboard(nodeInfo.path!)}
size="icon"
variant="ghost"
className="h-12 w-12"
className="h-12 w-12 shrink-0"
>
{isCopied ? (
<Check className="h-4 w-4" />
@@ -84,7 +139,6 @@ function NodeInfo({ node }: { node: SourceNode }) {
}
// node generated by unknown loader, implement renderer by analyzing logged out metadata
console.log("Node metadata", node.metadata);
return (
<p>
Sorry, unknown node type. Please add a new renderer in the NodeInfo
@@ -0,0 +1,26 @@
import { ToolData } from "./index";
import { WeatherCard, WeatherData } from "./widgets/WeatherCard";
// TODO: If needed, add displaying more tool outputs here
export default function ChatTools({ data }: { data: ToolData }) {
if (!data) return null;
const { toolCall, toolOutput } = data;
if (toolOutput.isError) {
return (
<div className="border-l-2 border-red-400 pl-2">
There was an error when calling the tool {toolCall.name} with input:{" "}
<br />
{JSON.stringify(toolCall.input)}
</div>
);
}
switch (toolCall.name) {
case "get_weather_information":
const weatherData = toolOutput.output as unknown as WeatherData;
return <WeatherCard data={weatherData} />;
default:
return null;
}
}
@@ -1,3 +1,4 @@
import { JSONValue } from "ai";
import ChatInput from "./chat-input";
import ChatMessages from "./chat-messages";
@@ -8,6 +9,7 @@ export enum MessageAnnotationType {
IMAGE = "image",
SOURCES = "sources",
EVENTS = "events",
TOOLS = "tools",
}
export type ImageData = {
@@ -30,7 +32,21 @@ export type EventData = {
isCollapsed: boolean;
};
export type AnnotationData = ImageData | SourceData | EventData;
export type ToolData = {
toolCall: {
id: string;
name: string;
input: {
[key: string]: JSONValue;
};
};
toolOutput: {
output: JSONValue;
isError: boolean;
};
};
export type AnnotationData = ImageData | SourceData | EventData | ToolData;
export type MessageAnnotation = {
type: MessageAnnotationType;
@@ -1,5 +1,7 @@
import "katex/dist/katex.min.css";
import { FC, memo } from "react";
import ReactMarkdown, { Options } from "react-markdown";
import rehypeKatex from "rehype-katex";
import remarkGfm from "remark-gfm";
import remarkMath from "remark-math";
@@ -12,11 +14,27 @@ const MemoizedReactMarkdown: FC<Options> = memo(
prevProps.className === nextProps.className,
);
const preprocessLaTeX = (content: string) => {
// Replace block-level LaTeX delimiters \[ \] with $$ $$
const blockProcessedContent = content.replace(
/\\\[(.*?)\\\]/gs,
(_, equation) => `$$${equation}$$`,
);
// Replace inline LaTeX delimiters \( \) with $ $
const inlineProcessedContent = blockProcessedContent.replace(
/\\\((.*?)\\\)/gs,
(_, equation) => `$${equation}$`,
);
return inlineProcessedContent;
};
export default function Markdown({ content }: { content: string }) {
const processedContent = preprocessLaTeX(content);
return (
<MemoizedReactMarkdown
className="prose dark:prose-invert prose-p:leading-relaxed prose-pre:p-0 break-words custom-markdown"
remarkPlugins={[remarkGfm, remarkMath]}
rehypePlugins={[rehypeKatex as any]}
components={{
p({ children }) {
return <p className="mb-2 last:mb-0">{children}</p>;
@@ -53,7 +71,7 @@ export default function Markdown({ content }: { content: string }) {
},
}}
>
{content}
{processedContent}
</MemoizedReactMarkdown>
);
}
@@ -0,0 +1,56 @@
import { PDFViewer, PdfFocusProvider } from "@llamaindex/pdf-viewer";
import { Button } from "../../button";
import {
Drawer,
DrawerClose,
DrawerContent,
DrawerDescription,
DrawerHeader,
DrawerTitle,
DrawerTrigger,
} from "../../drawer";
export interface PdfDialogProps {
documentId: string;
path: string;
url: string;
trigger: React.ReactNode;
}
export default function PdfDialog(props: PdfDialogProps) {
return (
<Drawer direction="left">
<DrawerTrigger>{props.trigger}</DrawerTrigger>
<DrawerContent className="w-3/5 mt-24 h-full max-h-[96%] ">
<DrawerHeader className="flex justify-between">
<div className="space-y-2">
<DrawerTitle>PDF Content</DrawerTitle>
<DrawerDescription>
File path:{" "}
<a
className="hover:text-blue-900"
href={props.url}
target="_blank"
>
{props.path}
</a>
</DrawerDescription>
</div>
<DrawerClose asChild>
<Button variant="outline">Close</Button>
</DrawerClose>
</DrawerHeader>
<div className="m-4">
<PdfFocusProvider>
<PDFViewer
file={{
id: props.documentId,
url: props.url,
}}
/>
</PdfFocusProvider>
</div>
</DrawerContent>
</Drawer>
);
}
@@ -0,0 +1,213 @@
export interface WeatherData {
latitude: number;
longitude: number;
generationtime_ms: number;
utc_offset_seconds: number;
timezone: string;
timezone_abbreviation: string;
elevation: number;
current_units: {
time: string;
interval: string;
temperature_2m: string;
weather_code: string;
};
current: {
time: string;
interval: number;
temperature_2m: number;
weather_code: number;
};
hourly_units: {
time: string;
temperature_2m: string;
weather_code: string;
};
hourly: {
time: string[];
temperature_2m: number[];
weather_code: number[];
};
daily_units: {
time: string;
weather_code: string;
};
daily: {
time: string[];
weather_code: number[];
};
}
// Follow WMO Weather interpretation codes (WW)
const weatherCodeDisplayMap: Record<
string,
{
icon: JSX.Element;
status: string;
}
> = {
"0": {
icon: <span></span>,
status: "Clear sky",
},
"1": {
icon: <span>🌤</span>,
status: "Mainly clear",
},
"2": {
icon: <span></span>,
status: "Partly cloudy",
},
"3": {
icon: <span></span>,
status: "Overcast",
},
"45": {
icon: <span>🌫</span>,
status: "Fog",
},
"48": {
icon: <span>🌫</span>,
status: "Depositing rime fog",
},
"51": {
icon: <span>🌧</span>,
status: "Drizzle",
},
"53": {
icon: <span>🌧</span>,
status: "Drizzle",
},
"55": {
icon: <span>🌧</span>,
status: "Drizzle",
},
"56": {
icon: <span>🌧</span>,
status: "Freezing Drizzle",
},
"57": {
icon: <span>🌧</span>,
status: "Freezing Drizzle",
},
"61": {
icon: <span>🌧</span>,
status: "Rain",
},
"63": {
icon: <span>🌧</span>,
status: "Rain",
},
"65": {
icon: <span>🌧</span>,
status: "Rain",
},
"66": {
icon: <span>🌧</span>,
status: "Freezing Rain",
},
"67": {
icon: <span>🌧</span>,
status: "Freezing Rain",
},
"71": {
icon: <span></span>,
status: "Snow fall",
},
"73": {
icon: <span></span>,
status: "Snow fall",
},
"75": {
icon: <span></span>,
status: "Snow fall",
},
"77": {
icon: <span></span>,
status: "Snow grains",
},
"80": {
icon: <span>🌧</span>,
status: "Rain showers",
},
"81": {
icon: <span>🌧</span>,
status: "Rain showers",
},
"82": {
icon: <span>🌧</span>,
status: "Rain showers",
},
"85": {
icon: <span></span>,
status: "Snow showers",
},
"86": {
icon: <span></span>,
status: "Snow showers",
},
"95": {
icon: <span></span>,
status: "Thunderstorm",
},
"96": {
icon: <span></span>,
status: "Thunderstorm",
},
"99": {
icon: <span></span>,
status: "Thunderstorm",
},
};
const displayDay = (time: string) => {
return new Date(time).toLocaleDateString("en-US", {
weekday: "long",
});
};
export function WeatherCard({ data }: { data: WeatherData }) {
const currentDayString = new Date(data.current.time).toLocaleDateString(
"en-US",
{
weekday: "long",
month: "long",
day: "numeric",
},
);
return (
<div className="bg-[#61B9F2] rounded-2xl shadow-xl p-5 space-y-4 text-white w-fit">
<div className="flex justify-between">
<div className="space-y-2">
<div className="text-xl">{currentDayString}</div>
<div className="text-5xl font-semibold flex gap-4">
<span>
{data.current.temperature_2m} {data.current_units.temperature_2m}
</span>
{weatherCodeDisplayMap[data.current.weather_code].icon}
</div>
</div>
<span className="text-xl">
{weatherCodeDisplayMap[data.current.weather_code].status}
</span>
</div>
<div className="gap-2 grid grid-cols-6">
{data.daily.time.map((time, index) => {
if (index === 0) return null; // skip the current day
return (
<div key={time} className="flex flex-col items-center gap-4">
<span>{displayDay(time)}</span>
<div className="text-4xl">
{weatherCodeDisplayMap[data.daily.weather_code[index]].icon}
</div>
<span className="text-sm">
{weatherCodeDisplayMap[data.daily.weather_code[index]].status}
</span>
</div>
);
})}
</div>
</div>
);
}
@@ -0,0 +1,118 @@
"use client";
import * as React from "react";
import { Drawer as DrawerPrimitive } from "vaul";
import { cn } from "./lib/utils";
const Drawer = ({
shouldScaleBackground = true,
...props
}: React.ComponentProps<typeof DrawerPrimitive.Root>) => (
<DrawerPrimitive.Root
shouldScaleBackground={shouldScaleBackground}
{...props}
/>
);
Drawer.displayName = "Drawer";
const DrawerTrigger = DrawerPrimitive.Trigger;
const DrawerPortal = DrawerPrimitive.Portal;
const DrawerClose = DrawerPrimitive.Close;
const DrawerOverlay = React.forwardRef<
React.ElementRef<typeof DrawerPrimitive.Overlay>,
React.ComponentPropsWithoutRef<typeof DrawerPrimitive.Overlay>
>(({ className, ...props }, ref) => (
<DrawerPrimitive.Overlay
ref={ref}
className={cn("fixed inset-0 z-50 bg-black/80", className)}
{...props}
/>
));
DrawerOverlay.displayName = DrawerPrimitive.Overlay.displayName;
const DrawerContent = React.forwardRef<
React.ElementRef<typeof DrawerPrimitive.Content>,
React.ComponentPropsWithoutRef<typeof DrawerPrimitive.Content>
>(({ className, children, ...props }, ref) => (
<DrawerPortal>
<DrawerOverlay />
<DrawerPrimitive.Content
ref={ref}
className={cn(
"fixed inset-x-0 bottom-0 z-50 mt-24 flex h-auto flex-col rounded-t-[10px] border bg-background",
className,
)}
{...props}
>
<div className="mx-auto mt-4 h-2 w-[100px] rounded-full bg-muted" />
{children}
</DrawerPrimitive.Content>
</DrawerPortal>
));
DrawerContent.displayName = "DrawerContent";
const DrawerHeader = ({
className,
...props
}: React.HTMLAttributes<HTMLDivElement>) => (
<div
className={cn("grid gap-1.5 p-4 text-center sm:text-left", className)}
{...props}
/>
);
DrawerHeader.displayName = "DrawerHeader";
const DrawerFooter = ({
className,
...props
}: React.HTMLAttributes<HTMLDivElement>) => (
<div
className={cn("mt-auto flex flex-col gap-2 p-4", className)}
{...props}
/>
);
DrawerFooter.displayName = "DrawerFooter";
const DrawerTitle = React.forwardRef<
React.ElementRef<typeof DrawerPrimitive.Title>,
React.ComponentPropsWithoutRef<typeof DrawerPrimitive.Title>
>(({ className, ...props }, ref) => (
<DrawerPrimitive.Title
ref={ref}
className={cn(
"text-lg font-semibold leading-none tracking-tight",
className,
)}
{...props}
/>
));
DrawerTitle.displayName = DrawerPrimitive.Title.displayName;
const DrawerDescription = React.forwardRef<
React.ElementRef<typeof DrawerPrimitive.Description>,
React.ComponentPropsWithoutRef<typeof DrawerPrimitive.Description>
>(({ className, ...props }, ref) => (
<DrawerPrimitive.Description
ref={ref}
className={cn("text-sm text-muted-foreground", className)}
{...props}
/>
));
DrawerDescription.displayName = DrawerPrimitive.Description.displayName;
export {
Drawer,
DrawerClose,
DrawerContent,
DrawerDescription,
DrawerFooter,
DrawerHeader,
DrawerOverlay,
DrawerPortal,
DrawerTitle,
DrawerTrigger,
};
@@ -0,0 +1,11 @@
const STORAGE_FOLDER = "data";
export const getStaticFileDataUrl = (filename: string) => {
const isUsingBackend = !!process.env.NEXT_PUBLIC_CHAT_API;
const fileUrl = `/api/${STORAGE_FOLDER}/${filename}`;
if (isUsingBackend) {
const backendOrigin = new URL(process.env.NEXT_PUBLIC_CHAT_API!).origin;
return `${backendOrigin}/${fileUrl}`;
}
return fileUrl;
};
@@ -14,12 +14,14 @@
"@radix-ui/react-hover-card": "^1.0.7",
"@radix-ui/react-slot": "^1.0.2",
"ai": "^3.0.21",
"ajv": "^8.12.0",
"class-variance-authority": "^0.7.0",
"clsx": "^1.2.1",
"clsx": "^2.1.1",
"dotenv": "^16.3.1",
"llamaindex": "0.2.10",
"llamaindex": "0.3.13",
"lucide-react": "^0.294.0",
"next": "^14.0.3",
"pdf2json": "3.0.5",
"react": "^18.2.0",
"react-dom": "^18.2.0",
"react-markdown": "^8.0.7",
@@ -28,8 +30,11 @@
"remark-code-import": "^1.2.0",
"remark-gfm": "^3.0.1",
"remark-math": "^5.1.1",
"rehype-katex": "^7.0.0",
"supports-color": "^8.1.1",
"tailwind-merge": "^2.1.0"
"tailwind-merge": "^2.1.0",
"vaul": "^0.9.1",
"@llamaindex/pdf-viewer": "^1.1.1"
},
"devDependencies": {
"@types/node": "^20.10.3",
+4 -2
View File
@@ -11,14 +11,16 @@
"forceConsistentCasingInFileNames": true,
"incremental": true,
"outDir": "./lib",
"tsBuildInfoFile": "./lib/.tsbuildinfo"
"tsBuildInfoFile": "./lib/.tsbuildinfo",
"typeRoots": ["./types", "./node_modules/@types"]
},
"include": [
"create-app.ts",
"index.ts",
"./helpers",
"questions.ts",
"package.json"
"package.json",
"types/**/*"
],
"exclude": ["dist"]
}
+1
View File
@@ -0,0 +1 @@
declare module "global-agent/bootstrap";