Compare commits

...

17 Commits

Author SHA1 Message Date
leehuwuj 05748bdf10 refactor code 2024-05-27 14:53:01 +07:00
leehuwuj d60b3c5a96 refactor code and add changeset 2024-05-27 13:09:59 +07:00
leehuwuj c3e9ed3df4 feat: add support for FastAPI in code interpreter tool 2024-05-27 12:37:49 +07:00
github-actions[bot] 1fde1dc585 Release 0.1.8 (#97)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2024-05-23 22:15:51 +07:00
Thuc Pham cd50a33d43 feat: implement interpreter tool (#94)
---------
Co-authored-by: Marcus Schiesser <mail@marcusschiesser.de>
2024-05-23 21:49:18 +07:00
github-actions[bot] ed114856d9 Release 0.1.7 (#93)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2024-05-22 18:30:49 +07:00
Marcus Schiesser 69c2e16c82 fix: streaming for express 2024-05-22 13:04:35 +02:00
Marcus Schiesser f5da6623cf fix: update llamaindex, use 127.0.0.1 for ollama as default 2024-05-22 12:42:34 +02:00
Marcus Schiesser 0950cb90f2 fix: global-agent types 2024-05-22 11:50:34 +02:00
Mohammad Amir bb53425b4b Proxy support added via global agent (#76) 2024-05-22 16:35:03 +07:00
Huu Le (Lee) bbd5b8ddd6 fix: Reuse PG vector store to avoid recreating sqlalchemy engine (#95) 2024-05-22 16:12:44 +07:00
Thuc Pham 260d37a3f1 feat(ts): add system prompt for chat engine (#92) 2024-05-20 16:12:19 +07:00
Huu Le (Lee) 7873bfb030 chore: Add Ollama API base URL environment variable (#91) 2024-05-17 17:01:06 +07:00
github-actions[bot] 0c7c41ee3b Release 0.1.6 (#90)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2024-05-16 19:08:40 +07:00
Thuc Pham 56537a1473 feat: host local files and add viewer for PDFs (#85) 2024-05-16 18:06:26 +07:00
github-actions[bot] d8dfc29edd Release 0.1.5 (#89)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2024-05-16 16:12:40 +07:00
Thuc Pham 84db798353 feat: support display latex in chat markdown (#88) 2024-05-16 15:25:53 +07:00
40 changed files with 1267 additions and 278 deletions
+5
View File
@@ -0,0 +1,5 @@
---
"create-llama": patch
---
Add support E2B code interpreter tool for FastAPI
+1 -1
View File
@@ -26,7 +26,7 @@ jobs:
- uses: actions/checkout@v4
- name: Set up python ${{ matrix.python-version }}
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
+28
View File
@@ -1,5 +1,33 @@
# create-llama
## 0.1.8
### Patch Changes
- cd50a33: Add interpreter tool for TS using e2b.dev
## 0.1.7
### Patch Changes
- 260d37a: Add system prompt env variable for TS
- bbd5b8d: Fix postgres connection leaking issue
- bb53425: Support HTTP proxies by setting the GLOBAL_AGENT_HTTP_PROXY env variable
- 69c2e16: Fix streaming for Express
- 7873bfb: Update Ollama provider to run with the base URL from the environment variable
## 0.1.6
### Patch Changes
- 56537a1: Display PDF files in source nodes
## 0.1.5
### Patch Changes
- 84db798: feat: support display latex in chat markdown
## 0.1.4
### Patch Changes
+86 -22
View File
@@ -1,5 +1,6 @@
import fs from "fs/promises";
import path from "path";
import { TOOL_SYSTEM_PROMPT_ENV_VAR, Tool } from "./tools";
import {
ModelConfig,
TemplateDataSource,
@@ -7,7 +8,7 @@ import {
TemplateVectorDB,
} from "./types";
type EnvVar = {
export type EnvVar = {
name?: string;
description?: string;
value?: string;
@@ -219,35 +220,52 @@ const getModelEnvs = (modelConfig: ModelConfig): EnvVar[] => {
},
]
: []),
...(modelConfig.provider === "ollama"
? [
{
name: "OLLAMA_BASE_URL",
description:
"The base URL for the Ollama API. Eg: http://127.0.0.1:11434",
},
]
: []),
];
};
const getFrameworkEnvs = (
framework?: TemplateFramework,
framework: TemplateFramework,
port?: number,
): EnvVar[] => {
if (framework !== "fastapi") {
return [];
}
return [
const sPort = port?.toString() || "8000";
const result: EnvVar[] = [
{
name: "APP_HOST",
description: "The address to start the backend app.",
value: "0.0.0.0",
},
{
name: "APP_PORT",
description: "The port to start the backend app.",
value: port?.toString() || "8000",
},
// TODO: Once LlamaIndexTS supports string templates, move this to `getEngineEnvs`
{
name: "SYSTEM_PROMPT",
description: `Custom system prompt.
Example:
SYSTEM_PROMPT="You are a helpful assistant who helps users with their questions."`,
name: "FILESERVER_URL_PREFIX",
description:
"FILESERVER_URL_PREFIX is the URL prefix of the server storing the images generated by the interpreter.",
value:
framework === "nextjs"
? // FIXME: if we are using nextjs, port should be 3000
"http://localhost:3000/api/files"
: `http://localhost:${sPort}/api/files`,
},
];
if (framework === "fastapi") {
result.push(
...[
{
name: "APP_HOST",
description: "The address to start the backend app.",
value: "0.0.0.0",
},
{
name: "APP_PORT",
description: "The port to start the backend app.",
value: sPort,
},
],
);
}
return result;
};
const getEngineEnvs = (): EnvVar[] => {
@@ -261,15 +279,59 @@ const getEngineEnvs = (): EnvVar[] => {
];
};
const getToolEnvs = (tools?: Tool[]): EnvVar[] => {
if (!tools?.length) return [];
const toolEnvs: EnvVar[] = [];
tools.forEach((tool) => {
if (tool.envVars?.length) {
toolEnvs.push(
// Don't include the system prompt env var here
// It should be handled separately by merging with the default system prompt
...tool.envVars.filter(
(env) => env.name !== TOOL_SYSTEM_PROMPT_ENV_VAR,
),
);
}
});
return toolEnvs;
};
const getSystemPromptEnv = (tools?: Tool[]): EnvVar => {
const defaultSystemPrompt =
"You are a helpful assistant who helps users with their questions.";
// build tool system prompt by merging all tool system prompts
let toolSystemPrompt = "";
tools?.forEach((tool) => {
const toolSystemPromptEnv = tool.envVars?.find(
(env) => env.name === TOOL_SYSTEM_PROMPT_ENV_VAR,
);
if (toolSystemPromptEnv) {
toolSystemPrompt += toolSystemPromptEnv.value + "\n";
}
});
const systemPrompt = toolSystemPrompt
? `\"${toolSystemPrompt}\"`
: defaultSystemPrompt;
return {
name: "SYSTEM_PROMPT",
description: "The system prompt for the AI model.",
value: systemPrompt,
};
};
export const createBackendEnvFile = async (
root: string,
opts: {
llamaCloudKey?: string;
vectorDb?: TemplateVectorDB;
modelConfig: ModelConfig;
framework?: TemplateFramework;
framework: TemplateFramework;
dataSources?: TemplateDataSource[];
port?: number;
tools?: Tool[];
},
) => {
// Init env values
@@ -287,6 +349,8 @@ export const createBackendEnvFile = async (
// Add vector database environment variables
...getVectorDBEnvs(opts.vectorDb, opts.framework),
...getFrameworkEnvs(opts.framework, opts.port),
...getToolEnvs(opts.tools),
getSystemPromptEnv(opts.tools),
];
// Render and write env file
const content = renderEnvVar(envVars);
+6
View File
@@ -148,6 +148,7 @@ export const installTemplate = async (
framework: props.framework,
dataSources: props.dataSources,
port: props.externalPort,
tools: props.tools,
});
if (props.dataSources.length > 0) {
@@ -170,6 +171,11 @@ export const installTemplate = async (
);
}
}
// Create tool-output directory
if (props.tools && props.tools.length > 0) {
await fsExtra.mkdir(path.join(props.root, "tool-output"));
}
} else {
// this is a frontend for a full-stack app, create .env file with model information
await createFrontendEnvFile(props.root, {
+8
View File
@@ -0,0 +1,8 @@
/* Function to conditionally load the global-agent/bootstrap module */
export async function initializeGlobalAgent() {
if (process.env.GLOBAL_AGENT_HTTP_PROXY) {
/* Dynamically import global-agent/bootstrap */
await import("global-agent/bootstrap");
console.log("Proxy enabled via global-agent.");
}
}
+55
View File
@@ -2,9 +2,12 @@ import fs from "fs/promises";
import path from "path";
import { red } from "picocolors";
import yaml from "yaml";
import { EnvVar } from "./env-variables";
import { makeDir } from "./make-dir";
import { TemplateFramework } from "./types";
export const TOOL_SYSTEM_PROMPT_ENV_VAR = "TOOL_SYSTEM_PROMPT";
export enum ToolType {
LLAMAHUB = "llamahub",
LOCAL = "local",
@@ -17,6 +20,7 @@ export type Tool = {
dependencies?: ToolDependencies[];
supportedFrameworks?: Array<TemplateFramework>;
type: ToolType;
envVars?: EnvVar[];
};
export type ToolDependencies = {
@@ -42,6 +46,13 @@ export const supportedTools: Tool[] = [
],
supportedFrameworks: ["fastapi"],
type: ToolType.LLAMAHUB,
envVars: [
{
name: TOOL_SYSTEM_PROMPT_ENV_VAR,
description: "System prompt for google search tool.",
value: `You are a Google search agent. You help users to get information from Google search.`,
},
],
},
{
display: "Wikipedia",
@@ -54,6 +65,13 @@ export const supportedTools: Tool[] = [
],
supportedFrameworks: ["fastapi", "express", "nextjs"],
type: ToolType.LLAMAHUB,
envVars: [
{
name: TOOL_SYSTEM_PROMPT_ENV_VAR,
description: "System prompt for wiki tool.",
value: `You are a Wikipedia agent. You help users to get information from Wikipedia.`,
},
],
},
{
display: "Weather",
@@ -61,6 +79,43 @@ export const supportedTools: Tool[] = [
dependencies: [],
supportedFrameworks: ["fastapi", "express", "nextjs"],
type: ToolType.LOCAL,
envVars: [
{
name: TOOL_SYSTEM_PROMPT_ENV_VAR,
description: "System prompt for weather tool.",
value: `You are a weather forecast agent. You help users to get the weather forecast for a given location.`,
},
],
},
{
display: "Code Interpreter",
name: "interpreter",
dependencies: [
{
name: "e2b_code_interpreter",
version: "0.0.7",
},
],
supportedFrameworks: ["fastapi", "express", "nextjs"],
type: ToolType.LOCAL,
envVars: [
{
name: "E2B_API_KEY",
description:
"E2B_API_KEY key is required to run code interpreter tool. Get it here: https://e2b.dev/docs/getting-started/api-key",
},
{
name: TOOL_SYSTEM_PROMPT_ENV_VAR,
description: "System prompt for code interpreter tool.",
value: `You are a Python interpreter.
- You are given tasks to complete and you run python code to solve them.
- The python code runs in a Jupyter notebook. Every time you call \`interpreter\` tool, the python code is executed in a separate cell. It's okay to make multiple calls to \`interpreter\`.
- Display visualizations using matplotlib or any other visualization library directly in the notebook. Shouldn't save the visualizations to a file, just return the base64 encoded data.
- You can install any pip package (if it exists) if you need to but the usual packages for data analysis are already preinstalled.
- You can run any python code you want in a secure environment.
- Use absolute url from result to display images or any other media.`,
},
],
},
];
+4
View File
@@ -12,12 +12,16 @@ import { createApp } from "./create-app";
import { getDataSources } from "./helpers/datasources";
import { getPkgManager } from "./helpers/get-pkg-manager";
import { isFolderEmpty } from "./helpers/is-folder-empty";
import { initializeGlobalAgent } from "./helpers/proxy";
import { runApp } from "./helpers/run-app";
import { getTools } from "./helpers/tools";
import { validateNpmName } from "./helpers/validate-pkg";
import packageJson from "./package.json";
import { QuestionArgs, askQuestions, onPromptState } from "./questions";
// Run the initialization function
initializeGlobalAgent();
let projectPath: string = "";
const handleSigTerm = () => process.exit(0);
+2 -1
View File
@@ -1,6 +1,6 @@
{
"name": "create-llama",
"version": "0.1.4",
"version": "0.1.8",
"description": "Create LlamaIndex-powered apps with one command",
"keywords": [
"rag",
@@ -52,6 +52,7 @@
"cross-spawn": "7.0.3",
"fast-glob": "3.3.1",
"fs-extra": "11.2.0",
"global-agent": "^3.0.0",
"got": "10.7.0",
"ollama": "^0.5.0",
"ora": "^8.0.1",
+265 -145
View File
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,134 @@
import os
import logging
import base64
import uuid
from pydantic import BaseModel
from typing import List, Tuple, Dict
from llama_index.core.tools import FunctionTool
from e2b_code_interpreter import CodeInterpreter
from e2b_code_interpreter.models import Logs
logger = logging.getLogger(__name__)
class InterpreterExtraResult(BaseModel):
type: str
filename: str
url: str
class E2BToolOutput(BaseModel):
is_error: bool
logs: Logs
results: List[InterpreterExtraResult] = []
class E2BCodeInterpreter:
output_dir = "tool-output"
def __init__(self, api_key: str, filesever_url_prefix: str):
self.api_key = api_key
self.filesever_url_prefix = filesever_url_prefix
def get_output_path(self, filename: str) -> str:
# if output directory doesn't exist, create it
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir, exist_ok=True)
return os.path.join(self.output_dir, filename)
def save_to_disk(self, base64_data: str, ext: str) -> Dict:
filename = f"{uuid.uuid4()}.{ext}" # generate a unique filename
buffer = base64.b64decode(base64_data)
output_path = self.get_output_path(filename)
try:
with open(output_path, "wb") as file:
file.write(buffer)
except IOError as e:
logger.error(f"Failed to write to file {output_path}: {str(e)}")
raise e
logger.info(f"Saved file to {output_path}")
return {
"outputPath": output_path,
"filename": filename,
}
def get_file_url(self, filename: str) -> str:
return f"{self.filesever_url_prefix}/{self.output_dir}/{filename}"
def parse_result(self, result) -> List[InterpreterExtraResult]:
"""
The result could include multiple formats (e.g. png, svg, etc.) but encoded in base64
We save each result to disk and return saved file metadata (extension, filename, url)
"""
if not result:
return []
output = []
try:
formats = result.formats()
base64_data_arr = [result[format] for format in formats]
for ext, base64_data in zip(formats, base64_data_arr):
if ext and base64_data:
result = self.save_to_disk(base64_data, ext)
filename = result["filename"]
output.append(
InterpreterExtraResult(
type=ext, filename=filename, url=self.get_file_url(filename)
)
)
except Exception as error:
logger.error("Error when saving data to disk", error)
return output
def interpret(self, code: str) -> E2BToolOutput:
with CodeInterpreter(api_key=self.api_key) as interpreter:
logger.info(
f"\n{'='*50}\n> Running following AI-generated code:\n{code}\n{'='*50}"
)
exec = interpreter.notebook.exec_cell(code)
if exec.error:
output = E2BToolOutput(is_error=True, logs=[exec.error])
else:
if len(exec.results) == 0:
output = E2BToolOutput(is_error=False, logs=exec.logs, results=[])
else:
results = self.parse_result(exec.results[0])
output = E2BToolOutput(
is_error=False, logs=exec.logs, results=results
)
return output
def code_interpret(code: str) -> Dict:
"""
Execute python code in a Jupyter notebook cell and return any result, stdout, stderr, display_data, and error.
"""
api_key = os.getenv("E2B_API_KEY")
filesever_url_prefix = os.getenv("FILESERVER_URL_PREFIX")
if not api_key:
raise ValueError(
"E2B_API_KEY key is required to run code interpreter. Get it here: https://e2b.dev/docs/getting-started/api-key"
)
if not filesever_url_prefix:
raise ValueError(
"FILESERVER_URL_PREFIX is required to display file output from sandbox"
)
interpreter = E2BCodeInterpreter(
api_key=api_key, filesever_url_prefix=filesever_url_prefix
)
output = interpreter.interpret(code)
return output.dict()
# Specify as functions tools to be loaded by the ToolFactory
tools = [FunctionTool.from_defaults(code_interpret)]
@@ -1,10 +1,9 @@
import { BaseToolWithCall, OpenAIAgent, QueryEngineTool } from "llamaindex";
import { ToolsFactory } from "llamaindex/tools/ToolsFactory";
import fs from "node:fs/promises";
import path from "node:path";
import { getDataSource } from "./index";
import { STORAGE_CACHE_DIR } from "./shared";
import { createLocalTools } from "./tools";
import { createTools } from "./tools";
export async function createChatEngine() {
const tools: BaseToolWithCall[] = [];
@@ -24,22 +23,20 @@ export async function createChatEngine() {
);
}
const configFile = path.join("config", "tools.json");
let toolConfig: any;
try {
// add tools from config file if it exists
const config = JSON.parse(
await fs.readFile(path.join("config", "tools.json"), "utf8"),
);
// add local tools from the 'tools' folder (if configured)
const localTools = createLocalTools(config.local);
tools.push(...localTools);
// add tools from LlamaIndexTS (if configured)
const llamaTools = await ToolsFactory.createTools(config.llamahub);
tools.push(...llamaTools);
} catch {}
toolConfig = JSON.parse(await fs.readFile(configFile, "utf8"));
} catch (e) {
console.info(`Could not read ${configFile} file. Using no tools.`);
}
if (toolConfig) {
tools.push(...(await createTools(toolConfig)));
}
return new OpenAIAgent({
tools,
systemPrompt: process.env.SYSTEM_PROMPT,
});
}
@@ -1,15 +1,31 @@
import { BaseToolWithCall } from "llamaindex";
import { ToolsFactory } from "llamaindex/tools/ToolsFactory";
import { InterpreterTool, InterpreterToolParams } from "./interpreter";
import { WeatherTool, WeatherToolParams } from "./weather";
type ToolCreator = (config: unknown) => BaseToolWithCall;
export async function createTools(toolConfig: {
local: Record<string, unknown>;
llamahub: any;
}): Promise<BaseToolWithCall[]> {
// add local tools from the 'tools' folder (if configured)
const tools = createLocalTools(toolConfig.local);
// add tools from LlamaIndexTS (if configured)
tools.push(...(await ToolsFactory.createTools(toolConfig.llamahub)));
return tools;
}
const toolFactory: Record<string, ToolCreator> = {
weather: (config: unknown) => {
return new WeatherTool(config as WeatherToolParams);
},
interpreter: (config: unknown) => {
return new InterpreterTool(config as InterpreterToolParams);
},
};
export function createLocalTools(
function createLocalTools(
localConfig: Record<string, unknown>,
): BaseToolWithCall[] {
const tools: BaseToolWithCall[] = [];
@@ -0,0 +1,174 @@
import { CodeInterpreter, Logs, Result } from "@e2b/code-interpreter";
import type { JSONSchemaType } from "ajv";
import fs from "fs";
import { BaseTool, ToolMetadata } from "llamaindex";
import crypto from "node:crypto";
import path from "node:path";
export type InterpreterParameter = {
code: string;
};
export type InterpreterToolParams = {
metadata?: ToolMetadata<JSONSchemaType<InterpreterParameter>>;
apiKey?: string;
fileServerURLPrefix?: string;
};
export type InterpreterToolOuput = {
isError: boolean;
logs: Logs;
extraResult: InterpreterExtraResult[];
};
type InterpreterExtraType =
| "html"
| "markdown"
| "svg"
| "png"
| "jpeg"
| "pdf"
| "latex"
| "json"
| "javascript";
export type InterpreterExtraResult = {
type: InterpreterExtraType;
filename: string;
url: string;
};
const DEFAULT_META_DATA: ToolMetadata<JSONSchemaType<InterpreterParameter>> = {
name: "interpreter",
description:
"Execute python code in a Jupyter notebook cell and return any result, stdout, stderr, display_data, and error.",
parameters: {
type: "object",
properties: {
code: {
type: "string",
description: "The python code to execute in a single cell.",
},
},
required: ["code"],
},
};
export class InterpreterTool implements BaseTool<InterpreterParameter> {
private readonly outputDir = "tool-output";
private apiKey?: string;
private fileServerURLPrefix?: string;
metadata: ToolMetadata<JSONSchemaType<InterpreterParameter>>;
codeInterpreter?: CodeInterpreter;
constructor(params?: InterpreterToolParams) {
this.metadata = params?.metadata || DEFAULT_META_DATA;
this.apiKey = params?.apiKey || process.env.E2B_API_KEY;
this.fileServerURLPrefix =
params?.fileServerURLPrefix || process.env.FILESERVER_URL_PREFIX;
if (!this.apiKey) {
throw new Error(
"E2B_API_KEY key is required to run code interpreter. Get it here: https://e2b.dev/docs/getting-started/api-key",
);
}
if (!this.fileServerURLPrefix) {
throw new Error(
"FILESERVER_URL_PREFIX is required to display file output from sandbox",
);
}
}
public async initInterpreter() {
if (!this.codeInterpreter) {
this.codeInterpreter = await CodeInterpreter.create({
apiKey: this.apiKey,
});
}
return this.codeInterpreter;
}
public async codeInterpret(code: string): Promise<InterpreterToolOuput> {
console.log(
`\n${"=".repeat(50)}\n> Running following AI-generated code:\n${code}\n${"=".repeat(50)}`,
);
const interpreter = await this.initInterpreter();
const exec = await interpreter.notebook.execCell(code);
if (exec.error) console.error("[Code Interpreter error]", exec.error);
const extraResult = await this.getExtraResult(exec.results[0]);
const result: InterpreterToolOuput = {
isError: !!exec.error,
logs: exec.logs,
extraResult,
};
return result;
}
async call(input: InterpreterParameter): Promise<InterpreterToolOuput> {
const result = await this.codeInterpret(input.code);
await this.codeInterpreter?.close();
return result;
}
private async getExtraResult(
res?: Result,
): Promise<InterpreterExtraResult[]> {
if (!res) return [];
const output: InterpreterExtraResult[] = [];
try {
const formats = res.formats(); // formats available for the result. Eg: ['png', ...]
const base64DataArr = formats.map((f) => res[f as keyof Result]); // get base64 data for each format
// save base64 data to file and return the url
for (let i = 0; i < formats.length; i++) {
const ext = formats[i];
const base64Data = base64DataArr[i];
if (ext && base64Data) {
const { filename } = this.saveToDisk(base64Data, ext);
output.push({
type: ext as InterpreterExtraType,
filename,
url: this.getFileUrl(filename),
});
}
}
} catch (error) {
console.error("Error when saving data to disk", error);
}
return output;
}
// Consider saving to cloud storage instead but it may cost more for you
// See: https://e2b.dev/docs/sandbox/api/filesystem#write-to-file
private saveToDisk(
base64Data: string,
ext: string,
): {
outputPath: string;
filename: string;
} {
const filename = `${crypto.randomUUID()}.${ext}`; // generate a unique filename
const buffer = Buffer.from(base64Data, "base64");
const outputPath = this.getOutputPath(filename);
fs.writeFileSync(outputPath, buffer);
console.log(`Saved file to ${outputPath}`);
return {
outputPath,
filename,
};
}
private getOutputPath(filename: string): string {
// if outputDir doesn't exist, create it
if (!fs.existsSync(this.outputDir)) {
fs.mkdirSync(this.outputDir, { recursive: true });
}
return path.join(this.outputDir, filename);
}
private getFileUrl(filename: string): string {
return `${this.fileServerURLPrefix}/${this.outputDir}/${filename}`;
}
}
@@ -16,5 +16,6 @@ export async function createChatEngine() {
return new ContextChatEngine({
chatModel: Settings.llm,
retriever,
systemPrompt: process.env.SYSTEM_PROMPT,
});
}
@@ -1,12 +1,22 @@
import logging
import os
import logging
from datetime import timedelta
from cachetools import cached, TTLCache
from llama_index.core.storage import StorageContext
from llama_index.core.indices import load_index_from_storage
logger = logging.getLogger("uvicorn")
@cached(
TTLCache(maxsize=10, ttl=timedelta(minutes=5).total_seconds()),
key=lambda *args, **kwargs: "global_storage_context",
)
def get_storage_context(persist_dir: str) -> StorageContext:
return StorageContext.from_defaults(persist_dir=persist_dir)
def get_index():
storage_dir = os.getenv("STORAGE_DIR", "storage")
# check if storage already exists
@@ -14,7 +24,7 @@ def get_index():
return None
# load the existing index
logger.info(f"Loading index from {storage_dir}...")
storage_context = StorageContext.from_defaults(persist_dir=storage_dir)
storage_context = get_storage_context(storage_dir)
index = load_index_from_storage(storage_context)
logger.info(f"Finished loading index from {storage_dir}")
return index
@@ -5,26 +5,33 @@ from urllib.parse import urlparse
PGVECTOR_SCHEMA = "public"
PGVECTOR_TABLE = "llamaindex_embedding"
vector_store: PGVectorStore = None
def get_vector_store():
original_conn_string = os.environ.get("PG_CONNECTION_STRING")
if original_conn_string is None or original_conn_string == "":
raise ValueError("PG_CONNECTION_STRING environment variable is not set.")
global vector_store
# The PGVectorStore requires both two connection strings, one for psycopg2 and one for asyncpg
# Update the configured scheme with the psycopg2 and asyncpg schemes
original_scheme = urlparse(original_conn_string).scheme + "://"
conn_string = original_conn_string.replace(
original_scheme, "postgresql+psycopg2://"
)
async_conn_string = original_conn_string.replace(
original_scheme, "postgresql+asyncpg://"
)
if vector_store is None:
original_conn_string = os.environ.get("PG_CONNECTION_STRING")
if original_conn_string is None or original_conn_string == "":
raise ValueError("PG_CONNECTION_STRING environment variable is not set.")
return PGVectorStore(
connection_string=conn_string,
async_connection_string=async_conn_string,
schema_name=PGVECTOR_SCHEMA,
table_name=PGVECTOR_TABLE,
embed_dim=int(os.environ.get("EMBEDDING_DIM", 1024)),
)
# The PGVectorStore requires both two connection strings, one for psycopg2 and one for asyncpg
# Update the configured scheme with the psycopg2 and asyncpg schemes
original_scheme = urlparse(original_conn_string).scheme + "://"
conn_string = original_conn_string.replace(
original_scheme, "postgresql+psycopg2://"
)
async_conn_string = original_conn_string.replace(
original_scheme, "postgresql+asyncpg://"
)
vector_store = PGVectorStore(
connection_string=conn_string,
async_connection_string=async_conn_string,
schema_name=PGVECTOR_SCHEMA,
table_name=PGVECTOR_TABLE,
embed_dim=int(os.environ.get("EMBEDDING_DIM", 1024)),
)
return vector_store
+3 -1
View File
@@ -1,3 +1,5 @@
# local env files
.env
node_modules/
node_modules/
tool-output/
@@ -31,6 +31,8 @@ if (isDevelopment) {
console.warn("Production CORS origin not set, defaulting to no CORS.");
}
app.use("/api/files/data", express.static("data"));
app.use("/api/files/tool-output", express.static("tool-output"));
app.use(express.text());
app.get("/", (req: Request, res: Response) => {
@@ -14,9 +14,10 @@
"cors": "^2.8.5",
"dotenv": "^16.3.1",
"express": "^4.18.2",
"llamaindex": "0.3.9",
"llamaindex": "0.3.13",
"pdf2json": "3.0.5",
"ajv": "^8.12.0"
"ajv": "^8.12.0",
"@e2b/code-interpreter": "^0.0.5"
},
"devDependencies": {
"@types/cors": "^2.8.16",
@@ -64,9 +64,8 @@ export const chat = async (req: Request, res: Response) => {
image_url: data?.imageUrl,
},
});
const processedStream = stream.pipeThrough(vercelStreamData.stream);
return streamToResponse(processedStream, res);
return streamToResponse(stream, res, {}, vercelStreamData);
} catch (error) {
console.error("[LlamaIndex]", error);
return res.status(500).json({
@@ -56,11 +56,17 @@ function initOpenAI() {
}
function initOllama() {
const config = {
host: process.env.OLLAMA_BASE_URL ?? "http://127.0.0.1:11434",
};
Settings.llm = new Ollama({
model: process.env.MODEL ?? "",
config,
});
Settings.embedModel = new OllamaEmbedding({
model: process.env.EMBEDDING_MODEL ?? "",
config,
});
}
@@ -93,7 +93,13 @@ async def chat(
event_handler = EventCallbackHandler()
chat_engine.callback_manager.handlers.append(event_handler) # type: ignore
response = await chat_engine.astream_chat(last_message_content, messages)
try:
response = await chat_engine.astream_chat(last_message_content, messages)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error in chat engine: {e}",
)
async def content_generator():
# Yield the text response
@@ -1,5 +1,6 @@
import json
import asyncio
import logging
from typing import AsyncGenerator, Dict, Any, List, Optional
from llama_index.core.callbacks.base import BaseCallbackHandler
from llama_index.core.callbacks.schema import CBEventType
@@ -7,6 +8,9 @@ from llama_index.core.tools.types import ToolOutput
from pydantic import BaseModel
logger = logging.getLogger(__name__)
class CallbackEvent(BaseModel):
event_type: CBEventType
payload: Optional[Dict[str, Any]] = None
@@ -72,15 +76,19 @@ class CallbackEvent(BaseModel):
}
def to_response(self):
match self.event_type:
case "retrieve":
return self.get_retrieval_message()
case "function_call":
return self.get_tool_message()
case "agent_step":
return self.get_agent_tool_response()
case _:
return None
try:
match self.event_type:
case "retrieve":
return self.get_retrieval_message()
case "function_call":
return self.get_tool_message()
case "agent_step":
return self.get_agent_tool_response()
case _:
return None
except Exception as e:
logger.error(f"Error in converting event to response: {e}")
return None
class EventCallbackHandler(BaseCallbackHandler):
@@ -23,8 +23,12 @@ def init_ollama():
from llama_index.llms.ollama import Ollama
from llama_index.embeddings.ollama import OllamaEmbedding
Settings.embed_model = OllamaEmbedding(model_name=os.getenv("EMBEDDING_MODEL"))
Settings.llm = Ollama(model=os.getenv("MODEL"))
base_url = os.getenv("OLLAMA_BASE_URL") or "http://127.0.0.1:11434"
Settings.embed_model = OllamaEmbedding(
base_url=base_url,
model_name=os.getenv("EMBEDDING_MODEL"),
)
Settings.llm = Ollama(base_url=base_url, model=os.getenv("MODEL"))
def init_openai():
+1 -1
View File
@@ -1,3 +1,3 @@
__pycache__
storage
.env
.env
+11 -1
View File
@@ -11,6 +11,7 @@ from fastapi.responses import RedirectResponse
from app.api.routers.chat import chat_router
from app.settings import init_settings
from app.observability import init_observability
from fastapi.staticfiles import StaticFiles
app = FastAPI()
@@ -20,7 +21,6 @@ init_observability()
environment = os.getenv("ENVIRONMENT", "dev") # Default to 'development' if not set
if environment == "dev":
logger = logging.getLogger("uvicorn")
logger.warning("Running in development mode - allowing CORS for all origins")
@@ -38,6 +38,16 @@ if environment == "dev":
return RedirectResponse(url="/docs")
def mount_static_files(directory, path):
if os.path.exists(directory):
app.mount(path, StaticFiles(directory=directory), name=f"{directory}-static")
# Mount the data files to serve the file viewer
mount_static_files("data", "/api/files/data")
# Mount the output files from tools
mount_static_files("tool-output", "/api/files/tool-output")
app.include_router(chat_router, prefix="/api/chat")
@@ -16,6 +16,7 @@ python-dotenv = "^1.0.0"
aiostream = "^0.5.2"
llama-index = "0.10.28"
llama-index-core = "0.10.28"
cachetools = "^5.3.3"
[build-system]
requires = ["poetry-core"]
@@ -56,11 +56,16 @@ function initOpenAI() {
}
function initOllama() {
const config = {
host: process.env.OLLAMA_BASE_URL ?? "http://127.0.0.1:11434",
};
Settings.llm = new Ollama({
model: process.env.MODEL ?? "",
config,
});
Settings.embedModel = new OllamaEmbedding({
model: process.env.EMBEDDING_MODEL ?? "",
config,
});
}
@@ -0,0 +1,45 @@
import { readFile } from "fs/promises";
import { NextRequest, NextResponse } from "next/server";
import path from "path";
/**
* This API is to get file data from allowed folders
* It receives path slug and response file data like serve static file
*/
export async function GET(
_request: NextRequest,
{ params }: { params: { slug: string[] } },
) {
const slug = params.slug;
if (!slug) {
return NextResponse.json({ detail: "Missing file slug" }, { status: 400 });
}
if (slug.includes("..") || path.isAbsolute(path.join(...slug))) {
return NextResponse.json({ detail: "Invalid file path" }, { status: 400 });
}
const [folder, ...pathTofile] = params.slug; // data, file.pdf
const allowedFolders = ["data", "tool-output"];
if (!allowedFolders.includes(folder)) {
return NextResponse.json({ detail: "No permission" }, { status: 400 });
}
try {
const filePath = path.join(process.cwd(), folder, path.join(...pathTofile));
const blob = await readFile(filePath);
return new NextResponse(blob, {
status: 200,
statusText: "OK",
headers: {
"Content-Length": blob.byteLength.toString(),
},
});
} catch (error) {
console.error(error);
return NextResponse.json({ detail: "File not found" }, { status: 404 });
}
}
@@ -38,7 +38,9 @@ export function ChatEvents({
<CollapsibleContent asChild>
<div className="mt-4 text-sm space-y-2">
{data.map((eventItem, index) => (
<div key={index}>{eventItem.title}</div>
<div className="whitespace-break-spaces" key={index}>
{eventItem.title}
</div>
))}
</div>
</CollapsibleContent>
@@ -1,20 +1,80 @@
import { ArrowUpRightSquare, Check, Copy } from "lucide-react";
import { Check, Copy } from "lucide-react";
import { useMemo } from "react";
import { Button } from "../button";
import { HoverCard, HoverCardContent, HoverCardTrigger } from "../hover-card";
import { getStaticFileDataUrl } from "../lib/url";
import { SourceData, SourceNode } from "./index";
import { useCopyToClipboard } from "./use-copy-to-clipboard";
import PdfDialog from "./widgets/PdfDialog";
const SCORE_THRESHOLD = 0.5;
const DATA_SOURCE_FOLDER = "data";
const SCORE_THRESHOLD = 0.3;
function SourceNumberButton({ index }: { index: number }) {
return (
<div className="text-xs w-5 h-5 rounded-full bg-gray-100 mb-2 flex items-center justify-center hover:text-white hover:bg-primary hover:cursor-pointer">
{index + 1}
</div>
);
}
enum NODE_TYPE {
URL,
FILE,
UNKNOWN,
}
type NodeInfo = {
id: string;
type: NODE_TYPE;
path?: string;
url?: string;
};
function getNodeInfo(node: SourceNode): NodeInfo {
if (typeof node.metadata["URL"] === "string") {
const url = node.metadata["URL"];
return {
id: node.id,
type: NODE_TYPE.URL,
path: url,
url,
};
}
if (typeof node.metadata["file_path"] === "string") {
const fileName = node.metadata["file_name"] as string;
const filePath = `${DATA_SOURCE_FOLDER}/${fileName}`;
return {
id: node.id,
type: NODE_TYPE.FILE,
path: node.metadata["file_path"],
url: getStaticFileDataUrl(filePath),
};
}
return {
id: node.id,
type: NODE_TYPE.UNKNOWN,
};
}
export function ChatSources({ data }: { data: SourceData }) {
const sources = useMemo(() => {
return (
data.nodes
?.filter((node) => Object.keys(node.metadata).length > 0)
?.filter((node) => (node.score ?? 1) > SCORE_THRESHOLD)
.sort((a, b) => (b.score ?? 1) - (a.score ?? 1)) || []
);
const sources: NodeInfo[] = useMemo(() => {
// aggregate nodes by url or file_path (get the highest one by score)
const nodesByPath: { [path: string]: NodeInfo } = {};
data.nodes
.filter((node) => (node.score ?? 1) > SCORE_THRESHOLD)
.sort((a, b) => (b.score ?? 1) - (a.score ?? 1))
.forEach((node) => {
const nodeInfo = getNodeInfo(node);
const key = nodeInfo.path ?? nodeInfo.id; // use id as key for UNKNOWN type
if (!nodesByPath[key]) {
nodesByPath[key] = nodeInfo;
}
});
return Object.values(nodesByPath);
}, [data.nodes]);
if (sources.length === 0) return null;
@@ -23,55 +83,52 @@ export function ChatSources({ data }: { data: SourceData }) {
<div className="space-x-2 text-sm">
<span className="font-semibold">Sources:</span>
<div className="inline-flex gap-1 items-center">
{sources.map((node: SourceNode, index: number) => (
<div key={node.id}>
<HoverCard>
<HoverCardTrigger>
<div className="text-xs w-5 h-5 rounded-full bg-gray-100 mb-2 flex items-center justify-center hover:text-white hover:bg-primary hover:cursor-pointer">
{index + 1}
</div>
</HoverCardTrigger>
<HoverCardContent>
<NodeInfo node={node} />
</HoverCardContent>
</HoverCard>
</div>
))}
{sources.map((nodeInfo: NodeInfo, index: number) => {
if (nodeInfo.path?.endsWith(".pdf")) {
return (
<PdfDialog
key={nodeInfo.id}
documentId={nodeInfo.id}
url={nodeInfo.url!}
path={nodeInfo.path}
trigger={<SourceNumberButton index={index} />}
/>
);
}
return (
<div key={nodeInfo.id}>
<HoverCard>
<HoverCardTrigger>
<SourceNumberButton index={index} />
</HoverCardTrigger>
<HoverCardContent className="w-[320px]">
<NodeInfo nodeInfo={nodeInfo} />
</HoverCardContent>
</HoverCard>
</div>
);
})}
</div>
</div>
);
}
function NodeInfo({ node }: { node: SourceNode }) {
function NodeInfo({ nodeInfo }: { nodeInfo: NodeInfo }) {
const { isCopied, copyToClipboard } = useCopyToClipboard({ timeout: 1000 });
if (typeof node.metadata["URL"] === "string") {
// this is a node generated by the web loader, it contains an external URL
// add a link to view this URL
if (nodeInfo.type !== NODE_TYPE.UNKNOWN) {
// this is a node generated by the web loader or file loader,
// add a link to view its URL and a button to copy the URL to the clipboard
return (
<a
className="space-x-2 flex items-center my-2 hover:text-blue-900"
href={node.metadata["URL"]}
target="_blank"
>
<span>{node.metadata["URL"]}</span>
<ArrowUpRightSquare className="w-4 h-4" />
</a>
);
}
if (typeof node.metadata["file_path"] === "string") {
// this is a node generated by the file loader, it contains file path
// add a button to copy the path to the clipboard
const filePath = node.metadata["file_path"];
return (
<div className="flex items-center px-2 py-1 justify-between my-2">
<span>{filePath}</span>
<div className="flex items-center my-2">
<a className="hover:text-blue-900" href={nodeInfo.url} target="_blank">
<span>{nodeInfo.path}</span>
</a>
<Button
onClick={() => copyToClipboard(filePath)}
onClick={() => copyToClipboard(nodeInfo.path!)}
size="icon"
variant="ghost"
className="h-12 w-12"
className="h-12 w-12 shrink-0"
>
{isCopied ? (
<Check className="h-4 w-4" />
@@ -84,7 +141,6 @@ function NodeInfo({ node }: { node: SourceNode }) {
}
// node generated by unknown loader, implement renderer by analyzing logged out metadata
console.log("Node metadata", node.metadata);
return (
<p>
Sorry, unknown node type. Please add a new renderer in the NodeInfo
@@ -1,5 +1,7 @@
import "katex/dist/katex.min.css";
import { FC, memo } from "react";
import ReactMarkdown, { Options } from "react-markdown";
import rehypeKatex from "rehype-katex";
import remarkGfm from "remark-gfm";
import remarkMath from "remark-math";
@@ -12,11 +14,27 @@ const MemoizedReactMarkdown: FC<Options> = memo(
prevProps.className === nextProps.className,
);
const preprocessLaTeX = (content: string) => {
// Replace block-level LaTeX delimiters \[ \] with $$ $$
const blockProcessedContent = content.replace(
/\\\[(.*?)\\\]/gs,
(_, equation) => `$$${equation}$$`,
);
// Replace inline LaTeX delimiters \( \) with $ $
const inlineProcessedContent = blockProcessedContent.replace(
/\\\((.*?)\\\)/gs,
(_, equation) => `$${equation}$`,
);
return inlineProcessedContent;
};
export default function Markdown({ content }: { content: string }) {
const processedContent = preprocessLaTeX(content);
return (
<MemoizedReactMarkdown
className="prose dark:prose-invert prose-p:leading-relaxed prose-pre:p-0 break-words custom-markdown"
remarkPlugins={[remarkGfm, remarkMath]}
rehypePlugins={[rehypeKatex as any]}
components={{
p({ children }) {
return <p className="mb-2 last:mb-0">{children}</p>;
@@ -53,7 +71,7 @@ export default function Markdown({ content }: { content: string }) {
},
}}
>
{content}
{processedContent}
</MemoizedReactMarkdown>
);
}
@@ -0,0 +1,56 @@
import { PDFViewer, PdfFocusProvider } from "@llamaindex/pdf-viewer";
import { Button } from "../../button";
import {
Drawer,
DrawerClose,
DrawerContent,
DrawerDescription,
DrawerHeader,
DrawerTitle,
DrawerTrigger,
} from "../../drawer";
export interface PdfDialogProps {
documentId: string;
path: string;
url: string;
trigger: React.ReactNode;
}
export default function PdfDialog(props: PdfDialogProps) {
return (
<Drawer direction="left">
<DrawerTrigger>{props.trigger}</DrawerTrigger>
<DrawerContent className="w-3/5 mt-24 h-full max-h-[96%] ">
<DrawerHeader className="flex justify-between">
<div className="space-y-2">
<DrawerTitle>PDF Content</DrawerTitle>
<DrawerDescription>
File path:{" "}
<a
className="hover:text-blue-900"
href={props.url}
target="_blank"
>
{props.path}
</a>
</DrawerDescription>
</div>
<DrawerClose asChild>
<Button variant="outline">Close</Button>
</DrawerClose>
</DrawerHeader>
<div className="m-4">
<PdfFocusProvider>
<PDFViewer
file={{
id: props.documentId,
url: props.url,
}}
/>
</PdfFocusProvider>
</div>
</DrawerContent>
</Drawer>
);
}
@@ -0,0 +1,118 @@
"use client";
import * as React from "react";
import { Drawer as DrawerPrimitive } from "vaul";
import { cn } from "./lib/utils";
const Drawer = ({
shouldScaleBackground = true,
...props
}: React.ComponentProps<typeof DrawerPrimitive.Root>) => (
<DrawerPrimitive.Root
shouldScaleBackground={shouldScaleBackground}
{...props}
/>
);
Drawer.displayName = "Drawer";
const DrawerTrigger = DrawerPrimitive.Trigger;
const DrawerPortal = DrawerPrimitive.Portal;
const DrawerClose = DrawerPrimitive.Close;
const DrawerOverlay = React.forwardRef<
React.ElementRef<typeof DrawerPrimitive.Overlay>,
React.ComponentPropsWithoutRef<typeof DrawerPrimitive.Overlay>
>(({ className, ...props }, ref) => (
<DrawerPrimitive.Overlay
ref={ref}
className={cn("fixed inset-0 z-50 bg-black/80", className)}
{...props}
/>
));
DrawerOverlay.displayName = DrawerPrimitive.Overlay.displayName;
const DrawerContent = React.forwardRef<
React.ElementRef<typeof DrawerPrimitive.Content>,
React.ComponentPropsWithoutRef<typeof DrawerPrimitive.Content>
>(({ className, children, ...props }, ref) => (
<DrawerPortal>
<DrawerOverlay />
<DrawerPrimitive.Content
ref={ref}
className={cn(
"fixed inset-x-0 bottom-0 z-50 mt-24 flex h-auto flex-col rounded-t-[10px] border bg-background",
className,
)}
{...props}
>
<div className="mx-auto mt-4 h-2 w-[100px] rounded-full bg-muted" />
{children}
</DrawerPrimitive.Content>
</DrawerPortal>
));
DrawerContent.displayName = "DrawerContent";
const DrawerHeader = ({
className,
...props
}: React.HTMLAttributes<HTMLDivElement>) => (
<div
className={cn("grid gap-1.5 p-4 text-center sm:text-left", className)}
{...props}
/>
);
DrawerHeader.displayName = "DrawerHeader";
const DrawerFooter = ({
className,
...props
}: React.HTMLAttributes<HTMLDivElement>) => (
<div
className={cn("mt-auto flex flex-col gap-2 p-4", className)}
{...props}
/>
);
DrawerFooter.displayName = "DrawerFooter";
const DrawerTitle = React.forwardRef<
React.ElementRef<typeof DrawerPrimitive.Title>,
React.ComponentPropsWithoutRef<typeof DrawerPrimitive.Title>
>(({ className, ...props }, ref) => (
<DrawerPrimitive.Title
ref={ref}
className={cn(
"text-lg font-semibold leading-none tracking-tight",
className,
)}
{...props}
/>
));
DrawerTitle.displayName = DrawerPrimitive.Title.displayName;
const DrawerDescription = React.forwardRef<
React.ElementRef<typeof DrawerPrimitive.Description>,
React.ComponentPropsWithoutRef<typeof DrawerPrimitive.Description>
>(({ className, ...props }, ref) => (
<DrawerPrimitive.Description
ref={ref}
className={cn("text-sm text-muted-foreground", className)}
{...props}
/>
));
DrawerDescription.displayName = DrawerPrimitive.Description.displayName;
export {
Drawer,
DrawerClose,
DrawerContent,
DrawerDescription,
DrawerFooter,
DrawerHeader,
DrawerOverlay,
DrawerPortal,
DrawerTitle,
DrawerTrigger,
};
@@ -0,0 +1,11 @@
const staticFileAPI = "/api/files";
export const getStaticFileDataUrl = (filePath: string) => {
const isUsingBackend = !!process.env.NEXT_PUBLIC_CHAT_API;
const fileUrl = `${staticFileAPI}/${filePath}`;
if (isUsingBackend) {
const backendOrigin = new URL(process.env.NEXT_PUBLIC_CHAT_API!).origin;
return `${backendOrigin}${fileUrl}`;
}
return fileUrl;
};
@@ -33,3 +33,5 @@ yarn-error.log*
# typescript
*.tsbuildinfo
next-env.d.ts
tool-output/
@@ -18,7 +18,7 @@
"class-variance-authority": "^0.7.0",
"clsx": "^2.1.1",
"dotenv": "^16.3.1",
"llamaindex": "0.3.8",
"llamaindex": "0.3.13",
"lucide-react": "^0.294.0",
"next": "^14.0.3",
"pdf2json": "3.0.5",
@@ -30,8 +30,12 @@
"remark-code-import": "^1.2.0",
"remark-gfm": "^3.0.1",
"remark-math": "^5.1.1",
"rehype-katex": "^7.0.0",
"supports-color": "^8.1.1",
"tailwind-merge": "^2.1.0"
"tailwind-merge": "^2.1.0",
"vaul": "^0.9.1",
"@llamaindex/pdf-viewer": "^1.1.1",
"@e2b/code-interpreter": "^0.0.5"
},
"devDependencies": {
"@types/node": "^20.10.3",
+4 -2
View File
@@ -11,14 +11,16 @@
"forceConsistentCasingInFileNames": true,
"incremental": true,
"outDir": "./lib",
"tsBuildInfoFile": "./lib/.tsbuildinfo"
"tsBuildInfoFile": "./lib/.tsbuildinfo",
"typeRoots": ["./types", "./node_modules/@types"]
},
"include": [
"create-app.ts",
"index.ts",
"./helpers",
"questions.ts",
"package.json"
"package.json",
"types/**/*"
],
"exclude": ["dist"]
}
+1
View File
@@ -0,0 +1 @@
declare module "global-agent/bootstrap";