Compare commits

..

49 Commits

Author SHA1 Message Date
github-actions[bot] 293557cbb4 Release 0.1.11 (#129)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2024-06-19 20:36:58 +07:00
Marcus Schiesser b46d050fc3 fix: format 2024-06-19 15:08:42 +02:00
Jacopo Zacchigna 02ed277dd0 Starting to add Groq as a provider (#131)
---------
Co-authored-by: Marcus Schiesser <marcus.schiesser@googlemail.com>
2024-06-19 17:43:36 +07:00
Huu Le 48b96ff188 feat: add DuckDuckGo search tool (#133) 2024-06-19 16:29:16 +07:00
Huu Le 9c9decbb88 Reuse function tool instance and improve e2b interpreter tool (#127)
---------
Co-authored-by: Marcus Schiesser <mail@marcusschiesser.de>
2024-06-14 16:04:05 +07:00
Huu Le 0748f2e8d7 remove gemini model map (#128) 2024-06-14 09:18:23 +02:00
github-actions[bot] 3079162806 Release 0.1.10 (#122)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2024-06-12 20:59:11 +07:00
Marcus Schiesser 48c19c6e62 fix: impove OpenAPI tool for TS 2024-06-12 15:28:59 +02:00
Thuc Pham d75c08e7d8 feat: make chat-session component independence from container (#124) 2024-06-12 19:02:58 +07:00
Huu Le 8f03f8d4bc chore: Improve fastapi (#123) 2024-06-12 16:50:20 +07:00
Marcus Schiesser 19c57d945a fix: reverse config hint 2024-06-12 10:46:50 +02:00
Thuc Pham 9112d0801e feat: implement openapi action tool for ts (#108)
---------
Co-authored-by: Marcus Schiesser <mail@marcusschiesser.de>
2024-06-10 19:40:09 +07:00
Thuc Pham 93b797c162 refactor: structure fe components (#121) 2024-06-10 17:02:25 +07:00
github-actions[bot] d53b760fd0 Release 0.1.9 (#101)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2024-06-07 22:56:34 +07:00
Marcus Schiesser a880c7c016 chore: update llamaindex@0.3.16 2024-06-07 17:40:39 +02:00
Marcus Schiesser 7b116ce7f7 fix: allow subsequent tool calls 2024-06-07 17:35:23 +02:00
Marcus Schiesser d1232fb1d5 fix: log interpreter tool error 2024-06-07 16:10:33 +02:00
Marcus Schiesser bedf199236 fix: throw and show error if unsupported annotation (e.g. image) is uploaded 2024-06-07 15:30:31 +02:00
Marcus Schiesser c1510bd3fa fix: remove redundant config info 2024-06-07 14:37:08 +02:00
Huu Le 69b9ce76bf refactor code (#119) 2024-06-07 13:46:25 +02:00
Marcus Schiesser 9ced116e1a refactor: use message annotations instead of sending data (#116)
---------
Co-authored-by: Thuc Pham <51660321+thucpn@users.noreply.github.com>
Co-authored-by: leehuwuj <leehuwuj@gmail.com>
2024-06-07 17:14:15 +07:00
Huu Le fae9bcd65a add raw text e2b tool output response (#115) 2024-06-06 13:23:31 +02:00
Thuc Pham 2091fea2b4 feat: display attachments in user messages (#114)
* use same csv card for message and upload box
* do not send csv and image data back to client
* fix: use LLM_MAX_TOKENS
---------

Co-authored-by: leehuwuj <leehuwuj@gmail.com>
Co-authored-by: Marcus Schiesser <mail@marcusschiesser.de>
2024-06-06 14:24:31 +07:00
Huu Le 563b51d76d Fix: Vercel streaming (python) does not stream data events instantly (#111) 2024-06-05 15:54:55 +07:00
Thuc Pham 88c88bf16d fix: logo overlay text input because of hegiht (#112) 2024-06-05 15:40:38 +07:00
Marcus Schiesser cd6ebf7295 dx: add hint if tool config is needed 2024-06-04 12:20:52 +02:00
Marcus Schiesser 50b2ddbbf5 docs: updated changeset 2024-06-04 11:15:47 +02:00
Huu Le 5fe2d519d2 chore: Add Azure OpenAI model provider python (#110) 2024-06-04 16:14:21 +07:00
Huu Le 09f1db3b5e feat: Support uploading CSV files for FastAPI app (#109) 2024-06-04 14:23:25 +07:00
Thuc Pham cb3be7d1d4 feat: display conversation starter from backend env (#104)
* feat: display conversation starter from frontend env

* use nextjs config api

* update to /api/chat/config

* add config api for express

* add api config for fast api

* Create ten-badgers-learn.md

* remove default conversation staters

* check empty string

* update pydantic docs

* refactor: move NEXT_PUBLIC_CHAT_API to use config

* use config to get chatAPI

* refactor: rename useClientConfig
2024-06-01 09:57:17 +07:00
Thuc Pham 5474a1f182 feat: enhance csv upload feature (#105)
* remove all multiModal props

* hide uploaded csv files if choose a new one

* feat: support multiple csv upload and reuse

* rename type and make it scrollable
2024-06-01 09:37:46 +07:00
Huu Le 1148ddba53 bump llama-index-agent-openai version to 0.2.6 (#107) 2024-05-31 13:46:35 +01:00
Huu Le 9e945ed355 bump llama_index and gemini version (#106) 2024-05-31 15:12:14 +07:00
Thuc Pham 6342163df2 Merge pull request #103 from run-llama/feat/add-openapi-tool
feat: Add OpenAPI Action tool
2024-05-30 15:33:36 +07:00
Thuc Pham a42fa53a6b feat: implement csv upload (#96)
* feat: implement interpreter tool

* build tool system prompt

* refactor: use local file system, use absolute resource url

* fix: typo

* feat: implement csv upload

* remove dead code

* fix lint

* update icon & fix code review

* fix lint

* Update .gitignore

* Update pre-commit

* add timeout for streaming

* Create bright-turkeys-melt.md

* remove multi modal prop

* suggest csv resources from frontend annotation data

* get resouces inside chat input

* resolve conflict

* update convert message content

* fix lint

* feat: limit display

---------

Co-authored-by: Marcus Schiesser <mail@marcusschiesser.de>
2024-05-30 10:38:54 +07:00
leehuwuj 099f626586 use urlparse for file path 2024-05-30 10:05:00 +07:00
leehuwuj 956538eeb0 add changeset 2024-05-30 09:27:21 +07:00
leehuwuj 555f6b2905 refactor code 2024-05-30 09:25:56 +07:00
leehuwuj d8bc271a21 add local tool that combine openapi and request tool 2024-05-30 09:11:21 +07:00
leehuwuj f29561cde2 add cache to toolfactory load_tools 2024-05-29 10:40:40 +07:00
leehuwuj 442abae8ac add openapi tool and http request tool 2024-05-29 08:40:16 +07:00
Huu Le 0ad2207684 Merge pull request #98 from run-llama/feat/construct-resource-url-from-backend
feat: construct resource url from backend
2024-05-28 20:43:04 +07:00
Thuc Pham bfde30deed move logger to global scope 2024-05-28 18:42:46 +07:00
Thuc Pham 96fdb83abf use logger warning 2024-05-28 18:33:53 +07:00
Huu Le b7e0072c9c chore: always generate tools config if user selects agent mode (#102) 2024-05-28 14:35:36 +07:00
Thuc Pham 81bc340dda add warning when no file server url prefix 2024-05-27 18:21:32 +07:00
Thuc Pham ddf3aef7dc remove node path 2024-05-27 18:20:27 +07:00
Thuc Pham 1f5a26f3a8 Merge pull request #100 from run-llama/feat/code-interpreter-python
feat: add support for FastAPI in code interpreter tool
2024-05-27 16:58:32 +07:00
Thuc Pham 48188ca3f9 feat: construct resource url from backend 2024-05-24 14:40:44 +07:00
65 changed files with 1943 additions and 512 deletions
-5
View File
@@ -1,5 +0,0 @@
---
"create-llama": patch
---
Add support E2B code interpreter tool for FastAPI
+26
View File
@@ -1,5 +1,31 @@
# create-llama
## 0.1.11
### Patch Changes
- 48b96ff: Add DuckDuckGo search tool
- 9c9decb: Reuse function tool instances and improve e2b interpreter tool for Python
- 02ed277: Add Groq as a model provider
- 0748f2e: Remove hard-coded Gemini supported models
## 0.1.10
### Patch Changes
- 9112d08: Add OpenAPI tool for Typescript
- 8f03f8d: Add OLLAMA_REQUEST_TIMEOUT variable to config Ollama timeout (Python)
- 8f03f8d: Apply nest_asyncio for llama parse
## 0.1.9
### Patch Changes
- a42fa53: Add CSV upload
- 563b51d: Fix Vercel streaming (python) to stream data events instantly
- d60b3c5: Add E2B code interpreter tool for FastAPI
- 956538e: Add OpenAPI action tool for FastAPI
## 0.1.8
### Patch Changes
+19
View File
@@ -185,6 +185,10 @@ const getModelEnvs = (modelConfig: ModelConfig): EnvVar[] => {
description: "Dimension of the embedding model to use.",
value: modelConfig.dimensions.toString(),
},
{
name: "CONVERSATION_STARTERS",
description: "The questions to help users get started (multi-line).",
},
...(modelConfig.provider === "openai"
? [
{
@@ -211,6 +215,15 @@ const getModelEnvs = (modelConfig: ModelConfig): EnvVar[] => {
},
]
: []),
...(modelConfig.provider === "groq"
? [
{
name: "GROQ_API_KEY",
description: "The Groq API key to use.",
value: modelConfig.apiKey,
},
]
: []),
...(modelConfig.provider === "gemini"
? [
{
@@ -276,6 +289,12 @@ const getEngineEnvs = (): EnvVar[] => {
"The number of similar embeddings to return when retrieving documents.",
value: "3",
},
{
name: "STREAM_TIMEOUT",
description:
"The time in milliseconds to wait for the stream to return a response.",
value: "60000",
},
];
};
+99
View File
@@ -0,0 +1,99 @@
import ciInfo from "ci-info";
import prompts from "prompts";
import { ModelConfigParams } from ".";
import { questionHandlers, toChoice } from "../../questions";
const MODELS = ["llama3-8b", "llama3-70b", "mixtral-8x7b"];
const DEFAULT_MODEL = MODELS[0];
// Use huggingface embedding models for now as Groq doesn't support embedding models
enum HuggingFaceEmbeddingModelType {
XENOVA_ALL_MINILM_L6_V2 = "all-MiniLM-L6-v2",
XENOVA_ALL_MPNET_BASE_V2 = "all-mpnet-base-v2",
}
type ModelData = {
dimensions: number;
};
const EMBEDDING_MODELS: Record<HuggingFaceEmbeddingModelType, ModelData> = {
[HuggingFaceEmbeddingModelType.XENOVA_ALL_MINILM_L6_V2]: {
dimensions: 384,
},
[HuggingFaceEmbeddingModelType.XENOVA_ALL_MPNET_BASE_V2]: {
dimensions: 768,
},
};
const DEFAULT_EMBEDDING_MODEL = Object.keys(EMBEDDING_MODELS)[0];
const DEFAULT_DIMENSIONS = Object.values(EMBEDDING_MODELS)[0].dimensions;
type GroqQuestionsParams = {
apiKey?: string;
askModels: boolean;
};
export async function askGroqQuestions({
askModels,
apiKey,
}: GroqQuestionsParams): Promise<ModelConfigParams> {
const config: ModelConfigParams = {
apiKey,
model: DEFAULT_MODEL,
embeddingModel: DEFAULT_EMBEDDING_MODEL,
dimensions: DEFAULT_DIMENSIONS,
isConfigured(): boolean {
if (config.apiKey) {
return true;
}
if (process.env["GROQ_API_KEY"]) {
return true;
}
return false;
},
};
if (!config.apiKey) {
const { key } = await prompts(
{
type: "text",
name: "key",
message:
"Please provide your Groq API key (or leave blank to use GROQ_API_KEY env variable):",
},
questionHandlers,
);
config.apiKey = key || process.env.GROQ_API_KEY;
}
// use default model values in CI or if user should not be asked
const useDefaults = ciInfo.isCI || !askModels;
if (!useDefaults) {
const { model } = await prompts(
{
type: "select",
name: "model",
message: "Which LLM model would you like to use?",
choices: MODELS.map(toChoice),
initial: 0,
},
questionHandlers,
);
config.model = model;
const { embeddingModel } = await prompts(
{
type: "select",
name: "embeddingModel",
message: "Which embedding model would you like to use?",
choices: Object.keys(EMBEDDING_MODELS).map(toChoice),
initial: 0,
},
questionHandlers,
);
config.embeddingModel = embeddingModel;
config.dimensions =
EMBEDDING_MODELS[
embeddingModel as HuggingFaceEmbeddingModelType
].dimensions;
}
return config;
}
+5
View File
@@ -4,6 +4,7 @@ import { questionHandlers } from "../../questions";
import { ModelConfig, ModelProvider } from "../types";
import { askAnthropicQuestions } from "./anthropic";
import { askGeminiQuestions } from "./gemini";
import { askGroqQuestions } from "./groq";
import { askOllamaQuestions } from "./ollama";
import { askOpenAIQuestions } from "./openai";
@@ -32,6 +33,7 @@ export async function askModelConfig({
title: "OpenAI",
value: "openai",
},
{ title: "Groq", value: "groq" },
{ title: "Ollama", value: "ollama" },
{ title: "Anthropic", value: "anthropic" },
{ title: "Gemini", value: "gemini" },
@@ -48,6 +50,9 @@ export async function askModelConfig({
case "ollama":
modelConfig = await askOllamaQuestions({ askModels });
break;
case "groq":
modelConfig = await askGroqQuestions({ askModels });
break;
case "anthropic":
modelConfig = await askAnthropicQuestions({ askModels });
break;
+2 -2
View File
@@ -144,7 +144,7 @@ const getAdditionalDependencies = (
case "openai":
dependencies.push({
name: "llama-index-agent-openai",
version: "0.2.2",
version: "0.2.6",
});
break;
case "anthropic":
@@ -160,7 +160,7 @@ const getAdditionalDependencies = (
case "gemini":
dependencies.push({
name: "llama-index-llms-gemini",
version: "0.1.7",
version: "0.1.10",
});
dependencies.push({
name: "llama-index-embeddings-gemini",
+68 -10
View File
@@ -30,7 +30,7 @@ export type ToolDependencies = {
export const supportedTools: Tool[] = [
{
display: "Google Search (configuration required after installation)",
display: "Google Search",
name: "google.GoogleSearchToolSpec",
config: {
engine:
@@ -54,6 +54,29 @@ export const supportedTools: Tool[] = [
},
],
},
{
// For python app, we will use a local DuckDuckGo search tool (instead of DuckDuckGo search tool in LlamaHub)
// to get the same results as the TS app.
display: "DuckDuckGo Search",
name: "duckduckgo",
dependencies: [
{
name: "duckduckgo-search",
version: "6.1.7",
},
],
supportedFrameworks: ["fastapi", "nextjs", "express"],
type: ToolType.LOCAL,
envVars: [
{
name: TOOL_SYSTEM_PROMPT_ENV_VAR,
description: "System prompt for DuckDuckGo search tool.",
value: `You are a DuckDuckGo search agent.
You can use the duckduckgo search tool to get information from the web to answer user questions.
For better results, you can specify the region parameter to get results from a specific region but it's optional.`,
},
],
},
{
display: "Wikipedia",
name: "wikipedia.WikipediaToolSpec",
@@ -107,13 +130,43 @@ export const supportedTools: Tool[] = [
{
name: TOOL_SYSTEM_PROMPT_ENV_VAR,
description: "System prompt for code interpreter tool.",
value: `You are a Python interpreter.
- You are given tasks to complete and you run python code to solve them.
- The python code runs in a Jupyter notebook. Every time you call \`interpreter\` tool, the python code is executed in a separate cell. It's okay to make multiple calls to \`interpreter\`.
- Display visualizations using matplotlib or any other visualization library directly in the notebook. Shouldn't save the visualizations to a file, just return the base64 encoded data.
- You can install any pip package (if it exists) if you need to but the usual packages for data analysis are already preinstalled.
- You can run any python code you want in a secure environment.
- Use absolute url from result to display images or any other media.`,
value: `-You are a Python interpreter that can run any python code in a secure environment.
- The python code runs in a Jupyter notebook. Every time you call the 'interpreter' tool, the python code is executed in a separate cell.
- You are given tasks to complete and you run python code to solve them.
- It's okay to make multiple calls to interpreter tool. If you get an error or the result is not what you expected, you can call the tool again. Don't give up too soon!
- Plot visualizations using matplotlib or any other visualization library directly in the notebook.
- You can install any pip package (if it exists) by running a cell with pip install.`,
},
],
},
{
display: "OpenAPI action",
name: "openapi_action.OpenAPIActionToolSpec",
dependencies: [
{
name: "llama-index-tools-openapi",
version: "0.1.3",
},
{
name: "jsonschema",
version: "^4.22.0",
},
{
name: "llama-index-tools-requests",
version: "0.1.3",
},
],
config: {
openapi_uri: "The URL or file path of the OpenAPI schema",
},
supportedFrameworks: ["fastapi", "express", "nextjs"],
type: ToolType.LOCAL,
envVars: [
{
name: TOOL_SYSTEM_PROMPT_ENV_VAR,
description: "System prompt for openapi action tool.",
value:
"You are an OpenAPI action agent. You help users to make requests to the provided OpenAPI schema.",
},
],
},
@@ -142,9 +195,15 @@ export const getTools = (toolsName: string[]): Tool[] => {
return tools;
};
export const toolRequiresConfig = (tool: Tool): boolean => {
const hasConfig = Object.keys(tool.config || {}).length > 0;
const hasEmptyEnvVar = tool.envVars?.some((envVar) => !envVar.value) ?? false;
return hasConfig || hasEmptyEnvVar;
};
export const toolsRequireConfig = (tools?: Tool[]): boolean => {
if (tools) {
return tools?.some((tool) => Object.keys(tool.config || {}).length > 0);
return tools?.some(toolRequiresConfig);
}
return false;
};
@@ -159,7 +218,6 @@ export const writeToolsConfig = async (
tools: Tool[] = [],
type: ConfigFileType = ConfigFileType.YAML,
) => {
if (tools.length === 0) return; // no tools selected, no config need
const configContent: {
[key in ToolType]: Record<string, any>;
} = {
+6 -1
View File
@@ -1,7 +1,12 @@
import { PackageManager } from "../helpers/get-pkg-manager";
import { Tool } from "./tools";
export type ModelProvider = "openai" | "ollama" | "anthropic" | "gemini";
export type ModelProvider =
| "openai"
| "groq"
| "ollama"
| "anthropic"
| "gemini";
export type ModelConfig = {
provider: ModelProvider;
apiKey?: string;
+1 -1
View File
@@ -1,6 +1,6 @@
{
"name": "create-llama",
"version": "0.1.8",
"version": "0.1.11",
"description": "Create LlamaIndex-powered apps with one command",
"keywords": [
"rag",
+6 -2
View File
@@ -16,7 +16,11 @@ import { templatesDir } from "./helpers/dir";
import { getAvailableLlamapackOptions } from "./helpers/llama-pack";
import { askModelConfig } from "./helpers/providers";
import { getProjectOptions } from "./helpers/repo";
import { supportedTools, toolsRequireConfig } from "./helpers/tools";
import {
supportedTools,
toolRequiresConfig,
toolsRequireConfig,
} from "./helpers/tools";
export type QuestionArgs = Omit<
InstallAppArgs,
@@ -652,7 +656,7 @@ export const askQuestions = async (
t.supportedFrameworks?.includes(program.framework),
);
const toolChoices = options.map((tool) => ({
title: tool.display,
title: `${tool.display}${toolRequiresConfig(tool) ? " (needs configuration)" : ""}`,
value: tool.name,
}));
const { toolsName } = await prompts({
@@ -1,7 +1,8 @@
import os
import yaml
import json
import importlib
from cachetools import cached, LRUCache
from llama_index.core.tools.tool_spec.base import BaseToolSpec
from llama_index.core.tools.function_tool import FunctionTool
@@ -18,7 +19,6 @@ class ToolFactory:
ToolType.LOCAL: "app.engine.tools",
}
@staticmethod
def load_tools(tool_type: str, tool_name: str, config: dict) -> list[FunctionTool]:
source_package = ToolFactory.TOOL_SOURCE_PACKAGE_MAP[tool_type]
try:
@@ -31,7 +31,7 @@ class ToolFactory:
return tool_spec.to_tool_list()
else:
module = importlib.import_module(f"{source_package}.{tool_name}")
tools = getattr(module, "tools")
tools = module.get_tools()
if not all(isinstance(tool, FunctionTool) for tool in tools):
raise ValueError(
f"The module {module} does not contain valid tools"
@@ -0,0 +1,36 @@
from llama_index.core.tools.function_tool import FunctionTool
def duckduckgo_search(
query: str,
region: str = "wt-wt",
max_results: int = 10,
):
"""
Use this function to search for any query in DuckDuckGo.
Args:
query (str): The query to search in DuckDuckGo.
region Optional(str): The region to be used for the search in [country-language] convention, ex us-en, uk-en, ru-ru, etc...
max_results Optional(int): The maximum number of results to be returned. Default is 10.
"""
try:
from duckduckgo_search import DDGS
except ImportError:
raise ImportError(
"duckduckgo_search package is required to use this function."
"Please install it by running: `poetry add duckduckgo_search` or `pip install duckduckgo_search`"
)
params = {
"keywords": query,
"region": region,
"max_results": max_results,
}
results = []
with DDGS() as ddg:
results = list(ddg.text(**params))
return results
def get_tools():
return [FunctionTool.from_defaults(duckduckgo_search)]
@@ -3,7 +3,7 @@ import logging
import base64
import uuid
from pydantic import BaseModel
from typing import List, Tuple, Dict
from typing import List, Tuple, Dict, Optional
from llama_index.core.tools import FunctionTool
from e2b_code_interpreter import CodeInterpreter
from e2b_code_interpreter.models import Logs
@@ -14,8 +14,9 @@ logger = logging.getLogger(__name__)
class InterpreterExtraResult(BaseModel):
type: str
filename: str
url: str
content: Optional[str] = None
filename: Optional[str] = None
url: Optional[str] = None
class E2BToolOutput(BaseModel):
@@ -28,9 +29,23 @@ class E2BCodeInterpreter:
output_dir = "tool-output"
def __init__(self, api_key: str, filesever_url_prefix: str):
self.api_key = api_key
def __init__(self):
api_key = os.getenv("E2B_API_KEY")
filesever_url_prefix = os.getenv("FILESERVER_URL_PREFIX")
if not api_key:
raise ValueError(
"E2B_API_KEY key is required to run code interpreter. Get it here: https://e2b.dev/docs/getting-started/api-key"
)
if not filesever_url_prefix:
raise ValueError(
"FILESERVER_URL_PREFIX is required to display file output from sandbox"
)
self.filesever_url_prefix = filesever_url_prefix
self.interpreter = CodeInterpreter(api_key=api_key)
def __del__(self):
self.interpreter.close()
def get_output_path(self, filename: str) -> str:
# if output directory doesn't exist, create it
@@ -72,63 +87,56 @@ class E2BCodeInterpreter:
try:
formats = result.formats()
base64_data_arr = [result[format] for format in formats]
results = [result[format] for format in formats]
for ext, base64_data in zip(formats, base64_data_arr):
if ext and base64_data:
result = self.save_to_disk(base64_data, ext)
filename = result["filename"]
output.append(
InterpreterExtraResult(
type=ext, filename=filename, url=self.get_file_url(filename)
for ext, data in zip(formats, results):
match ext:
case "png" | "svg" | "jpeg" | "pdf":
result = self.save_to_disk(data, ext)
filename = result["filename"]
output.append(
InterpreterExtraResult(
type=ext,
filename=filename,
url=self.get_file_url(filename),
)
)
case _:
output.append(
InterpreterExtraResult(
type=ext,
content=data,
)
)
)
except Exception as error:
logger.error("Error when saving data to disk", error)
logger.exception(error, exc_info=True)
logger.error("Error when parsing output from E2b interpreter tool", error)
return output
def interpret(self, code: str) -> E2BToolOutput:
with CodeInterpreter(api_key=self.api_key) as interpreter:
logger.info(
f"\n{'='*50}\n> Running following AI-generated code:\n{code}\n{'='*50}"
)
exec = interpreter.notebook.exec_cell(code)
"""
Execute python code in a Jupyter notebook cell, the toll will return result, stdout, stderr, display_data, and error.
if exec.error:
output = E2BToolOutput(is_error=True, logs=[exec.error])
Parameters:
code (str): The python code to be executed in a single cell.
"""
logger.info(
f"\n{'='*50}\n> Running following AI-generated code:\n{code}\n{'='*50}"
)
exec = self.interpreter.notebook.exec_cell(code)
if exec.error:
logger.error("Error when executing code", exec.error)
output = E2BToolOutput(is_error=True, logs=exec.logs, results=[])
else:
if len(exec.results) == 0:
output = E2BToolOutput(is_error=False, logs=exec.logs, results=[])
else:
if len(exec.results) == 0:
output = E2BToolOutput(is_error=False, logs=exec.logs, results=[])
else:
results = self.parse_result(exec.results[0])
output = E2BToolOutput(
is_error=False, logs=exec.logs, results=results
)
return output
results = self.parse_result(exec.results[0])
output = E2BToolOutput(is_error=False, logs=exec.logs, results=results)
return output
def code_interpret(code: str) -> Dict:
"""
Execute python code in a Jupyter notebook cell and return any result, stdout, stderr, display_data, and error.
"""
api_key = os.getenv("E2B_API_KEY")
filesever_url_prefix = os.getenv("FILESERVER_URL_PREFIX")
if not api_key:
raise ValueError(
"E2B_API_KEY key is required to run code interpreter. Get it here: https://e2b.dev/docs/getting-started/api-key"
)
if not filesever_url_prefix:
raise ValueError(
"FILESERVER_URL_PREFIX is required to display file output from sandbox"
)
interpreter = E2BCodeInterpreter(
api_key=api_key, filesever_url_prefix=filesever_url_prefix
)
output = interpreter.interpret(code)
return output.dict()
# Specify as functions tools to be loaded by the ToolFactory
tools = [FunctionTool.from_defaults(code_interpret)]
def get_tools():
return [FunctionTool.from_defaults(E2BCodeInterpreter().interpret)]
@@ -0,0 +1,78 @@
from typing import Dict, List, Tuple
from llama_index.tools.openapi import OpenAPIToolSpec
from llama_index.tools.requests import RequestsToolSpec
class OpenAPIActionToolSpec(OpenAPIToolSpec, RequestsToolSpec):
"""
A combination of OpenAPI and Requests tool specs that can parse OpenAPI specs and make requests.
openapi_uri: str: The file path or URL to the OpenAPI spec.
domain_headers: dict: Whitelist domains and the headers to use.
"""
spec_functions = OpenAPIToolSpec.spec_functions + RequestsToolSpec.spec_functions
# Cached parsed specs by URI
_specs: Dict[str, Tuple[Dict, List[str]]] = {}
def __init__(self, openapi_uri: str, domain_headers: dict = None, **kwargs):
if domain_headers is None:
domain_headers = {}
if openapi_uri not in self._specs:
openapi_spec, servers = self._load_openapi_spec(openapi_uri)
self._specs[openapi_uri] = (openapi_spec, servers)
else:
openapi_spec, servers = self._specs[openapi_uri]
# Add the servers to the domain headers if they are not already present
for server in servers:
if server not in domain_headers:
domain_headers[server] = {}
OpenAPIToolSpec.__init__(self, spec=openapi_spec)
RequestsToolSpec.__init__(self, domain_headers)
@staticmethod
def _load_openapi_spec(uri: str) -> Tuple[Dict, List[str]]:
"""
Load an OpenAPI spec from a URI.
Args:
uri (str): A file path or URL to the OpenAPI spec.
Returns:
List[Document]: A list of Document objects.
"""
import yaml
from urllib.parse import urlparse
if uri.startswith("http"):
import requests
response = requests.get(uri)
if response.status_code != 200:
raise ValueError(
"Could not initialize OpenAPIActionToolSpec: "
f"Failed to load OpenAPI spec from {uri}, status code: {response.status_code}"
)
spec = yaml.safe_load(response.text)
elif uri.startswith("file"):
filepath = urlparse(uri).path
with open(filepath, "r") as file:
spec = yaml.safe_load(file)
else:
raise ValueError(
"Could not initialize OpenAPIActionToolSpec: Invalid OpenAPI URI provided. "
"Only HTTP and file path are supported."
)
# Add the servers to the whitelist
try:
servers = [
urlparse(server["url"]).netloc for server in spec.get("servers", [])
]
except KeyError as e:
raise ValueError(
"Could not initialize OpenAPIActionToolSpec: Invalid OpenAPI spec provided. "
"Could not get `servers` from the spec."
) from e
return spec, servers
@@ -69,4 +69,5 @@ class OpenMeteoWeather:
return response.json()
tools = [FunctionTool.from_defaults(OpenMeteoWeather.get_weather_information)]
def get_tools():
return [FunctionTool.from_defaults(OpenMeteoWeather.get_weather_information)]
@@ -0,0 +1,61 @@
import { JSONSchemaType } from "ajv";
import { search } from "duck-duck-scrape";
import { BaseTool, ToolMetadata } from "llamaindex";
export type DuckDuckGoParameter = {
query: string;
region?: string;
};
export type DuckDuckGoToolParams = {
metadata?: ToolMetadata<JSONSchemaType<DuckDuckGoParameter>>;
};
const DEFAULT_META_DATA: ToolMetadata<JSONSchemaType<DuckDuckGoParameter>> = {
name: "duckduckgo",
description: "Use this function to search for any query in DuckDuckGo.",
parameters: {
type: "object",
properties: {
query: {
type: "string",
description: "The query to search in DuckDuckGo.",
},
region: {
type: "string",
description:
"Optional, The region to be used for the search in [country-language] convention, ex us-en, uk-en, ru-ru, etc...",
nullable: true,
},
},
required: ["query"],
},
};
type DuckDuckGoSearchResult = {
title: string;
description: string;
url: string;
};
export class DuckDuckGoSearchTool implements BaseTool<DuckDuckGoParameter> {
metadata: ToolMetadata<JSONSchemaType<DuckDuckGoParameter>>;
constructor(params: DuckDuckGoToolParams) {
this.metadata = params.metadata ?? DEFAULT_META_DATA;
}
async call(input: DuckDuckGoParameter) {
const { query, region } = input;
const options = region ? { region } : {};
const searchResults = await search(query, options);
return searchResults.results.map((result) => {
return {
title: result.title,
description: result.description,
url: result.url,
} as DuckDuckGoSearchResult;
});
}
}
@@ -1,42 +1,57 @@
import { BaseToolWithCall } from "llamaindex";
import { ToolsFactory } from "llamaindex/tools/ToolsFactory";
import { DuckDuckGoSearchTool, DuckDuckGoToolParams } from "./duckduckgo";
import { InterpreterTool, InterpreterToolParams } from "./interpreter";
import { OpenAPIActionTool } from "./openapi-action";
import { WeatherTool, WeatherToolParams } from "./weather";
type ToolCreator = (config: unknown) => BaseToolWithCall;
type ToolCreator = (config: unknown) => Promise<BaseToolWithCall[]>;
export async function createTools(toolConfig: {
local: Record<string, unknown>;
llamahub: any;
}): Promise<BaseToolWithCall[]> {
// add local tools from the 'tools' folder (if configured)
const tools = createLocalTools(toolConfig.local);
const tools = await createLocalTools(toolConfig.local);
// add tools from LlamaIndexTS (if configured)
tools.push(...(await ToolsFactory.createTools(toolConfig.llamahub)));
return tools;
}
const toolFactory: Record<string, ToolCreator> = {
weather: (config: unknown) => {
return new WeatherTool(config as WeatherToolParams);
weather: async (config: unknown) => {
return [new WeatherTool(config as WeatherToolParams)];
},
interpreter: (config: unknown) => {
return new InterpreterTool(config as InterpreterToolParams);
interpreter: async (config: unknown) => {
return [new InterpreterTool(config as InterpreterToolParams)];
},
"openapi_action.OpenAPIActionToolSpec": async (config: unknown) => {
const { openapi_uri, domain_headers } = config as {
openapi_uri: string;
domain_headers: Record<string, Record<string, string>>;
};
const openAPIActionTool = new OpenAPIActionTool(
openapi_uri,
domain_headers,
);
return await openAPIActionTool.toToolFunctions();
},
duckduckgo: async (config: unknown) => {
return [new DuckDuckGoSearchTool(config as DuckDuckGoToolParams)];
},
};
function createLocalTools(
async function createLocalTools(
localConfig: Record<string, unknown>,
): BaseToolWithCall[] {
): Promise<BaseToolWithCall[]> {
const tools: BaseToolWithCall[] = [];
Object.keys(localConfig).forEach((key) => {
for (const [key, toolConfig] of Object.entries(localConfig)) {
if (key in toolFactory) {
const toolConfig = localConfig[key];
const tool = toolFactory[key](toolConfig);
tools.push(tool);
const newTools = await toolFactory[key](toolConfig);
tools.push(...newTools);
}
});
}
return tools;
}
@@ -15,7 +15,7 @@ export type InterpreterToolParams = {
fileServerURLPrefix?: string;
};
export type InterpreterToolOuput = {
export type InterpreterToolOutput = {
isError: boolean;
logs: Logs;
extraResult: InterpreterExtraResult[];
@@ -34,8 +34,9 @@ type InterpreterExtraType =
export type InterpreterExtraResult = {
type: InterpreterExtraType;
filename: string;
url: string;
content?: string;
filename?: string;
url?: string;
};
const DEFAULT_META_DATA: ToolMetadata<JSONSchemaType<InterpreterParameter>> = {
@@ -88,7 +89,7 @@ export class InterpreterTool implements BaseTool<InterpreterParameter> {
return this.codeInterpreter;
}
public async codeInterpret(code: string): Promise<InterpreterToolOuput> {
public async codeInterpret(code: string): Promise<InterpreterToolOutput> {
console.log(
`\n${"=".repeat(50)}\n> Running following AI-generated code:\n${code}\n${"=".repeat(50)}`,
);
@@ -96,7 +97,7 @@ export class InterpreterTool implements BaseTool<InterpreterParameter> {
const exec = await interpreter.notebook.execCell(code);
if (exec.error) console.error("[Code Interpreter error]", exec.error);
const extraResult = await this.getExtraResult(exec.results[0]);
const result: InterpreterToolOuput = {
const result: InterpreterToolOutput = {
isError: !!exec.error,
logs: exec.logs,
extraResult,
@@ -104,12 +105,15 @@ export class InterpreterTool implements BaseTool<InterpreterParameter> {
return result;
}
async call(input: InterpreterParameter): Promise<InterpreterToolOuput> {
async call(input: InterpreterParameter): Promise<InterpreterToolOutput> {
const result = await this.codeInterpret(input.code);
await this.codeInterpreter?.close();
return result;
}
async close() {
await this.codeInterpreter?.close();
}
private async getExtraResult(
res?: Result,
): Promise<InterpreterExtraResult[]> {
@@ -118,23 +122,34 @@ export class InterpreterTool implements BaseTool<InterpreterParameter> {
try {
const formats = res.formats(); // formats available for the result. Eg: ['png', ...]
const base64DataArr = formats.map((f) => res[f as keyof Result]); // get base64 data for each format
const results = formats.map((f) => res[f as keyof Result]); // get base64 data for each format
// save base64 data to file and return the url
for (let i = 0; i < formats.length; i++) {
const ext = formats[i];
const base64Data = base64DataArr[i];
if (ext && base64Data) {
const { filename } = this.saveToDisk(base64Data, ext);
output.push({
type: ext as InterpreterExtraType,
filename,
url: this.getFileUrl(filename),
});
const data = results[i];
switch (ext) {
case "png":
case "jpeg":
case "svg":
case "pdf":
const { filename } = this.saveToDisk(data, ext);
output.push({
type: ext as InterpreterExtraType,
filename,
url: this.getFileUrl(filename),
});
break;
default:
output.push({
type: ext as InterpreterExtraType,
content: data,
});
break;
}
}
} catch (error) {
console.error("Error when saving data to disk", error);
console.error("Error when parsing e2b response", error);
}
return output;
@@ -0,0 +1,164 @@
import SwaggerParser from "@apidevtools/swagger-parser";
import { JSONSchemaType } from "ajv";
import got from "got";
import { FunctionTool, JSONValue, ToolMetadata } from "llamaindex";
interface DomainHeaders {
[key: string]: { [header: string]: string };
}
type Input = {
url: string;
params: object;
};
type APIInfo = {
description: string;
title: string;
};
export class OpenAPIActionTool {
// cache the loaded specs by URL
private static specs: Record<string, any> = {};
private readonly INVALID_URL_PROMPT =
"This url did not include a hostname or scheme. Please determine the complete URL and try again.";
private createLoadSpecMetaData = (info: APIInfo) => {
return {
name: "load_openapi_spec",
description: `Use this to retrieve the OpenAPI spec for the API named ${info.title} with the following description: ${info.description}. Call it before making any requests to the API.`,
};
};
private readonly createMethodCallMetaData = (
method: "POST" | "PATCH" | "GET",
info: APIInfo,
) => {
return {
name: `${method.toLowerCase()}_request`,
description: `Use this to call the ${method} method on the API named ${info.title}`,
parameters: {
type: "object",
properties: {
url: {
type: "string",
description: `The url to make the ${method} request against`,
},
params: {
type: "object",
description:
method === "GET"
? "the URL parameters to provide with the get request"
: `the key-value pairs to provide with the ${method} request`,
},
},
required: ["url"],
},
} as ToolMetadata<JSONSchemaType<Input>>;
};
constructor(
public openapi_uri: string,
public domainHeaders: DomainHeaders = {},
) {}
async loadOpenapiSpec(url: string): Promise<any> {
const api = await SwaggerParser.validate(url);
return {
servers: "servers" in api ? api.servers : "",
info: { description: api.info.description, title: api.info.title },
endpoints: api.paths,
};
}
async getRequest(input: Input): Promise<JSONValue> {
if (!this.validUrl(input.url)) {
return this.INVALID_URL_PROMPT;
}
try {
const data = await got
.get(input.url, {
headers: this.getHeadersForUrl(input.url),
searchParams: input.params as URLSearchParams,
})
.json();
return data as JSONValue;
} catch (error) {
return error as JSONValue;
}
}
async postRequest(input: Input): Promise<JSONValue> {
if (!this.validUrl(input.url)) {
return this.INVALID_URL_PROMPT;
}
try {
const res = await got.post(input.url, {
headers: this.getHeadersForUrl(input.url),
json: input.params,
});
return res.body as JSONValue;
} catch (error) {
return error as JSONValue;
}
}
async patchRequest(input: Input): Promise<JSONValue> {
if (!this.validUrl(input.url)) {
return this.INVALID_URL_PROMPT;
}
try {
const res = await got.patch(input.url, {
headers: this.getHeadersForUrl(input.url),
json: input.params,
});
return res.body as JSONValue;
} catch (error) {
return error as JSONValue;
}
}
public async toToolFunctions() {
if (!OpenAPIActionTool.specs[this.openapi_uri]) {
console.log(`Loading spec for URL: ${this.openapi_uri}`);
const spec = await this.loadOpenapiSpec(this.openapi_uri);
OpenAPIActionTool.specs[this.openapi_uri] = spec;
}
const spec = OpenAPIActionTool.specs[this.openapi_uri];
// TODO: read endpoints with parameters from spec and create one tool for each endpoint
// For now, we just create a tool for each HTTP method which does not work well for passing parameters
return [
FunctionTool.from(() => {
return spec;
}, this.createLoadSpecMetaData(spec.info)),
FunctionTool.from(
this.getRequest.bind(this),
this.createMethodCallMetaData("GET", spec.info),
),
FunctionTool.from(
this.postRequest.bind(this),
this.createMethodCallMetaData("POST", spec.info),
),
FunctionTool.from(
this.patchRequest.bind(this),
this.createMethodCallMetaData("PATCH", spec.info),
),
];
}
private validUrl(url: string): boolean {
const parsed = new URL(url);
return !!parsed.protocol && !!parsed.hostname;
}
private getDomain(url: string): string {
const parsed = new URL(url);
return parsed.hostname;
}
private getHeadersForUrl(url: string): { [header: string]: string } {
const domain = this.getDomain(url);
return this.domainHeaders[domain] || {};
}
}
+14 -5
View File
@@ -23,7 +23,12 @@ def llama_parse_parser():
"LLAMA_CLOUD_API_KEY environment variable is not set. "
"Please set it in .env file or in your shell environment then run again!"
)
parser = LlamaParse(result_type="markdown", verbose=True, language="en")
parser = LlamaParse(
result_type="markdown",
verbose=True,
language="en",
ignore_errors=False,
)
return parser
@@ -32,15 +37,19 @@ def get_file_documents(config: FileLoaderConfig):
try:
reader = SimpleDirectoryReader(
config.data_dir,
recursive=True,
filename_as_id=True,
config.data_dir, recursive=True, filename_as_id=True, raise_on_error=True
)
if config.use_llama_parse:
# LlamaParse is async first,
# so we need to use nest_asyncio to run it in sync mode
import nest_asyncio
nest_asyncio.apply()
parser = llama_parse_parser()
reader.file_extractor = {".pdf": parser}
return reader.load_data()
except ValueError as e:
except Exception as e:
import sys, traceback
# Catch the error if the data dir is empty
@@ -1,5 +1,7 @@
"use client";
import { Message } from "./chat-messages";
export interface ChatInputProps {
/** The current value of the input */
input?: string;
@@ -12,7 +14,8 @@ export interface ChatInputProps {
/** Form submission handler to automatically reset input and append a user message */
handleSubmit: (e: React.FormEvent<HTMLFormElement>) => void;
isLoading: boolean;
multiModal?: boolean;
messages: Message[];
setInput?: (input: string) => void;
}
export default function ChatInput(props: ChatInputProps) {
@@ -19,8 +19,12 @@ export default function ChatMessages({
isLoading?: boolean;
stop?: () => void;
reload?: () => void;
append?: (
message: Message | Omit<Message, "id">,
) => Promise<string | null | undefined>;
}) {
const scrollableChatContainerRef = useRef<HTMLDivElement>(null);
const lastMessage = messages[messages.length - 1];
const scrollToBottom = () => {
if (scrollableChatContainerRef.current) {
@@ -31,14 +35,14 @@ export default function ChatMessages({
useEffect(() => {
scrollToBottom();
}, [messages.length]);
}, [messages.length, lastMessage]);
return (
<div className="w-full max-w-5xl p-4 bg-white rounded-xl shadow-xl">
<div
className="flex flex-col gap-5 divide-y h-[50vh] overflow-auto"
ref={scrollableChatContainerRef}
>
<div
className="flex-1 w-full max-w-5xl p-4 bg-white rounded-xl shadow-xl overflow-auto"
ref={scrollableChatContainerRef}
>
<div className="flex flex-col gap-5 divide-y">
{messages.map((m: Message) => (
<ChatItem key={m.id} {...m} />
))}
@@ -0,0 +1,30 @@
"use client";
import { useEffect, useMemo, useState } from "react";
export interface ChatConfig {
chatAPI?: string;
starterQuestions?: string[];
}
export function useClientConfig() {
const API_ROUTE = "/api/chat/config";
const chatAPI = process.env.NEXT_PUBLIC_CHAT_API;
const [config, setConfig] = useState<ChatConfig>({
chatAPI,
});
const configAPI = useMemo(() => {
const backendOrigin = chatAPI ? new URL(chatAPI).origin : "";
return `${backendOrigin}${API_ROUTE}`;
}, [chatAPI]);
useEffect(() => {
fetch(configAPI)
.then((response) => response.json())
.then((data) => setConfig({ ...data, chatAPI }))
.catch((error) => console.error("Error fetching config", error));
}, [chatAPI, configAPI]);
return config;
}
@@ -7,17 +7,20 @@
"format:write": "prettier --ignore-unknown --write .",
"build": "tsup index.ts --format cjs --dts",
"start": "node dist/index.js",
"dev": "concurrently \"tsup index.ts --format cjs --dts --watch\" \"nodemon -q dist/index.js\""
"dev": "concurrently \"tsup index.ts --format cjs --dts --watch\" \"nodemon --watch dist/index.js\""
},
"dependencies": {
"ai": "^3.0.21",
"cors": "^2.8.5",
"dotenv": "^16.3.1",
"duck-duck-scrape": "^2.2.5",
"express": "^4.18.2",
"llamaindex": "0.3.13",
"llamaindex": "0.3.16",
"pdf2json": "3.0.5",
"ajv": "^8.12.0",
"@e2b/code-interpreter": "^0.0.5"
"@e2b/code-interpreter": "^0.0.5",
"got": "10.7.0",
"@apidevtools/swagger-parser": "^10.1.0"
},
"devDependencies": {
"@types/cors": "^2.8.16",
@@ -0,0 +1,14 @@
import { Request, Response } from "express";
export const chatConfig = async (_req: Request, res: Response) => {
let starterQuestions = undefined;
if (
process.env.CONVERSATION_STARTERS &&
process.env.CONVERSATION_STARTERS.trim()
) {
starterQuestions = process.env.CONVERSATION_STARTERS.trim().split("\n");
}
return res.status(200).json({
starterQuestions,
});
};
@@ -1,32 +1,16 @@
import { Message, StreamData, streamToResponse } from "ai";
import { Request, Response } from "express";
import { ChatMessage, MessageContent, Settings } from "llamaindex";
import { ChatMessage, Settings } from "llamaindex";
import { createChatEngine } from "./engine/chat";
import { LlamaIndexStream } from "./llamaindex-stream";
import { createCallbackManager } from "./stream-helper";
const convertMessageContent = (
textMessage: string,
imageUrl: string | undefined,
): MessageContent => {
if (!imageUrl) return textMessage;
return [
{
type: "text",
text: textMessage,
},
{
type: "image_url",
image_url: {
url: imageUrl,
},
},
];
};
import { LlamaIndexStream, convertMessageContent } from "./llamaindex-stream";
import { createCallbackManager, createStreamTimeout } from "./stream-helper";
export const chat = async (req: Request, res: Response) => {
// Init Vercel AI StreamData and timeout
const vercelStreamData = new StreamData();
const streamTimeout = createStreamTimeout(vercelStreamData);
try {
const { messages, data }: { messages: Message[]; data: any } = req.body;
const { messages }: { messages: Message[] } = req.body;
const userMessage = messages.pop();
if (!messages || !userMessage || userMessage.role !== "user") {
return res.status(400).json({
@@ -37,15 +21,25 @@ export const chat = async (req: Request, res: Response) => {
const chatEngine = await createChatEngine();
let annotations = userMessage.annotations;
if (!annotations) {
// the user didn't send any new annotations with the last message
// so use the annotations from the last user message that has annotations
// REASON: GPT4 doesn't consider MessageContentDetail from previous messages, only strings
annotations = messages
.slice()
.reverse()
.find(
(message) => message.role === "user" && message.annotations,
)?.annotations;
}
// Convert message content from Vercel/AI format to LlamaIndex/OpenAI format
const userMessageContent = convertMessageContent(
userMessage.content,
data?.imageUrl,
annotations,
);
// Init Vercel AI StreamData
const vercelStreamData = new StreamData();
// Setup callbacks
const callbackManager = createCallbackManager(vercelStreamData);
@@ -59,11 +53,7 @@ export const chat = async (req: Request, res: Response) => {
});
// Return a stream, which can be consumed by the Vercel/AI client
const stream = LlamaIndexStream(response, vercelStreamData, {
parserOptions: {
image_url: data?.imageUrl,
},
});
const stream = LlamaIndexStream(response, vercelStreamData);
return streamToResponse(stream, res, {}, vercelStreamData);
} catch (error) {
@@ -71,5 +61,7 @@ export const chat = async (req: Request, res: Response) => {
return res.status(500).json({
detail: (error as Error).message,
});
} finally {
clearTimeout(streamTimeout);
}
};
@@ -4,6 +4,7 @@ import {
GEMINI_MODEL,
Gemini,
GeminiEmbedding,
Groq,
OpenAI,
OpenAIEmbedding,
Settings,
@@ -28,6 +29,9 @@ export const initSettings = async () => {
case "ollama":
initOllama();
break;
case "groq":
initGroq();
break;
case "anthropic":
initAnthropic();
break;
@@ -45,7 +49,9 @@ export const initSettings = async () => {
function initOpenAI() {
Settings.llm = new OpenAI({
model: process.env.MODEL ?? "gpt-3.5-turbo",
maxTokens: 512,
maxTokens: process.env.LLM_MAX_TOKENS
? Number(process.env.LLM_MAX_TOKENS)
: undefined,
});
Settings.embedModel = new OpenAIEmbedding({
model: process.env.EMBEDDING_MODEL,
@@ -83,6 +89,27 @@ function initAnthropic() {
});
}
function initGroq() {
const embedModelMap: Record<string, string> = {
"all-MiniLM-L6-v2": "Xenova/all-MiniLM-L6-v2",
"all-mpnet-base-v2": "Xenova/all-mpnet-base-v2",
};
const modelMap: Record<string, string> = {
"llama3-8b": "llama3-8b-8192",
"llama3-70b": "llama3-70b-8192",
"mixtral-8x7b": "mixtral-8x7b-32768",
};
Settings.llm = new Groq({
model: modelMap[process.env.MODEL!],
});
Settings.embedModel = new HuggingFaceEmbedding({
modelType: embedModelMap[process.env.EMBEDDING_MODEL!],
});
}
function initGemini() {
Settings.llm = new Gemini({
model: process.env.MODEL as GEMINI_MODEL,
@@ -1,4 +1,5 @@
import {
JSONValue,
StreamData,
createCallbacksTransformer,
createStreamDataTransformer,
@@ -6,6 +7,8 @@ import {
type AIStreamCallbacksAndOptions,
} from "ai";
import {
MessageContent,
MessageContentDetail,
Metadata,
NodeWithScore,
Response,
@@ -13,29 +16,85 @@ import {
} from "llamaindex";
import { AgentStreamChatResponse } from "llamaindex/agent/base";
import { appendImageData, appendSourceData } from "./stream-helper";
import { CsvFile, appendSourceData } from "./stream-helper";
type LlamaIndexResponse =
| AgentStreamChatResponse<ToolCallLLMMessageOptions>
| Response;
type ParserOptions = {
image_url?: string;
export const convertMessageContent = (
content: string,
annotations?: JSONValue[],
): MessageContent => {
if (!annotations) return content;
return [
{
type: "text",
text: content,
},
...convertAnnotations(annotations),
];
};
const convertAnnotations = (
annotations: JSONValue[],
): MessageContentDetail[] => {
const content: MessageContentDetail[] = [];
annotations.forEach((annotation: JSONValue) => {
// first skip invalid annotation
if (
!(
annotation &&
typeof annotation === "object" &&
"type" in annotation &&
"data" in annotation &&
annotation.data &&
typeof annotation.data === "object"
)
) {
console.log(
"Client sent invalid annotation. Missing data and type",
annotation,
);
return;
}
const { type, data } = annotation;
// convert image
if (type === "image" && "url" in data && typeof data.url === "string") {
content.push({
type: "image_url",
image_url: {
url: data.url,
},
});
}
// convert CSV files to text
if (type === "csv" && "csvFiles" in data && Array.isArray(data.csvFiles)) {
const rawContents = data.csvFiles.map((csv) => {
return "```csv\n" + (csv as CsvFile).content + "\n```";
});
const csvContent =
"Use data from following CSV raw contents:\n" +
rawContents.join("\n\n");
content.push({
type: "text",
text: csvContent,
});
}
});
return content;
};
function createParser(
res: AsyncIterable<LlamaIndexResponse>,
data: StreamData,
opts?: ParserOptions,
) {
const it = res[Symbol.asyncIterator]();
const trimStartOfStream = trimStartOfStreamHelper();
let sourceNodes: NodeWithScore<Metadata>[] | undefined;
return new ReadableStream<string>({
start() {
appendImageData(data, opts?.image_url);
},
async pull(controller): Promise<void> {
const { value, done } = await it.next();
if (done) {
@@ -72,10 +131,9 @@ export function LlamaIndexStream(
data: StreamData,
opts?: {
callbacks?: AIStreamCallbacksAndOptions;
parserOptions?: ParserOptions;
},
): ReadableStream<Uint8Array> {
return createParser(response, data, opts?.parserOptions)
return createParser(response, data)
.pipeThrough(createCallbacksTransformer(opts?.callbacks))
.pipeThrough(createStreamDataTransformer());
}
@@ -7,14 +7,20 @@ import {
ToolOutput,
} from "llamaindex";
export function appendImageData(data: StreamData, imageUrl?: string) {
if (!imageUrl) return;
data.appendMessageAnnotation({
type: "image",
data: {
url: imageUrl,
},
});
function getNodeUrl(metadata: Metadata) {
const url = metadata["URL"];
if (url) return url;
const fileName = metadata["file_name"];
if (!process.env.FILESERVER_URL_PREFIX) {
console.warn(
"FILESERVER_URL_PREFIX is not set. File URLs will not be generated.",
);
return undefined;
}
if (fileName) {
return `${process.env.FILESERVER_URL_PREFIX}/data/${fileName}`;
}
return undefined;
}
export function appendSourceData(
@@ -29,6 +35,7 @@ export function appendSourceData(
...node.node.toMutableJSON(),
id: node.node.id_,
score: node.score ?? null,
url: getNodeUrl(node.node.metadata),
})),
},
});
@@ -65,6 +72,15 @@ export function appendToolData(
});
}
export function createStreamTimeout(stream: StreamData) {
const timeout = Number(process.env.STREAM_TIMEOUT ?? 1000 * 60 * 5); // default to 5 minutes
const t = setTimeout(() => {
appendEventData(stream, `Stream timed out after ${timeout / 1000} seconds`);
stream.close();
}, timeout);
return t;
}
export function createCallbackManager(stream: StreamData) {
const callbackManager = new CallbackManager();
@@ -95,3 +111,10 @@ export function createCallbackManager(stream: StreamData) {
return callbackManager;
}
export type CsvFile = {
content: string;
filename: string;
filesize: number;
id: string;
};
@@ -1,4 +1,5 @@
import express, { Router } from "express";
import { chatConfig } from "../controllers/chat-config.controller";
import { chatRequest } from "../controllers/chat-request.controller";
import { chat } from "../controllers/chat.controller";
import { initSettings } from "../controllers/engine/settings";
@@ -8,5 +9,6 @@ const llmRouter: Router = express.Router();
initSettings();
llmRouter.route("/").post(chat);
llmRouter.route("/request").post(chatRequest);
llmRouter.route("/config").get(chatConfig);
export default llmRouter;
@@ -1,154 +1,114 @@
from pydantic import BaseModel
from typing import List, Any, Optional, Dict, Tuple
import os
import logging
from aiostream import stream
from fastapi import APIRouter, Depends, HTTPException, Request, status
from llama_index.core.chat_engine.types import BaseChatEngine
from llama_index.core.schema import NodeWithScore
from llama_index.core.llms import ChatMessage, MessageRole
from llama_index.core.llms import MessageRole
from app.engine import get_chat_engine
from app.api.routers.vercel_response import VercelStreamResponse
from app.api.routers.messaging import EventCallbackHandler
from aiostream import stream
from app.api.routers.events import EventCallbackHandler
from app.api.routers.models import (
ChatData,
ChatConfig,
SourceNodes,
Result,
Message,
)
chat_router = r = APIRouter()
class _Message(BaseModel):
role: MessageRole
content: str
class _ChatData(BaseModel):
messages: List[_Message]
class Config:
json_schema_extra = {
"example": {
"messages": [
{
"role": "user",
"content": "What standards for letters exist?",
}
]
}
}
class _SourceNodes(BaseModel):
id: str
metadata: Dict[str, Any]
score: Optional[float]
text: str
@classmethod
def from_source_node(cls, source_node: NodeWithScore):
return cls(
id=source_node.node.node_id,
metadata=source_node.node.metadata,
score=source_node.score,
text=source_node.node.text, # type: ignore
)
@classmethod
def from_source_nodes(cls, source_nodes: List[NodeWithScore]):
return [cls.from_source_node(node) for node in source_nodes]
class _Result(BaseModel):
result: _Message
nodes: List[_SourceNodes]
async def parse_chat_data(data: _ChatData) -> Tuple[str, List[ChatMessage]]:
# check preconditions and get last message
if len(data.messages) == 0:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="No messages provided",
)
last_message = data.messages.pop()
if last_message.role != MessageRole.USER:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Last message must be from user",
)
# convert messages coming from the request to type ChatMessage
messages = [
ChatMessage(
role=m.role,
content=m.content,
)
for m in data.messages
]
return last_message.content, messages
logger = logging.getLogger("uvicorn")
# streaming endpoint - delete if not needed
@r.post("")
async def chat(
request: Request,
data: _ChatData,
data: ChatData,
chat_engine: BaseChatEngine = Depends(get_chat_engine),
):
last_message_content, messages = await parse_chat_data(data)
event_handler = EventCallbackHandler()
chat_engine.callback_manager.handlers.append(event_handler) # type: ignore
try:
response = await chat_engine.astream_chat(last_message_content, messages)
last_message_content = data.get_last_message_content()
messages = data.get_history_messages()
event_handler = EventCallbackHandler()
chat_engine.callback_manager.handlers.append(event_handler) # type: ignore
async def content_generator():
# Yield the text response
async def _chat_response_generator():
response = await chat_engine.astream_chat(
last_message_content, messages
)
async for token in response.async_response_gen():
yield VercelStreamResponse.convert_text(token)
# the text_generator is the leading stream, once it's finished, also finish the event stream
event_handler.is_done = True
# Yield the source nodes
yield VercelStreamResponse.convert_data(
{
"type": "sources",
"data": {
"nodes": [
SourceNodes.from_source_node(node).dict()
for node in response.source_nodes
]
},
}
)
# Yield the events from the event handler
async def _event_generator():
async for event in event_handler.async_event_gen():
event_response = event.to_response()
if event_response is not None:
yield VercelStreamResponse.convert_data(event_response)
combine = stream.merge(_chat_response_generator(), _event_generator())
is_stream_started = False
async with combine.stream() as streamer:
async for output in streamer:
if not is_stream_started:
is_stream_started = True
# Stream a blank message to start the stream
yield VercelStreamResponse.convert_text("")
yield output
if await request.is_disconnected():
break
return VercelStreamResponse(content=content_generator())
except Exception as e:
logger.exception("Error in chat engine", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error in chat engine: {e}",
)
async def content_generator():
# Yield the text response
async def _text_generator():
async for token in response.async_response_gen():
yield VercelStreamResponse.convert_text(token)
# the text_generator is the leading stream, once it's finished, also finish the event stream
event_handler.is_done = True
# Yield the events from the event handler
async def _event_generator():
async for event in event_handler.async_event_gen():
event_response = event.to_response()
if event_response is not None:
yield VercelStreamResponse.convert_data(event_response)
combine = stream.merge(_text_generator(), _event_generator())
async with combine.stream() as streamer:
async for item in streamer:
if await request.is_disconnected():
break
yield item
# Yield the source nodes
yield VercelStreamResponse.convert_data(
{
"type": "sources",
"data": {
"nodes": [
_SourceNodes.from_source_node(node).dict()
for node in response.source_nodes
]
},
}
)
return VercelStreamResponse(content=content_generator())
) from e
# non-streaming endpoint - delete if not needed
@r.post("/request")
async def chat_request(
data: _ChatData,
data: ChatData,
chat_engine: BaseChatEngine = Depends(get_chat_engine),
) -> _Result:
last_message_content, messages = await parse_chat_data(data)
) -> Result:
last_message_content = data.get_last_message_content()
messages = data.get_history_messages()
response = await chat_engine.achat(last_message_content, messages)
return _Result(
result=_Message(role=MessageRole.ASSISTANT, content=response.response),
nodes=_SourceNodes.from_source_nodes(response.source_nodes),
return Result(
result=Message(role=MessageRole.ASSISTANT, content=response.response),
nodes=SourceNodes.from_source_nodes(response.source_nodes),
)
@r.get("/config")
async def chat_config() -> ChatConfig:
starter_questions = None
conversation_starters = os.getenv("CONVERSATION_STARTERS")
if conversation_starters and conversation_starters.strip():
starter_questions = conversation_starters.strip().split("\n")
return ChatConfig(starterQuestions=starter_questions)
@@ -0,0 +1,170 @@
import os
import logging
from pydantic import BaseModel, Field, validator
from pydantic.alias_generators import to_camel
from typing import List, Any, Optional, Dict
from llama_index.core.schema import NodeWithScore
from llama_index.core.llms import ChatMessage, MessageRole
logger = logging.getLogger("uvicorn")
class CsvFile(BaseModel):
content: str
filename: str
filesize: int
id: str
class AnnotationData(BaseModel):
csv_files: List[CsvFile] | None = Field(
default=None,
description="List of CSV files",
)
class Config:
json_schema_extra = {
"example": {
"csvFiles": [
{
"content": "Name, Age\nAlice, 25\nBob, 30",
"filename": "example.csv",
"filesize": 123,
"id": "123",
"type": "text/csv",
}
]
}
}
alias_generator = to_camel
class Annotation(BaseModel):
type: str
data: AnnotationData
def to_content(self) -> str:
if self.type == "csv":
csv_files = self.data.csv_files
if csv_files is not None and len(csv_files) > 0:
return "Use data from following CSV raw contents\n" + "\n".join(
[f"```csv\n{csv_file.content}\n```" for csv_file in csv_files]
)
raise ValueError(f"Unsupported annotation type: {self.type}")
class Message(BaseModel):
role: MessageRole
content: str
annotations: List[Annotation] | None = None
class ChatData(BaseModel):
messages: List[Message]
class Config:
json_schema_extra = {
"example": {
"messages": [
{
"role": "user",
"content": "What standards for letters exist?",
}
]
}
}
@validator("messages")
def messages_must_not_be_empty(cls, v):
if len(v) == 0:
raise ValueError("Messages must not be empty")
return v
def get_last_message_content(self) -> str:
"""
Get the content of the last message along with the data content if available. Fallback to use data content from previous messages
"""
if len(self.messages) == 0:
raise ValueError("There is not any message in the chat")
last_message = self.messages[-1]
message_content = last_message.content
for message in reversed(self.messages):
if message.role == MessageRole.USER and message.annotations is not None:
annotation_contents = (
annotation.to_content() for annotation in message.annotations
)
annotation_text = "\n".join(annotation_contents)
message_content = f"{message_content}\n{annotation_text}"
break
return message_content
def get_history_messages(self) -> List[Message]:
"""
Get the history messages
"""
return [
ChatMessage(role=message.role, content=message.content)
for message in self.messages[:-1]
]
def is_last_message_from_user(self) -> bool:
return self.messages[-1].role == MessageRole.USER
class SourceNodes(BaseModel):
id: str
metadata: Dict[str, Any]
score: Optional[float]
text: str
url: Optional[str]
@classmethod
def from_source_node(cls, source_node: NodeWithScore):
metadata = source_node.node.metadata
url = metadata.get("URL")
if not url:
file_name = metadata.get("file_name")
url_prefix = os.getenv("FILESERVER_URL_PREFIX")
if not url_prefix:
logger.warning(
"Warning: FILESERVER_URL_PREFIX not set in environment variables"
)
if file_name and url_prefix:
url = f"{url_prefix}/data/{file_name}"
return cls(
id=source_node.node.node_id,
metadata=metadata,
score=source_node.score,
text=source_node.node.text, # type: ignore
url=url,
)
@classmethod
def from_source_nodes(cls, source_nodes: List[NodeWithScore]):
return [cls.from_source_node(node) for node in source_nodes]
class Result(BaseModel):
result: Message
nodes: List[SourceNodes]
class ChatConfig(BaseModel):
starter_questions: Optional[List[str]] = Field(
default=None,
description="List of starter questions",
)
class Config:
json_schema_extra = {
"example": {
"starterQuestions": [
"What standards for letters exist?",
"What are the requirements for a letter to be considered a letter?",
]
}
}
alias_generator = to_camel
@@ -1,40 +1,51 @@
import os
from typing import Dict
from llama_index.core.settings import Settings
def init_settings():
model_provider = os.getenv("MODEL_PROVIDER")
if model_provider == "openai":
init_openai()
elif model_provider == "ollama":
init_ollama()
elif model_provider == "anthropic":
init_anthropic()
elif model_provider == "gemini":
init_gemini()
else:
raise ValueError(f"Invalid model provider: {model_provider}")
match model_provider:
case "openai":
init_openai()
case "groq":
init_groq()
case "ollama":
init_ollama()
case "anthropic":
init_anthropic()
case "gemini":
init_gemini()
case "azure-openai":
init_azure_openai()
case _:
raise ValueError(f"Invalid model provider: {model_provider}")
Settings.chunk_size = int(os.getenv("CHUNK_SIZE", "1024"))
Settings.chunk_overlap = int(os.getenv("CHUNK_OVERLAP", "20"))
def init_ollama():
from llama_index.llms.ollama import Ollama
from llama_index.embeddings.ollama import OllamaEmbedding
from llama_index.llms.ollama.base import DEFAULT_REQUEST_TIMEOUT, Ollama
base_url = os.getenv("OLLAMA_BASE_URL") or "http://127.0.0.1:11434"
request_timeout = float(
os.getenv("OLLAMA_REQUEST_TIMEOUT", DEFAULT_REQUEST_TIMEOUT)
)
Settings.embed_model = OllamaEmbedding(
base_url=base_url,
model_name=os.getenv("EMBEDDING_MODEL"),
)
Settings.llm = Ollama(base_url=base_url, model=os.getenv("MODEL"))
Settings.llm = Ollama(
base_url=base_url, model=os.getenv("MODEL"), request_timeout=request_timeout
)
def init_openai():
from llama_index.llms.openai import OpenAI
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.core.constants import DEFAULT_TEMPERATURE
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI
max_tokens = os.getenv("LLM_MAX_TOKENS")
config = {
@@ -52,9 +63,58 @@ def init_openai():
Settings.embed_model = OpenAIEmbedding(**config)
def init_anthropic():
from llama_index.llms.anthropic import Anthropic
def init_azure_openai():
from llama_index.core.constants import DEFAULT_TEMPERATURE
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
from llama_index.llms.azure_openai import AzureOpenAI
llm_deployment = os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT")
embedding_deployment = os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT")
max_tokens = os.getenv("LLM_MAX_TOKENS")
api_key = os.getenv("AZURE_OPENAI_API_KEY")
llm_config = {
"api_key": api_key,
"deployment_name": llm_deployment,
"model": os.getenv("MODEL"),
"temperature": float(os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)),
"max_tokens": int(max_tokens) if max_tokens is not None else None,
}
Settings.llm = AzureOpenAI(**llm_config)
dimensions = os.getenv("EMBEDDING_DIM")
embedding_config = {
"api_key": api_key,
"deployment_name": embedding_deployment,
"model": os.getenv("EMBEDDING_MODEL"),
"dimensions": int(dimensions) if dimensions is not None else None,
}
Settings.embed_model = AzureOpenAIEmbedding(**embedding_config)
def init_groq():
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.llms.groq import Groq
model_map: Dict[str, str] = {
"llama3-8b": "llama3-8b-8192",
"llama3-70b": "llama3-70b-8192",
"mixtral-8x7b": "mixtral-8x7b-32768",
}
embed_model_map: Dict[str, str] = {
"all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2",
"all-mpnet-base-v2": "sentence-transformers/all-mpnet-base-v2",
}
Settings.llm = Groq(model=model_map[os.getenv("MODEL")])
Settings.embed_model = HuggingFaceEmbedding(
model_name=embed_model_map[os.getenv("EMBEDDING_MODEL")]
)
def init_anthropic():
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.llms.anthropic import Anthropic
model_map: Dict[str, str] = {
"claude-3-opus": "claude-3-opus-20240229",
@@ -76,21 +136,11 @@ def init_anthropic():
def init_gemini():
from llama_index.llms.gemini import Gemini
from llama_index.embeddings.gemini import GeminiEmbedding
from llama_index.llms.gemini import Gemini
model_map: Dict[str, str] = {
"gemini-1.5-pro-latest": "models/gemini-1.5-pro-latest",
"gemini-pro": "models/gemini-pro",
"gemini-pro-vision": "models/gemini-pro-vision",
}
model_name = f"models/{os.getenv('MODEL')}"
embed_model_name = f"models/{os.getenv('EMBEDDING_MODEL')}"
embed_model_map: Dict[str, str] = {
"embedding-001": "models/embedding-001",
"text-embedding-004": "models/text-embedding-004",
}
Settings.llm = Gemini(model=model_map[os.getenv("MODEL")])
Settings.embed_model = GeminiEmbedding(
model_name=embed_model_map[os.getenv("EMBEDDING_MODEL")]
)
Settings.llm = Gemini(model=model_name)
Settings.embed_model = GeminiEmbedding(model_name=embed_model_name)
@@ -14,8 +14,8 @@ fastapi = "^0.109.1"
uvicorn = { extras = ["standard"], version = "^0.23.2" }
python-dotenv = "^1.0.0"
aiostream = "^0.5.2"
llama-index = "0.10.28"
llama-index-core = "0.10.28"
llama-index = "0.10.41"
llama-index-core = "0.10.41"
cachetools = "^5.3.3"
[build-system]
@@ -0,0 +1,11 @@
import { NextResponse } from "next/server";
/**
* This API is to get config from the backend envs and expose them to the frontend
*/
export async function GET() {
const config = {
starterQuestions: process.env.CONVERSATION_STARTERS?.trim().split("\n"),
};
return NextResponse.json(config, { status: 200 });
}
@@ -4,6 +4,7 @@ import {
GEMINI_MODEL,
Gemini,
GeminiEmbedding,
Groq,
OpenAI,
OpenAIEmbedding,
Settings,
@@ -28,6 +29,9 @@ export const initSettings = async () => {
case "ollama":
initOllama();
break;
case "groq":
initGroq();
break;
case "anthropic":
initAnthropic();
break;
@@ -45,7 +49,9 @@ export const initSettings = async () => {
function initOpenAI() {
Settings.llm = new OpenAI({
model: process.env.MODEL ?? "gpt-3.5-turbo",
maxTokens: 512,
maxTokens: process.env.LLM_MAX_TOKENS
? Number(process.env.LLM_MAX_TOKENS)
: undefined,
});
Settings.embedModel = new OpenAIEmbedding({
model: process.env.EMBEDDING_MODEL,
@@ -69,6 +75,27 @@ function initOllama() {
});
}
function initGroq() {
const embedModelMap: Record<string, string> = {
"all-MiniLM-L6-v2": "Xenova/all-MiniLM-L6-v2",
"all-mpnet-base-v2": "Xenova/all-mpnet-base-v2",
};
const modelMap: Record<string, string> = {
"llama3-8b": "llama3-8b-8192",
"llama3-70b": "llama3-70b-8192",
"mixtral-8x7b": "mixtral-8x7b-32768",
};
Settings.llm = new Groq({
model: modelMap[process.env.MODEL!],
});
Settings.embedModel = new HuggingFaceEmbedding({
modelType: embedModelMap[process.env.EMBEDDING_MODEL!],
});
}
function initAnthropic() {
const embedModelMap: Record<string, string> = {
"all-MiniLM-L6-v2": "Xenova/all-MiniLM-L6-v2",
@@ -1,4 +1,5 @@
import {
JSONValue,
StreamData,
createCallbacksTransformer,
createStreamDataTransformer,
@@ -6,6 +7,8 @@ import {
type AIStreamCallbacksAndOptions,
} from "ai";
import {
MessageContent,
MessageContentDetail,
Metadata,
NodeWithScore,
Response,
@@ -13,29 +16,85 @@ import {
} from "llamaindex";
import { AgentStreamChatResponse } from "llamaindex/agent/base";
import { appendImageData, appendSourceData } from "./stream-helper";
import { CsvFile, appendSourceData } from "./stream-helper";
type LlamaIndexResponse =
| AgentStreamChatResponse<ToolCallLLMMessageOptions>
| Response;
type ParserOptions = {
image_url?: string;
export const convertMessageContent = (
content: string,
annotations?: JSONValue[],
): MessageContent => {
if (!annotations) return content;
return [
{
type: "text",
text: content,
},
...convertAnnotations(annotations),
];
};
const convertAnnotations = (
annotations: JSONValue[],
): MessageContentDetail[] => {
const content: MessageContentDetail[] = [];
annotations.forEach((annotation: JSONValue) => {
// first skip invalid annotation
if (
!(
annotation &&
typeof annotation === "object" &&
"type" in annotation &&
"data" in annotation &&
annotation.data &&
typeof annotation.data === "object"
)
) {
console.log(
"Client sent invalid annotation. Missing data and type",
annotation,
);
return;
}
const { type, data } = annotation;
// convert image
if (type === "image" && "url" in data && typeof data.url === "string") {
content.push({
type: "image_url",
image_url: {
url: data.url,
},
});
}
// convert CSV files to text
if (type === "csv" && "csvFiles" in data && Array.isArray(data.csvFiles)) {
const rawContents = data.csvFiles.map((csv) => {
return "```csv\n" + (csv as CsvFile).content + "\n```";
});
const csvContent =
"Use data from following CSV raw contents:\n" +
rawContents.join("\n\n");
content.push({
type: "text",
text: csvContent,
});
}
});
return content;
};
function createParser(
res: AsyncIterable<LlamaIndexResponse>,
data: StreamData,
opts?: ParserOptions,
) {
const it = res[Symbol.asyncIterator]();
const trimStartOfStream = trimStartOfStreamHelper();
let sourceNodes: NodeWithScore<Metadata>[] | undefined;
return new ReadableStream<string>({
start() {
appendImageData(data, opts?.image_url);
},
async pull(controller): Promise<void> {
const { value, done } = await it.next();
if (done) {
@@ -72,10 +131,9 @@ export function LlamaIndexStream(
data: StreamData,
opts?: {
callbacks?: AIStreamCallbacksAndOptions;
parserOptions?: ParserOptions;
},
): ReadableStream<Uint8Array> {
return createParser(response, data, opts?.parserOptions)
return createParser(response, data)
.pipeThrough(createCallbacksTransformer(opts?.callbacks))
.pipeThrough(createStreamDataTransformer());
}
@@ -1,11 +1,11 @@
import { initObservability } from "@/app/observability";
import { Message, StreamData, StreamingTextResponse } from "ai";
import { ChatMessage, MessageContent, Settings } from "llamaindex";
import { ChatMessage, Settings } from "llamaindex";
import { NextRequest, NextResponse } from "next/server";
import { createChatEngine } from "./engine/chat";
import { initSettings } from "./engine/settings";
import { LlamaIndexStream } from "./llamaindex-stream";
import { createCallbackManager } from "./stream-helper";
import { LlamaIndexStream, convertMessageContent } from "./llamaindex-stream";
import { createCallbackManager, createStreamTimeout } from "./stream-helper";
initObservability();
initSettings();
@@ -13,29 +13,14 @@ initSettings();
export const runtime = "nodejs";
export const dynamic = "force-dynamic";
const convertMessageContent = (
textMessage: string,
imageUrl: string | undefined,
): MessageContent => {
if (!imageUrl) return textMessage;
return [
{
type: "text",
text: textMessage,
},
{
type: "image_url",
image_url: {
url: imageUrl,
},
},
];
};
export async function POST(request: NextRequest) {
// Init Vercel AI StreamData and timeout
const vercelStreamData = new StreamData();
const streamTimeout = createStreamTimeout(vercelStreamData);
try {
const body = await request.json();
const { messages, data }: { messages: Message[]; data: any } = body;
const { messages }: { messages: Message[] } = body;
const userMessage = messages.pop();
if (!messages || !userMessage || userMessage.role !== "user") {
return NextResponse.json(
@@ -49,15 +34,25 @@ export async function POST(request: NextRequest) {
const chatEngine = await createChatEngine();
let annotations = userMessage.annotations;
if (!annotations) {
// the user didn't send any new annotations with the last message
// so use the annotations from the last user message that has annotations
// REASON: GPT4 doesn't consider MessageContentDetail from previous messages, only strings
annotations = messages
.slice()
.reverse()
.find(
(message) => message.role === "user" && message.annotations,
)?.annotations;
}
// Convert message content from Vercel/AI format to LlamaIndex/OpenAI format
const userMessageContent = convertMessageContent(
userMessage.content,
data?.imageUrl,
annotations,
);
// Init Vercel AI StreamData
const vercelStreamData = new StreamData();
// Setup callbacks
const callbackManager = createCallbackManager(vercelStreamData);
@@ -71,11 +66,7 @@ export async function POST(request: NextRequest) {
});
// Transform LlamaIndex stream to Vercel/AI format
const stream = LlamaIndexStream(response, vercelStreamData, {
parserOptions: {
image_url: data?.imageUrl,
},
});
const stream = LlamaIndexStream(response, vercelStreamData);
// Return a StreamingTextResponse, which can be consumed by the Vercel/AI client
return new StreamingTextResponse(stream, {}, vercelStreamData);
@@ -89,5 +80,7 @@ export async function POST(request: NextRequest) {
status: 500,
},
);
} finally {
clearTimeout(streamTimeout);
}
}
@@ -7,14 +7,20 @@ import {
ToolOutput,
} from "llamaindex";
export function appendImageData(data: StreamData, imageUrl?: string) {
if (!imageUrl) return;
data.appendMessageAnnotation({
type: "image",
data: {
url: imageUrl,
},
});
function getNodeUrl(metadata: Metadata) {
const url = metadata["URL"];
if (url) return url;
const fileName = metadata["file_name"];
if (!process.env.FILESERVER_URL_PREFIX) {
console.warn(
"FILESERVER_URL_PREFIX is not set. File URLs will not be generated.",
);
return undefined;
}
if (fileName) {
return `${process.env.FILESERVER_URL_PREFIX}/data/${fileName}`;
}
return undefined;
}
export function appendSourceData(
@@ -29,6 +35,7 @@ export function appendSourceData(
...node.node.toMutableJSON(),
id: node.node.id_,
score: node.score ?? null,
url: getNodeUrl(node.node.metadata),
})),
},
});
@@ -65,6 +72,15 @@ export function appendToolData(
});
}
export function createStreamTimeout(stream: StreamData) {
const timeout = Number(process.env.STREAM_TIMEOUT ?? 1000 * 60 * 5); // default to 5 minutes
const t = setTimeout(() => {
appendEventData(stream, `Stream timed out after ${timeout / 1000} seconds`);
stream.close();
}, timeout);
return t;
}
export function createCallbackManager(stream: StreamData) {
const callbackManager = new CallbackManager();
@@ -95,3 +111,10 @@ export function createCallbackManager(stream: StreamData) {
return callbackManager;
}
export type CsvFile = {
content: string;
filename: string;
filesize: number;
id: string;
};
@@ -2,8 +2,10 @@
import { useChat } from "ai/react";
import { ChatInput, ChatMessages } from "./ui/chat";
import { useClientConfig } from "./ui/chat/hooks/use-config";
export default function ChatSection() {
const { chatAPI } = useClientConfig();
const {
messages,
input,
@@ -12,8 +14,10 @@ export default function ChatSection() {
handleInputChange,
reload,
stop,
append,
setInput,
} = useChat({
api: process.env.NEXT_PUBLIC_CHAT_API,
api: chatAPI,
headers: {
"Content-Type": "application/json", // using JSON because of vercel/ai 2.2.26
},
@@ -25,19 +29,22 @@ export default function ChatSection() {
});
return (
<div className="space-y-4 max-w-5xl w-full">
<div className="space-y-4 w-full h-full flex flex-col">
<ChatMessages
messages={messages}
isLoading={isLoading}
reload={reload}
stop={stop}
append={append}
/>
<ChatInput
input={input}
handleSubmit={handleSubmit}
handleInputChange={handleInputChange}
isLoading={isLoading}
multiModal={true}
messages={messages}
append={append}
setInput={setInput}
/>
</div>
);
@@ -7,7 +7,7 @@ export default function Header() {
Get started by editing&nbsp;
<code className="font-mono font-bold">app/page.tsx</code>
</p>
<div className="fixed bottom-0 left-0 flex h-48 w-full items-end justify-center bg-gradient-to-t from-white via-white dark:from-black dark:via-black lg:static lg:h-auto lg:w-auto lg:bg-none">
<div className="fixed bottom-0 left-0 mb-4 flex h-auto w-full items-end justify-center bg-gradient-to-t from-white via-white dark:from-black dark:via-black lg:static lg:w-auto lg:bg-none lg:mb-0">
<a
href="https://www.llamaindex.ai/"
className="flex items-center justify-center font-nunito text-lg font-bold gap-2"
@@ -1,9 +1,14 @@
import { JSONValue } from "ai";
import { useState } from "react";
import { v4 as uuidv4 } from "uuid";
import { MessageAnnotation, MessageAnnotationType } from ".";
import { Button } from "../button";
import FileUploader from "../file-uploader";
import { Input } from "../input";
import UploadCsvPreview from "../upload-csv-preview";
import UploadImagePreview from "../upload-image-preview";
import { ChatHandler } from "./chat.interface";
import { useCsv } from "./hooks/use-csv";
export default function ChatInput(
props: Pick<
@@ -14,18 +19,61 @@ export default function ChatInput(
| "onFileError"
| "handleSubmit"
| "handleInputChange"
> & {
multiModal?: boolean;
},
| "messages"
| "setInput"
| "append"
>,
) {
const [imageUrl, setImageUrl] = useState<string | null>(null);
const { files: csvFiles, upload, remove, reset } = useCsv();
const getAnnotations = () => {
if (!imageUrl && csvFiles.length === 0) return undefined;
const annotations: MessageAnnotation[] = [];
if (imageUrl) {
annotations.push({
type: MessageAnnotationType.IMAGE,
data: { url: imageUrl },
});
}
if (csvFiles.length > 0) {
annotations.push({
type: MessageAnnotationType.CSV,
data: {
csvFiles: csvFiles.map((file) => ({
id: file.id,
content: file.content,
filename: file.filename,
filesize: file.filesize,
})),
},
});
}
return annotations as JSONValue[];
};
// default submit function does not handle including annotations in the message
// so we need to use append function to submit new message with annotations
const handleSubmitWithAnnotations = (
e: React.FormEvent<HTMLFormElement>,
annotations: JSONValue[] | undefined,
) => {
e.preventDefault();
props.append!({
content: props.input,
role: "user",
createdAt: new Date(),
annotations,
});
props.setInput!("");
};
const onSubmit = (e: React.FormEvent<HTMLFormElement>) => {
if (imageUrl) {
props.handleSubmit(e, {
data: { imageUrl: imageUrl },
});
setImageUrl(null);
const annotations = getAnnotations();
if (annotations) {
handleSubmitWithAnnotations(e, annotations);
imageUrl && setImageUrl(null);
csvFiles.length && reset();
return;
}
props.handleSubmit(e);
@@ -33,21 +81,50 @@ export default function ChatInput(
const onRemovePreviewImage = () => setImageUrl(null);
const handleUploadImageFile = async (file: File) => {
const base64 = await new Promise<string>((resolve, reject) => {
const readContent = async (file: File): Promise<string> => {
const content = await new Promise<string>((resolve, reject) => {
const reader = new FileReader();
reader.readAsDataURL(file);
if (file.type.startsWith("image/")) {
reader.readAsDataURL(file);
} else {
reader.readAsText(file);
}
reader.onload = () => resolve(reader.result as string);
reader.onerror = (error) => reject(error);
});
return content;
};
const handleUploadImageFile = async (file: File) => {
const base64 = await readContent(file);
setImageUrl(base64);
};
const handleUploadCsvFile = async (file: File) => {
const content = await readContent(file);
const isSuccess = upload({
id: uuidv4(),
content,
filename: file.name,
filesize: file.size,
});
if (!isSuccess) {
alert("File already exists in the list.");
}
};
const handleUploadFile = async (file: File) => {
try {
if (props.multiModal && file.type.startsWith("image/")) {
if (file.type.startsWith("image/")) {
return await handleUploadImageFile(file);
}
if (file.type === "text/csv") {
if (csvFiles.length > 0) {
alert("You can only upload one csv file at a time.");
return;
}
return await handleUploadCsvFile(file);
}
props.onFileUpload?.(file);
} catch (error: any) {
props.onFileError?.(error.message);
@@ -57,11 +134,24 @@ export default function ChatInput(
return (
<form
onSubmit={onSubmit}
className="rounded-xl bg-white p-4 shadow-xl space-y-4"
className="rounded-xl bg-white p-4 shadow-xl space-y-4 shrink-0"
>
{imageUrl && (
<UploadImagePreview url={imageUrl} onRemove={onRemovePreviewImage} />
)}
{csvFiles.length > 0 && (
<div className="flex gap-4 w-full overflow-auto py-2">
{csvFiles.map((csv) => {
return (
<UploadCsvPreview
key={csv.id}
csv={csv}
onRemove={() => remove(csv)}
/>
);
})}
</div>
)}
<div className="flex w-full items-start justify-between gap-4 ">
<Input
autoFocus
@@ -75,7 +165,7 @@ export default function ChatInput(
onFileUpload={handleUploadFile}
onFileError={props.onFileError}
/>
<Button type="submit" disabled={props.isLoading}>
<Button type="submit" disabled={props.isLoading || !props.input.trim()}>
Send message
</Button>
</div>
@@ -1,12 +1,12 @@
import { ChevronDown, ChevronRight, Loader2 } from "lucide-react";
import { useState } from "react";
import { Button } from "../button";
import { Button } from "../../button";
import {
Collapsible,
CollapsibleContent,
CollapsibleTrigger,
} from "../collapsible";
import { EventData } from "./index";
} from "../../collapsible";
import { EventData } from "../index";
export function ChatEvents({
data,
@@ -1,5 +1,5 @@
import Image from "next/image";
import { type ImageData } from "./index";
import { type ImageData } from "../index";
export function ChatImage({ data }: { data: ImageData }) {
return (
@@ -1,13 +1,15 @@
import { Check, Copy } from "lucide-react";
import { useMemo } from "react";
import { Button } from "../button";
import { HoverCard, HoverCardContent, HoverCardTrigger } from "../hover-card";
import { getStaticFileDataUrl } from "../lib/url";
import { SourceData, SourceNode } from "./index";
import { useCopyToClipboard } from "./use-copy-to-clipboard";
import PdfDialog from "./widgets/PdfDialog";
import { Button } from "../../button";
import {
HoverCard,
HoverCardContent,
HoverCardTrigger,
} from "../../hover-card";
import { useCopyToClipboard } from "../hooks/use-copy-to-clipboard";
import { SourceData } from "../index";
import PdfDialog from "../widgets/PdfDialog";
const DATA_SOURCE_FOLDER = "data";
const SCORE_THRESHOLD = 0.3;
function SourceNumberButton({ index }: { index: number }) {
@@ -18,46 +20,11 @@ function SourceNumberButton({ index }: { index: number }) {
);
}
enum NODE_TYPE {
URL,
FILE,
UNKNOWN,
}
type NodeInfo = {
id: string;
type: NODE_TYPE;
path?: string;
url?: string;
};
function getNodeInfo(node: SourceNode): NodeInfo {
if (typeof node.metadata["URL"] === "string") {
const url = node.metadata["URL"];
return {
id: node.id,
type: NODE_TYPE.URL,
path: url,
url,
};
}
if (typeof node.metadata["file_path"] === "string") {
const fileName = node.metadata["file_name"] as string;
const filePath = `${DATA_SOURCE_FOLDER}/${fileName}`;
return {
id: node.id,
type: NODE_TYPE.FILE,
path: node.metadata["file_path"],
url: getStaticFileDataUrl(filePath),
};
}
return {
id: node.id,
type: NODE_TYPE.UNKNOWN,
};
}
export function ChatSources({ data }: { data: SourceData }) {
const sources: NodeInfo[] = useMemo(() => {
// aggregate nodes by url or file_path (get the highest one by score)
@@ -67,8 +34,11 @@ export function ChatSources({ data }: { data: SourceData }) {
.filter((node) => (node.score ?? 1) > SCORE_THRESHOLD)
.sort((a, b) => (b.score ?? 1) - (a.score ?? 1))
.forEach((node) => {
const nodeInfo = getNodeInfo(node);
const key = nodeInfo.path ?? nodeInfo.id; // use id as key for UNKNOWN type
const nodeInfo = {
id: node.id,
url: node.url,
};
const key = nodeInfo.url ?? nodeInfo.id; // use id as key for UNKNOWN type
if (!nodesByPath[key]) {
nodesByPath[key] = nodeInfo;
}
@@ -84,13 +54,12 @@ export function ChatSources({ data }: { data: SourceData }) {
<span className="font-semibold">Sources:</span>
<div className="inline-flex gap-1 items-center">
{sources.map((nodeInfo: NodeInfo, index: number) => {
if (nodeInfo.path?.endsWith(".pdf")) {
if (nodeInfo.url?.endsWith(".pdf")) {
return (
<PdfDialog
key={nodeInfo.id}
documentId={nodeInfo.id}
url={nodeInfo.url!}
path={nodeInfo.path}
trigger={<SourceNumberButton index={index} />}
/>
);
@@ -116,16 +85,16 @@ export function ChatSources({ data }: { data: SourceData }) {
function NodeInfo({ nodeInfo }: { nodeInfo: NodeInfo }) {
const { isCopied, copyToClipboard } = useCopyToClipboard({ timeout: 1000 });
if (nodeInfo.type !== NODE_TYPE.UNKNOWN) {
if (nodeInfo.url) {
// this is a node generated by the web loader or file loader,
// add a link to view its URL and a button to copy the URL to the clipboard
return (
<div className="flex items-center my-2">
<a className="hover:text-blue-900" href={nodeInfo.url} target="_blank">
<span>{nodeInfo.path}</span>
<span>{nodeInfo.url}</span>
</a>
<Button
onClick={() => copyToClipboard(nodeInfo.path!)}
onClick={() => copyToClipboard(nodeInfo.url!)}
size="icon"
variant="ghost"
className="h-12 w-12 shrink-0"
@@ -1,5 +1,5 @@
import { ToolData } from "./index";
import { WeatherCard, WeatherData } from "./widgets/WeatherCard";
import { ToolData } from "../index";
import { WeatherCard, WeatherData } from "../widgets/WeatherCard";
// TODO: If needed, add displaying more tool outputs here
export default function ChatTools({ data }: { data: ToolData }) {
@@ -5,8 +5,8 @@ import { FC, memo } from "react";
import { Prism, SyntaxHighlighterProps } from "react-syntax-highlighter";
import { coldarkDark } from "react-syntax-highlighter/dist/cjs/styles/prism";
import { Button } from "../button";
import { useCopyToClipboard } from "./use-copy-to-clipboard";
import { Button } from "../../button";
import { useCopyToClipboard } from "../hooks/use-copy-to-clipboard";
// TODO: Remove this when @type/react-syntax-highlighter is updated
const SyntaxHighlighter = Prism as unknown as FC<SyntaxHighlighterProps>;
@@ -0,0 +1,13 @@
import UploadCsvPreview from "../../upload-csv-preview";
import { CsvData } from "../index";
export default function CsvContent({ data }: { data: CsvData }) {
if (!data.csvFiles.length) return null;
return (
<div className="flex gap-2 items-center">
{data.csvFiles.map((csv, index) => (
<UploadCsvPreview key={index} csv={csv} />
))}
</div>
);
}
@@ -2,36 +2,31 @@ import { Check, Copy } from "lucide-react";
import { Message } from "ai";
import { Fragment } from "react";
import { Button } from "../button";
import ChatAvatar from "./chat-avatar";
import { ChatEvents } from "./chat-events";
import { ChatImage } from "./chat-image";
import { ChatSources } from "./chat-sources";
import ChatTools from "./chat-tools";
import { Button } from "../../button";
import { useCopyToClipboard } from "../hooks/use-copy-to-clipboard";
import {
AnnotationData,
CsvData,
EventData,
ImageData,
MessageAnnotation,
MessageAnnotationType,
SourceData,
ToolData,
} from "./index";
getAnnotationData,
} from "../index";
import ChatAvatar from "./chat-avatar";
import { ChatEvents } from "./chat-events";
import { ChatImage } from "./chat-image";
import { ChatSources } from "./chat-sources";
import ChatTools from "./chat-tools";
import CsvContent from "./csv-content";
import Markdown from "./markdown";
import { useCopyToClipboard } from "./use-copy-to-clipboard";
type ContentDisplayConfig = {
order: number;
component: JSX.Element | null;
};
function getAnnotationData<T extends AnnotationData>(
annotations: MessageAnnotation[],
type: MessageAnnotationType,
): T[] {
return annotations.filter((a) => a.type === type).map((a) => a.data as T);
}
function ChatMessageContent({
message,
isLoading,
@@ -46,6 +41,10 @@ function ChatMessageContent({
annotations,
MessageAnnotationType.IMAGE,
);
const csvData = getAnnotationData<CsvData>(
annotations,
MessageAnnotationType.CSV,
);
const eventData = getAnnotationData<EventData>(
annotations,
MessageAnnotationType.EVENTS,
@@ -61,16 +60,20 @@ function ChatMessageContent({
const contents: ContentDisplayConfig[] = [
{
order: -3,
order: 1,
component: imageData[0] ? <ChatImage data={imageData[0]} /> : null,
},
{
order: -2,
order: -3,
component:
eventData.length > 0 ? (
<ChatEvents isLoading={isLoading} data={eventData} />
) : null,
},
{
order: 2,
component: csvData[0] ? <CsvContent data={csvData[0]} /> : null,
},
{
order: -1,
component: toolData[0] ? <ChatTools data={toolData[0]} /> : null,
@@ -80,7 +83,7 @@ function ChatMessageContent({
component: <Markdown content={message.content} />,
},
{
order: 1,
order: 3,
component: sourceData[0] ? <ChatSources data={sourceData[0]} /> : null,
},
];
@@ -1,13 +1,19 @@
import { Loader2 } from "lucide-react";
import { useEffect, useRef } from "react";
import { Button } from "../button";
import ChatActions from "./chat-actions";
import ChatMessage from "./chat-message";
import { ChatHandler } from "./chat.interface";
import { useClientConfig } from "./hooks/use-config";
export default function ChatMessages(
props: Pick<ChatHandler, "messages" | "isLoading" | "reload" | "stop">,
props: Pick<
ChatHandler,
"messages" | "isLoading" | "reload" | "stop" | "append"
>,
) {
const { starterQuestions } = useClientConfig();
const scrollableChatContainerRef = useRef<HTMLDivElement>(null);
const messageLength = props.messages.length;
const lastMessage = props.messages[messageLength - 1];
@@ -35,11 +41,11 @@ export default function ChatMessages(
}, [messageLength, lastMessage]);
return (
<div className="w-full rounded-xl bg-white p-4 shadow-xl pb-0">
<div
className="flex h-[50vh] flex-col gap-5 divide-y overflow-y-auto pb-4"
ref={scrollableChatContainerRef}
>
<div
className="flex-1 w-full rounded-xl bg-white p-4 shadow-xl relative overflow-y-auto"
ref={scrollableChatContainerRef}
>
<div className="flex flex-col gap-5 divide-y">
{props.messages.map((m, i) => {
const isLoadingMessage = i === messageLength - 1 && props.isLoading;
return (
@@ -56,14 +62,33 @@ export default function ChatMessages(
</div>
)}
</div>
<div className="flex justify-end py-4">
<ChatActions
reload={props.reload}
stop={props.stop}
showReload={showReload}
showStop={showStop}
/>
</div>
{(showReload || showStop) && (
<div className="flex justify-end py-4">
<ChatActions
reload={props.reload}
stop={props.stop}
showReload={showReload}
showStop={showStop}
/>
</div>
)}
{!messageLength && starterQuestions?.length && props.append && (
<div className="absolute bottom-6 left-0 w-full">
<div className="grid grid-cols-2 gap-2 mx-20">
{starterQuestions.map((question, i) => (
<Button
variant="outline"
key={i}
onClick={() =>
props.append!({ role: "user", content: question })
}
>
{question}
</Button>
))}
</div>
</div>
)}
</div>
);
}
@@ -15,4 +15,11 @@ export interface ChatHandler {
stop?: () => void;
onFileUpload?: (file: File) => Promise<void>;
onFileError?: (errMsg: string) => void;
setInput?: (input: string) => void;
append?: (
message: Message | Omit<Message, "id">,
ops?: {
data: any;
},
) => Promise<string | null | undefined>;
}
@@ -0,0 +1,30 @@
"use client";
import { useEffect, useMemo, useState } from "react";
export interface ChatConfig {
chatAPI?: string;
starterQuestions?: string[];
}
export function useClientConfig() {
const API_ROUTE = "/api/chat/config";
const chatAPI = process.env.NEXT_PUBLIC_CHAT_API;
const [config, setConfig] = useState<ChatConfig>({
chatAPI,
});
const configAPI = useMemo(() => {
const backendOrigin = chatAPI ? new URL(chatAPI).origin : "";
return `${backendOrigin}${API_ROUTE}`;
}, [chatAPI]);
useEffect(() => {
fetch(configAPI)
.then((response) => response.json())
.then((data) => setConfig({ ...data, chatAPI }))
.catch((error) => console.error("Error fetching config", error));
}, [chatAPI, configAPI]);
return config;
}
@@ -0,0 +1,33 @@
"use client";
import { useState } from "react";
import { CsvFile } from "../index";
export function useCsv() {
const [files, setFiles] = useState<CsvFile[]>([]);
const csvEqual = (a: CsvFile, b: CsvFile) => {
if (a.id === b.id) return true;
if (a.filename === b.filename && a.filesize === b.filesize) return true;
return false;
};
const upload = (file: CsvFile) => {
const existedCsv = files.find((f) => csvEqual(f, file));
if (!existedCsv) {
setFiles((prev) => [...prev, file]);
return true;
}
return false;
};
const remove = (file: CsvFile) => {
setFiles((prev) => prev.filter((f) => f.id !== file.id));
};
const reset = () => {
setFiles([]);
};
return { files, upload, remove, reset };
}
@@ -6,6 +6,7 @@ export { type ChatHandler } from "./chat.interface";
export { ChatInput, ChatMessages };
export enum MessageAnnotationType {
CSV = "csv",
IMAGE = "image",
SOURCES = "sources",
EVENTS = "events",
@@ -16,11 +17,23 @@ export type ImageData = {
url: string;
};
export type CsvFile = {
content: string;
filename: string;
filesize: number;
id: string;
};
export type CsvData = {
csvFiles: CsvFile[];
};
export type SourceNode = {
id: string;
metadata: Record<string, unknown>;
score?: number;
text: string;
url?: string;
};
export type SourceData = {
@@ -46,9 +59,21 @@ export type ToolData = {
};
};
export type AnnotationData = ImageData | SourceData | EventData | ToolData;
export type AnnotationData =
| ImageData
| CsvData
| SourceData
| EventData
| ToolData;
export type MessageAnnotation = {
type: MessageAnnotationType;
data: AnnotationData;
};
export function getAnnotationData<T extends AnnotationData>(
annotations: MessageAnnotation[],
type: MessageAnnotationType,
): T[] {
return annotations.filter((a) => a.type === type).map((a) => a.data as T);
}
@@ -12,7 +12,6 @@ import {
export interface PdfDialogProps {
documentId: string;
path: string;
url: string;
trigger: React.ReactNode;
}
@@ -26,13 +25,13 @@ export default function PdfDialog(props: PdfDialogProps) {
<div className="space-y-2">
<DrawerTitle>PDF Content</DrawerTitle>
<DrawerDescription>
File path:{" "}
File URL:{" "}
<a
className="hover:text-blue-900"
href={props.url}
target="_blank"
>
{props.path}
{props.url}
</a>
</DrawerDescription>
</div>
@@ -0,0 +1,90 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg width="49px" height="67px" viewBox="0 0 49 67" version="1.1"
xmlns="http://www.w3.org/2000/svg"
xmlns:xlink="http://www.w3.org/1999/xlink">
<title>Sheets-icon</title>
<desc>Created with Sketch.</desc>
<defs>
<path d="M29.5833333,0 L4.4375,0 C1.996875,0 0,1.996875 0,4.4375 L0,60.6458333 C0,63.0864583 1.996875,65.0833333 4.4375,65.0833333 L42.8958333,65.0833333 C45.3364583,65.0833333 47.3333333,63.0864583 47.3333333,60.6458333 L47.3333333,17.75 L29.5833333,0 Z" id="path-1"></path>
<path d="M29.5833333,0 L4.4375,0 C1.996875,0 0,1.996875 0,4.4375 L0,60.6458333 C0,63.0864583 1.996875,65.0833333 4.4375,65.0833333 L42.8958333,65.0833333 C45.3364583,65.0833333 47.3333333,63.0864583 47.3333333,60.6458333 L47.3333333,17.75 L29.5833333,0 Z" id="path-3"></path>
<path d="M29.5833333,0 L4.4375,0 C1.996875,0 0,1.996875 0,4.4375 L0,60.6458333 C0,63.0864583 1.996875,65.0833333 4.4375,65.0833333 L42.8958333,65.0833333 C45.3364583,65.0833333 47.3333333,63.0864583 47.3333333,60.6458333 L47.3333333,17.75 L29.5833333,0 Z" id="path-5"></path>
<linearGradient x1="50.0053945%" y1="8.58610612%" x2="50.0053945%" y2="100.013939%" id="linearGradient-7">
<stop stop-color="#263238" stop-opacity="0.2" offset="0%"></stop>
<stop stop-color="#263238" stop-opacity="0.02" offset="100%"></stop>
</linearGradient>
<path d="M29.5833333,0 L4.4375,0 C1.996875,0 0,1.996875 0,4.4375 L0,60.6458333 C0,63.0864583 1.996875,65.0833333 4.4375,65.0833333 L42.8958333,65.0833333 C45.3364583,65.0833333 47.3333333,63.0864583 47.3333333,60.6458333 L47.3333333,17.75 L29.5833333,0 Z" id="path-8"></path>
<path d="M29.5833333,0 L4.4375,0 C1.996875,0 0,1.996875 0,4.4375 L0,60.6458333 C0,63.0864583 1.996875,65.0833333 4.4375,65.0833333 L42.8958333,65.0833333 C45.3364583,65.0833333 47.3333333,63.0864583 47.3333333,60.6458333 L47.3333333,17.75 L29.5833333,0 Z" id="path-10"></path>
<path d="M29.5833333,0 L4.4375,0 C1.996875,0 0,1.996875 0,4.4375 L0,60.6458333 C0,63.0864583 1.996875,65.0833333 4.4375,65.0833333 L42.8958333,65.0833333 C45.3364583,65.0833333 47.3333333,63.0864583 47.3333333,60.6458333 L47.3333333,17.75 L29.5833333,0 Z" id="path-12"></path>
<path d="M29.5833333,0 L4.4375,0 C1.996875,0 0,1.996875 0,4.4375 L0,60.6458333 C0,63.0864583 1.996875,65.0833333 4.4375,65.0833333 L42.8958333,65.0833333 C45.3364583,65.0833333 47.3333333,63.0864583 47.3333333,60.6458333 L47.3333333,17.75 L29.5833333,0 Z" id="path-14"></path>
<radialGradient cx="3.16804688%" cy="2.71744318%" fx="3.16804688%" fy="2.71744318%" r="161.248516%" gradientTransform="translate(0.031680,0.027174),scale(1.000000,0.727273),translate(-0.031680,-0.027174)" id="radialGradient-16">
<stop stop-color="#FFFFFF" stop-opacity="0.1" offset="0%"></stop>
<stop stop-color="#FFFFFF" stop-opacity="0" offset="100%"></stop>
</radialGradient>
</defs>
<g id="Page-1" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
<g id="Consumer-Apps-Sheets-Large-VD-R8-" transform="translate(-451.000000, -451.000000)">
<g id="Hero" transform="translate(0.000000, 63.000000)">
<g id="Personal" transform="translate(277.000000, 299.000000)">
<g id="Sheets-icon" transform="translate(174.833333, 89.958333)">
<g id="Group">
<g id="Clipped">
<mask id="mask-2" fill="white">
<use xlink:href="#path-1"></use>
</mask>
<g id="SVGID_1_"></g>
<path d="M29.5833333,0 L4.4375,0 C1.996875,0 0,1.996875 0,4.4375 L0,60.6458333 C0,63.0864583 1.996875,65.0833333 4.4375,65.0833333 L42.8958333,65.0833333 C45.3364583,65.0833333 47.3333333,63.0864583 47.3333333,60.6458333 L47.3333333,17.75 L36.9791667,10.3541667 L29.5833333,0 Z" id="Path" fill="#0F9D58" fill-rule="nonzero" mask="url(#mask-2)"></path>
</g>
<g id="Clipped">
<mask id="mask-4" fill="white">
<use xlink:href="#path-3"></use>
</mask>
<g id="SVGID_1_"></g>
<path d="M11.8333333,31.8020833 L11.8333333,53.25 L35.5,53.25 L35.5,31.8020833 L11.8333333,31.8020833 Z M22.1875,50.2916667 L14.7916667,50.2916667 L14.7916667,46.59375 L22.1875,46.59375 L22.1875,50.2916667 Z M22.1875,44.375 L14.7916667,44.375 L14.7916667,40.6770833 L22.1875,40.6770833 L22.1875,44.375 Z M22.1875,38.4583333 L14.7916667,38.4583333 L14.7916667,34.7604167 L22.1875,34.7604167 L22.1875,38.4583333 Z M32.5416667,50.2916667 L25.1458333,50.2916667 L25.1458333,46.59375 L32.5416667,46.59375 L32.5416667,50.2916667 Z M32.5416667,44.375 L25.1458333,44.375 L25.1458333,40.6770833 L32.5416667,40.6770833 L32.5416667,44.375 Z M32.5416667,38.4583333 L25.1458333,38.4583333 L25.1458333,34.7604167 L32.5416667,34.7604167 L32.5416667,38.4583333 Z" id="Shape" fill="#F1F1F1" fill-rule="nonzero" mask="url(#mask-4)"></path>
</g>
<g id="Clipped">
<mask id="mask-6" fill="white">
<use xlink:href="#path-5"></use>
</mask>
<g id="SVGID_1_"></g>
<polygon id="Path" fill="url(#linearGradient-7)" fill-rule="nonzero" mask="url(#mask-6)" points="30.8813021 16.4520313 47.3333333 32.9003646 47.3333333 17.75"></polygon>
</g>
<g id="Clipped">
<mask id="mask-9" fill="white">
<use xlink:href="#path-8"></use>
</mask>
<g id="SVGID_1_"></g>
<g id="Group" mask="url(#mask-9)">
<g transform="translate(26.625000, -2.958333)">
<path d="M2.95833333,2.95833333 L2.95833333,16.2708333 C2.95833333,18.7225521 4.94411458,20.7083333 7.39583333,20.7083333 L20.7083333,20.7083333 L2.95833333,2.95833333 Z" id="Path" fill="#87CEAC" fill-rule="nonzero"></path>
</g>
</g>
</g>
<g id="Clipped">
<mask id="mask-11" fill="white">
<use xlink:href="#path-10"></use>
</mask>
<g id="SVGID_1_"></g>
<path d="M4.4375,0 C1.996875,0 0,1.996875 0,4.4375 L0,4.80729167 C0,2.36666667 1.996875,0.369791667 4.4375,0.369791667 L29.5833333,0.369791667 L29.5833333,0 L4.4375,0 Z" id="Path" fill-opacity="0.2" fill="#FFFFFF" fill-rule="nonzero" mask="url(#mask-11)"></path>
</g>
<g id="Clipped">
<mask id="mask-13" fill="white">
<use xlink:href="#path-12"></use>
</mask>
<g id="SVGID_1_"></g>
<path d="M42.8958333,64.7135417 L4.4375,64.7135417 C1.996875,64.7135417 0,62.7166667 0,60.2760417 L0,60.6458333 C0,63.0864583 1.996875,65.0833333 4.4375,65.0833333 L42.8958333,65.0833333 C45.3364583,65.0833333 47.3333333,63.0864583 47.3333333,60.6458333 L47.3333333,60.2760417 C47.3333333,62.7166667 45.3364583,64.7135417 42.8958333,64.7135417 Z" id="Path" fill-opacity="0.2" fill="#263238" fill-rule="nonzero" mask="url(#mask-13)"></path>
</g>
<g id="Clipped">
<mask id="mask-15" fill="white">
<use xlink:href="#path-14"></use>
</mask>
<g id="SVGID_1_"></g>
<path d="M34.0208333,17.75 C31.5691146,17.75 29.5833333,15.7642188 29.5833333,13.3125 L29.5833333,13.6822917 C29.5833333,16.1340104 31.5691146,18.1197917 34.0208333,18.1197917 L47.3333333,18.1197917 L47.3333333,17.75 L34.0208333,17.75 Z" id="Path" fill-opacity="0.1" fill="#263238" fill-rule="nonzero" mask="url(#mask-15)"></path>
</g>
</g>
<path d="M29.5833333,0 L4.4375,0 C1.996875,0 0,1.996875 0,4.4375 L0,60.6458333 C0,63.0864583 1.996875,65.0833333 4.4375,65.0833333 L42.8958333,65.0833333 C45.3364583,65.0833333 47.3333333,63.0864583 47.3333333,60.6458333 L47.3333333,17.75 L29.5833333,0 Z" id="Path" fill="url(#radialGradient-16)" fill-rule="nonzero"></path>
</g>
</g>
</g>
</g>
</g>
</svg>

After

Width:  |  Height:  |  Size: 8.9 KiB

@@ -1,11 +0,0 @@
const staticFileAPI = "/api/files";
export const getStaticFileDataUrl = (filePath: string) => {
const isUsingBackend = !!process.env.NEXT_PUBLIC_CHAT_API;
const fileUrl = `${staticFileAPI}/${filePath}`;
if (isUsingBackend) {
const backendOrigin = new URL(process.env.NEXT_PUBLIC_CHAT_API!).origin;
return `${backendOrigin}${fileUrl}`;
}
return fileUrl;
};
@@ -0,0 +1,93 @@
import { XCircleIcon } from "lucide-react";
import Image from "next/image";
import SheetIcon from "../ui/icons/sheet.svg";
import { Button } from "./button";
import { CsvFile } from "./chat";
import {
Drawer,
DrawerClose,
DrawerContent,
DrawerDescription,
DrawerHeader,
DrawerTitle,
DrawerTrigger,
} from "./drawer";
import { cn } from "./lib/utils";
export interface UploadCsvPreviewProps {
csv: CsvFile;
onRemove?: () => void;
}
export default function UploadCsvPreview(props: UploadCsvPreviewProps) {
const { filename, filesize, content } = props.csv;
return (
<Drawer direction="left">
<DrawerTrigger asChild>
<div>
<CSVSummaryCard {...props} />
</div>
</DrawerTrigger>
<DrawerContent className="w-3/5 mt-24 h-full max-h-[96%] ">
<DrawerHeader className="flex justify-between">
<div className="space-y-2">
<DrawerTitle>Csv Raw Content</DrawerTitle>
<DrawerDescription>
{filename} ({inKB(filesize)} KB)
</DrawerDescription>
</div>
<DrawerClose asChild>
<Button variant="outline">Close</Button>
</DrawerClose>
</DrawerHeader>
<div className="m-4 max-h-[80%] overflow-auto">
<pre className="bg-secondary rounded-md p-4 block text-sm">
{content}
</pre>
</div>
</DrawerContent>
</Drawer>
);
}
function CSVSummaryCard(props: UploadCsvPreviewProps) {
const { onRemove, csv } = props;
return (
<div className="p-2 w-60 max-w-60 bg-secondary rounded-lg text-sm relative cursor-pointer">
<div className="flex flex-row items-center gap-2">
<div className="relative h-10 w-10 shrink-0 overflow-hidden rounded-md">
<Image
className="h-full w-auto"
priority
src={SheetIcon}
alt="SheetIcon"
/>
</div>
<div className="overflow-hidden">
<div className="truncate font-semibold">
{csv.filename} ({inKB(csv.filesize)} KB)
</div>
<div className="truncate text-token-text-tertiary flex items-center gap-2">
<span>Spreadsheet</span>
</div>
</div>
</div>
{onRemove && (
<div
className={cn(
"absolute -top-2 -right-2 w-6 h-6 z-10 bg-gray-500 text-white rounded-full",
)}
>
<XCircleIcon
className="w-6 h-6 bg-gray-500 text-white rounded-full"
onClick={onRemove}
/>
</div>
)}
</div>
);
}
function inKB(size: number) {
return Math.round((size / 1024) * 10) / 10;
}
@@ -74,8 +74,11 @@
* {
@apply border-border;
}
html {
@apply h-full;
}
body {
@apply bg-background text-foreground;
@apply bg-background text-foreground h-full;
font-feature-settings:
"rlig" 1,
"calt" 1;
@@ -3,9 +3,13 @@ import ChatSection from "./components/chat-section";
export default function Home() {
return (
<main className="flex min-h-screen flex-col items-center gap-10 p-24 background-gradient">
<Header />
<ChatSection />
<main className="h-full w-full flex justify-center items-center background-gradient">
<div className="space-y-2 lg:space-y-10 w-[90%] lg:w-[60rem]">
<Header />
<div className="h-[65vh] flex">
<ChatSection />
</div>
</div>
</main>
);
}
@@ -18,7 +18,8 @@
"class-variance-authority": "^0.7.0",
"clsx": "^2.1.1",
"dotenv": "^16.3.1",
"llamaindex": "0.3.13",
"duck-duck-scrape": "^2.2.5",
"llamaindex": "0.3.16",
"lucide-react": "^0.294.0",
"next": "^14.0.3",
"pdf2json": "3.0.5",
@@ -35,7 +36,10 @@
"tailwind-merge": "^2.1.0",
"vaul": "^0.9.1",
"@llamaindex/pdf-viewer": "^1.1.1",
"@e2b/code-interpreter": "^0.0.5"
"@e2b/code-interpreter": "^0.0.5",
"uuid": "^9.0.1",
"got": "10.7.0",
"@apidevtools/swagger-parser": "^10.1.0"
},
"devDependencies": {
"@types/node": "^20.10.3",
@@ -52,6 +56,7 @@
"prettier-plugin-organize-imports": "^3.2.4",
"tailwindcss": "^3.3.6",
"tsx": "^4.7.2",
"typescript": "^5.3.2"
"typescript": "^5.3.2",
"@types/uuid": "^9.0.8"
}
}