mirror of
https://github.com/run-llama/create-llama.git
synced 2026-07-02 19:14:28 -04:00
Compare commits
31 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 05748bdf10 | |||
| d60b3c5a96 | |||
| c3e9ed3df4 | |||
| 1fde1dc585 | |||
| cd50a33d43 | |||
| ed114856d9 | |||
| 69c2e16c82 | |||
| f5da6623cf | |||
| 0950cb90f2 | |||
| bb53425b4b | |||
| bbd5b8ddd6 | |||
| 260d37a3f1 | |||
| 7873bfb030 | |||
| 0c7c41ee3b | |||
| 56537a1473 | |||
| d8dfc29edd | |||
| 84db798353 | |||
| 67a062af14 | |||
| 0bc8e75c64 | |||
| 6bd5e7b77a | |||
| 38bc1d1350 | |||
| cb1001de95 | |||
| 78776ac51e | |||
| 416073db1d | |||
| 84929de8b2 | |||
| 6fe240b854 | |||
| 8bb1024d0f | |||
| 988bfc2a60 | |||
| 056e376ee0 | |||
| 819cccb11a | |||
| 8a5ece10c2 |
@@ -0,0 +1,5 @@
|
||||
---
|
||||
"create-llama": patch
|
||||
---
|
||||
|
||||
Add support E2B code interpreter tool for FastAPI
|
||||
@@ -26,7 +26,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
|
||||
@@ -1,5 +1,52 @@
|
||||
# create-llama
|
||||
|
||||
## 0.1.8
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- cd50a33: Add interpreter tool for TS using e2b.dev
|
||||
|
||||
## 0.1.7
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 260d37a: Add system prompt env variable for TS
|
||||
- bbd5b8d: Fix postgres connection leaking issue
|
||||
- bb53425: Support HTTP proxies by setting the GLOBAL_AGENT_HTTP_PROXY env variable
|
||||
- 69c2e16: Fix streaming for Express
|
||||
- 7873bfb: Update Ollama provider to run with the base URL from the environment variable
|
||||
|
||||
## 0.1.6
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 56537a1: Display PDF files in source nodes
|
||||
|
||||
## 0.1.5
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 84db798: feat: support display latex in chat markdown
|
||||
|
||||
## 0.1.4
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 0bc8e75: Use ingestion pipeline for dedicated vector stores (Python only)
|
||||
- cb1001d: Add ChromaDB vector store
|
||||
|
||||
## 0.1.3
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 416073d: Directly import vector stores to work with NextJS
|
||||
|
||||
## 0.1.2
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 056e376: Add support for displaying tool outputs (including weather widget as example)
|
||||
|
||||
## 0.1.1
|
||||
|
||||
### Patch Changes
|
||||
|
||||
+119
-33
@@ -1,5 +1,6 @@
|
||||
import fs from "fs/promises";
|
||||
import path from "path";
|
||||
import { TOOL_SYSTEM_PROMPT_ENV_VAR, Tool } from "./tools";
|
||||
import {
|
||||
ModelConfig,
|
||||
TemplateDataSource,
|
||||
@@ -7,7 +8,7 @@ import {
|
||||
TemplateVectorDB,
|
||||
} from "./types";
|
||||
|
||||
type EnvVar = {
|
||||
export type EnvVar = {
|
||||
name?: string;
|
||||
description?: string;
|
||||
value?: string;
|
||||
@@ -29,17 +30,20 @@ const renderEnvVar = (envVars: EnvVar[]): string => {
|
||||
);
|
||||
};
|
||||
|
||||
const getVectorDBEnvs = (vectorDb?: TemplateVectorDB): EnvVar[] => {
|
||||
if (!vectorDb) {
|
||||
const getVectorDBEnvs = (
|
||||
vectorDb?: TemplateVectorDB,
|
||||
framework?: TemplateFramework,
|
||||
): EnvVar[] => {
|
||||
if (!vectorDb || !framework) {
|
||||
return [];
|
||||
}
|
||||
switch (vectorDb) {
|
||||
case "mongo":
|
||||
return [
|
||||
{
|
||||
name: "MONGO_URI",
|
||||
name: "MONGODB_URI",
|
||||
description:
|
||||
"For generating a connection URI, see https://docs.timescale.com/use-timescale/latest/services/create-a-service\nThe MongoDB connection URI.",
|
||||
"For generating a connection URI, see https://www.mongodb.com/docs/manual/reference/connection-string/ \nThe MongoDB connection URI.",
|
||||
},
|
||||
{
|
||||
name: "MONGODB_DATABASE",
|
||||
@@ -129,6 +133,31 @@ const getVectorDBEnvs = (vectorDb?: TemplateVectorDB): EnvVar[] => {
|
||||
"Optional API key for authenticating requests to Qdrant.",
|
||||
},
|
||||
];
|
||||
case "chroma":
|
||||
const envs = [
|
||||
{
|
||||
name: "CHROMA_COLLECTION",
|
||||
description: "The name of the collection in your Chroma database",
|
||||
},
|
||||
{
|
||||
name: "CHROMA_HOST",
|
||||
description: "The API endpoint for your Chroma database",
|
||||
},
|
||||
{
|
||||
name: "CHROMA_PORT",
|
||||
description: "The port for your Chroma database",
|
||||
},
|
||||
];
|
||||
// TS Version doesn't support config local storage path
|
||||
if (framework === "fastapi") {
|
||||
envs.push({
|
||||
name: "CHROMA_PATH",
|
||||
description: `The local path to the Chroma database.
|
||||
Specify this if you are using a local Chroma database.
|
||||
Otherwise, use CHROMA_HOST and CHROMA_PORT config above`,
|
||||
});
|
||||
}
|
||||
return envs;
|
||||
default:
|
||||
return [];
|
||||
}
|
||||
@@ -191,41 +220,52 @@ const getModelEnvs = (modelConfig: ModelConfig): EnvVar[] => {
|
||||
},
|
||||
]
|
||||
: []),
|
||||
...(modelConfig.provider === "ollama"
|
||||
? [
|
||||
{
|
||||
name: "OLLAMA_BASE_URL",
|
||||
description:
|
||||
"The base URL for the Ollama API. Eg: http://127.0.0.1:11434",
|
||||
},
|
||||
]
|
||||
: []),
|
||||
];
|
||||
};
|
||||
|
||||
const getFrameworkEnvs = (
|
||||
framework?: TemplateFramework,
|
||||
framework: TemplateFramework,
|
||||
port?: number,
|
||||
): EnvVar[] => {
|
||||
if (framework !== "fastapi") {
|
||||
return [];
|
||||
}
|
||||
return [
|
||||
const sPort = port?.toString() || "8000";
|
||||
const result: EnvVar[] = [
|
||||
{
|
||||
name: "APP_HOST",
|
||||
description: "The address to start the backend app.",
|
||||
value: "0.0.0.0",
|
||||
},
|
||||
{
|
||||
name: "APP_PORT",
|
||||
description: "The port to start the backend app.",
|
||||
value: port?.toString() || "8000",
|
||||
},
|
||||
// TODO: Once LlamaIndexTS supports string templates, move this to `getEngineEnvs`
|
||||
{
|
||||
name: "SYSTEM_PROMPT",
|
||||
description: `Custom system prompt.
|
||||
Example:
|
||||
SYSTEM_PROMPT="
|
||||
We have provided context information below.
|
||||
---------------------
|
||||
{context_str}
|
||||
---------------------
|
||||
Given this information, please answer the question: {query_str}
|
||||
"`,
|
||||
name: "FILESERVER_URL_PREFIX",
|
||||
description:
|
||||
"FILESERVER_URL_PREFIX is the URL prefix of the server storing the images generated by the interpreter.",
|
||||
value:
|
||||
framework === "nextjs"
|
||||
? // FIXME: if we are using nextjs, port should be 3000
|
||||
"http://localhost:3000/api/files"
|
||||
: `http://localhost:${sPort}/api/files`,
|
||||
},
|
||||
];
|
||||
if (framework === "fastapi") {
|
||||
result.push(
|
||||
...[
|
||||
{
|
||||
name: "APP_HOST",
|
||||
description: "The address to start the backend app.",
|
||||
value: "0.0.0.0",
|
||||
},
|
||||
{
|
||||
name: "APP_PORT",
|
||||
description: "The port to start the backend app.",
|
||||
value: sPort,
|
||||
},
|
||||
],
|
||||
);
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
const getEngineEnvs = (): EnvVar[] => {
|
||||
@@ -239,15 +279,59 @@ const getEngineEnvs = (): EnvVar[] => {
|
||||
];
|
||||
};
|
||||
|
||||
const getToolEnvs = (tools?: Tool[]): EnvVar[] => {
|
||||
if (!tools?.length) return [];
|
||||
const toolEnvs: EnvVar[] = [];
|
||||
tools.forEach((tool) => {
|
||||
if (tool.envVars?.length) {
|
||||
toolEnvs.push(
|
||||
// Don't include the system prompt env var here
|
||||
// It should be handled separately by merging with the default system prompt
|
||||
...tool.envVars.filter(
|
||||
(env) => env.name !== TOOL_SYSTEM_PROMPT_ENV_VAR,
|
||||
),
|
||||
);
|
||||
}
|
||||
});
|
||||
return toolEnvs;
|
||||
};
|
||||
|
||||
const getSystemPromptEnv = (tools?: Tool[]): EnvVar => {
|
||||
const defaultSystemPrompt =
|
||||
"You are a helpful assistant who helps users with their questions.";
|
||||
|
||||
// build tool system prompt by merging all tool system prompts
|
||||
let toolSystemPrompt = "";
|
||||
tools?.forEach((tool) => {
|
||||
const toolSystemPromptEnv = tool.envVars?.find(
|
||||
(env) => env.name === TOOL_SYSTEM_PROMPT_ENV_VAR,
|
||||
);
|
||||
if (toolSystemPromptEnv) {
|
||||
toolSystemPrompt += toolSystemPromptEnv.value + "\n";
|
||||
}
|
||||
});
|
||||
|
||||
const systemPrompt = toolSystemPrompt
|
||||
? `\"${toolSystemPrompt}\"`
|
||||
: defaultSystemPrompt;
|
||||
|
||||
return {
|
||||
name: "SYSTEM_PROMPT",
|
||||
description: "The system prompt for the AI model.",
|
||||
value: systemPrompt,
|
||||
};
|
||||
};
|
||||
|
||||
export const createBackendEnvFile = async (
|
||||
root: string,
|
||||
opts: {
|
||||
llamaCloudKey?: string;
|
||||
vectorDb?: TemplateVectorDB;
|
||||
modelConfig: ModelConfig;
|
||||
framework?: TemplateFramework;
|
||||
framework: TemplateFramework;
|
||||
dataSources?: TemplateDataSource[];
|
||||
port?: number;
|
||||
tools?: Tool[];
|
||||
},
|
||||
) => {
|
||||
// Init env values
|
||||
@@ -263,8 +347,10 @@ export const createBackendEnvFile = async (
|
||||
// Add engine environment variables
|
||||
...getEngineEnvs(),
|
||||
// Add vector database environment variables
|
||||
...getVectorDBEnvs(opts.vectorDb),
|
||||
...getVectorDBEnvs(opts.vectorDb, opts.framework),
|
||||
...getFrameworkEnvs(opts.framework, opts.port),
|
||||
...getToolEnvs(opts.tools),
|
||||
getSystemPromptEnv(opts.tools),
|
||||
];
|
||||
// Render and write env file
|
||||
const content = renderEnvVar(envVars);
|
||||
|
||||
@@ -148,6 +148,7 @@ export const installTemplate = async (
|
||||
framework: props.framework,
|
||||
dataSources: props.dataSources,
|
||||
port: props.externalPort,
|
||||
tools: props.tools,
|
||||
});
|
||||
|
||||
if (props.dataSources.length > 0) {
|
||||
@@ -170,6 +171,11 @@ export const installTemplate = async (
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Create tool-output directory
|
||||
if (props.tools && props.tools.length > 0) {
|
||||
await fsExtra.mkdir(path.join(props.root, "tool-output"));
|
||||
}
|
||||
} else {
|
||||
// this is a frontend for a full-stack app, create .env file with model information
|
||||
await createFrontendEnvFile(props.root, {
|
||||
|
||||
@@ -8,7 +8,7 @@ import { questionHandlers } from "../../questions";
|
||||
|
||||
const OPENAI_API_URL = "https://api.openai.com/v1";
|
||||
|
||||
const DEFAULT_MODEL = "gpt-4-turbo";
|
||||
const DEFAULT_MODEL = "gpt-3.5-turbo";
|
||||
const DEFAULT_EMBEDDING_MODEL = "text-embedding-3-large";
|
||||
|
||||
export async function askOpenAIQuestions({
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
/* Function to conditionally load the global-agent/bootstrap module */
|
||||
export async function initializeGlobalAgent() {
|
||||
if (process.env.GLOBAL_AGENT_HTTP_PROXY) {
|
||||
/* Dynamically import global-agent/bootstrap */
|
||||
await import("global-agent/bootstrap");
|
||||
console.log("Proxy enabled via global-agent.");
|
||||
}
|
||||
}
|
||||
+58
-33
@@ -24,7 +24,7 @@ interface Dependency {
|
||||
const getAdditionalDependencies = (
|
||||
modelConfig: ModelConfig,
|
||||
vectorDb?: TemplateVectorDB,
|
||||
dataSource?: TemplateDataSource,
|
||||
dataSources?: TemplateDataSource[],
|
||||
tools?: Tool[],
|
||||
) => {
|
||||
const dependencies: Dependency[] = [];
|
||||
@@ -43,6 +43,7 @@ const getAdditionalDependencies = (
|
||||
name: "llama-index-vector-stores-postgres",
|
||||
version: "^0.1.1",
|
||||
});
|
||||
break;
|
||||
}
|
||||
case "pinecone": {
|
||||
dependencies.push({
|
||||
@@ -69,41 +70,60 @@ const getAdditionalDependencies = (
|
||||
});
|
||||
break;
|
||||
}
|
||||
case "qdrant": {
|
||||
dependencies.push({
|
||||
name: "llama-index-vector-stores-qdrant",
|
||||
version: "^0.2.8",
|
||||
});
|
||||
break;
|
||||
}
|
||||
case "chroma": {
|
||||
dependencies.push({
|
||||
name: "llama-index-vector-stores-chroma",
|
||||
version: "^0.1.8",
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Add data source dependencies
|
||||
const dataSourceType = dataSource?.type;
|
||||
switch (dataSourceType) {
|
||||
case "file":
|
||||
dependencies.push({
|
||||
name: "docx2txt",
|
||||
version: "^0.8",
|
||||
});
|
||||
break;
|
||||
case "web":
|
||||
dependencies.push({
|
||||
name: "llama-index-readers-web",
|
||||
version: "^0.1.6",
|
||||
});
|
||||
break;
|
||||
case "db":
|
||||
dependencies.push({
|
||||
name: "llama-index-readers-database",
|
||||
version: "^0.1.3",
|
||||
});
|
||||
dependencies.push({
|
||||
name: "pymysql",
|
||||
version: "^1.1.0",
|
||||
extras: ["rsa"],
|
||||
});
|
||||
dependencies.push({
|
||||
name: "psycopg2",
|
||||
version: "^2.9.9",
|
||||
});
|
||||
break;
|
||||
if (dataSources) {
|
||||
for (const ds of dataSources) {
|
||||
const dsType = ds?.type;
|
||||
switch (dsType) {
|
||||
case "file":
|
||||
dependencies.push({
|
||||
name: "docx2txt",
|
||||
version: "^0.8",
|
||||
});
|
||||
break;
|
||||
case "web":
|
||||
dependencies.push({
|
||||
name: "llama-index-readers-web",
|
||||
version: "^0.1.6",
|
||||
});
|
||||
break;
|
||||
case "db":
|
||||
dependencies.push({
|
||||
name: "llama-index-readers-database",
|
||||
version: "^0.1.3",
|
||||
});
|
||||
dependencies.push({
|
||||
name: "pymysql",
|
||||
version: "^1.1.0",
|
||||
extras: ["rsa"],
|
||||
});
|
||||
dependencies.push({
|
||||
name: "psycopg2",
|
||||
version: "^2.9.9",
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add tools dependencies
|
||||
console.log("Adding tools dependencies");
|
||||
tools?.forEach((tool) => {
|
||||
tool.dependencies?.forEach((dep) => {
|
||||
dependencies.push(dep);
|
||||
@@ -298,9 +318,14 @@ export const installPythonTemplate = async ({
|
||||
cwd: path.join(compPath, "engines", "python", engine),
|
||||
});
|
||||
|
||||
const addOnDependencies = dataSources
|
||||
.map((ds) => getAdditionalDependencies(modelConfig, vectorDb, ds, tools))
|
||||
.flat();
|
||||
console.log("Adding additional dependencies");
|
||||
|
||||
const addOnDependencies = getAdditionalDependencies(
|
||||
modelConfig,
|
||||
vectorDb,
|
||||
dataSources,
|
||||
tools,
|
||||
);
|
||||
|
||||
if (observability === "opentelemetry") {
|
||||
addOnDependencies.push({
|
||||
|
||||
+82
-2
@@ -2,15 +2,25 @@ import fs from "fs/promises";
|
||||
import path from "path";
|
||||
import { red } from "picocolors";
|
||||
import yaml from "yaml";
|
||||
import { EnvVar } from "./env-variables";
|
||||
import { makeDir } from "./make-dir";
|
||||
import { TemplateFramework } from "./types";
|
||||
|
||||
export const TOOL_SYSTEM_PROMPT_ENV_VAR = "TOOL_SYSTEM_PROMPT";
|
||||
|
||||
export enum ToolType {
|
||||
LLAMAHUB = "llamahub",
|
||||
LOCAL = "local",
|
||||
}
|
||||
|
||||
export type Tool = {
|
||||
display: string;
|
||||
name: string;
|
||||
config?: Record<string, any>;
|
||||
dependencies?: ToolDependencies[];
|
||||
supportedFrameworks?: Array<TemplateFramework>;
|
||||
type: ToolType;
|
||||
envVars?: EnvVar[];
|
||||
};
|
||||
|
||||
export type ToolDependencies = {
|
||||
@@ -35,6 +45,14 @@ export const supportedTools: Tool[] = [
|
||||
},
|
||||
],
|
||||
supportedFrameworks: ["fastapi"],
|
||||
type: ToolType.LLAMAHUB,
|
||||
envVars: [
|
||||
{
|
||||
name: TOOL_SYSTEM_PROMPT_ENV_VAR,
|
||||
description: "System prompt for google search tool.",
|
||||
value: `You are a Google search agent. You help users to get information from Google search.`,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
display: "Wikipedia",
|
||||
@@ -46,6 +64,58 @@ export const supportedTools: Tool[] = [
|
||||
},
|
||||
],
|
||||
supportedFrameworks: ["fastapi", "express", "nextjs"],
|
||||
type: ToolType.LLAMAHUB,
|
||||
envVars: [
|
||||
{
|
||||
name: TOOL_SYSTEM_PROMPT_ENV_VAR,
|
||||
description: "System prompt for wiki tool.",
|
||||
value: `You are a Wikipedia agent. You help users to get information from Wikipedia.`,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
display: "Weather",
|
||||
name: "weather",
|
||||
dependencies: [],
|
||||
supportedFrameworks: ["fastapi", "express", "nextjs"],
|
||||
type: ToolType.LOCAL,
|
||||
envVars: [
|
||||
{
|
||||
name: TOOL_SYSTEM_PROMPT_ENV_VAR,
|
||||
description: "System prompt for weather tool.",
|
||||
value: `You are a weather forecast agent. You help users to get the weather forecast for a given location.`,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
display: "Code Interpreter",
|
||||
name: "interpreter",
|
||||
dependencies: [
|
||||
{
|
||||
name: "e2b_code_interpreter",
|
||||
version: "0.0.7",
|
||||
},
|
||||
],
|
||||
supportedFrameworks: ["fastapi", "express", "nextjs"],
|
||||
type: ToolType.LOCAL,
|
||||
envVars: [
|
||||
{
|
||||
name: "E2B_API_KEY",
|
||||
description:
|
||||
"E2B_API_KEY key is required to run code interpreter tool. Get it here: https://e2b.dev/docs/getting-started/api-key",
|
||||
},
|
||||
{
|
||||
name: TOOL_SYSTEM_PROMPT_ENV_VAR,
|
||||
description: "System prompt for code interpreter tool.",
|
||||
value: `You are a Python interpreter.
|
||||
- You are given tasks to complete and you run python code to solve them.
|
||||
- The python code runs in a Jupyter notebook. Every time you call \`interpreter\` tool, the python code is executed in a separate cell. It's okay to make multiple calls to \`interpreter\`.
|
||||
- Display visualizations using matplotlib or any other visualization library directly in the notebook. Shouldn't save the visualizations to a file, just return the base64 encoded data.
|
||||
- You can install any pip package (if it exists) if you need to but the usual packages for data analysis are already preinstalled.
|
||||
- You can run any python code you want in a secure environment.
|
||||
- Use absolute url from result to display images or any other media.`,
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
@@ -90,9 +160,19 @@ export const writeToolsConfig = async (
|
||||
type: ConfigFileType = ConfigFileType.YAML,
|
||||
) => {
|
||||
if (tools.length === 0) return; // no tools selected, no config need
|
||||
const configContent: Record<string, any> = {};
|
||||
const configContent: {
|
||||
[key in ToolType]: Record<string, any>;
|
||||
} = {
|
||||
local: {},
|
||||
llamahub: {},
|
||||
};
|
||||
tools.forEach((tool) => {
|
||||
configContent[tool.name] = tool.config ?? {};
|
||||
if (tool.type === ToolType.LLAMAHUB) {
|
||||
configContent.llamahub[tool.name] = tool.config ?? {};
|
||||
}
|
||||
if (tool.type === ToolType.LOCAL) {
|
||||
configContent.local[tool.name] = tool.config ?? {};
|
||||
}
|
||||
});
|
||||
const configPath = path.join(root, "config");
|
||||
await makeDir(configPath);
|
||||
|
||||
+2
-1
@@ -20,7 +20,8 @@ export type TemplateVectorDB =
|
||||
| "pinecone"
|
||||
| "milvus"
|
||||
| "astra"
|
||||
| "qdrant";
|
||||
| "qdrant"
|
||||
| "chroma";
|
||||
export type TemplatePostInstallAction =
|
||||
| "none"
|
||||
| "VSCode"
|
||||
|
||||
@@ -105,7 +105,7 @@ export const installTSTemplate = async ({
|
||||
const enginePath = path.join(root, relativeEngineDestPath, "engine");
|
||||
|
||||
// copy vector db component
|
||||
console.log("\nUsing vector DB:", vectorDb, "\n");
|
||||
console.log("\nUsing vector DB:", vectorDb ?? "none", "\n");
|
||||
await copy("**", enginePath, {
|
||||
parents: true,
|
||||
cwd: path.join(compPath, "vectordbs", "typescript", vectorDb ?? "none"),
|
||||
|
||||
@@ -12,12 +12,16 @@ import { createApp } from "./create-app";
|
||||
import { getDataSources } from "./helpers/datasources";
|
||||
import { getPkgManager } from "./helpers/get-pkg-manager";
|
||||
import { isFolderEmpty } from "./helpers/is-folder-empty";
|
||||
import { initializeGlobalAgent } from "./helpers/proxy";
|
||||
import { runApp } from "./helpers/run-app";
|
||||
import { getTools } from "./helpers/tools";
|
||||
import { validateNpmName } from "./helpers/validate-pkg";
|
||||
import packageJson from "./package.json";
|
||||
import { QuestionArgs, askQuestions, onPromptState } from "./questions";
|
||||
|
||||
// Run the initialization function
|
||||
initializeGlobalAgent();
|
||||
|
||||
let projectPath: string = "";
|
||||
|
||||
const handleSigTerm = () => process.exit(0);
|
||||
|
||||
+2
-1
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "create-llama",
|
||||
"version": "0.1.1",
|
||||
"version": "0.1.8",
|
||||
"description": "Create LlamaIndex-powered apps with one command",
|
||||
"keywords": [
|
||||
"rag",
|
||||
@@ -52,6 +52,7 @@
|
||||
"cross-spawn": "7.0.3",
|
||||
"fast-glob": "3.3.1",
|
||||
"fs-extra": "11.2.0",
|
||||
"global-agent": "^3.0.0",
|
||||
"got": "10.7.0",
|
||||
"ollama": "^0.5.0",
|
||||
"ora": "^8.0.1",
|
||||
|
||||
Generated
+265
-145
File diff suppressed because it is too large
Load Diff
@@ -97,6 +97,7 @@ const getVectorDbChoices = (framework: TemplateFramework) => {
|
||||
{ title: "Milvus", value: "milvus" },
|
||||
{ title: "Astra", value: "astra" },
|
||||
{ title: "Qdrant", value: "qdrant" },
|
||||
{ title: "ChromaDB", value: "chroma" },
|
||||
];
|
||||
|
||||
const vectordbLang = framework === "fastapi" ? "python" : "typescript";
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
import os
|
||||
import yaml
|
||||
import importlib
|
||||
|
||||
from llama_index.core.tools.tool_spec.base import BaseToolSpec
|
||||
from llama_index.core.tools.function_tool import FunctionTool
|
||||
|
||||
|
||||
class ToolFactory:
|
||||
|
||||
@staticmethod
|
||||
def create_tool(tool_name: str, **kwargs) -> list[FunctionTool]:
|
||||
try:
|
||||
tool_package, tool_cls_name = tool_name.split(".")
|
||||
module_name = f"llama_index.tools.{tool_package}"
|
||||
module = importlib.import_module(module_name)
|
||||
tool_class = getattr(module, tool_cls_name)
|
||||
tool_spec: BaseToolSpec = tool_class(**kwargs)
|
||||
return tool_spec.to_tool_list()
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise ValueError(f"Unsupported tool: {tool_name}") from e
|
||||
except TypeError as e:
|
||||
raise ValueError(
|
||||
f"Could not create tool: {tool_name}. With config: {kwargs}"
|
||||
) from e
|
||||
|
||||
@staticmethod
|
||||
def from_env() -> list[FunctionTool]:
|
||||
tools = []
|
||||
if os.path.exists("config/tools.yaml"):
|
||||
with open("config/tools.yaml", "r") as f:
|
||||
tool_configs = yaml.safe_load(f)
|
||||
for name, config in tool_configs.items():
|
||||
tools += ToolFactory.create_tool(name, **config)
|
||||
return tools
|
||||
@@ -0,0 +1,56 @@
|
||||
import os
|
||||
import yaml
|
||||
import importlib
|
||||
|
||||
from llama_index.core.tools.tool_spec.base import BaseToolSpec
|
||||
from llama_index.core.tools.function_tool import FunctionTool
|
||||
|
||||
|
||||
class ToolType:
|
||||
LLAMAHUB = "llamahub"
|
||||
LOCAL = "local"
|
||||
|
||||
|
||||
class ToolFactory:
|
||||
|
||||
TOOL_SOURCE_PACKAGE_MAP = {
|
||||
ToolType.LLAMAHUB: "llama_index.tools",
|
||||
ToolType.LOCAL: "app.engine.tools",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def load_tools(tool_type: str, tool_name: str, config: dict) -> list[FunctionTool]:
|
||||
source_package = ToolFactory.TOOL_SOURCE_PACKAGE_MAP[tool_type]
|
||||
try:
|
||||
if "ToolSpec" in tool_name:
|
||||
tool_package, tool_cls_name = tool_name.split(".")
|
||||
module_name = f"{source_package}.{tool_package}"
|
||||
module = importlib.import_module(module_name)
|
||||
tool_class = getattr(module, tool_cls_name)
|
||||
tool_spec: BaseToolSpec = tool_class(**config)
|
||||
return tool_spec.to_tool_list()
|
||||
else:
|
||||
module = importlib.import_module(f"{source_package}.{tool_name}")
|
||||
tools = getattr(module, "tools")
|
||||
if not all(isinstance(tool, FunctionTool) for tool in tools):
|
||||
raise ValueError(
|
||||
f"The module {module} does not contain valid tools"
|
||||
)
|
||||
return tools
|
||||
except ImportError as e:
|
||||
raise ValueError(f"Failed to import tool {tool_name}: {e}")
|
||||
except AttributeError as e:
|
||||
raise ValueError(f"Failed to load tool {tool_name}: {e}")
|
||||
|
||||
@staticmethod
|
||||
def from_env() -> list[FunctionTool]:
|
||||
tools = []
|
||||
if os.path.exists("config/tools.yaml"):
|
||||
with open("config/tools.yaml", "r") as f:
|
||||
tool_configs = yaml.safe_load(f)
|
||||
for tool_type, config_entries in tool_configs.items():
|
||||
for tool_name, config in config_entries.items():
|
||||
tools.extend(
|
||||
ToolFactory.load_tools(tool_type, tool_name, config)
|
||||
)
|
||||
return tools
|
||||
@@ -0,0 +1,134 @@
|
||||
import os
|
||||
import logging
|
||||
import base64
|
||||
import uuid
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Tuple, Dict
|
||||
from llama_index.core.tools import FunctionTool
|
||||
from e2b_code_interpreter import CodeInterpreter
|
||||
from e2b_code_interpreter.models import Logs
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InterpreterExtraResult(BaseModel):
|
||||
type: str
|
||||
filename: str
|
||||
url: str
|
||||
|
||||
|
||||
class E2BToolOutput(BaseModel):
|
||||
is_error: bool
|
||||
logs: Logs
|
||||
results: List[InterpreterExtraResult] = []
|
||||
|
||||
|
||||
class E2BCodeInterpreter:
|
||||
|
||||
output_dir = "tool-output"
|
||||
|
||||
def __init__(self, api_key: str, filesever_url_prefix: str):
|
||||
self.api_key = api_key
|
||||
self.filesever_url_prefix = filesever_url_prefix
|
||||
|
||||
def get_output_path(self, filename: str) -> str:
|
||||
# if output directory doesn't exist, create it
|
||||
if not os.path.exists(self.output_dir):
|
||||
os.makedirs(self.output_dir, exist_ok=True)
|
||||
return os.path.join(self.output_dir, filename)
|
||||
|
||||
def save_to_disk(self, base64_data: str, ext: str) -> Dict:
|
||||
filename = f"{uuid.uuid4()}.{ext}" # generate a unique filename
|
||||
buffer = base64.b64decode(base64_data)
|
||||
output_path = self.get_output_path(filename)
|
||||
|
||||
try:
|
||||
with open(output_path, "wb") as file:
|
||||
file.write(buffer)
|
||||
except IOError as e:
|
||||
logger.error(f"Failed to write to file {output_path}: {str(e)}")
|
||||
raise e
|
||||
|
||||
logger.info(f"Saved file to {output_path}")
|
||||
|
||||
return {
|
||||
"outputPath": output_path,
|
||||
"filename": filename,
|
||||
}
|
||||
|
||||
def get_file_url(self, filename: str) -> str:
|
||||
return f"{self.filesever_url_prefix}/{self.output_dir}/{filename}"
|
||||
|
||||
def parse_result(self, result) -> List[InterpreterExtraResult]:
|
||||
"""
|
||||
The result could include multiple formats (e.g. png, svg, etc.) but encoded in base64
|
||||
We save each result to disk and return saved file metadata (extension, filename, url)
|
||||
"""
|
||||
if not result:
|
||||
return []
|
||||
|
||||
output = []
|
||||
|
||||
try:
|
||||
formats = result.formats()
|
||||
base64_data_arr = [result[format] for format in formats]
|
||||
|
||||
for ext, base64_data in zip(formats, base64_data_arr):
|
||||
if ext and base64_data:
|
||||
result = self.save_to_disk(base64_data, ext)
|
||||
filename = result["filename"]
|
||||
output.append(
|
||||
InterpreterExtraResult(
|
||||
type=ext, filename=filename, url=self.get_file_url(filename)
|
||||
)
|
||||
)
|
||||
except Exception as error:
|
||||
logger.error("Error when saving data to disk", error)
|
||||
|
||||
return output
|
||||
|
||||
def interpret(self, code: str) -> E2BToolOutput:
|
||||
with CodeInterpreter(api_key=self.api_key) as interpreter:
|
||||
logger.info(
|
||||
f"\n{'='*50}\n> Running following AI-generated code:\n{code}\n{'='*50}"
|
||||
)
|
||||
exec = interpreter.notebook.exec_cell(code)
|
||||
|
||||
if exec.error:
|
||||
output = E2BToolOutput(is_error=True, logs=[exec.error])
|
||||
else:
|
||||
if len(exec.results) == 0:
|
||||
output = E2BToolOutput(is_error=False, logs=exec.logs, results=[])
|
||||
else:
|
||||
results = self.parse_result(exec.results[0])
|
||||
output = E2BToolOutput(
|
||||
is_error=False, logs=exec.logs, results=results
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def code_interpret(code: str) -> Dict:
|
||||
"""
|
||||
Execute python code in a Jupyter notebook cell and return any result, stdout, stderr, display_data, and error.
|
||||
"""
|
||||
api_key = os.getenv("E2B_API_KEY")
|
||||
filesever_url_prefix = os.getenv("FILESERVER_URL_PREFIX")
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"E2B_API_KEY key is required to run code interpreter. Get it here: https://e2b.dev/docs/getting-started/api-key"
|
||||
)
|
||||
if not filesever_url_prefix:
|
||||
raise ValueError(
|
||||
"FILESERVER_URL_PREFIX is required to display file output from sandbox"
|
||||
)
|
||||
|
||||
interpreter = E2BCodeInterpreter(
|
||||
api_key=api_key, filesever_url_prefix=filesever_url_prefix
|
||||
)
|
||||
output = interpreter.interpret(code)
|
||||
return output.dict()
|
||||
|
||||
|
||||
# Specify as functions tools to be loaded by the ToolFactory
|
||||
tools = [FunctionTool.from_defaults(code_interpret)]
|
||||
@@ -0,0 +1,72 @@
|
||||
"""Open Meteo weather map tool spec."""
|
||||
|
||||
import logging
|
||||
import requests
|
||||
import pytz
|
||||
from llama_index.core.tools import FunctionTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenMeteoWeather:
|
||||
geo_api = "https://geocoding-api.open-meteo.com/v1"
|
||||
weather_api = "https://api.open-meteo.com/v1"
|
||||
|
||||
@classmethod
|
||||
def _get_geo_location(cls, location: str) -> dict:
|
||||
"""Get geo location from location name."""
|
||||
params = {"name": location, "count": 10, "language": "en", "format": "json"}
|
||||
response = requests.get(f"{cls.geo_api}/search", params=params)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to fetch geo location: {response.status_code}")
|
||||
else:
|
||||
data = response.json()
|
||||
result = data["results"][0]
|
||||
geo_location = {
|
||||
"id": result["id"],
|
||||
"name": result["name"],
|
||||
"latitude": result["latitude"],
|
||||
"longitude": result["longitude"],
|
||||
}
|
||||
return geo_location
|
||||
|
||||
@classmethod
|
||||
def get_weather_information(cls, location: str) -> dict:
|
||||
"""Use this function to get the weather of any given location.
|
||||
Note that the weather code should follow WMO Weather interpretation codes (WW):
|
||||
0: Clear sky
|
||||
1, 2, 3: Mainly clear, partly cloudy, and overcast
|
||||
45, 48: Fog and depositing rime fog
|
||||
51, 53, 55: Drizzle: Light, moderate, and dense intensity
|
||||
56, 57: Freezing Drizzle: Light and dense intensity
|
||||
61, 63, 65: Rain: Slight, moderate and heavy intensity
|
||||
66, 67: Freezing Rain: Light and heavy intensity
|
||||
71, 73, 75: Snow fall: Slight, moderate, and heavy intensity
|
||||
77: Snow grains
|
||||
80, 81, 82: Rain showers: Slight, moderate, and violent
|
||||
85, 86: Snow showers slight and heavy
|
||||
95: Thunderstorm: Slight or moderate
|
||||
96, 99: Thunderstorm with slight and heavy hail
|
||||
"""
|
||||
logger.info(
|
||||
f"Calling open-meteo api to get weather information of location: {location}"
|
||||
)
|
||||
geo_location = cls._get_geo_location(location)
|
||||
timezone = pytz.timezone("UTC").zone
|
||||
params = {
|
||||
"latitude": geo_location["latitude"],
|
||||
"longitude": geo_location["longitude"],
|
||||
"current": "temperature_2m,weather_code",
|
||||
"hourly": "temperature_2m,weather_code",
|
||||
"daily": "weather_code",
|
||||
"timezone": timezone,
|
||||
}
|
||||
response = requests.get(f"{cls.weather_api}/forecast", params=params)
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Failed to fetch weather information: {response.status_code}"
|
||||
)
|
||||
return response.json()
|
||||
|
||||
|
||||
tools = [FunctionTool.from_defaults(OpenMeteoWeather.get_weather_information)]
|
||||
@@ -1,12 +1,12 @@
|
||||
import { BaseToolWithCall, OpenAIAgent, QueryEngineTool } from "llamaindex";
|
||||
import { ToolsFactory } from "llamaindex/tools/ToolsFactory";
|
||||
import fs from "node:fs/promises";
|
||||
import path from "node:path";
|
||||
import { getDataSource } from "./index";
|
||||
import { STORAGE_CACHE_DIR } from "./shared";
|
||||
import { createTools } from "./tools";
|
||||
|
||||
export async function createChatEngine() {
|
||||
let tools: BaseToolWithCall[] = [];
|
||||
const tools: BaseToolWithCall[] = [];
|
||||
|
||||
// Add a query engine tool if we have a data source
|
||||
// Delete this code if you don't have a data source
|
||||
@@ -23,15 +23,20 @@ export async function createChatEngine() {
|
||||
);
|
||||
}
|
||||
|
||||
const configFile = path.join("config", "tools.json");
|
||||
let toolConfig: any;
|
||||
try {
|
||||
// add tools from config file if it exists
|
||||
const config = JSON.parse(
|
||||
await fs.readFile(path.join("config", "tools.json"), "utf8"),
|
||||
);
|
||||
tools = tools.concat(await ToolsFactory.createTools(config));
|
||||
} catch {}
|
||||
toolConfig = JSON.parse(await fs.readFile(configFile, "utf8"));
|
||||
} catch (e) {
|
||||
console.info(`Could not read ${configFile} file. Using no tools.`);
|
||||
}
|
||||
if (toolConfig) {
|
||||
tools.push(...(await createTools(toolConfig)));
|
||||
}
|
||||
|
||||
return new OpenAIAgent({
|
||||
tools,
|
||||
systemPrompt: process.env.SYSTEM_PROMPT,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
import { BaseToolWithCall } from "llamaindex";
|
||||
import { ToolsFactory } from "llamaindex/tools/ToolsFactory";
|
||||
import { InterpreterTool, InterpreterToolParams } from "./interpreter";
|
||||
import { WeatherTool, WeatherToolParams } from "./weather";
|
||||
|
||||
type ToolCreator = (config: unknown) => BaseToolWithCall;
|
||||
|
||||
export async function createTools(toolConfig: {
|
||||
local: Record<string, unknown>;
|
||||
llamahub: any;
|
||||
}): Promise<BaseToolWithCall[]> {
|
||||
// add local tools from the 'tools' folder (if configured)
|
||||
const tools = createLocalTools(toolConfig.local);
|
||||
// add tools from LlamaIndexTS (if configured)
|
||||
tools.push(...(await ToolsFactory.createTools(toolConfig.llamahub)));
|
||||
return tools;
|
||||
}
|
||||
|
||||
const toolFactory: Record<string, ToolCreator> = {
|
||||
weather: (config: unknown) => {
|
||||
return new WeatherTool(config as WeatherToolParams);
|
||||
},
|
||||
interpreter: (config: unknown) => {
|
||||
return new InterpreterTool(config as InterpreterToolParams);
|
||||
},
|
||||
};
|
||||
|
||||
function createLocalTools(
|
||||
localConfig: Record<string, unknown>,
|
||||
): BaseToolWithCall[] {
|
||||
const tools: BaseToolWithCall[] = [];
|
||||
|
||||
Object.keys(localConfig).forEach((key) => {
|
||||
if (key in toolFactory) {
|
||||
const toolConfig = localConfig[key];
|
||||
const tool = toolFactory[key](toolConfig);
|
||||
tools.push(tool);
|
||||
}
|
||||
});
|
||||
|
||||
return tools;
|
||||
}
|
||||
@@ -0,0 +1,174 @@
|
||||
import { CodeInterpreter, Logs, Result } from "@e2b/code-interpreter";
|
||||
import type { JSONSchemaType } from "ajv";
|
||||
import fs from "fs";
|
||||
import { BaseTool, ToolMetadata } from "llamaindex";
|
||||
import crypto from "node:crypto";
|
||||
import path from "node:path";
|
||||
|
||||
export type InterpreterParameter = {
|
||||
code: string;
|
||||
};
|
||||
|
||||
export type InterpreterToolParams = {
|
||||
metadata?: ToolMetadata<JSONSchemaType<InterpreterParameter>>;
|
||||
apiKey?: string;
|
||||
fileServerURLPrefix?: string;
|
||||
};
|
||||
|
||||
export type InterpreterToolOuput = {
|
||||
isError: boolean;
|
||||
logs: Logs;
|
||||
extraResult: InterpreterExtraResult[];
|
||||
};
|
||||
|
||||
type InterpreterExtraType =
|
||||
| "html"
|
||||
| "markdown"
|
||||
| "svg"
|
||||
| "png"
|
||||
| "jpeg"
|
||||
| "pdf"
|
||||
| "latex"
|
||||
| "json"
|
||||
| "javascript";
|
||||
|
||||
export type InterpreterExtraResult = {
|
||||
type: InterpreterExtraType;
|
||||
filename: string;
|
||||
url: string;
|
||||
};
|
||||
|
||||
const DEFAULT_META_DATA: ToolMetadata<JSONSchemaType<InterpreterParameter>> = {
|
||||
name: "interpreter",
|
||||
description:
|
||||
"Execute python code in a Jupyter notebook cell and return any result, stdout, stderr, display_data, and error.",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
code: {
|
||||
type: "string",
|
||||
description: "The python code to execute in a single cell.",
|
||||
},
|
||||
},
|
||||
required: ["code"],
|
||||
},
|
||||
};
|
||||
|
||||
export class InterpreterTool implements BaseTool<InterpreterParameter> {
|
||||
private readonly outputDir = "tool-output";
|
||||
private apiKey?: string;
|
||||
private fileServerURLPrefix?: string;
|
||||
metadata: ToolMetadata<JSONSchemaType<InterpreterParameter>>;
|
||||
codeInterpreter?: CodeInterpreter;
|
||||
|
||||
constructor(params?: InterpreterToolParams) {
|
||||
this.metadata = params?.metadata || DEFAULT_META_DATA;
|
||||
this.apiKey = params?.apiKey || process.env.E2B_API_KEY;
|
||||
this.fileServerURLPrefix =
|
||||
params?.fileServerURLPrefix || process.env.FILESERVER_URL_PREFIX;
|
||||
|
||||
if (!this.apiKey) {
|
||||
throw new Error(
|
||||
"E2B_API_KEY key is required to run code interpreter. Get it here: https://e2b.dev/docs/getting-started/api-key",
|
||||
);
|
||||
}
|
||||
if (!this.fileServerURLPrefix) {
|
||||
throw new Error(
|
||||
"FILESERVER_URL_PREFIX is required to display file output from sandbox",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
public async initInterpreter() {
|
||||
if (!this.codeInterpreter) {
|
||||
this.codeInterpreter = await CodeInterpreter.create({
|
||||
apiKey: this.apiKey,
|
||||
});
|
||||
}
|
||||
return this.codeInterpreter;
|
||||
}
|
||||
|
||||
public async codeInterpret(code: string): Promise<InterpreterToolOuput> {
|
||||
console.log(
|
||||
`\n${"=".repeat(50)}\n> Running following AI-generated code:\n${code}\n${"=".repeat(50)}`,
|
||||
);
|
||||
const interpreter = await this.initInterpreter();
|
||||
const exec = await interpreter.notebook.execCell(code);
|
||||
if (exec.error) console.error("[Code Interpreter error]", exec.error);
|
||||
const extraResult = await this.getExtraResult(exec.results[0]);
|
||||
const result: InterpreterToolOuput = {
|
||||
isError: !!exec.error,
|
||||
logs: exec.logs,
|
||||
extraResult,
|
||||
};
|
||||
return result;
|
||||
}
|
||||
|
||||
async call(input: InterpreterParameter): Promise<InterpreterToolOuput> {
|
||||
const result = await this.codeInterpret(input.code);
|
||||
await this.codeInterpreter?.close();
|
||||
return result;
|
||||
}
|
||||
|
||||
private async getExtraResult(
|
||||
res?: Result,
|
||||
): Promise<InterpreterExtraResult[]> {
|
||||
if (!res) return [];
|
||||
const output: InterpreterExtraResult[] = [];
|
||||
|
||||
try {
|
||||
const formats = res.formats(); // formats available for the result. Eg: ['png', ...]
|
||||
const base64DataArr = formats.map((f) => res[f as keyof Result]); // get base64 data for each format
|
||||
|
||||
// save base64 data to file and return the url
|
||||
for (let i = 0; i < formats.length; i++) {
|
||||
const ext = formats[i];
|
||||
const base64Data = base64DataArr[i];
|
||||
if (ext && base64Data) {
|
||||
const { filename } = this.saveToDisk(base64Data, ext);
|
||||
output.push({
|
||||
type: ext as InterpreterExtraType,
|
||||
filename,
|
||||
url: this.getFileUrl(filename),
|
||||
});
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error when saving data to disk", error);
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
// Consider saving to cloud storage instead but it may cost more for you
|
||||
// See: https://e2b.dev/docs/sandbox/api/filesystem#write-to-file
|
||||
private saveToDisk(
|
||||
base64Data: string,
|
||||
ext: string,
|
||||
): {
|
||||
outputPath: string;
|
||||
filename: string;
|
||||
} {
|
||||
const filename = `${crypto.randomUUID()}.${ext}`; // generate a unique filename
|
||||
const buffer = Buffer.from(base64Data, "base64");
|
||||
const outputPath = this.getOutputPath(filename);
|
||||
fs.writeFileSync(outputPath, buffer);
|
||||
console.log(`Saved file to ${outputPath}`);
|
||||
return {
|
||||
outputPath,
|
||||
filename,
|
||||
};
|
||||
}
|
||||
|
||||
private getOutputPath(filename: string): string {
|
||||
// if outputDir doesn't exist, create it
|
||||
if (!fs.existsSync(this.outputDir)) {
|
||||
fs.mkdirSync(this.outputDir, { recursive: true });
|
||||
}
|
||||
return path.join(this.outputDir, filename);
|
||||
}
|
||||
|
||||
private getFileUrl(filename: string): string {
|
||||
return `${this.fileServerURLPrefix}/${this.outputDir}/${filename}`;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
import type { JSONSchemaType } from "ajv";
|
||||
import { BaseTool, ToolMetadata } from "llamaindex";
|
||||
|
||||
interface GeoLocation {
|
||||
id: string;
|
||||
name: string;
|
||||
latitude: number;
|
||||
longitude: number;
|
||||
}
|
||||
|
||||
export type WeatherParameter = {
|
||||
location: string;
|
||||
};
|
||||
|
||||
export type WeatherToolParams = {
|
||||
metadata?: ToolMetadata<JSONSchemaType<WeatherParameter>>;
|
||||
};
|
||||
|
||||
const DEFAULT_META_DATA: ToolMetadata<JSONSchemaType<WeatherParameter>> = {
|
||||
name: "get_weather_information",
|
||||
description: `
|
||||
Use this function to get the weather of any given location.
|
||||
Note that the weather code should follow WMO Weather interpretation codes (WW):
|
||||
0: Clear sky
|
||||
1, 2, 3: Mainly clear, partly cloudy, and overcast
|
||||
45, 48: Fog and depositing rime fog
|
||||
51, 53, 55: Drizzle: Light, moderate, and dense intensity
|
||||
56, 57: Freezing Drizzle: Light and dense intensity
|
||||
61, 63, 65: Rain: Slight, moderate and heavy intensity
|
||||
66, 67: Freezing Rain: Light and heavy intensity
|
||||
71, 73, 75: Snow fall: Slight, moderate, and heavy intensity
|
||||
77: Snow grains
|
||||
80, 81, 82: Rain showers: Slight, moderate, and violent
|
||||
85, 86: Snow showers slight and heavy
|
||||
95: Thunderstorm: Slight or moderate
|
||||
96, 99: Thunderstorm with slight and heavy hail
|
||||
`,
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
location: {
|
||||
type: "string",
|
||||
description: "The location to get the weather information",
|
||||
},
|
||||
},
|
||||
required: ["location"],
|
||||
},
|
||||
};
|
||||
|
||||
export class WeatherTool implements BaseTool<WeatherParameter> {
|
||||
metadata: ToolMetadata<JSONSchemaType<WeatherParameter>>;
|
||||
|
||||
private getGeoLocation = async (location: string): Promise<GeoLocation> => {
|
||||
const apiUrl = `https://geocoding-api.open-meteo.com/v1/search?name=${location}&count=10&language=en&format=json`;
|
||||
const response = await fetch(apiUrl);
|
||||
const data = await response.json();
|
||||
const { id, name, latitude, longitude } = data.results[0];
|
||||
return { id, name, latitude, longitude };
|
||||
};
|
||||
|
||||
private getWeatherByLocation = async (location: string) => {
|
||||
console.log(
|
||||
"Calling open-meteo api to get weather information of location:",
|
||||
location,
|
||||
);
|
||||
const { latitude, longitude } = await this.getGeoLocation(location);
|
||||
const timezone = Intl.DateTimeFormat().resolvedOptions().timeZone;
|
||||
const apiUrl = `https://api.open-meteo.com/v1/forecast?latitude=${latitude}&longitude=${longitude}¤t=temperature_2m,weather_code&hourly=temperature_2m,weather_code&daily=weather_code&timezone=${timezone}`;
|
||||
const response = await fetch(apiUrl);
|
||||
const data = await response.json();
|
||||
return data;
|
||||
};
|
||||
|
||||
constructor(params?: WeatherToolParams) {
|
||||
this.metadata = params?.metadata || DEFAULT_META_DATA;
|
||||
}
|
||||
|
||||
async call(input: WeatherParameter) {
|
||||
return await this.getWeatherByLocation(input.location);
|
||||
}
|
||||
}
|
||||
@@ -16,5 +16,6 @@ export async function createChatEngine() {
|
||||
return new ContextChatEngine({
|
||||
chatModel: Settings.llm,
|
||||
retriever,
|
||||
systemPrompt: process.env.SYSTEM_PROMPT,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import os
|
||||
import logging
|
||||
from llama_parse import LlamaParse
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileLoaderConfig(BaseModel):
|
||||
data_dir: str = "data"
|
||||
@@ -27,11 +30,28 @@ def llama_parse_parser():
|
||||
def get_file_documents(config: FileLoaderConfig):
|
||||
from llama_index.core.readers import SimpleDirectoryReader
|
||||
|
||||
reader = SimpleDirectoryReader(
|
||||
config.data_dir,
|
||||
recursive=True,
|
||||
)
|
||||
if config.use_llama_parse:
|
||||
parser = llama_parse_parser()
|
||||
reader.file_extractor = {".pdf": parser}
|
||||
return reader.load_data()
|
||||
try:
|
||||
reader = SimpleDirectoryReader(
|
||||
config.data_dir,
|
||||
recursive=True,
|
||||
filename_as_id=True,
|
||||
)
|
||||
if config.use_llama_parse:
|
||||
parser = llama_parse_parser()
|
||||
reader.file_extractor = {".pdf": parser}
|
||||
return reader.load_data()
|
||||
except ValueError as e:
|
||||
import sys, traceback
|
||||
|
||||
# Catch the error if the data dir is empty
|
||||
# and return as empty document list
|
||||
_, _, exc_traceback = sys.exc_info()
|
||||
function_name = traceback.extract_tb(exc_traceback)[-1].name
|
||||
if function_name == "_add_files":
|
||||
logger.warning(
|
||||
f"Failed to load file documents, error message: {e} . Return as empty document list."
|
||||
)
|
||||
return []
|
||||
else:
|
||||
# Raise the error if it is not the case of empty data dir
|
||||
raise e
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
import os
|
||||
import logging
|
||||
from llama_index.core.storage import StorageContext
|
||||
from llama_index.core.indices import VectorStoreIndex
|
||||
from llama_index.vector_stores.astra_db import AstraDBVectorStore
|
||||
from app.settings import init_settings
|
||||
from app.engine.loaders import get_documents
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def generate_datasource():
|
||||
init_settings()
|
||||
logger.info("Creating new index")
|
||||
documents = get_documents()
|
||||
store = AstraDBVectorStore(
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_ENDPOINT"],
|
||||
collection_name=os.environ["ASTRA_DB_COLLECTION"],
|
||||
embedding_dimension=int(os.environ["EMBEDDING_DIM"]),
|
||||
)
|
||||
storage_context = StorageContext.from_defaults(vector_store=store)
|
||||
VectorStoreIndex.from_documents(
|
||||
documents,
|
||||
storage_context=storage_context,
|
||||
show_progress=True, # this will show you a progress bar as the embeddings are created
|
||||
)
|
||||
logger.info(f"Successfully created embeddings in the AstraDB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_datasource()
|
||||
@@ -1,21 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from llama_index.core.indices import VectorStoreIndex
|
||||
from llama_index.vector_stores.astra_db import AstraDBVectorStore
|
||||
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
def get_index():
|
||||
logger.info("Connecting to index from AstraDB...")
|
||||
store = AstraDBVectorStore(
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_ENDPOINT"],
|
||||
collection_name=os.environ["ASTRA_DB_COLLECTION"],
|
||||
embedding_dimension=int(os.environ["EMBEDDING_DIM"]),
|
||||
)
|
||||
index = VectorStoreIndex.from_vector_store(store)
|
||||
logger.info("Finished connecting to index from AstraDB.")
|
||||
return index
|
||||
@@ -0,0 +1,20 @@
|
||||
import os
|
||||
from llama_index.vector_stores.astra_db import AstraDBVectorStore
|
||||
|
||||
|
||||
def get_vector_store():
|
||||
endpoint = os.getenv("ASTRA_DB_ENDPOINT")
|
||||
token = os.getenv("ASTRA_DB_APPLICATION_TOKEN")
|
||||
collection = os.getenv("ASTRA_DB_COLLECTION")
|
||||
if not endpoint or not token or not collection:
|
||||
raise ValueError(
|
||||
"Please config ASTRA_DB_ENDPOINT, ASTRA_DB_APPLICATION_TOKEN and ASTRA_DB_COLLECTION"
|
||||
" to your environment variables or config them in the .env file"
|
||||
)
|
||||
store = AstraDBVectorStore(
|
||||
token=token,
|
||||
api_endpoint=endpoint,
|
||||
collection_name=collection,
|
||||
embedding_dimension=int(os.getenv("EMBEDDING_DIM")),
|
||||
)
|
||||
return store
|
||||
@@ -0,0 +1,24 @@
|
||||
import os
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
|
||||
|
||||
def get_vector_store():
|
||||
collection_name = os.getenv("CHROMA_COLLECTION", "default")
|
||||
chroma_path = os.getenv("CHROMA_PATH")
|
||||
# if CHROMA_PATH is set, use a local ChromaVectorStore from the path
|
||||
# otherwise, use a remote ChromaVectorStore (ChromaDB Cloud is not supported yet)
|
||||
if chroma_path:
|
||||
store = ChromaVectorStore.from_params(
|
||||
persist_dir=chroma_path, collection_name=collection_name
|
||||
)
|
||||
else:
|
||||
if not os.getenv("CHROMA_HOST") or not os.getenv("CHROMA_PORT"):
|
||||
raise ValueError(
|
||||
"Please provide either CHROMA_PATH or CHROMA_HOST and CHROMA_PORT"
|
||||
)
|
||||
store = ChromaVectorStore.from_params(
|
||||
host=os.getenv("CHROMA_HOST"),
|
||||
port=int(os.getenv("CHROMA_PORT")),
|
||||
collection_name=collection_name,
|
||||
)
|
||||
return store
|
||||
@@ -1,39 +0,0 @@
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
import os
|
||||
import logging
|
||||
from llama_index.core.storage import StorageContext
|
||||
from llama_index.core.indices import VectorStoreIndex
|
||||
from llama_index.vector_stores.milvus import MilvusVectorStore
|
||||
from app.settings import init_settings
|
||||
from app.engine.loaders import get_documents
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def generate_datasource():
|
||||
init_settings()
|
||||
logger.info("Creating new index")
|
||||
# load the documents and create the index
|
||||
documents = get_documents()
|
||||
store = MilvusVectorStore(
|
||||
uri=os.environ["MILVUS_ADDRESS"],
|
||||
user=os.getenv("MILVUS_USERNAME"),
|
||||
password=os.getenv("MILVUS_PASSWORD"),
|
||||
collection_name=os.getenv("MILVUS_COLLECTION"),
|
||||
dim=int(os.getenv("EMBEDDING_DIM")),
|
||||
)
|
||||
storage_context = StorageContext.from_defaults(vector_store=store)
|
||||
VectorStoreIndex.from_documents(
|
||||
documents,
|
||||
storage_context=storage_context,
|
||||
show_progress=True, # this will show you a progress bar as the embeddings are created
|
||||
)
|
||||
logger.info(f"Successfully created embeddings in the Milvus")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_datasource()
|
||||
@@ -1,22 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from llama_index.core.indices import VectorStoreIndex
|
||||
from llama_index.vector_stores.milvus import MilvusVectorStore
|
||||
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
def get_index():
|
||||
logger.info("Connecting to index from Milvus...")
|
||||
store = MilvusVectorStore(
|
||||
uri=os.getenv("MILVUS_ADDRESS"),
|
||||
user=os.getenv("MILVUS_USERNAME"),
|
||||
password=os.getenv("MILVUS_PASSWORD"),
|
||||
collection_name=os.getenv("MILVUS_COLLECTION"),
|
||||
dim=int(os.getenv("EMBEDDING_DIM")),
|
||||
)
|
||||
index = VectorStoreIndex.from_vector_store(store)
|
||||
logger.info("Finished connecting to index from Milvus.")
|
||||
return index
|
||||
@@ -0,0 +1,20 @@
|
||||
import os
|
||||
from llama_index.vector_stores.milvus import MilvusVectorStore
|
||||
|
||||
|
||||
def get_vector_store():
|
||||
address = os.getenv("MILVUS_ADDRESS")
|
||||
collection = os.getenv("MILVUS_COLLECTION")
|
||||
if not address or not collection:
|
||||
raise ValueError(
|
||||
"Please set MILVUS_ADDRESS and MILVUS_COLLECTION to your environment variables"
|
||||
" or config them in the .env file"
|
||||
)
|
||||
store = MilvusVectorStore(
|
||||
uri=address,
|
||||
user=os.getenv("MILVUS_USERNAME"),
|
||||
password=os.getenv("MILVUS_PASSWORD"),
|
||||
collection_name=collection,
|
||||
dim=int(os.getenv("EMBEDDING_DIM")),
|
||||
)
|
||||
return store
|
||||
@@ -1,43 +0,0 @@
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
import os
|
||||
import logging
|
||||
from llama_index.core.storage import StorageContext
|
||||
from llama_index.core.indices import VectorStoreIndex
|
||||
from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch
|
||||
from app.settings import init_settings
|
||||
from app.engine.loaders import get_documents
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def generate_datasource():
|
||||
init_settings()
|
||||
logger.info("Creating new index")
|
||||
# load the documents and create the index
|
||||
documents = get_documents()
|
||||
store = MongoDBAtlasVectorSearch(
|
||||
db_name=os.environ["MONGODB_DATABASE"],
|
||||
collection_name=os.environ["MONGODB_VECTORS"],
|
||||
index_name=os.environ["MONGODB_VECTOR_INDEX"],
|
||||
)
|
||||
storage_context = StorageContext.from_defaults(vector_store=store)
|
||||
VectorStoreIndex.from_documents(
|
||||
documents,
|
||||
storage_context=storage_context,
|
||||
show_progress=True, # this will show you a progress bar as the embeddings are created
|
||||
)
|
||||
logger.info(
|
||||
f"Successfully created embeddings in the MongoDB collection {os.environ['MONGODB_VECTORS']}"
|
||||
)
|
||||
logger.info(
|
||||
"""IMPORTANT: You can't query your index yet because you need to create a vector search index in MongoDB's UI now.
|
||||
See https://github.com/run-llama/mongodb-demo/tree/main?tab=readme-ov-file#create-a-vector-search-index"""
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_datasource()
|
||||
@@ -1,20 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from llama_index.core.indices import VectorStoreIndex
|
||||
from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch
|
||||
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
def get_index():
|
||||
logger.info("Connecting to index from MongoDB...")
|
||||
store = MongoDBAtlasVectorSearch(
|
||||
db_name=os.environ["MONGODB_DATABASE"],
|
||||
collection_name=os.environ["MONGODB_VECTORS"],
|
||||
index_name=os.environ["MONGODB_VECTOR_INDEX"],
|
||||
)
|
||||
index = VectorStoreIndex.from_vector_store(store)
|
||||
logger.info("Finished connecting to index from MongoDB.")
|
||||
return index
|
||||
@@ -0,0 +1,20 @@
|
||||
import os
|
||||
from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch
|
||||
|
||||
|
||||
def get_vector_store():
|
||||
db_uri = os.getenv("MONGODB_URI")
|
||||
db_name = os.getenv("MONGODB_DATABASE")
|
||||
collection_name = os.getenv("MONGODB_VECTORS")
|
||||
index_name = os.getenv("MONGODB_VECTOR_INDEX")
|
||||
if not db_uri or not db_name or not collection_name or not index_name:
|
||||
raise ValueError(
|
||||
"Please set MONGODB_URI, MONGODB_DATABASE, MONGODB_VECTORS, and MONGODB_VECTOR_INDEX"
|
||||
" to your environment variables or config them in .env file"
|
||||
)
|
||||
store = MongoDBAtlasVectorSearch(
|
||||
db_name=db_name,
|
||||
collection_name=collection_name,
|
||||
index_name=index_name,
|
||||
)
|
||||
return store
|
||||
@@ -1 +0,0 @@
|
||||
STORAGE_DIR = "storage" # directory to cache the generated index
|
||||
@@ -2,11 +2,11 @@ from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
import os
|
||||
import logging
|
||||
from llama_index.core.indices import (
|
||||
VectorStoreIndex,
|
||||
)
|
||||
from app.engine.constants import STORAGE_DIR
|
||||
from app.engine.loaders import get_documents
|
||||
from app.settings import init_settings
|
||||
|
||||
@@ -18,14 +18,15 @@ logger = logging.getLogger()
|
||||
def generate_datasource():
|
||||
init_settings()
|
||||
logger.info("Creating new index")
|
||||
storage_dir = os.environ.get("STORAGE_DIR", "storage")
|
||||
# load the documents and create the index
|
||||
documents = get_documents()
|
||||
index = VectorStoreIndex.from_documents(
|
||||
documents,
|
||||
)
|
||||
# store it for later
|
||||
index.storage_context.persist(STORAGE_DIR)
|
||||
logger.info(f"Finished creating new index. Stored in {STORAGE_DIR}")
|
||||
index.storage_context.persist(storage_dir)
|
||||
logger.info(f"Finished creating new index. Stored in {storage_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,20 +1,30 @@
|
||||
import logging
|
||||
import os
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
|
||||
from app.engine.constants import STORAGE_DIR
|
||||
from cachetools import cached, TTLCache
|
||||
from llama_index.core.storage import StorageContext
|
||||
from llama_index.core.indices import load_index_from_storage
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
@cached(
|
||||
TTLCache(maxsize=10, ttl=timedelta(minutes=5).total_seconds()),
|
||||
key=lambda *args, **kwargs: "global_storage_context",
|
||||
)
|
||||
def get_storage_context(persist_dir: str) -> StorageContext:
|
||||
return StorageContext.from_defaults(persist_dir=persist_dir)
|
||||
|
||||
|
||||
def get_index():
|
||||
storage_dir = os.getenv("STORAGE_DIR", "storage")
|
||||
# check if storage already exists
|
||||
if not os.path.exists(STORAGE_DIR):
|
||||
if not os.path.exists(storage_dir):
|
||||
return None
|
||||
# load the existing index
|
||||
logger.info(f"Loading index from {STORAGE_DIR}...")
|
||||
storage_context = StorageContext.from_defaults(persist_dir=STORAGE_DIR)
|
||||
logger.info(f"Loading index from {storage_dir}...")
|
||||
storage_context = get_storage_context(storage_dir)
|
||||
index = load_index_from_storage(storage_context)
|
||||
logger.info(f"Finished loading index from {STORAGE_DIR}")
|
||||
logger.info(f"Finished loading index from {storage_dir}")
|
||||
return index
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
PGVECTOR_SCHEMA = "public"
|
||||
PGVECTOR_TABLE = "llamaindex_embedding"
|
||||
@@ -1,35 +0,0 @@
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
import logging
|
||||
from llama_index.core.indices import VectorStoreIndex
|
||||
from llama_index.core.storage import StorageContext
|
||||
|
||||
from app.engine.loaders import get_documents
|
||||
from app.settings import init_settings
|
||||
from app.engine.utils import init_pg_vector_store_from_env
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def generate_datasource():
|
||||
init_settings()
|
||||
logger.info("Creating new index")
|
||||
# load the documents and create the index
|
||||
documents = get_documents()
|
||||
store = init_pg_vector_store_from_env()
|
||||
storage_context = StorageContext.from_defaults(vector_store=store)
|
||||
VectorStoreIndex.from_documents(
|
||||
documents,
|
||||
storage_context=storage_context,
|
||||
show_progress=True, # this will show you a progress bar as the embeddings are created
|
||||
)
|
||||
logger.info(
|
||||
f"Successfully created embeddings in the PG vector store, schema={store.schema_name} table={store.table_name}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_datasource()
|
||||
@@ -1,13 +0,0 @@
|
||||
import logging
|
||||
from llama_index.core.indices.vector_store import VectorStoreIndex
|
||||
from app.engine.utils import init_pg_vector_store_from_env
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
def get_index():
|
||||
logger.info("Connecting to index from PGVector...")
|
||||
store = init_pg_vector_store_from_env()
|
||||
index = VectorStoreIndex.from_vector_store(store)
|
||||
logger.info("Finished connecting to index from PGVector.")
|
||||
return index
|
||||
@@ -1,27 +0,0 @@
|
||||
import os
|
||||
from llama_index.vector_stores.postgres import PGVectorStore
|
||||
from urllib.parse import urlparse
|
||||
from app.engine.constants import PGVECTOR_SCHEMA, PGVECTOR_TABLE
|
||||
|
||||
|
||||
def init_pg_vector_store_from_env():
|
||||
original_conn_string = os.environ.get("PG_CONNECTION_STRING")
|
||||
if original_conn_string is None or original_conn_string == "":
|
||||
raise ValueError("PG_CONNECTION_STRING environment variable is not set.")
|
||||
|
||||
# The PGVectorStore requires both two connection strings, one for psycopg2 and one for asyncpg
|
||||
# Update the configured scheme with the psycopg2 and asyncpg schemes
|
||||
original_scheme = urlparse(original_conn_string).scheme + "://"
|
||||
conn_string = original_conn_string.replace(
|
||||
original_scheme, "postgresql+psycopg2://"
|
||||
)
|
||||
async_conn_string = original_conn_string.replace(
|
||||
original_scheme, "postgresql+asyncpg://"
|
||||
)
|
||||
|
||||
return PGVectorStore(
|
||||
connection_string=conn_string,
|
||||
async_connection_string=async_conn_string,
|
||||
schema_name=PGVECTOR_SCHEMA,
|
||||
table_name=PGVECTOR_TABLE,
|
||||
)
|
||||
@@ -0,0 +1,37 @@
|
||||
import os
|
||||
from llama_index.vector_stores.postgres import PGVectorStore
|
||||
from urllib.parse import urlparse
|
||||
|
||||
PGVECTOR_SCHEMA = "public"
|
||||
PGVECTOR_TABLE = "llamaindex_embedding"
|
||||
|
||||
vector_store: PGVectorStore = None
|
||||
|
||||
|
||||
def get_vector_store():
|
||||
global vector_store
|
||||
|
||||
if vector_store is None:
|
||||
original_conn_string = os.environ.get("PG_CONNECTION_STRING")
|
||||
if original_conn_string is None or original_conn_string == "":
|
||||
raise ValueError("PG_CONNECTION_STRING environment variable is not set.")
|
||||
|
||||
# The PGVectorStore requires both two connection strings, one for psycopg2 and one for asyncpg
|
||||
# Update the configured scheme with the psycopg2 and asyncpg schemes
|
||||
original_scheme = urlparse(original_conn_string).scheme + "://"
|
||||
conn_string = original_conn_string.replace(
|
||||
original_scheme, "postgresql+psycopg2://"
|
||||
)
|
||||
async_conn_string = original_conn_string.replace(
|
||||
original_scheme, "postgresql+asyncpg://"
|
||||
)
|
||||
|
||||
vector_store = PGVectorStore(
|
||||
connection_string=conn_string,
|
||||
async_connection_string=async_conn_string,
|
||||
schema_name=PGVECTOR_SCHEMA,
|
||||
table_name=PGVECTOR_TABLE,
|
||||
embed_dim=int(os.environ.get("EMBEDDING_DIM", 1024)),
|
||||
)
|
||||
|
||||
return vector_store
|
||||
@@ -1,39 +0,0 @@
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
import os
|
||||
import logging
|
||||
from llama_index.core.storage import StorageContext
|
||||
from llama_index.core.indices import VectorStoreIndex
|
||||
from llama_index.vector_stores.pinecone import PineconeVectorStore
|
||||
from app.settings import init_settings
|
||||
from app.engine.loaders import get_documents
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def generate_datasource():
|
||||
init_settings()
|
||||
logger.info("Creating new index")
|
||||
# load the documents and create the index
|
||||
documents = get_documents()
|
||||
store = PineconeVectorStore(
|
||||
api_key=os.environ["PINECONE_API_KEY"],
|
||||
index_name=os.environ["PINECONE_INDEX_NAME"],
|
||||
environment=os.environ["PINECONE_ENVIRONMENT"],
|
||||
)
|
||||
storage_context = StorageContext.from_defaults(vector_store=store)
|
||||
VectorStoreIndex.from_documents(
|
||||
documents,
|
||||
storage_context=storage_context,
|
||||
show_progress=True, # this will show you a progress bar as the embeddings are created
|
||||
)
|
||||
logger.info(
|
||||
f"Successfully created embeddings and save to your Pinecone index {os.environ['PINECONE_INDEX_NAME']}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_datasource()
|
||||
@@ -1,20 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from llama_index.core.indices import VectorStoreIndex
|
||||
from llama_index.vector_stores.pinecone import PineconeVectorStore
|
||||
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
def get_index():
|
||||
logger.info("Connecting to index from Pinecone...")
|
||||
store = PineconeVectorStore(
|
||||
api_key=os.environ["PINECONE_API_KEY"],
|
||||
index_name=os.environ["PINECONE_INDEX_NAME"],
|
||||
environment=os.environ["PINECONE_ENVIRONMENT"],
|
||||
)
|
||||
index = VectorStoreIndex.from_vector_store(store)
|
||||
logger.info("Finished connecting to index from Pinecone.")
|
||||
return index
|
||||
@@ -0,0 +1,19 @@
|
||||
import os
|
||||
from llama_index.vector_stores.pinecone import PineconeVectorStore
|
||||
|
||||
|
||||
def get_vector_store():
|
||||
api_key = os.getenv("PINECONE_API_KEY")
|
||||
index_name = os.getenv("PINECONE_INDEX_NAME")
|
||||
environment = os.getenv("PINECONE_ENVIRONMENT")
|
||||
if not api_key or not index_name or not environment:
|
||||
raise ValueError(
|
||||
"Please set PINECONE_API_KEY, PINECONE_INDEX_NAME, and PINECONE_ENVIRONMENT"
|
||||
" to your environment variables or config them in the .env file"
|
||||
)
|
||||
store = PineconeVectorStore(
|
||||
api_key=api_key,
|
||||
index_name=index_name,
|
||||
environment=environment,
|
||||
)
|
||||
return store
|
||||
@@ -1,37 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
from app.engine.loaders import get_documents
|
||||
from app.settings import init_settings
|
||||
from dotenv import load_dotenv
|
||||
from llama_index.core.indices import VectorStoreIndex
|
||||
from llama_index.core.storage import StorageContext
|
||||
from llama_index.vector_stores.qdrant import QdrantVectorStore
|
||||
load_dotenv()
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def generate_datasource():
|
||||
init_settings()
|
||||
logger.info("Creating new index with Qdrant")
|
||||
# load the documents and create the index
|
||||
documents = get_documents()
|
||||
store = QdrantVectorStore(
|
||||
collection_name=os.getenv("QDRANT_COLLECTION"),
|
||||
url=os.getenv("QDRANT_URL"),
|
||||
api_key=os.getenv("QDRANT_API_KEY"),
|
||||
)
|
||||
storage_context = StorageContext.from_defaults(vector_store=store)
|
||||
VectorStoreIndex.from_documents(
|
||||
documents,
|
||||
storage_context=storage_context,
|
||||
show_progress=True, # this will show you a progress bar as the embeddings are created
|
||||
)
|
||||
logger.info(
|
||||
f"Successfully uploaded documents to the {os.getenv('QDRANT_COLLECTION')} collection."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_datasource()
|
||||
@@ -1,20 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from llama_index.core.indices import VectorStoreIndex
|
||||
from llama_index.vector_stores.qdrant import QdrantVectorStore
|
||||
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
def get_index():
|
||||
logger.info("Connecting to Qdrant collection..")
|
||||
store = QdrantVectorStore(
|
||||
collection_name=os.getenv("QDRANT_COLLECTION"),
|
||||
url=os.getenv("QDRANT_URL"),
|
||||
api_key=os.getenv("QDRANT_API_KEY"),
|
||||
)
|
||||
index = VectorStoreIndex.from_vector_store(store)
|
||||
logger.info("Finished connecting to Qdrant collection.")
|
||||
return index
|
||||
@@ -0,0 +1,19 @@
|
||||
import os
|
||||
from llama_index.vector_stores.qdrant import QdrantVectorStore
|
||||
|
||||
|
||||
def get_vector_store():
|
||||
collection_name = os.getenv("QDRANT_COLLECTION")
|
||||
url = os.getenv("QDRANT_URL")
|
||||
api_key = os.getenv("QDRANT_API_KEY")
|
||||
if not collection_name or not url:
|
||||
raise ValueError(
|
||||
"Please set QDRANT_COLLECTION, QDRANT_URL"
|
||||
" to your environment variables or config them in the .env file"
|
||||
)
|
||||
store = QdrantVectorStore(
|
||||
collection_name=collection_name,
|
||||
url=url,
|
||||
api_key=api_key,
|
||||
)
|
||||
return store
|
||||
@@ -1,10 +1,7 @@
|
||||
/* eslint-disable turbo/no-undeclared-env-vars */
|
||||
import * as dotenv from "dotenv";
|
||||
import {
|
||||
AstraDBVectorStore,
|
||||
VectorStoreIndex,
|
||||
storageContextFromDefaults,
|
||||
} from "llamaindex";
|
||||
import { VectorStoreIndex, storageContextFromDefaults } from "llamaindex";
|
||||
import { AstraDBVectorStore } from "llamaindex/storage/vectorStore/AstraDBVectorStore";
|
||||
import { getDocuments } from "./loader";
|
||||
import { initSettings } from "./settings";
|
||||
import { checkRequiredEnvVars } from "./shared";
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
/* eslint-disable turbo/no-undeclared-env-vars */
|
||||
import { AstraDBVectorStore, VectorStoreIndex } from "llamaindex";
|
||||
import { VectorStoreIndex } from "llamaindex";
|
||||
import { AstraDBVectorStore } from "llamaindex/storage/vectorStore/AstraDBVectorStore";
|
||||
import { checkRequiredEnvVars } from "./shared";
|
||||
|
||||
export async function getDataSource() {
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
/* eslint-disable turbo/no-undeclared-env-vars */
|
||||
import * as dotenv from "dotenv";
|
||||
import { VectorStoreIndex, storageContextFromDefaults } from "llamaindex";
|
||||
import { ChromaVectorStore } from "llamaindex/storage/vectorStore/ChromaVectorStore";
|
||||
import { getDocuments } from "./loader";
|
||||
import { initSettings } from "./settings";
|
||||
import { checkRequiredEnvVars } from "./shared";
|
||||
|
||||
dotenv.config();
|
||||
|
||||
async function loadAndIndex() {
|
||||
// load objects from storage and convert them into LlamaIndex Document objects
|
||||
const documents = await getDocuments();
|
||||
|
||||
// create vector store
|
||||
const chromaUri = `http://${process.env.CHROMA_HOST}:${process.env.CHROMA_PORT}`;
|
||||
|
||||
const vectorStore = new ChromaVectorStore({
|
||||
collectionName: process.env.CHROMA_COLLECTION,
|
||||
chromaClientParams: { path: chromaUri },
|
||||
});
|
||||
|
||||
// create index from all the Documentss and store them in Pinecone
|
||||
console.log("Start creating embeddings...");
|
||||
const storageContext = await storageContextFromDefaults({ vectorStore });
|
||||
await VectorStoreIndex.fromDocuments(documents, { storageContext });
|
||||
console.log(
|
||||
"Successfully created embeddings and save to your ChromaDB index.",
|
||||
);
|
||||
}
|
||||
|
||||
(async () => {
|
||||
checkRequiredEnvVars();
|
||||
initSettings();
|
||||
await loadAndIndex();
|
||||
console.log("Finished generating storage.");
|
||||
})();
|
||||
@@ -0,0 +1,16 @@
|
||||
/* eslint-disable turbo/no-undeclared-env-vars */
|
||||
import { VectorStoreIndex } from "llamaindex";
|
||||
import { ChromaVectorStore } from "llamaindex/storage/vectorStore/ChromaVectorStore";
|
||||
import { checkRequiredEnvVars } from "./shared";
|
||||
|
||||
export async function getDataSource() {
|
||||
checkRequiredEnvVars();
|
||||
const chromaUri = `http://${process.env.CHROMA_HOST}:${process.env.CHROMA_PORT}`;
|
||||
|
||||
const store = new ChromaVectorStore({
|
||||
collectionName: process.env.CHROMA_COLLECTION,
|
||||
chromaClientParams: { path: chromaUri },
|
||||
});
|
||||
|
||||
return await VectorStoreIndex.fromVectorStore(store);
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
const REQUIRED_ENV_VARS = ["CHROMA_COLLECTION", "CHROMA_HOST", "CHROMA_PORT"];
|
||||
|
||||
export function checkRequiredEnvVars() {
|
||||
const missingEnvVars = REQUIRED_ENV_VARS.filter((envVar) => {
|
||||
return !process.env[envVar];
|
||||
});
|
||||
|
||||
if (missingEnvVars.length > 0) {
|
||||
console.log(
|
||||
`The following environment variables are required but missing: ${missingEnvVars.join(
|
||||
", ",
|
||||
)}`,
|
||||
);
|
||||
throw new Error(
|
||||
`Missing environment variables: ${missingEnvVars.join(", ")}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,7 @@
|
||||
/* eslint-disable turbo/no-undeclared-env-vars */
|
||||
import * as dotenv from "dotenv";
|
||||
import {
|
||||
MilvusVectorStore,
|
||||
VectorStoreIndex,
|
||||
storageContextFromDefaults,
|
||||
} from "llamaindex";
|
||||
import { VectorStoreIndex, storageContextFromDefaults } from "llamaindex";
|
||||
import { MilvusVectorStore } from "llamaindex/storage/vectorStore/MilvusVectorStore";
|
||||
import { getDocuments } from "./loader";
|
||||
import { initSettings } from "./settings";
|
||||
import { checkRequiredEnvVars, getMilvusClient } from "./shared";
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { MilvusVectorStore, VectorStoreIndex } from "llamaindex";
|
||||
import { VectorStoreIndex } from "llamaindex";
|
||||
import { MilvusVectorStore } from "llamaindex/storage/vectorStore/MilvusVectorStore";
|
||||
import { checkRequiredEnvVars, getMilvusClient } from "./shared";
|
||||
|
||||
export async function getDataSource() {
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
/* eslint-disable turbo/no-undeclared-env-vars */
|
||||
import * as dotenv from "dotenv";
|
||||
import {
|
||||
MongoDBAtlasVectorSearch,
|
||||
VectorStoreIndex,
|
||||
storageContextFromDefaults,
|
||||
} from "llamaindex";
|
||||
import { VectorStoreIndex, storageContextFromDefaults } from "llamaindex";
|
||||
import { MongoDBAtlasVectorSearch } from "llamaindex/storage/vectorStore/MongoDBAtlasVectorSearch";
|
||||
import { MongoClient } from "mongodb";
|
||||
import { getDocuments } from "./loader";
|
||||
import { initSettings } from "./settings";
|
||||
@@ -12,7 +9,7 @@ import { checkRequiredEnvVars } from "./shared";
|
||||
|
||||
dotenv.config();
|
||||
|
||||
const mongoUri = process.env.MONGO_URI!;
|
||||
const mongoUri = process.env.MONGODB_URI!;
|
||||
const databaseName = process.env.MONGODB_DATABASE!;
|
||||
const vectorCollectionName = process.env.MONGODB_VECTORS!;
|
||||
const indexName = process.env.MONGODB_VECTOR_INDEX;
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
/* eslint-disable turbo/no-undeclared-env-vars */
|
||||
import { MongoDBAtlasVectorSearch, VectorStoreIndex } from "llamaindex";
|
||||
import { VectorStoreIndex } from "llamaindex";
|
||||
import { MongoDBAtlasVectorSearch } from "llamaindex/storage/vectorStore/MongoDBAtlasVectorSearch";
|
||||
import { MongoClient } from "mongodb";
|
||||
import { checkRequiredEnvVars } from "./shared";
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
const REQUIRED_ENV_VARS = [
|
||||
"MONGO_URI",
|
||||
"MONGODB_URI",
|
||||
"MONGODB_DATABASE",
|
||||
"MONGODB_VECTORS",
|
||||
"MONGODB_VECTOR_INDEX",
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
/* eslint-disable turbo/no-undeclared-env-vars */
|
||||
import * as dotenv from "dotenv";
|
||||
import {
|
||||
PGVectorStore,
|
||||
VectorStoreIndex,
|
||||
storageContextFromDefaults,
|
||||
} from "llamaindex";
|
||||
import { VectorStoreIndex, storageContextFromDefaults } from "llamaindex";
|
||||
import { PGVectorStore } from "llamaindex/storage/vectorStore/PGVectorStore";
|
||||
import { getDocuments } from "./loader";
|
||||
import { initSettings } from "./settings";
|
||||
import {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
/* eslint-disable turbo/no-undeclared-env-vars */
|
||||
import { PGVectorStore, VectorStoreIndex } from "llamaindex";
|
||||
import { VectorStoreIndex } from "llamaindex";
|
||||
import { PGVectorStore } from "llamaindex/storage/vectorStore/PGVectorStore";
|
||||
import {
|
||||
PGVECTOR_SCHEMA,
|
||||
PGVECTOR_TABLE,
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
/* eslint-disable turbo/no-undeclared-env-vars */
|
||||
import * as dotenv from "dotenv";
|
||||
import {
|
||||
PineconeVectorStore,
|
||||
VectorStoreIndex,
|
||||
storageContextFromDefaults,
|
||||
} from "llamaindex";
|
||||
import { VectorStoreIndex, storageContextFromDefaults } from "llamaindex";
|
||||
import { PineconeVectorStore } from "llamaindex/storage/vectorStore/PineconeVectorStore";
|
||||
import { getDocuments } from "./loader";
|
||||
import { initSettings } from "./settings";
|
||||
import { checkRequiredEnvVars } from "./shared";
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
/* eslint-disable turbo/no-undeclared-env-vars */
|
||||
import { PineconeVectorStore, VectorStoreIndex } from "llamaindex";
|
||||
import { VectorStoreIndex } from "llamaindex";
|
||||
import { PineconeVectorStore } from "llamaindex/storage/vectorStore/PineconeVectorStore";
|
||||
import { checkRequiredEnvVars } from "./shared";
|
||||
|
||||
export async function getDataSource() {
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
/* eslint-disable turbo/no-undeclared-env-vars */
|
||||
import * as dotenv from "dotenv";
|
||||
import {
|
||||
QdrantVectorStore,
|
||||
VectorStoreIndex,
|
||||
storageContextFromDefaults,
|
||||
} from "llamaindex";
|
||||
import { VectorStoreIndex, storageContextFromDefaults } from "llamaindex";
|
||||
import { QdrantVectorStore } from "llamaindex/storage/vectorStore/QdrantVectorStore";
|
||||
import { getDocuments } from "./loader";
|
||||
import { initSettings } from "./settings";
|
||||
import { checkRequiredEnvVars, getQdrantClient } from "./shared";
|
||||
@@ -18,7 +15,10 @@ async function loadAndIndex() {
|
||||
const documents = await getDocuments();
|
||||
|
||||
// Connect to Qdrant
|
||||
const vectorStore = new QdrantVectorStore(collectionName, getQdrantClient());
|
||||
const vectorStore = new QdrantVectorStore({
|
||||
collectionName,
|
||||
client: getQdrantClient(),
|
||||
});
|
||||
|
||||
const storageContext = await storageContextFromDefaults({ vectorStore });
|
||||
await VectorStoreIndex.fromDocuments(documents, {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import * as dotenv from "dotenv";
|
||||
import { QdrantVectorStore, VectorStoreIndex } from "llamaindex";
|
||||
import { VectorStoreIndex } from "llamaindex";
|
||||
import { QdrantVectorStore } from "llamaindex/storage/vectorStore/QdrantVectorStore";
|
||||
import { checkRequiredEnvVars, getQdrantClient } from "./shared";
|
||||
|
||||
dotenv.config();
|
||||
@@ -7,7 +8,10 @@ dotenv.config();
|
||||
export async function getDataSource() {
|
||||
checkRequiredEnvVars();
|
||||
const collectionName = process.env.QDRANT_COLLECTION;
|
||||
const store = new QdrantVectorStore(collectionName, getQdrantClient());
|
||||
const store = new QdrantVectorStore({
|
||||
collectionName,
|
||||
client: getQdrantClient(),
|
||||
});
|
||||
|
||||
return await VectorStoreIndex.fromVectorStore(store);
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
# local env files
|
||||
.env
|
||||
node_modules/
|
||||
node_modules/
|
||||
|
||||
tool-output/
|
||||
@@ -31,6 +31,8 @@ if (isDevelopment) {
|
||||
console.warn("Production CORS origin not set, defaulting to no CORS.");
|
||||
}
|
||||
|
||||
app.use("/api/files/data", express.static("data"));
|
||||
app.use("/api/files/tool-output", express.static("tool-output"));
|
||||
app.use(express.text());
|
||||
|
||||
app.get("/", (req: Request, res: Response) => {
|
||||
|
||||
@@ -1,21 +1,23 @@
|
||||
{
|
||||
"name": "llama-index-express-streaming",
|
||||
"version": "1.0.0",
|
||||
"main": "dist/index.mjs",
|
||||
"main": "dist/index.js",
|
||||
"scripts": {
|
||||
"format": "prettier --ignore-unknown --cache --check .",
|
||||
"format:write": "prettier --ignore-unknown --write .",
|
||||
"build": "tsup index.ts --format esm --dts",
|
||||
"start": "node dist/index.mjs",
|
||||
"dev": "concurrently \"tsup index.ts --format esm --dts --watch\" \"nodemon -q dist/index.mjs\""
|
||||
"build": "tsup index.ts --format cjs --dts",
|
||||
"start": "node dist/index.js",
|
||||
"dev": "concurrently \"tsup index.ts --format cjs --dts --watch\" \"nodemon -q dist/index.js\""
|
||||
},
|
||||
"dependencies": {
|
||||
"ai": "^3.0.21",
|
||||
"cors": "^2.8.5",
|
||||
"dotenv": "^16.3.1",
|
||||
"express": "^4.18.2",
|
||||
"llamaindex": "0.3.3",
|
||||
"pdf2json": "3.0.5"
|
||||
"llamaindex": "0.3.13",
|
||||
"pdf2json": "3.0.5",
|
||||
"ajv": "^8.12.0",
|
||||
"@e2b/code-interpreter": "^0.0.5"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/cors": "^2.8.16",
|
||||
|
||||
@@ -1,14 +1,9 @@
|
||||
import { Message, StreamData, streamToResponse } from "ai";
|
||||
import { Request, Response } from "express";
|
||||
import {
|
||||
CallbackManager,
|
||||
ChatMessage,
|
||||
MessageContent,
|
||||
Settings,
|
||||
} from "llamaindex";
|
||||
import { ChatMessage, MessageContent, Settings } from "llamaindex";
|
||||
import { createChatEngine } from "./engine/chat";
|
||||
import { LlamaIndexStream } from "./llamaindex-stream";
|
||||
import { appendEventData } from "./stream-helper";
|
||||
import { createCallbackManager } from "./stream-helper";
|
||||
|
||||
const convertMessageContent = (
|
||||
textMessage: string,
|
||||
@@ -52,18 +47,7 @@ export const chat = async (req: Request, res: Response) => {
|
||||
const vercelStreamData = new StreamData();
|
||||
|
||||
// Setup callbacks
|
||||
const callbackManager = new CallbackManager();
|
||||
callbackManager.on("retrieve", (data) => {
|
||||
const { nodes } = data.detail;
|
||||
appendEventData(
|
||||
vercelStreamData,
|
||||
`Retrieving context for query: '${userMessage.content}'`,
|
||||
);
|
||||
appendEventData(
|
||||
vercelStreamData,
|
||||
`Retrieved ${nodes.length} sources to use as context for the query`,
|
||||
);
|
||||
});
|
||||
const callbackManager = createCallbackManager(vercelStreamData);
|
||||
|
||||
// Calling LlamaIndex's ChatEngine to get a streamed response
|
||||
const response = await Settings.withCallbackManager(callbackManager, () => {
|
||||
@@ -80,9 +64,8 @@ export const chat = async (req: Request, res: Response) => {
|
||||
image_url: data?.imageUrl,
|
||||
},
|
||||
});
|
||||
const processedStream = stream.pipeThrough(vercelStreamData.stream);
|
||||
|
||||
return streamToResponse(processedStream, res);
|
||||
return streamToResponse(stream, res, {}, vercelStreamData);
|
||||
} catch (error) {
|
||||
console.error("[LlamaIndex]", error);
|
||||
return res.status(500).json({
|
||||
|
||||
@@ -56,11 +56,17 @@ function initOpenAI() {
|
||||
}
|
||||
|
||||
function initOllama() {
|
||||
const config = {
|
||||
host: process.env.OLLAMA_BASE_URL ?? "http://127.0.0.1:11434",
|
||||
};
|
||||
|
||||
Settings.llm = new Ollama({
|
||||
model: process.env.MODEL ?? "",
|
||||
config,
|
||||
});
|
||||
Settings.embedModel = new OllamaEmbedding({
|
||||
model: process.env.EMBEDDING_MODEL ?? "",
|
||||
config,
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
import { StreamData } from "ai";
|
||||
import { Metadata, NodeWithScore } from "llamaindex";
|
||||
import {
|
||||
CallbackManager,
|
||||
Metadata,
|
||||
NodeWithScore,
|
||||
ToolCall,
|
||||
ToolOutput,
|
||||
} from "llamaindex";
|
||||
|
||||
export function appendImageData(data: StreamData, imageUrl?: string) {
|
||||
if (!imageUrl) return;
|
||||
@@ -37,3 +43,55 @@ export function appendEventData(data: StreamData, title?: string) {
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
export function appendToolData(
|
||||
data: StreamData,
|
||||
toolCall: ToolCall,
|
||||
toolOutput: ToolOutput,
|
||||
) {
|
||||
data.appendMessageAnnotation({
|
||||
type: "tools",
|
||||
data: {
|
||||
toolCall: {
|
||||
id: toolCall.id,
|
||||
name: toolCall.name,
|
||||
input: toolCall.input,
|
||||
},
|
||||
toolOutput: {
|
||||
output: toolOutput.output,
|
||||
isError: toolOutput.isError,
|
||||
},
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
export function createCallbackManager(stream: StreamData) {
|
||||
const callbackManager = new CallbackManager();
|
||||
|
||||
callbackManager.on("retrieve", (data) => {
|
||||
const { nodes, query } = data.detail;
|
||||
appendEventData(stream, `Retrieving context for query: '${query}'`);
|
||||
appendEventData(
|
||||
stream,
|
||||
`Retrieved ${nodes.length} sources to use as context for the query`,
|
||||
);
|
||||
});
|
||||
|
||||
callbackManager.on("llm-tool-call", (event) => {
|
||||
const { name, input } = event.detail.payload.toolCall;
|
||||
const inputString = Object.entries(input)
|
||||
.map(([key, value]) => `${key}: ${value}`)
|
||||
.join(", ");
|
||||
appendEventData(
|
||||
stream,
|
||||
`Using tool: '${name}' with inputs: '${inputString}'`,
|
||||
);
|
||||
});
|
||||
|
||||
callbackManager.on("llm-tool-result", (event) => {
|
||||
const { toolCall, toolResult } = event.detail.payload;
|
||||
appendToolData(stream, toolCall, toolResult);
|
||||
});
|
||||
|
||||
return callbackManager;
|
||||
}
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Any, Optional, Dict, Tuple
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from llama_index.core.chat_engine.types import (
|
||||
BaseChatEngine,
|
||||
StreamingAgentChatResponse,
|
||||
)
|
||||
from llama_index.core.chat_engine.types import BaseChatEngine
|
||||
from llama_index.core.schema import NodeWithScore
|
||||
from llama_index.core.llms import ChatMessage, MessageRole
|
||||
from app.engine import get_chat_engine
|
||||
@@ -96,7 +93,13 @@ async def chat(
|
||||
|
||||
event_handler = EventCallbackHandler()
|
||||
chat_engine.callback_manager.handlers.append(event_handler) # type: ignore
|
||||
response = await chat_engine.astream_chat(last_message_content, messages)
|
||||
try:
|
||||
response = await chat_engine.astream_chat(last_message_content, messages)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error in chat engine: {e}",
|
||||
)
|
||||
|
||||
async def content_generator():
|
||||
# Yield the text response
|
||||
@@ -109,12 +112,9 @@ async def chat(
|
||||
# Yield the events from the event handler
|
||||
async def _event_generator():
|
||||
async for event in event_handler.async_event_gen():
|
||||
yield VercelStreamResponse.convert_data(
|
||||
{
|
||||
"type": "events",
|
||||
"data": {"title": event.get_title()},
|
||||
}
|
||||
)
|
||||
event_response = event.to_response()
|
||||
if event_response is not None:
|
||||
yield VercelStreamResponse.convert_data(event_response)
|
||||
|
||||
combine = stream.merge(_text_generator(), _event_generator())
|
||||
async with combine.stream() as streamer:
|
||||
|
||||
@@ -1,31 +1,94 @@
|
||||
import json
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import AsyncGenerator, Dict, Any, List, Optional
|
||||
|
||||
from llama_index.core.callbacks.base import BaseCallbackHandler
|
||||
from llama_index.core.callbacks.schema import CBEventType
|
||||
from llama_index.core.tools.types import ToolOutput
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CallbackEvent(BaseModel):
|
||||
event_type: CBEventType
|
||||
payload: Optional[Dict[str, Any]] = None
|
||||
event_id: str = ""
|
||||
|
||||
def get_title(self) -> str | None:
|
||||
# Return as None for the unhandled event types
|
||||
# to avoid showing them in the UI
|
||||
match self.event_type:
|
||||
case "retrieve":
|
||||
if self.payload:
|
||||
nodes = self.payload.get("nodes")
|
||||
if nodes:
|
||||
return f"Retrieved {len(nodes)} sources to use as context for the query"
|
||||
def get_retrieval_message(self) -> dict | None:
|
||||
if self.payload:
|
||||
nodes = self.payload.get("nodes")
|
||||
if nodes:
|
||||
msg = f"Retrieved {len(nodes)} sources to use as context for the query"
|
||||
else:
|
||||
msg = f"Retrieving context for query: '{self.payload.get('query_str')}'"
|
||||
return {
|
||||
"type": "events",
|
||||
"data": {"title": msg},
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_tool_message(self) -> dict | None:
|
||||
func_call_args = self.payload.get("function_call")
|
||||
if func_call_args is not None and "tool" in self.payload:
|
||||
tool = self.payload.get("tool")
|
||||
return {
|
||||
"type": "events",
|
||||
"data": {
|
||||
"title": f"Calling tool: {tool.name} with inputs: {func_call_args}",
|
||||
},
|
||||
}
|
||||
|
||||
def _is_output_serializable(self, output: Any) -> bool:
|
||||
try:
|
||||
json.dumps(output)
|
||||
return True
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
def get_agent_tool_response(self) -> dict | None:
|
||||
response = self.payload.get("response")
|
||||
if response is not None:
|
||||
sources = response.sources
|
||||
for source in sources:
|
||||
# Return the tool response here to include the toolCall information
|
||||
if isinstance(source, ToolOutput):
|
||||
if self._is_output_serializable(source.raw_output):
|
||||
output = source.raw_output
|
||||
else:
|
||||
return f"Retrieving context for query: '{self.payload.get('query_str')}'"
|
||||
else:
|
||||
output = source.content
|
||||
|
||||
return {
|
||||
"type": "tools",
|
||||
"data": {
|
||||
"toolOutput": {
|
||||
"output": output,
|
||||
"isError": source.is_error,
|
||||
},
|
||||
"toolCall": {
|
||||
"id": None, # There is no tool id in the ToolOutput
|
||||
"name": source.tool_name,
|
||||
"input": source.raw_input,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def to_response(self):
|
||||
try:
|
||||
match self.event_type:
|
||||
case "retrieve":
|
||||
return self.get_retrieval_message()
|
||||
case "function_call":
|
||||
return self.get_tool_message()
|
||||
case "agent_step":
|
||||
return self.get_agent_tool_response()
|
||||
case _:
|
||||
return None
|
||||
case _:
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error in converting event to response: {e}")
|
||||
return None
|
||||
|
||||
|
||||
class EventCallbackHandler(BaseCallbackHandler):
|
||||
@@ -54,7 +117,7 @@ class EventCallbackHandler(BaseCallbackHandler):
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
event = CallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
|
||||
if event.get_title() is not None:
|
||||
if event.to_response() is not None:
|
||||
self._aqueue.put_nowait(event)
|
||||
|
||||
def on_event_end(
|
||||
@@ -65,7 +128,7 @@ class EventCallbackHandler(BaseCallbackHandler):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
event = CallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
|
||||
if event.get_title() is not None:
|
||||
if event.to_response() is not None:
|
||||
self._aqueue.put_nowait(event)
|
||||
|
||||
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
import os
|
||||
import logging
|
||||
from llama_index.core.settings import Settings
|
||||
from llama_index.core.ingestion import IngestionPipeline
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core.storage.docstore import SimpleDocumentStore
|
||||
from llama_index.core.storage import StorageContext
|
||||
from app.settings import init_settings
|
||||
from app.engine.loaders import get_documents
|
||||
from app.engine.vectordb import get_vector_store
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger()
|
||||
|
||||
STORAGE_DIR = os.getenv("STORAGE_DIR", "storage")
|
||||
|
||||
|
||||
def get_doc_store():
|
||||
|
||||
# If the storage directory is there, load the document store from it.
|
||||
# If not, set up an in-memory document store since we can't load from a directory that doesn't exist.
|
||||
if os.path.exists(STORAGE_DIR):
|
||||
return SimpleDocumentStore.from_persist_dir(STORAGE_DIR)
|
||||
else:
|
||||
return SimpleDocumentStore()
|
||||
|
||||
|
||||
def run_pipeline(docstore, vector_store, documents):
|
||||
pipeline = IngestionPipeline(
|
||||
transformations=[
|
||||
SentenceSplitter(
|
||||
chunk_size=Settings.chunk_size,
|
||||
chunk_overlap=Settings.chunk_overlap,
|
||||
),
|
||||
Settings.embed_model,
|
||||
],
|
||||
docstore=docstore,
|
||||
docstore_strategy="upserts_and_delete",
|
||||
vector_store=vector_store,
|
||||
)
|
||||
|
||||
# Run the ingestion pipeline and store the results
|
||||
nodes = pipeline.run(show_progress=True, documents=documents)
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
def persist_storage(docstore, vector_store):
|
||||
storage_context = StorageContext.from_defaults(
|
||||
docstore=docstore,
|
||||
vector_store=vector_store,
|
||||
)
|
||||
storage_context.persist(STORAGE_DIR)
|
||||
|
||||
|
||||
def generate_datasource():
|
||||
init_settings()
|
||||
logger.info("Generate index for the provided data")
|
||||
|
||||
# Get the stores and documents or create new ones
|
||||
documents = get_documents()
|
||||
docstore = get_doc_store()
|
||||
vector_store = get_vector_store()
|
||||
|
||||
# Run the ingestion pipeline
|
||||
_ = run_pipeline(docstore, vector_store, documents)
|
||||
|
||||
# Build the index and persist storage
|
||||
persist_storage(docstore, vector_store)
|
||||
|
||||
logger.info("Finished generating the index")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_datasource()
|
||||
@@ -0,0 +1,17 @@
|
||||
import logging
|
||||
from llama_index.core.indices import VectorStoreIndex
|
||||
from app.engine.vectordb import get_vector_store
|
||||
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
def get_index():
|
||||
logger.info("Connecting vector store...")
|
||||
store = get_vector_store()
|
||||
# Load the index from the vector store
|
||||
# If you are using a vector store that doesn't store text,
|
||||
# you must load the index from both the vector store and the document store
|
||||
index = VectorStoreIndex.from_vector_store(store)
|
||||
logger.info("Finished load index from vector store.")
|
||||
return index
|
||||
@@ -23,8 +23,12 @@ def init_ollama():
|
||||
from llama_index.llms.ollama import Ollama
|
||||
from llama_index.embeddings.ollama import OllamaEmbedding
|
||||
|
||||
Settings.embed_model = OllamaEmbedding(model_name=os.getenv("EMBEDDING_MODEL"))
|
||||
Settings.llm = Ollama(model=os.getenv("MODEL"))
|
||||
base_url = os.getenv("OLLAMA_BASE_URL") or "http://127.0.0.1:11434"
|
||||
Settings.embed_model = OllamaEmbedding(
|
||||
base_url=base_url,
|
||||
model_name=os.getenv("EMBEDDING_MODEL"),
|
||||
)
|
||||
Settings.llm = Ollama(base_url=base_url, model=os.getenv("MODEL"))
|
||||
|
||||
|
||||
def init_openai():
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
__pycache__
|
||||
storage
|
||||
.env
|
||||
.env
|
||||
@@ -11,6 +11,7 @@ from fastapi.responses import RedirectResponse
|
||||
from app.api.routers.chat import chat_router
|
||||
from app.settings import init_settings
|
||||
from app.observability import init_observability
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
@@ -20,7 +21,6 @@ init_observability()
|
||||
|
||||
environment = os.getenv("ENVIRONMENT", "dev") # Default to 'development' if not set
|
||||
|
||||
|
||||
if environment == "dev":
|
||||
logger = logging.getLogger("uvicorn")
|
||||
logger.warning("Running in development mode - allowing CORS for all origins")
|
||||
@@ -38,6 +38,16 @@ if environment == "dev":
|
||||
return RedirectResponse(url="/docs")
|
||||
|
||||
|
||||
def mount_static_files(directory, path):
|
||||
if os.path.exists(directory):
|
||||
app.mount(path, StaticFiles(directory=directory), name=f"{directory}-static")
|
||||
|
||||
|
||||
# Mount the data files to serve the file viewer
|
||||
mount_static_files("data", "/api/files/data")
|
||||
# Mount the output files from tools
|
||||
mount_static_files("tool-output", "/api/files/tool-output")
|
||||
|
||||
app.include_router(chat_router, prefix="/api/chat")
|
||||
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ python-dotenv = "^1.0.0"
|
||||
aiostream = "^0.5.2"
|
||||
llama-index = "0.10.28"
|
||||
llama-index-core = "0.10.28"
|
||||
cachetools = "^5.3.3"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
@@ -56,11 +56,16 @@ function initOpenAI() {
|
||||
}
|
||||
|
||||
function initOllama() {
|
||||
const config = {
|
||||
host: process.env.OLLAMA_BASE_URL ?? "http://127.0.0.1:11434",
|
||||
};
|
||||
Settings.llm = new Ollama({
|
||||
model: process.env.MODEL ?? "",
|
||||
config,
|
||||
});
|
||||
Settings.embedModel = new OllamaEmbedding({
|
||||
model: process.env.EMBEDDING_MODEL ?? "",
|
||||
config,
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -1,16 +1,11 @@
|
||||
import { initObservability } from "@/app/observability";
|
||||
import { Message, StreamData, StreamingTextResponse } from "ai";
|
||||
import {
|
||||
CallbackManager,
|
||||
ChatMessage,
|
||||
MessageContent,
|
||||
Settings,
|
||||
} from "llamaindex";
|
||||
import { ChatMessage, MessageContent, Settings } from "llamaindex";
|
||||
import { NextRequest, NextResponse } from "next/server";
|
||||
import { createChatEngine } from "./engine/chat";
|
||||
import { initSettings } from "./engine/settings";
|
||||
import { LlamaIndexStream } from "./llamaindex-stream";
|
||||
import { appendEventData } from "./stream-helper";
|
||||
import { createCallbackManager } from "./stream-helper";
|
||||
|
||||
initObservability();
|
||||
initSettings();
|
||||
@@ -64,18 +59,7 @@ export async function POST(request: NextRequest) {
|
||||
const vercelStreamData = new StreamData();
|
||||
|
||||
// Setup callbacks
|
||||
const callbackManager = new CallbackManager();
|
||||
callbackManager.on("retrieve", (data) => {
|
||||
const { nodes } = data.detail;
|
||||
appendEventData(
|
||||
vercelStreamData,
|
||||
`Retrieving context for query: '${userMessage.content}'`,
|
||||
);
|
||||
appendEventData(
|
||||
vercelStreamData,
|
||||
`Retrieved ${nodes.length} sources to use as context for the query`,
|
||||
);
|
||||
});
|
||||
const callbackManager = createCallbackManager(vercelStreamData);
|
||||
|
||||
// Calling LlamaIndex's ChatEngine to get a streamed response
|
||||
const response = await Settings.withCallbackManager(callbackManager, () => {
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
import { StreamData } from "ai";
|
||||
import { Metadata, NodeWithScore } from "llamaindex";
|
||||
import {
|
||||
CallbackManager,
|
||||
Metadata,
|
||||
NodeWithScore,
|
||||
ToolCall,
|
||||
ToolOutput,
|
||||
} from "llamaindex";
|
||||
|
||||
export function appendImageData(data: StreamData, imageUrl?: string) {
|
||||
if (!imageUrl) return;
|
||||
@@ -37,3 +43,55 @@ export function appendEventData(data: StreamData, title?: string) {
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
export function appendToolData(
|
||||
data: StreamData,
|
||||
toolCall: ToolCall,
|
||||
toolOutput: ToolOutput,
|
||||
) {
|
||||
data.appendMessageAnnotation({
|
||||
type: "tools",
|
||||
data: {
|
||||
toolCall: {
|
||||
id: toolCall.id,
|
||||
name: toolCall.name,
|
||||
input: toolCall.input,
|
||||
},
|
||||
toolOutput: {
|
||||
output: toolOutput.output,
|
||||
isError: toolOutput.isError,
|
||||
},
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
export function createCallbackManager(stream: StreamData) {
|
||||
const callbackManager = new CallbackManager();
|
||||
|
||||
callbackManager.on("retrieve", (data) => {
|
||||
const { nodes, query } = data.detail;
|
||||
appendEventData(stream, `Retrieving context for query: '${query}'`);
|
||||
appendEventData(
|
||||
stream,
|
||||
`Retrieved ${nodes.length} sources to use as context for the query`,
|
||||
);
|
||||
});
|
||||
|
||||
callbackManager.on("llm-tool-call", (event) => {
|
||||
const { name, input } = event.detail.payload.toolCall;
|
||||
const inputString = Object.entries(input)
|
||||
.map(([key, value]) => `${key}: ${value}`)
|
||||
.join(", ");
|
||||
appendEventData(
|
||||
stream,
|
||||
`Using tool: '${name}' with inputs: '${inputString}'`,
|
||||
);
|
||||
});
|
||||
|
||||
callbackManager.on("llm-tool-result", (event) => {
|
||||
const { toolCall, toolResult } = event.detail.payload;
|
||||
appendToolData(stream, toolCall, toolResult);
|
||||
});
|
||||
|
||||
return callbackManager;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
import { readFile } from "fs/promises";
|
||||
import { NextRequest, NextResponse } from "next/server";
|
||||
import path from "path";
|
||||
|
||||
/**
|
||||
* This API is to get file data from allowed folders
|
||||
* It receives path slug and response file data like serve static file
|
||||
*/
|
||||
export async function GET(
|
||||
_request: NextRequest,
|
||||
{ params }: { params: { slug: string[] } },
|
||||
) {
|
||||
const slug = params.slug;
|
||||
|
||||
if (!slug) {
|
||||
return NextResponse.json({ detail: "Missing file slug" }, { status: 400 });
|
||||
}
|
||||
|
||||
if (slug.includes("..") || path.isAbsolute(path.join(...slug))) {
|
||||
return NextResponse.json({ detail: "Invalid file path" }, { status: 400 });
|
||||
}
|
||||
|
||||
const [folder, ...pathTofile] = params.slug; // data, file.pdf
|
||||
const allowedFolders = ["data", "tool-output"];
|
||||
|
||||
if (!allowedFolders.includes(folder)) {
|
||||
return NextResponse.json({ detail: "No permission" }, { status: 400 });
|
||||
}
|
||||
|
||||
try {
|
||||
const filePath = path.join(process.cwd(), folder, path.join(...pathTofile));
|
||||
const blob = await readFile(filePath);
|
||||
|
||||
return new NextResponse(blob, {
|
||||
status: 200,
|
||||
statusText: "OK",
|
||||
headers: {
|
||||
"Content-Length": blob.byteLength.toString(),
|
||||
},
|
||||
});
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
return NextResponse.json({ detail: "File not found" }, { status: 404 });
|
||||
}
|
||||
}
|
||||
@@ -17,7 +17,8 @@ export default function ChatSection() {
|
||||
headers: {
|
||||
"Content-Type": "application/json", // using JSON because of vercel/ai 2.2.26
|
||||
},
|
||||
onError: (error) => {
|
||||
onError: (error: unknown) => {
|
||||
if (!(error instanceof Error)) throw error;
|
||||
const message = JSON.parse(error.message);
|
||||
alert(message.detail);
|
||||
},
|
||||
|
||||
@@ -38,7 +38,9 @@ export function ChatEvents({
|
||||
<CollapsibleContent asChild>
|
||||
<div className="mt-4 text-sm space-y-2">
|
||||
{data.map((eventItem, index) => (
|
||||
<div key={index}>{eventItem.title}</div>
|
||||
<div className="whitespace-break-spaces" key={index}>
|
||||
{eventItem.title}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</CollapsibleContent>
|
||||
|
||||
@@ -7,6 +7,7 @@ import ChatAvatar from "./chat-avatar";
|
||||
import { ChatEvents } from "./chat-events";
|
||||
import { ChatImage } from "./chat-image";
|
||||
import { ChatSources } from "./chat-sources";
|
||||
import ChatTools from "./chat-tools";
|
||||
import {
|
||||
AnnotationData,
|
||||
EventData,
|
||||
@@ -14,6 +15,7 @@ import {
|
||||
MessageAnnotation,
|
||||
MessageAnnotationType,
|
||||
SourceData,
|
||||
ToolData,
|
||||
} from "./index";
|
||||
import Markdown from "./markdown";
|
||||
import { useCopyToClipboard } from "./use-copy-to-clipboard";
|
||||
@@ -52,19 +54,27 @@ function ChatMessageContent({
|
||||
annotations,
|
||||
MessageAnnotationType.SOURCES,
|
||||
);
|
||||
const toolData = getAnnotationData<ToolData>(
|
||||
annotations,
|
||||
MessageAnnotationType.TOOLS,
|
||||
);
|
||||
|
||||
const contents: ContentDisplayConfig[] = [
|
||||
{
|
||||
order: -2,
|
||||
order: -3,
|
||||
component: imageData[0] ? <ChatImage data={imageData[0]} /> : null,
|
||||
},
|
||||
{
|
||||
order: -1,
|
||||
order: -2,
|
||||
component:
|
||||
eventData.length > 0 ? (
|
||||
<ChatEvents isLoading={isLoading} data={eventData} />
|
||||
) : null,
|
||||
},
|
||||
{
|
||||
order: -1,
|
||||
component: toolData[0] ? <ChatTools data={toolData[0]} /> : null,
|
||||
},
|
||||
{
|
||||
order: 0,
|
||||
component: <Markdown content={message.content} />,
|
||||
|
||||
@@ -40,9 +40,16 @@ export default function ChatMessages(
|
||||
className="flex h-[50vh] flex-col gap-5 divide-y overflow-y-auto pb-4"
|
||||
ref={scrollableChatContainerRef}
|
||||
>
|
||||
{props.messages.map((m) => (
|
||||
<ChatMessage key={m.id} chatMessage={m} isLoading={props.isLoading} />
|
||||
))}
|
||||
{props.messages.map((m, i) => {
|
||||
const isLoadingMessage = i === messageLength - 1 && props.isLoading;
|
||||
return (
|
||||
<ChatMessage
|
||||
key={m.id}
|
||||
chatMessage={m}
|
||||
isLoading={isLoadingMessage}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
{isPending && (
|
||||
<div className="flex justify-center items-center pt-10">
|
||||
<Loader2 className="h-4 w-4 animate-spin" />
|
||||
|
||||
@@ -1,20 +1,80 @@
|
||||
import { ArrowUpRightSquare, Check, Copy } from "lucide-react";
|
||||
import { Check, Copy } from "lucide-react";
|
||||
import { useMemo } from "react";
|
||||
import { Button } from "../button";
|
||||
import { HoverCard, HoverCardContent, HoverCardTrigger } from "../hover-card";
|
||||
import { getStaticFileDataUrl } from "../lib/url";
|
||||
import { SourceData, SourceNode } from "./index";
|
||||
import { useCopyToClipboard } from "./use-copy-to-clipboard";
|
||||
import PdfDialog from "./widgets/PdfDialog";
|
||||
|
||||
const SCORE_THRESHOLD = 0.5;
|
||||
const DATA_SOURCE_FOLDER = "data";
|
||||
const SCORE_THRESHOLD = 0.3;
|
||||
|
||||
function SourceNumberButton({ index }: { index: number }) {
|
||||
return (
|
||||
<div className="text-xs w-5 h-5 rounded-full bg-gray-100 mb-2 flex items-center justify-center hover:text-white hover:bg-primary hover:cursor-pointer">
|
||||
{index + 1}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
enum NODE_TYPE {
|
||||
URL,
|
||||
FILE,
|
||||
UNKNOWN,
|
||||
}
|
||||
|
||||
type NodeInfo = {
|
||||
id: string;
|
||||
type: NODE_TYPE;
|
||||
path?: string;
|
||||
url?: string;
|
||||
};
|
||||
|
||||
function getNodeInfo(node: SourceNode): NodeInfo {
|
||||
if (typeof node.metadata["URL"] === "string") {
|
||||
const url = node.metadata["URL"];
|
||||
return {
|
||||
id: node.id,
|
||||
type: NODE_TYPE.URL,
|
||||
path: url,
|
||||
url,
|
||||
};
|
||||
}
|
||||
if (typeof node.metadata["file_path"] === "string") {
|
||||
const fileName = node.metadata["file_name"] as string;
|
||||
const filePath = `${DATA_SOURCE_FOLDER}/${fileName}`;
|
||||
return {
|
||||
id: node.id,
|
||||
type: NODE_TYPE.FILE,
|
||||
path: node.metadata["file_path"],
|
||||
url: getStaticFileDataUrl(filePath),
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
id: node.id,
|
||||
type: NODE_TYPE.UNKNOWN,
|
||||
};
|
||||
}
|
||||
|
||||
export function ChatSources({ data }: { data: SourceData }) {
|
||||
const sources = useMemo(() => {
|
||||
return (
|
||||
data.nodes
|
||||
?.filter((node) => Object.keys(node.metadata).length > 0)
|
||||
?.filter((node) => (node.score ?? 1) > SCORE_THRESHOLD)
|
||||
.sort((a, b) => (b.score ?? 1) - (a.score ?? 1)) || []
|
||||
);
|
||||
const sources: NodeInfo[] = useMemo(() => {
|
||||
// aggregate nodes by url or file_path (get the highest one by score)
|
||||
const nodesByPath: { [path: string]: NodeInfo } = {};
|
||||
|
||||
data.nodes
|
||||
.filter((node) => (node.score ?? 1) > SCORE_THRESHOLD)
|
||||
.sort((a, b) => (b.score ?? 1) - (a.score ?? 1))
|
||||
.forEach((node) => {
|
||||
const nodeInfo = getNodeInfo(node);
|
||||
const key = nodeInfo.path ?? nodeInfo.id; // use id as key for UNKNOWN type
|
||||
if (!nodesByPath[key]) {
|
||||
nodesByPath[key] = nodeInfo;
|
||||
}
|
||||
});
|
||||
|
||||
return Object.values(nodesByPath);
|
||||
}, [data.nodes]);
|
||||
|
||||
if (sources.length === 0) return null;
|
||||
@@ -23,55 +83,52 @@ export function ChatSources({ data }: { data: SourceData }) {
|
||||
<div className="space-x-2 text-sm">
|
||||
<span className="font-semibold">Sources:</span>
|
||||
<div className="inline-flex gap-1 items-center">
|
||||
{sources.map((node: SourceNode, index: number) => (
|
||||
<div key={node.id}>
|
||||
<HoverCard>
|
||||
<HoverCardTrigger>
|
||||
<div className="text-xs w-5 h-5 rounded-full bg-gray-100 mb-2 flex items-center justify-center hover:text-white hover:bg-primary hover:cursor-pointer">
|
||||
{index + 1}
|
||||
</div>
|
||||
</HoverCardTrigger>
|
||||
<HoverCardContent>
|
||||
<NodeInfo node={node} />
|
||||
</HoverCardContent>
|
||||
</HoverCard>
|
||||
</div>
|
||||
))}
|
||||
{sources.map((nodeInfo: NodeInfo, index: number) => {
|
||||
if (nodeInfo.path?.endsWith(".pdf")) {
|
||||
return (
|
||||
<PdfDialog
|
||||
key={nodeInfo.id}
|
||||
documentId={nodeInfo.id}
|
||||
url={nodeInfo.url!}
|
||||
path={nodeInfo.path}
|
||||
trigger={<SourceNumberButton index={index} />}
|
||||
/>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<div key={nodeInfo.id}>
|
||||
<HoverCard>
|
||||
<HoverCardTrigger>
|
||||
<SourceNumberButton index={index} />
|
||||
</HoverCardTrigger>
|
||||
<HoverCardContent className="w-[320px]">
|
||||
<NodeInfo nodeInfo={nodeInfo} />
|
||||
</HoverCardContent>
|
||||
</HoverCard>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function NodeInfo({ node }: { node: SourceNode }) {
|
||||
function NodeInfo({ nodeInfo }: { nodeInfo: NodeInfo }) {
|
||||
const { isCopied, copyToClipboard } = useCopyToClipboard({ timeout: 1000 });
|
||||
|
||||
if (typeof node.metadata["URL"] === "string") {
|
||||
// this is a node generated by the web loader, it contains an external URL
|
||||
// add a link to view this URL
|
||||
if (nodeInfo.type !== NODE_TYPE.UNKNOWN) {
|
||||
// this is a node generated by the web loader or file loader,
|
||||
// add a link to view its URL and a button to copy the URL to the clipboard
|
||||
return (
|
||||
<a
|
||||
className="space-x-2 flex items-center my-2 hover:text-blue-900"
|
||||
href={node.metadata["URL"]}
|
||||
target="_blank"
|
||||
>
|
||||
<span>{node.metadata["URL"]}</span>
|
||||
<ArrowUpRightSquare className="w-4 h-4" />
|
||||
</a>
|
||||
);
|
||||
}
|
||||
|
||||
if (typeof node.metadata["file_path"] === "string") {
|
||||
// this is a node generated by the file loader, it contains file path
|
||||
// add a button to copy the path to the clipboard
|
||||
const filePath = node.metadata["file_path"];
|
||||
return (
|
||||
<div className="flex items-center px-2 py-1 justify-between my-2">
|
||||
<span>{filePath}</span>
|
||||
<div className="flex items-center my-2">
|
||||
<a className="hover:text-blue-900" href={nodeInfo.url} target="_blank">
|
||||
<span>{nodeInfo.path}</span>
|
||||
</a>
|
||||
<Button
|
||||
onClick={() => copyToClipboard(filePath)}
|
||||
onClick={() => copyToClipboard(nodeInfo.path!)}
|
||||
size="icon"
|
||||
variant="ghost"
|
||||
className="h-12 w-12"
|
||||
className="h-12 w-12 shrink-0"
|
||||
>
|
||||
{isCopied ? (
|
||||
<Check className="h-4 w-4" />
|
||||
@@ -84,7 +141,6 @@ function NodeInfo({ node }: { node: SourceNode }) {
|
||||
}
|
||||
|
||||
// node generated by unknown loader, implement renderer by analyzing logged out metadata
|
||||
console.log("Node metadata", node.metadata);
|
||||
return (
|
||||
<p>
|
||||
Sorry, unknown node type. Please add a new renderer in the NodeInfo
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
import { ToolData } from "./index";
|
||||
import { WeatherCard, WeatherData } from "./widgets/WeatherCard";
|
||||
|
||||
// TODO: If needed, add displaying more tool outputs here
|
||||
export default function ChatTools({ data }: { data: ToolData }) {
|
||||
if (!data) return null;
|
||||
const { toolCall, toolOutput } = data;
|
||||
|
||||
if (toolOutput.isError) {
|
||||
return (
|
||||
<div className="border-l-2 border-red-400 pl-2">
|
||||
There was an error when calling the tool {toolCall.name} with input:{" "}
|
||||
<br />
|
||||
{JSON.stringify(toolCall.input)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
switch (toolCall.name) {
|
||||
case "get_weather_information":
|
||||
const weatherData = toolOutput.output as unknown as WeatherData;
|
||||
return <WeatherCard data={weatherData} />;
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
import { JSONValue } from "ai";
|
||||
import ChatInput from "./chat-input";
|
||||
import ChatMessages from "./chat-messages";
|
||||
|
||||
@@ -8,6 +9,7 @@ export enum MessageAnnotationType {
|
||||
IMAGE = "image",
|
||||
SOURCES = "sources",
|
||||
EVENTS = "events",
|
||||
TOOLS = "tools",
|
||||
}
|
||||
|
||||
export type ImageData = {
|
||||
@@ -30,7 +32,21 @@ export type EventData = {
|
||||
isCollapsed: boolean;
|
||||
};
|
||||
|
||||
export type AnnotationData = ImageData | SourceData | EventData;
|
||||
export type ToolData = {
|
||||
toolCall: {
|
||||
id: string;
|
||||
name: string;
|
||||
input: {
|
||||
[key: string]: JSONValue;
|
||||
};
|
||||
};
|
||||
toolOutput: {
|
||||
output: JSONValue;
|
||||
isError: boolean;
|
||||
};
|
||||
};
|
||||
|
||||
export type AnnotationData = ImageData | SourceData | EventData | ToolData;
|
||||
|
||||
export type MessageAnnotation = {
|
||||
type: MessageAnnotationType;
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import "katex/dist/katex.min.css";
|
||||
import { FC, memo } from "react";
|
||||
import ReactMarkdown, { Options } from "react-markdown";
|
||||
import rehypeKatex from "rehype-katex";
|
||||
import remarkGfm from "remark-gfm";
|
||||
import remarkMath from "remark-math";
|
||||
|
||||
@@ -12,11 +14,27 @@ const MemoizedReactMarkdown: FC<Options> = memo(
|
||||
prevProps.className === nextProps.className,
|
||||
);
|
||||
|
||||
const preprocessLaTeX = (content: string) => {
|
||||
// Replace block-level LaTeX delimiters \[ \] with $$ $$
|
||||
const blockProcessedContent = content.replace(
|
||||
/\\\[(.*?)\\\]/gs,
|
||||
(_, equation) => `$$${equation}$$`,
|
||||
);
|
||||
// Replace inline LaTeX delimiters \( \) with $ $
|
||||
const inlineProcessedContent = blockProcessedContent.replace(
|
||||
/\\\((.*?)\\\)/gs,
|
||||
(_, equation) => `$${equation}$`,
|
||||
);
|
||||
return inlineProcessedContent;
|
||||
};
|
||||
|
||||
export default function Markdown({ content }: { content: string }) {
|
||||
const processedContent = preprocessLaTeX(content);
|
||||
return (
|
||||
<MemoizedReactMarkdown
|
||||
className="prose dark:prose-invert prose-p:leading-relaxed prose-pre:p-0 break-words custom-markdown"
|
||||
remarkPlugins={[remarkGfm, remarkMath]}
|
||||
rehypePlugins={[rehypeKatex as any]}
|
||||
components={{
|
||||
p({ children }) {
|
||||
return <p className="mb-2 last:mb-0">{children}</p>;
|
||||
@@ -53,7 +71,7 @@ export default function Markdown({ content }: { content: string }) {
|
||||
},
|
||||
}}
|
||||
>
|
||||
{content}
|
||||
{processedContent}
|
||||
</MemoizedReactMarkdown>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
import { PDFViewer, PdfFocusProvider } from "@llamaindex/pdf-viewer";
|
||||
import { Button } from "../../button";
|
||||
import {
|
||||
Drawer,
|
||||
DrawerClose,
|
||||
DrawerContent,
|
||||
DrawerDescription,
|
||||
DrawerHeader,
|
||||
DrawerTitle,
|
||||
DrawerTrigger,
|
||||
} from "../../drawer";
|
||||
|
||||
export interface PdfDialogProps {
|
||||
documentId: string;
|
||||
path: string;
|
||||
url: string;
|
||||
trigger: React.ReactNode;
|
||||
}
|
||||
|
||||
export default function PdfDialog(props: PdfDialogProps) {
|
||||
return (
|
||||
<Drawer direction="left">
|
||||
<DrawerTrigger>{props.trigger}</DrawerTrigger>
|
||||
<DrawerContent className="w-3/5 mt-24 h-full max-h-[96%] ">
|
||||
<DrawerHeader className="flex justify-between">
|
||||
<div className="space-y-2">
|
||||
<DrawerTitle>PDF Content</DrawerTitle>
|
||||
<DrawerDescription>
|
||||
File path:{" "}
|
||||
<a
|
||||
className="hover:text-blue-900"
|
||||
href={props.url}
|
||||
target="_blank"
|
||||
>
|
||||
{props.path}
|
||||
</a>
|
||||
</DrawerDescription>
|
||||
</div>
|
||||
<DrawerClose asChild>
|
||||
<Button variant="outline">Close</Button>
|
||||
</DrawerClose>
|
||||
</DrawerHeader>
|
||||
<div className="m-4">
|
||||
<PdfFocusProvider>
|
||||
<PDFViewer
|
||||
file={{
|
||||
id: props.documentId,
|
||||
url: props.url,
|
||||
}}
|
||||
/>
|
||||
</PdfFocusProvider>
|
||||
</div>
|
||||
</DrawerContent>
|
||||
</Drawer>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,213 @@
|
||||
export interface WeatherData {
|
||||
latitude: number;
|
||||
longitude: number;
|
||||
generationtime_ms: number;
|
||||
utc_offset_seconds: number;
|
||||
timezone: string;
|
||||
timezone_abbreviation: string;
|
||||
elevation: number;
|
||||
current_units: {
|
||||
time: string;
|
||||
interval: string;
|
||||
temperature_2m: string;
|
||||
weather_code: string;
|
||||
};
|
||||
current: {
|
||||
time: string;
|
||||
interval: number;
|
||||
temperature_2m: number;
|
||||
weather_code: number;
|
||||
};
|
||||
hourly_units: {
|
||||
time: string;
|
||||
temperature_2m: string;
|
||||
weather_code: string;
|
||||
};
|
||||
hourly: {
|
||||
time: string[];
|
||||
temperature_2m: number[];
|
||||
weather_code: number[];
|
||||
};
|
||||
daily_units: {
|
||||
time: string;
|
||||
weather_code: string;
|
||||
};
|
||||
daily: {
|
||||
time: string[];
|
||||
weather_code: number[];
|
||||
};
|
||||
}
|
||||
|
||||
// Follow WMO Weather interpretation codes (WW)
|
||||
const weatherCodeDisplayMap: Record<
|
||||
string,
|
||||
{
|
||||
icon: JSX.Element;
|
||||
status: string;
|
||||
}
|
||||
> = {
|
||||
"0": {
|
||||
icon: <span>☀️</span>,
|
||||
status: "Clear sky",
|
||||
},
|
||||
"1": {
|
||||
icon: <span>🌤️</span>,
|
||||
status: "Mainly clear",
|
||||
},
|
||||
"2": {
|
||||
icon: <span>☁️</span>,
|
||||
status: "Partly cloudy",
|
||||
},
|
||||
"3": {
|
||||
icon: <span>☁️</span>,
|
||||
status: "Overcast",
|
||||
},
|
||||
"45": {
|
||||
icon: <span>🌫️</span>,
|
||||
status: "Fog",
|
||||
},
|
||||
"48": {
|
||||
icon: <span>🌫️</span>,
|
||||
status: "Depositing rime fog",
|
||||
},
|
||||
"51": {
|
||||
icon: <span>🌧️</span>,
|
||||
status: "Drizzle",
|
||||
},
|
||||
"53": {
|
||||
icon: <span>🌧️</span>,
|
||||
status: "Drizzle",
|
||||
},
|
||||
"55": {
|
||||
icon: <span>🌧️</span>,
|
||||
status: "Drizzle",
|
||||
},
|
||||
"56": {
|
||||
icon: <span>🌧️</span>,
|
||||
status: "Freezing Drizzle",
|
||||
},
|
||||
"57": {
|
||||
icon: <span>🌧️</span>,
|
||||
status: "Freezing Drizzle",
|
||||
},
|
||||
"61": {
|
||||
icon: <span>🌧️</span>,
|
||||
status: "Rain",
|
||||
},
|
||||
"63": {
|
||||
icon: <span>🌧️</span>,
|
||||
status: "Rain",
|
||||
},
|
||||
"65": {
|
||||
icon: <span>🌧️</span>,
|
||||
status: "Rain",
|
||||
},
|
||||
"66": {
|
||||
icon: <span>🌧️</span>,
|
||||
status: "Freezing Rain",
|
||||
},
|
||||
"67": {
|
||||
icon: <span>🌧️</span>,
|
||||
status: "Freezing Rain",
|
||||
},
|
||||
"71": {
|
||||
icon: <span>❄️</span>,
|
||||
status: "Snow fall",
|
||||
},
|
||||
"73": {
|
||||
icon: <span>❄️</span>,
|
||||
status: "Snow fall",
|
||||
},
|
||||
"75": {
|
||||
icon: <span>❄️</span>,
|
||||
status: "Snow fall",
|
||||
},
|
||||
"77": {
|
||||
icon: <span>❄️</span>,
|
||||
status: "Snow grains",
|
||||
},
|
||||
"80": {
|
||||
icon: <span>🌧️</span>,
|
||||
status: "Rain showers",
|
||||
},
|
||||
"81": {
|
||||
icon: <span>🌧️</span>,
|
||||
status: "Rain showers",
|
||||
},
|
||||
"82": {
|
||||
icon: <span>🌧️</span>,
|
||||
status: "Rain showers",
|
||||
},
|
||||
"85": {
|
||||
icon: <span>❄️</span>,
|
||||
status: "Snow showers",
|
||||
},
|
||||
"86": {
|
||||
icon: <span>❄️</span>,
|
||||
status: "Snow showers",
|
||||
},
|
||||
"95": {
|
||||
icon: <span>⛈️</span>,
|
||||
status: "Thunderstorm",
|
||||
},
|
||||
"96": {
|
||||
icon: <span>⛈️</span>,
|
||||
status: "Thunderstorm",
|
||||
},
|
||||
"99": {
|
||||
icon: <span>⛈️</span>,
|
||||
status: "Thunderstorm",
|
||||
},
|
||||
};
|
||||
|
||||
const displayDay = (time: string) => {
|
||||
return new Date(time).toLocaleDateString("en-US", {
|
||||
weekday: "long",
|
||||
});
|
||||
};
|
||||
|
||||
export function WeatherCard({ data }: { data: WeatherData }) {
|
||||
const currentDayString = new Date(data.current.time).toLocaleDateString(
|
||||
"en-US",
|
||||
{
|
||||
weekday: "long",
|
||||
month: "long",
|
||||
day: "numeric",
|
||||
},
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="bg-[#61B9F2] rounded-2xl shadow-xl p-5 space-y-4 text-white w-fit">
|
||||
<div className="flex justify-between">
|
||||
<div className="space-y-2">
|
||||
<div className="text-xl">{currentDayString}</div>
|
||||
<div className="text-5xl font-semibold flex gap-4">
|
||||
<span>
|
||||
{data.current.temperature_2m} {data.current_units.temperature_2m}
|
||||
</span>
|
||||
{weatherCodeDisplayMap[data.current.weather_code].icon}
|
||||
</div>
|
||||
</div>
|
||||
<span className="text-xl">
|
||||
{weatherCodeDisplayMap[data.current.weather_code].status}
|
||||
</span>
|
||||
</div>
|
||||
<div className="gap-2 grid grid-cols-6">
|
||||
{data.daily.time.map((time, index) => {
|
||||
if (index === 0) return null; // skip the current day
|
||||
return (
|
||||
<div key={time} className="flex flex-col items-center gap-4">
|
||||
<span>{displayDay(time)}</span>
|
||||
<div className="text-4xl">
|
||||
{weatherCodeDisplayMap[data.daily.weather_code[index]].icon}
|
||||
</div>
|
||||
<span className="text-sm">
|
||||
{weatherCodeDisplayMap[data.daily.weather_code[index]].status}
|
||||
</span>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,118 @@
|
||||
"use client";
|
||||
|
||||
import * as React from "react";
|
||||
import { Drawer as DrawerPrimitive } from "vaul";
|
||||
|
||||
import { cn } from "./lib/utils";
|
||||
|
||||
const Drawer = ({
|
||||
shouldScaleBackground = true,
|
||||
...props
|
||||
}: React.ComponentProps<typeof DrawerPrimitive.Root>) => (
|
||||
<DrawerPrimitive.Root
|
||||
shouldScaleBackground={shouldScaleBackground}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
Drawer.displayName = "Drawer";
|
||||
|
||||
const DrawerTrigger = DrawerPrimitive.Trigger;
|
||||
|
||||
const DrawerPortal = DrawerPrimitive.Portal;
|
||||
|
||||
const DrawerClose = DrawerPrimitive.Close;
|
||||
|
||||
const DrawerOverlay = React.forwardRef<
|
||||
React.ElementRef<typeof DrawerPrimitive.Overlay>,
|
||||
React.ComponentPropsWithoutRef<typeof DrawerPrimitive.Overlay>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<DrawerPrimitive.Overlay
|
||||
ref={ref}
|
||||
className={cn("fixed inset-0 z-50 bg-black/80", className)}
|
||||
{...props}
|
||||
/>
|
||||
));
|
||||
DrawerOverlay.displayName = DrawerPrimitive.Overlay.displayName;
|
||||
|
||||
const DrawerContent = React.forwardRef<
|
||||
React.ElementRef<typeof DrawerPrimitive.Content>,
|
||||
React.ComponentPropsWithoutRef<typeof DrawerPrimitive.Content>
|
||||
>(({ className, children, ...props }, ref) => (
|
||||
<DrawerPortal>
|
||||
<DrawerOverlay />
|
||||
<DrawerPrimitive.Content
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"fixed inset-x-0 bottom-0 z-50 mt-24 flex h-auto flex-col rounded-t-[10px] border bg-background",
|
||||
className,
|
||||
)}
|
||||
{...props}
|
||||
>
|
||||
<div className="mx-auto mt-4 h-2 w-[100px] rounded-full bg-muted" />
|
||||
{children}
|
||||
</DrawerPrimitive.Content>
|
||||
</DrawerPortal>
|
||||
));
|
||||
DrawerContent.displayName = "DrawerContent";
|
||||
|
||||
const DrawerHeader = ({
|
||||
className,
|
||||
...props
|
||||
}: React.HTMLAttributes<HTMLDivElement>) => (
|
||||
<div
|
||||
className={cn("grid gap-1.5 p-4 text-center sm:text-left", className)}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
DrawerHeader.displayName = "DrawerHeader";
|
||||
|
||||
const DrawerFooter = ({
|
||||
className,
|
||||
...props
|
||||
}: React.HTMLAttributes<HTMLDivElement>) => (
|
||||
<div
|
||||
className={cn("mt-auto flex flex-col gap-2 p-4", className)}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
DrawerFooter.displayName = "DrawerFooter";
|
||||
|
||||
const DrawerTitle = React.forwardRef<
|
||||
React.ElementRef<typeof DrawerPrimitive.Title>,
|
||||
React.ComponentPropsWithoutRef<typeof DrawerPrimitive.Title>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<DrawerPrimitive.Title
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"text-lg font-semibold leading-none tracking-tight",
|
||||
className,
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
));
|
||||
DrawerTitle.displayName = DrawerPrimitive.Title.displayName;
|
||||
|
||||
const DrawerDescription = React.forwardRef<
|
||||
React.ElementRef<typeof DrawerPrimitive.Description>,
|
||||
React.ComponentPropsWithoutRef<typeof DrawerPrimitive.Description>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<DrawerPrimitive.Description
|
||||
ref={ref}
|
||||
className={cn("text-sm text-muted-foreground", className)}
|
||||
{...props}
|
||||
/>
|
||||
));
|
||||
DrawerDescription.displayName = DrawerPrimitive.Description.displayName;
|
||||
|
||||
export {
|
||||
Drawer,
|
||||
DrawerClose,
|
||||
DrawerContent,
|
||||
DrawerDescription,
|
||||
DrawerFooter,
|
||||
DrawerHeader,
|
||||
DrawerOverlay,
|
||||
DrawerPortal,
|
||||
DrawerTitle,
|
||||
DrawerTrigger,
|
||||
};
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user