Compare commits

..

9 Commits

Author SHA1 Message Date
github-actions[bot] 988bfc2a60 Release 0.1.2 (#79)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2024-05-10 14:12:31 +07:00
Thuc Pham 056e376ee0 feat: add weather widget and weather tool (#72)
---------
Co-authored-by: leehuwuj <leehuwuj@gmail.com>
Co-authored-by: Marcus Schiesser <mail@marcusschiesser.de>
2024-05-10 14:00:16 +07:00
Thuc Pham 819cccb11a feat: use 3.5 as default model (#77) 2024-05-09 15:48:25 +07:00
Huu Le (Lee) 8a5ece10c2 chores: update wrong example system prompt and fix missing switch breaking (#75) 2024-05-08 10:14:34 +07:00
github-actions[bot] 63bb0505d6 Release 0.1.1 (#60)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2024-05-03 10:38:01 +07:00
Huu Le (Lee) 2e80ef47ee Fix typo in settings.py (#73) 2024-05-03 10:36:12 +07:00
Marcus Schiesser a1feb524e9 Revert "Use ingestion pipeline in Python code (#61)"
This reverts commit c094b0c6bf.
2024-05-03 11:06:02 +08:00
Marcus Schiesser 06823da849 fix: stream type 2024-05-02 17:25:49 +08:00
Thuc Pham 7bd3ed551f feat: support anthropic and gemini model providers and update to LITS 0.3.3 (#63)
Co-authored-by: Marcus Schiesser <mail@marcusschiesser.de>
2024-05-02 16:13:31 +07:00
76 changed files with 1685 additions and 459 deletions
-5
View File
@@ -1,5 +0,0 @@
---
"create-llama": patch
---
Use ingestion pipeline for Python
-5
View File
@@ -1,5 +0,0 @@
---
"create-llama": patch
---
Display events (e.g. retrieving nodes) per chat message
+14
View File
@@ -1,5 +1,19 @@
# create-llama
## 0.1.2
### Patch Changes
- 056e376: Add support for displaying tool outputs (including weather widget as example)
## 0.1.1
### Patch Changes
- 7bd3ed5: Support Anthropic and Gemini as model providers
- 7bd3ed5: Support new agents from LITS 0.3
- cfb5257: Display events (e.g. retrieving nodes) per chat message
## 0.1.0
### Minor Changes
+19 -7
View File
@@ -173,6 +173,24 @@ const getModelEnvs = (modelConfig: ModelConfig): EnvVar[] => {
},
]
: []),
...(modelConfig.provider === "anthropic"
? [
{
name: "ANTHROPIC_API_KEY",
description: "The Anthropic API key to use.",
value: modelConfig.apiKey,
},
]
: []),
...(modelConfig.provider === "gemini"
? [
{
name: "GOOGLE_API_KEY",
description: "The Google API key to use.",
value: modelConfig.apiKey,
},
]
: []),
];
};
@@ -199,13 +217,7 @@ const getFrameworkEnvs = (
name: "SYSTEM_PROMPT",
description: `Custom system prompt.
Example:
SYSTEM_PROMPT="
We have provided context information below.
---------------------
{context_str}
---------------------
Given this information, please answer the question: {query_str}
"`,
SYSTEM_PROMPT="You are a helpful assistant who helps users with their questions."`,
},
];
};
+1 -2
View File
@@ -9,7 +9,6 @@ import { createBackendEnvFile, createFrontendEnvFile } from "./env-variables";
import { PackageManager } from "./get-pkg-manager";
import { installLlamapackProject } from "./llama-pack";
import { isHavingPoetryLockFile, tryPoetryRun } from "./poetry";
import { isModelConfigured } from "./providers";
import { installPythonTemplate } from "./python";
import { downloadAndExtractRepo } from "./repo";
import { ConfigFileType, writeToolsConfig } from "./tools";
@@ -38,7 +37,7 @@ async function generateContextData(
? "poetry run generate"
: `${packageManager} run generate`,
)}`;
const modelConfigured = isModelConfigured(modelConfig);
const modelConfigured = modelConfig.isConfigured();
const llamaCloudKeyConfigured = useLlamaParse
? llamaCloudKey || process.env["LLAMA_CLOUD_API_KEY"]
: true;
+106
View File
@@ -0,0 +1,106 @@
import ciInfo from "ci-info";
import prompts from "prompts";
import { ModelConfigParams } from ".";
import { questionHandlers, toChoice } from "../../questions";
const MODELS = [
"claude-3-opus",
"claude-3-sonnet",
"claude-3-haiku",
"claude-2.1",
"claude-instant-1.2",
];
const DEFAULT_MODEL = MODELS[0];
// TODO: get embedding vector dimensions from the anthropic sdk (currently not supported)
// Use huggingface embedding models for now
enum HuggingFaceEmbeddingModelType {
XENOVA_ALL_MINILM_L6_V2 = "all-MiniLM-L6-v2",
XENOVA_ALL_MPNET_BASE_V2 = "all-mpnet-base-v2",
}
type ModelData = {
dimensions: number;
};
const EMBEDDING_MODELS: Record<HuggingFaceEmbeddingModelType, ModelData> = {
[HuggingFaceEmbeddingModelType.XENOVA_ALL_MINILM_L6_V2]: {
dimensions: 384,
},
[HuggingFaceEmbeddingModelType.XENOVA_ALL_MPNET_BASE_V2]: {
dimensions: 768,
},
};
const DEFAULT_EMBEDDING_MODEL = Object.keys(EMBEDDING_MODELS)[0];
const DEFAULT_DIMENSIONS = Object.values(EMBEDDING_MODELS)[0].dimensions;
type AnthropicQuestionsParams = {
apiKey?: string;
askModels: boolean;
};
export async function askAnthropicQuestions({
askModels,
apiKey,
}: AnthropicQuestionsParams): Promise<ModelConfigParams> {
const config: ModelConfigParams = {
apiKey,
model: DEFAULT_MODEL,
embeddingModel: DEFAULT_EMBEDDING_MODEL,
dimensions: DEFAULT_DIMENSIONS,
isConfigured(): boolean {
if (config.apiKey) {
return true;
}
if (process.env["ANTHROPIC_API_KEY"]) {
return true;
}
return false;
},
};
if (!config.apiKey) {
const { key } = await prompts(
{
type: "text",
name: "key",
message:
"Please provide your Anthropic API key (or leave blank to use ANTHROPIC_API_KEY env variable):",
},
questionHandlers,
);
config.apiKey = key || process.env.ANTHROPIC_API_KEY;
}
// use default model values in CI or if user should not be asked
const useDefaults = ciInfo.isCI || !askModels;
if (!useDefaults) {
const { model } = await prompts(
{
type: "select",
name: "model",
message: "Which LLM model would you like to use?",
choices: MODELS.map(toChoice),
initial: 0,
},
questionHandlers,
);
config.model = model;
const { embeddingModel } = await prompts(
{
type: "select",
name: "embeddingModel",
message: "Which embedding model would you like to use?",
choices: Object.keys(EMBEDDING_MODELS).map(toChoice),
initial: 0,
},
questionHandlers,
);
config.embeddingModel = embeddingModel;
config.dimensions =
EMBEDDING_MODELS[
embeddingModel as HuggingFaceEmbeddingModelType
].dimensions;
}
return config;
}
+87
View File
@@ -0,0 +1,87 @@
import ciInfo from "ci-info";
import prompts from "prompts";
import { ModelConfigParams } from ".";
import { questionHandlers, toChoice } from "../../questions";
const MODELS = ["gemini-1.5-pro-latest", "gemini-pro", "gemini-pro-vision"];
type ModelData = {
dimensions: number;
};
const EMBEDDING_MODELS: Record<string, ModelData> = {
"embedding-001": { dimensions: 768 },
"text-embedding-004": { dimensions: 768 },
};
const DEFAULT_MODEL = MODELS[0];
const DEFAULT_EMBEDDING_MODEL = Object.keys(EMBEDDING_MODELS)[0];
const DEFAULT_DIMENSIONS = Object.values(EMBEDDING_MODELS)[0].dimensions;
type GeminiQuestionsParams = {
apiKey?: string;
askModels: boolean;
};
export async function askGeminiQuestions({
askModels,
apiKey,
}: GeminiQuestionsParams): Promise<ModelConfigParams> {
const config: ModelConfigParams = {
apiKey,
model: DEFAULT_MODEL,
embeddingModel: DEFAULT_EMBEDDING_MODEL,
dimensions: DEFAULT_DIMENSIONS,
isConfigured(): boolean {
if (config.apiKey) {
return true;
}
if (process.env["GOOGLE_API_KEY"]) {
return true;
}
return false;
},
};
if (!config.apiKey) {
const { key } = await prompts(
{
type: "text",
name: "key",
message:
"Please provide your Google API key (or leave blank to use GOOGLE_API_KEY env variable):",
},
questionHandlers,
);
config.apiKey = key || process.env.GOOGLE_API_KEY;
}
// use default model values in CI or if user should not be asked
const useDefaults = ciInfo.isCI || !askModels;
if (!useDefaults) {
const { model } = await prompts(
{
type: "select",
name: "model",
message: "Which LLM model would you like to use?",
choices: MODELS.map(toChoice),
initial: 0,
},
questionHandlers,
);
config.model = model;
const { embeddingModel } = await prompts(
{
type: "select",
name: "embeddingModel",
message: "Which embedding model would you like to use?",
choices: Object.keys(EMBEDDING_MODELS).map(toChoice),
initial: 0,
},
questionHandlers,
);
config.embeddingModel = embeddingModel;
config.dimensions = EMBEDDING_MODELS[embeddingModel].dimensions;
}
return config;
}
+11 -10
View File
@@ -2,8 +2,10 @@ import ciInfo from "ci-info";
import prompts from "prompts";
import { questionHandlers } from "../../questions";
import { ModelConfig, ModelProvider } from "../types";
import { askAnthropicQuestions } from "./anthropic";
import { askGeminiQuestions } from "./gemini";
import { askOllamaQuestions } from "./ollama";
import { askOpenAIQuestions, isOpenAIConfigured } from "./openai";
import { askOpenAIQuestions } from "./openai";
const DEFAULT_MODEL_PROVIDER = "openai";
@@ -31,6 +33,8 @@ export async function askModelConfig({
value: "openai",
},
{ title: "Ollama", value: "ollama" },
{ title: "Anthropic", value: "anthropic" },
{ title: "Gemini", value: "gemini" },
],
initial: 0,
},
@@ -44,6 +48,12 @@ export async function askModelConfig({
case "ollama":
modelConfig = await askOllamaQuestions({ askModels });
break;
case "anthropic":
modelConfig = await askAnthropicQuestions({ askModels });
break;
case "gemini":
modelConfig = await askGeminiQuestions({ askModels });
break;
default:
modelConfig = await askOpenAIQuestions({
openAiKey,
@@ -55,12 +65,3 @@ export async function askModelConfig({
provider: modelProvider,
};
}
export function isModelConfigured(modelConfig: ModelConfig): boolean {
switch (modelConfig.provider) {
case "openai":
return isOpenAIConfigured(modelConfig);
default:
return true;
}
}
+3
View File
@@ -29,6 +29,9 @@ export async function askOllamaQuestions({
model: DEFAULT_MODEL,
embeddingModel: DEFAULT_EMBEDDING_MODEL,
dimensions: EMBEDDING_MODELS[DEFAULT_EMBEDDING_MODEL].dimensions,
isConfigured(): boolean {
return true;
},
};
// use default model values in CI or if user should not be asked
+10 -12
View File
@@ -8,7 +8,7 @@ import { questionHandlers } from "../../questions";
const OPENAI_API_URL = "https://api.openai.com/v1";
const DEFAULT_MODEL = "gpt-4-turbo";
const DEFAULT_MODEL = "gpt-3.5-turbo";
const DEFAULT_EMBEDDING_MODEL = "text-embedding-3-large";
export async function askOpenAIQuestions({
@@ -20,6 +20,15 @@ export async function askOpenAIQuestions({
model: DEFAULT_MODEL,
embeddingModel: DEFAULT_EMBEDDING_MODEL,
dimensions: getDimensions(DEFAULT_EMBEDDING_MODEL),
isConfigured(): boolean {
if (config.apiKey) {
return true;
}
if (process.env["OPENAI_API_KEY"]) {
return true;
}
return false;
},
};
if (!config.apiKey) {
@@ -31,7 +40,6 @@ export async function askOpenAIQuestions({
? "Please provide your OpenAI API key (or leave blank to use OPENAI_API_KEY env variable):"
: "Please provide your OpenAI API key (leave blank to skip):",
validate: (value: string) => {
console.log(value);
if (askModels && !value) {
if (process.env.OPENAI_API_KEY) {
return true;
@@ -78,16 +86,6 @@ export async function askOpenAIQuestions({
return config;
}
export function isOpenAIConfigured(params: ModelConfigParams): boolean {
if (params.apiKey) {
return true;
}
if (process.env["OPENAI_API_KEY"]) {
return true;
}
return false;
}
async function getAvailableModelChoices(
selectEmbedding: boolean,
apiKey?: string,
+64 -33
View File
@@ -24,7 +24,7 @@ interface Dependency {
const getAdditionalDependencies = (
modelConfig: ModelConfig,
vectorDb?: TemplateVectorDB,
dataSource?: TemplateDataSource,
dataSources?: TemplateDataSource[],
tools?: Tool[],
) => {
const dependencies: Dependency[] = [];
@@ -43,6 +43,7 @@ const getAdditionalDependencies = (
name: "llama-index-vector-stores-postgres",
version: "^0.1.1",
});
break;
}
case "pinecone": {
dependencies.push({
@@ -72,38 +73,43 @@ const getAdditionalDependencies = (
}
// Add data source dependencies
const dataSourceType = dataSource?.type;
switch (dataSourceType) {
case "file":
dependencies.push({
name: "docx2txt",
version: "^0.8",
});
break;
case "web":
dependencies.push({
name: "llama-index-readers-web",
version: "^0.1.6",
});
break;
case "db":
dependencies.push({
name: "llama-index-readers-database",
version: "^0.1.3",
});
dependencies.push({
name: "pymysql",
version: "^1.1.0",
extras: ["rsa"],
});
dependencies.push({
name: "psycopg2",
version: "^2.9.9",
});
break;
if (dataSources) {
for (const ds of dataSources) {
const dsType = ds?.type;
switch (dsType) {
case "file":
dependencies.push({
name: "docx2txt",
version: "^0.8",
});
break;
case "web":
dependencies.push({
name: "llama-index-readers-web",
version: "^0.1.6",
});
break;
case "db":
dependencies.push({
name: "llama-index-readers-database",
version: "^0.1.3",
});
dependencies.push({
name: "pymysql",
version: "^1.1.0",
extras: ["rsa"],
});
dependencies.push({
name: "psycopg2",
version: "^2.9.9",
});
break;
}
}
}
// Add tools dependencies
console.log("Adding tools dependencies");
tools?.forEach((tool) => {
tool.dependencies?.forEach((dep) => {
dependencies.push(dep);
@@ -127,6 +133,26 @@ const getAdditionalDependencies = (
version: "0.2.2",
});
break;
case "anthropic":
dependencies.push({
name: "llama-index-llms-anthropic",
version: "0.1.10",
});
dependencies.push({
name: "llama-index-embeddings-huggingface",
version: "0.2.0",
});
break;
case "gemini":
dependencies.push({
name: "llama-index-llms-gemini",
version: "0.1.7",
});
dependencies.push({
name: "llama-index-embeddings-gemini",
version: "0.1.6",
});
break;
}
return dependencies;
@@ -278,9 +304,14 @@ export const installPythonTemplate = async ({
cwd: path.join(compPath, "engines", "python", engine),
});
const addOnDependencies = dataSources
.map((ds) => getAdditionalDependencies(modelConfig, vectorDb, ds, tools))
.flat();
console.log("Adding additional dependencies");
const addOnDependencies = getAdditionalDependencies(
modelConfig,
vectorDb,
dataSources,
tools,
);
if (observability === "opentelemetry") {
addOnDependencies.push({
+27 -2
View File
@@ -5,12 +5,18 @@ import yaml from "yaml";
import { makeDir } from "./make-dir";
import { TemplateFramework } from "./types";
export enum ToolType {
LLAMAHUB = "llamahub",
LOCAL = "local",
}
export type Tool = {
display: string;
name: string;
config?: Record<string, any>;
dependencies?: ToolDependencies[];
supportedFrameworks?: Array<TemplateFramework>;
type: ToolType;
};
export type ToolDependencies = {
@@ -35,6 +41,7 @@ export const supportedTools: Tool[] = [
},
],
supportedFrameworks: ["fastapi"],
type: ToolType.LLAMAHUB,
},
{
display: "Wikipedia",
@@ -46,6 +53,14 @@ export const supportedTools: Tool[] = [
},
],
supportedFrameworks: ["fastapi", "express", "nextjs"],
type: ToolType.LLAMAHUB,
},
{
display: "Weather",
name: "weather",
dependencies: [],
supportedFrameworks: ["fastapi", "express", "nextjs"],
type: ToolType.LOCAL,
},
];
@@ -90,9 +105,19 @@ export const writeToolsConfig = async (
type: ConfigFileType = ConfigFileType.YAML,
) => {
if (tools.length === 0) return; // no tools selected, no config need
const configContent: Record<string, any> = {};
const configContent: {
[key in ToolType]: Record<string, any>;
} = {
local: {},
llamahub: {},
};
tools.forEach((tool) => {
configContent[tool.name] = tool.config ?? {};
if (tool.type === ToolType.LLAMAHUB) {
configContent.llamahub[tool.name] = tool.config ?? {};
}
if (tool.type === ToolType.LOCAL) {
configContent.local[tool.name] = tool.config ?? {};
}
});
const configPath = path.join(root, "config");
await makeDir(configPath);
+2 -1
View File
@@ -1,13 +1,14 @@
import { PackageManager } from "../helpers/get-pkg-manager";
import { Tool } from "./tools";
export type ModelProvider = "openai" | "ollama";
export type ModelProvider = "openai" | "ollama" | "anthropic" | "gemini";
export type ModelConfig = {
provider: ModelProvider;
apiKey?: string;
model: string;
embeddingModel: string;
dimensions: number;
isConfigured(): boolean;
};
export type TemplateType = "streaming" | "community" | "llamapack";
export type TemplateFramework = "nextjs" | "express" | "fastapi";
+1 -1
View File
@@ -105,7 +105,7 @@ export const installTSTemplate = async ({
const enginePath = path.join(root, relativeEngineDestPath, "engine");
// copy vector db component
console.log("\nUsing vector DB:", vectorDb, "\n");
console.log("\nUsing vector DB:", vectorDb ?? "none", "\n");
await copy("**", enginePath, {
parents: true,
cwd: path.join(compPath, "vectordbs", "typescript", vectorDb ?? "none"),
+1 -1
View File
@@ -1,6 +1,6 @@
{
"name": "create-llama",
"version": "0.1.0",
"version": "0.1.2",
"description": "Create LlamaIndex-powered apps with one command",
"keywords": [
"rag",
+4 -5
View File
@@ -14,7 +14,7 @@ import { COMMUNITY_OWNER, COMMUNITY_REPO } from "./helpers/constant";
import { EXAMPLE_FILE } from "./helpers/datasources";
import { templatesDir } from "./helpers/dir";
import { getAvailableLlamapackOptions } from "./helpers/llama-pack";
import { askModelConfig, isModelConfigured } from "./helpers/providers";
import { askModelConfig } from "./helpers/providers";
import { getProjectOptions } from "./helpers/repo";
import { supportedTools, toolsRequireConfig } from "./helpers/tools";
@@ -257,7 +257,8 @@ export const askQuestions = async (
},
];
const modelConfigured = isModelConfigured(program.modelConfig);
const modelConfigured =
!program.llamapack && program.modelConfig.isConfigured();
// If using LlamaParse, require LlamaCloud API key
const llamaCloudKeyConfigured = program.useLlamaParse
? program.llamaCloudKey || process.env["LLAMA_CLOUD_API_KEY"]
@@ -268,8 +269,7 @@ export const askQuestions = async (
!hasVectorDb &&
modelConfigured &&
llamaCloudKeyConfigured &&
!toolsRequireConfig(program.tools) &&
!program.llamapack
!toolsRequireConfig(program.tools)
) {
actionChoices.push({
title:
@@ -398,7 +398,6 @@ export const askQuestions = async (
if (program.framework === "express" || program.framework === "fastapi") {
// if a backend-only framework is selected, ask whether we should create a frontend
// (only for streaming backends)
if (program.frontend === undefined) {
if (ciInfo.isCI) {
program.frontend = getPrefOrDefault("frontend");
@@ -1,35 +0,0 @@
import os
import yaml
import importlib
from llama_index.core.tools.tool_spec.base import BaseToolSpec
from llama_index.core.tools.function_tool import FunctionTool
class ToolFactory:
@staticmethod
def create_tool(tool_name: str, **kwargs) -> list[FunctionTool]:
try:
tool_package, tool_cls_name = tool_name.split(".")
module_name = f"llama_index.tools.{tool_package}"
module = importlib.import_module(module_name)
tool_class = getattr(module, tool_cls_name)
tool_spec: BaseToolSpec = tool_class(**kwargs)
return tool_spec.to_tool_list()
except (ImportError, AttributeError) as e:
raise ValueError(f"Unsupported tool: {tool_name}") from e
except TypeError as e:
raise ValueError(
f"Could not create tool: {tool_name}. With config: {kwargs}"
) from e
@staticmethod
def from_env() -> list[FunctionTool]:
tools = []
if os.path.exists("config/tools.yaml"):
with open("config/tools.yaml", "r") as f:
tool_configs = yaml.safe_load(f)
for name, config in tool_configs.items():
tools += ToolFactory.create_tool(name, **config)
return tools
@@ -0,0 +1,56 @@
import os
import yaml
import importlib
from llama_index.core.tools.tool_spec.base import BaseToolSpec
from llama_index.core.tools.function_tool import FunctionTool
class ToolType:
LLAMAHUB = "llamahub"
LOCAL = "local"
class ToolFactory:
TOOL_SOURCE_PACKAGE_MAP = {
ToolType.LLAMAHUB: "llama_index.tools",
ToolType.LOCAL: "app.engine.tools",
}
@staticmethod
def load_tools(tool_type: str, tool_name: str, config: dict) -> list[FunctionTool]:
source_package = ToolFactory.TOOL_SOURCE_PACKAGE_MAP[tool_type]
try:
if "ToolSpec" in tool_name:
tool_package, tool_cls_name = tool_name.split(".")
module_name = f"{source_package}.{tool_package}"
module = importlib.import_module(module_name)
tool_class = getattr(module, tool_cls_name)
tool_spec: BaseToolSpec = tool_class(**config)
return tool_spec.to_tool_list()
else:
module = importlib.import_module(f"{source_package}.{tool_name}")
tools = getattr(module, "tools")
if not all(isinstance(tool, FunctionTool) for tool in tools):
raise ValueError(
f"The module {module} does not contain valid tools"
)
return tools
except ImportError as e:
raise ValueError(f"Failed to import tool {tool_name}: {e}")
except AttributeError as e:
raise ValueError(f"Failed to load tool {tool_name}: {e}")
@staticmethod
def from_env() -> list[FunctionTool]:
tools = []
if os.path.exists("config/tools.yaml"):
with open("config/tools.yaml", "r") as f:
tool_configs = yaml.safe_load(f)
for tool_type, config_entries in tool_configs.items():
for tool_name, config in config_entries.items():
tools.extend(
ToolFactory.load_tools(tool_type, tool_name, config)
)
return tools
@@ -0,0 +1,72 @@
"""Open Meteo weather map tool spec."""
import logging
import requests
import pytz
from llama_index.core.tools import FunctionTool
logger = logging.getLogger(__name__)
class OpenMeteoWeather:
geo_api = "https://geocoding-api.open-meteo.com/v1"
weather_api = "https://api.open-meteo.com/v1"
@classmethod
def _get_geo_location(cls, location: str) -> dict:
"""Get geo location from location name."""
params = {"name": location, "count": 10, "language": "en", "format": "json"}
response = requests.get(f"{cls.geo_api}/search", params=params)
if response.status_code != 200:
raise Exception(f"Failed to fetch geo location: {response.status_code}")
else:
data = response.json()
result = data["results"][0]
geo_location = {
"id": result["id"],
"name": result["name"],
"latitude": result["latitude"],
"longitude": result["longitude"],
}
return geo_location
@classmethod
def get_weather_information(cls, location: str) -> dict:
"""Use this function to get the weather of any given location.
Note that the weather code should follow WMO Weather interpretation codes (WW):
0: Clear sky
1, 2, 3: Mainly clear, partly cloudy, and overcast
45, 48: Fog and depositing rime fog
51, 53, 55: Drizzle: Light, moderate, and dense intensity
56, 57: Freezing Drizzle: Light and dense intensity
61, 63, 65: Rain: Slight, moderate and heavy intensity
66, 67: Freezing Rain: Light and heavy intensity
71, 73, 75: Snow fall: Slight, moderate, and heavy intensity
77: Snow grains
80, 81, 82: Rain showers: Slight, moderate, and violent
85, 86: Snow showers slight and heavy
95: Thunderstorm: Slight or moderate
96, 99: Thunderstorm with slight and heavy hail
"""
logger.info(
f"Calling open-meteo api to get weather information of location: {location}"
)
geo_location = cls._get_geo_location(location)
timezone = pytz.timezone("UTC").zone
params = {
"latitude": geo_location["latitude"],
"longitude": geo_location["longitude"],
"current": "temperature_2m,weather_code",
"hourly": "temperature_2m,weather_code",
"daily": "weather_code",
"timezone": timezone,
}
response = requests.get(f"{cls.weather_api}/forecast", params=params)
if response.status_code != 200:
raise Exception(
f"Failed to fetch weather information: {response.status_code}"
)
return response.json()
tools = [FunctionTool.from_defaults(OpenMeteoWeather.get_weather_information)]
@@ -1,12 +1,13 @@
import { BaseTool, OpenAIAgent, QueryEngineTool } from "llamaindex";
import { BaseToolWithCall, OpenAIAgent, QueryEngineTool } from "llamaindex";
import { ToolsFactory } from "llamaindex/tools/ToolsFactory";
import fs from "node:fs/promises";
import path from "node:path";
import { getDataSource } from "./index";
import { STORAGE_CACHE_DIR } from "./shared";
import { createLocalTools } from "./tools";
export async function createChatEngine() {
let tools: BaseTool[] = [];
const tools: BaseToolWithCall[] = [];
// Add a query engine tool if we have a data source
// Delete this code if you don't have a data source
@@ -28,7 +29,14 @@ export async function createChatEngine() {
const config = JSON.parse(
await fs.readFile(path.join("config", "tools.json"), "utf8"),
);
tools = tools.concat(await ToolsFactory.createTools(config));
// add local tools from the 'tools' folder (if configured)
const localTools = createLocalTools(config.local);
tools.push(...localTools);
// add tools from LlamaIndexTS (if configured)
const llamaTools = await ToolsFactory.createTools(config.llamahub);
tools.push(...llamaTools);
} catch {}
return new OpenAIAgent({
@@ -0,0 +1,26 @@
import { BaseToolWithCall } from "llamaindex";
import { WeatherTool, WeatherToolParams } from "./weather";
type ToolCreator = (config: unknown) => BaseToolWithCall;
const toolFactory: Record<string, ToolCreator> = {
weather: (config: unknown) => {
return new WeatherTool(config as WeatherToolParams);
},
};
export function createLocalTools(
localConfig: Record<string, unknown>,
): BaseToolWithCall[] {
const tools: BaseToolWithCall[] = [];
Object.keys(localConfig).forEach((key) => {
if (key in toolFactory) {
const toolConfig = localConfig[key];
const tool = toolFactory[key](toolConfig);
tools.push(tool);
}
});
return tools;
}
@@ -0,0 +1,81 @@
import type { JSONSchemaType } from "ajv";
import { BaseTool, ToolMetadata } from "llamaindex";
interface GeoLocation {
id: string;
name: string;
latitude: number;
longitude: number;
}
export type WeatherParameter = {
location: string;
};
export type WeatherToolParams = {
metadata?: ToolMetadata<JSONSchemaType<WeatherParameter>>;
};
const DEFAULT_META_DATA: ToolMetadata<JSONSchemaType<WeatherParameter>> = {
name: "get_weather_information",
description: `
Use this function to get the weather of any given location.
Note that the weather code should follow WMO Weather interpretation codes (WW):
0: Clear sky
1, 2, 3: Mainly clear, partly cloudy, and overcast
45, 48: Fog and depositing rime fog
51, 53, 55: Drizzle: Light, moderate, and dense intensity
56, 57: Freezing Drizzle: Light and dense intensity
61, 63, 65: Rain: Slight, moderate and heavy intensity
66, 67: Freezing Rain: Light and heavy intensity
71, 73, 75: Snow fall: Slight, moderate, and heavy intensity
77: Snow grains
80, 81, 82: Rain showers: Slight, moderate, and violent
85, 86: Snow showers slight and heavy
95: Thunderstorm: Slight or moderate
96, 99: Thunderstorm with slight and heavy hail
`,
parameters: {
type: "object",
properties: {
location: {
type: "string",
description: "The location to get the weather information",
},
},
required: ["location"],
},
};
export class WeatherTool implements BaseTool<WeatherParameter> {
metadata: ToolMetadata<JSONSchemaType<WeatherParameter>>;
private getGeoLocation = async (location: string): Promise<GeoLocation> => {
const apiUrl = `https://geocoding-api.open-meteo.com/v1/search?name=${location}&count=10&language=en&format=json`;
const response = await fetch(apiUrl);
const data = await response.json();
const { id, name, latitude, longitude } = data.results[0];
return { id, name, latitude, longitude };
};
private getWeatherByLocation = async (location: string) => {
console.log(
"Calling open-meteo api to get weather information of location:",
location,
);
const { latitude, longitude } = await this.getGeoLocation(location);
const timezone = Intl.DateTimeFormat().resolvedOptions().timeZone;
const apiUrl = `https://api.open-meteo.com/v1/forecast?latitude=${latitude}&longitude=${longitude}&current=temperature_2m,weather_code&hourly=temperature_2m,weather_code&daily=weather_code&timezone=${timezone}`;
const response = await fetch(apiUrl);
const data = await response.json();
return data;
};
constructor(params?: WeatherToolParams) {
this.metadata = params?.metadata || DEFAULT_META_DATA;
}
async call(input: WeatherParameter) {
return await this.getWeatherByLocation(input.location);
}
}
+4 -1
View File
@@ -27,7 +27,10 @@ def llama_parse_parser():
def get_file_documents(config: FileLoaderConfig):
from llama_index.core.readers import SimpleDirectoryReader
reader = SimpleDirectoryReader(config.data_dir, recursive=True, filename_as_id=True)
reader = SimpleDirectoryReader(
config.data_dir,
recursive=True,
)
if config.use_llama_parse:
parser = llama_parse_parser()
reader.file_extractor = {".pdf": parser}
@@ -0,0 +1,37 @@
from dotenv import load_dotenv
load_dotenv()
import os
import logging
from llama_index.core.storage import StorageContext
from llama_index.core.indices import VectorStoreIndex
from llama_index.vector_stores.astra_db import AstraDBVectorStore
from app.settings import init_settings
from app.engine.loaders import get_documents
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()
def generate_datasource():
init_settings()
logger.info("Creating new index")
documents = get_documents()
store = AstraDBVectorStore(
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_ENDPOINT"],
collection_name=os.environ["ASTRA_DB_COLLECTION"],
embedding_dimension=int(os.environ["EMBEDDING_DIM"]),
)
storage_context = StorageContext.from_defaults(vector_store=store)
VectorStoreIndex.from_documents(
documents,
storage_context=storage_context,
show_progress=True, # this will show you a progress bar as the embeddings are created
)
logger.info(f"Successfully created embeddings in the AstraDB")
if __name__ == "__main__":
generate_datasource()
@@ -1,12 +1,21 @@
import logging
import os
from llama_index.core.indices import VectorStoreIndex
from llama_index.vector_stores.astra_db import AstraDBVectorStore
def get_vector_store():
logger = logging.getLogger("uvicorn")
def get_index():
logger.info("Connecting to index from AstraDB...")
store = AstraDBVectorStore(
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_ENDPOINT"],
collection_name=os.environ["ASTRA_DB_COLLECTION"],
embedding_dimension=int(os.environ["EMBEDDING_DIM"]),
)
return store
index = VectorStoreIndex.from_vector_store(store)
logger.info("Finished connecting to index from AstraDB.")
return index
@@ -0,0 +1,39 @@
from dotenv import load_dotenv
load_dotenv()
import os
import logging
from llama_index.core.storage import StorageContext
from llama_index.core.indices import VectorStoreIndex
from llama_index.vector_stores.milvus import MilvusVectorStore
from app.settings import init_settings
from app.engine.loaders import get_documents
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()
def generate_datasource():
init_settings()
logger.info("Creating new index")
# load the documents and create the index
documents = get_documents()
store = MilvusVectorStore(
uri=os.environ["MILVUS_ADDRESS"],
user=os.getenv("MILVUS_USERNAME"),
password=os.getenv("MILVUS_PASSWORD"),
collection_name=os.getenv("MILVUS_COLLECTION"),
dim=int(os.getenv("EMBEDDING_DIM")),
)
storage_context = StorageContext.from_defaults(vector_store=store)
VectorStoreIndex.from_documents(
documents,
storage_context=storage_context,
show_progress=True, # this will show you a progress bar as the embeddings are created
)
logger.info(f"Successfully created embeddings in the Milvus")
if __name__ == "__main__":
generate_datasource()
@@ -0,0 +1,22 @@
import logging
import os
from llama_index.core.indices import VectorStoreIndex
from llama_index.vector_stores.milvus import MilvusVectorStore
logger = logging.getLogger("uvicorn")
def get_index():
logger.info("Connecting to index from Milvus...")
store = MilvusVectorStore(
uri=os.getenv("MILVUS_ADDRESS"),
user=os.getenv("MILVUS_USERNAME"),
password=os.getenv("MILVUS_PASSWORD"),
collection_name=os.getenv("MILVUS_COLLECTION"),
dim=int(os.getenv("EMBEDDING_DIM")),
)
index = VectorStoreIndex.from_vector_store(store)
logger.info("Finished connecting to index from Milvus.")
return index
@@ -1,13 +0,0 @@
import os
from llama_index.vector_stores.milvus import MilvusVectorStore
def get_vector_store():
store = MilvusVectorStore(
uri=os.environ["MILVUS_ADDRESS"],
user=os.getenv("MILVUS_USERNAME"),
password=os.getenv("MILVUS_PASSWORD"),
collection_name=os.getenv("MILVUS_COLLECTION"),
dim=int(os.getenv("EMBEDDING_DIM")),
)
return store
@@ -0,0 +1,43 @@
from dotenv import load_dotenv
load_dotenv()
import os
import logging
from llama_index.core.storage import StorageContext
from llama_index.core.indices import VectorStoreIndex
from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch
from app.settings import init_settings
from app.engine.loaders import get_documents
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()
def generate_datasource():
init_settings()
logger.info("Creating new index")
# load the documents and create the index
documents = get_documents()
store = MongoDBAtlasVectorSearch(
db_name=os.environ["MONGODB_DATABASE"],
collection_name=os.environ["MONGODB_VECTORS"],
index_name=os.environ["MONGODB_VECTOR_INDEX"],
)
storage_context = StorageContext.from_defaults(vector_store=store)
VectorStoreIndex.from_documents(
documents,
storage_context=storage_context,
show_progress=True, # this will show you a progress bar as the embeddings are created
)
logger.info(
f"Successfully created embeddings in the MongoDB collection {os.environ['MONGODB_VECTORS']}"
)
logger.info(
"""IMPORTANT: You can't query your index yet because you need to create a vector search index in MongoDB's UI now.
See https://github.com/run-llama/mongodb-demo/tree/main?tab=readme-ov-file#create-a-vector-search-index"""
)
if __name__ == "__main__":
generate_datasource()
@@ -0,0 +1,20 @@
import logging
import os
from llama_index.core.indices import VectorStoreIndex
from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch
logger = logging.getLogger("uvicorn")
def get_index():
logger.info("Connecting to index from MongoDB...")
store = MongoDBAtlasVectorSearch(
db_name=os.environ["MONGODB_DATABASE"],
collection_name=os.environ["MONGODB_VECTORS"],
index_name=os.environ["MONGODB_VECTOR_INDEX"],
)
index = VectorStoreIndex.from_vector_store(store)
logger.info("Finished connecting to index from MongoDB.")
return index
@@ -1,11 +0,0 @@
import os
from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch
def get_vector_store():
store = MongoDBAtlasVectorSearch(
db_name=os.environ["MONGODB_DATABASE"],
collection_name=os.environ["MONGODB_VECTORS"],
index_name=os.environ["MONGODB_VECTOR_INDEX"],
)
return store
@@ -0,0 +1 @@
STORAGE_DIR = "storage" # directory to cache the generated index
@@ -0,0 +1,32 @@
from dotenv import load_dotenv
load_dotenv()
import logging
from llama_index.core.indices import (
VectorStoreIndex,
)
from app.engine.constants import STORAGE_DIR
from app.engine.loaders import get_documents
from app.settings import init_settings
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()
def generate_datasource():
init_settings()
logger.info("Creating new index")
# load the documents and create the index
documents = get_documents()
index = VectorStoreIndex.from_documents(
documents,
)
# store it for later
index.storage_context.persist(STORAGE_DIR)
logger.info(f"Finished creating new index. Stored in {STORAGE_DIR}")
if __name__ == "__main__":
generate_datasource()
@@ -0,0 +1,20 @@
import logging
import os
from app.engine.constants import STORAGE_DIR
from llama_index.core.storage import StorageContext
from llama_index.core.indices import load_index_from_storage
logger = logging.getLogger("uvicorn")
def get_index():
# check if storage already exists
if not os.path.exists(STORAGE_DIR):
return None
# load the existing index
logger.info(f"Loading index from {STORAGE_DIR}...")
storage_context = StorageContext.from_defaults(persist_dir=STORAGE_DIR)
index = load_index_from_storage(storage_context)
logger.info(f"Finished loading index from {STORAGE_DIR}")
return index
@@ -1,16 +0,0 @@
import os
from llama_index.core.vector_stores import SimpleVectorStore
from app.constants import STORAGE_DIR
def get_vector_store():
if not os.path.exists(STORAGE_DIR):
# Vector store hasn't been persisted before, create a new one
vector_store = SimpleVectorStore()
else:
# Vector store has already been persisted before at STORAGE_DIR - load it
vector_store = SimpleVectorStore.from_persist_dir(
STORAGE_DIR, namespace="default"
)
return vector_store
@@ -0,0 +1,2 @@
PGVECTOR_SCHEMA = "public"
PGVECTOR_TABLE = "llamaindex_embedding"
@@ -0,0 +1,35 @@
from dotenv import load_dotenv
load_dotenv()
import logging
from llama_index.core.indices import VectorStoreIndex
from llama_index.core.storage import StorageContext
from app.engine.loaders import get_documents
from app.settings import init_settings
from app.engine.utils import init_pg_vector_store_from_env
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()
def generate_datasource():
init_settings()
logger.info("Creating new index")
# load the documents and create the index
documents = get_documents()
store = init_pg_vector_store_from_env()
storage_context = StorageContext.from_defaults(vector_store=store)
VectorStoreIndex.from_documents(
documents,
storage_context=storage_context,
show_progress=True, # this will show you a progress bar as the embeddings are created
)
logger.info(
f"Successfully created embeddings in the PG vector store, schema={store.schema_name} table={store.table_name}"
)
if __name__ == "__main__":
generate_datasource()
@@ -0,0 +1,13 @@
import logging
from llama_index.core.indices.vector_store import VectorStoreIndex
from app.engine.utils import init_pg_vector_store_from_env
logger = logging.getLogger("uvicorn")
def get_index():
logger.info("Connecting to index from PGVector...")
store = init_pg_vector_store_from_env()
index = VectorStoreIndex.from_vector_store(store)
logger.info("Finished connecting to index from PGVector.")
return index
@@ -1,13 +1,10 @@
import os
from llama_index.vector_stores.postgres import PGVectorStore
from urllib.parse import urlparse
STORAGE_DIR = "storage"
PGVECTOR_SCHEMA = "public"
PGVECTOR_TABLE = "llamaindex_embedding"
from app.engine.constants import PGVECTOR_SCHEMA, PGVECTOR_TABLE
def get_vector_store():
def init_pg_vector_store_from_env():
original_conn_string = os.environ.get("PG_CONNECTION_STRING")
if original_conn_string is None or original_conn_string == "":
raise ValueError("PG_CONNECTION_STRING environment variable is not set.")
@@ -27,5 +24,4 @@ def get_vector_store():
async_connection_string=async_conn_string,
schema_name=PGVECTOR_SCHEMA,
table_name=PGVECTOR_TABLE,
embed_dim=int(os.environ.get("EMBEDDING_DIM", 768)),
)
@@ -0,0 +1,39 @@
from dotenv import load_dotenv
load_dotenv()
import os
import logging
from llama_index.core.storage import StorageContext
from llama_index.core.indices import VectorStoreIndex
from llama_index.vector_stores.pinecone import PineconeVectorStore
from app.settings import init_settings
from app.engine.loaders import get_documents
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()
def generate_datasource():
init_settings()
logger.info("Creating new index")
# load the documents and create the index
documents = get_documents()
store = PineconeVectorStore(
api_key=os.environ["PINECONE_API_KEY"],
index_name=os.environ["PINECONE_INDEX_NAME"],
environment=os.environ["PINECONE_ENVIRONMENT"],
)
storage_context = StorageContext.from_defaults(vector_store=store)
VectorStoreIndex.from_documents(
documents,
storage_context=storage_context,
show_progress=True, # this will show you a progress bar as the embeddings are created
)
logger.info(
f"Successfully created embeddings and save to your Pinecone index {os.environ['PINECONE_INDEX_NAME']}"
)
if __name__ == "__main__":
generate_datasource()
@@ -0,0 +1,20 @@
import logging
import os
from llama_index.core.indices import VectorStoreIndex
from llama_index.vector_stores.pinecone import PineconeVectorStore
logger = logging.getLogger("uvicorn")
def get_index():
logger.info("Connecting to index from Pinecone...")
store = PineconeVectorStore(
api_key=os.environ["PINECONE_API_KEY"],
index_name=os.environ["PINECONE_INDEX_NAME"],
environment=os.environ["PINECONE_ENVIRONMENT"],
)
index = VectorStoreIndex.from_vector_store(store)
logger.info("Finished connecting to index from Pinecone.")
return index
@@ -1,11 +0,0 @@
import os
from llama_index.vector_stores.pinecone import PineconeVectorStore
def get_vector_store():
store = PineconeVectorStore(
api_key=os.environ["PINECONE_API_KEY"],
index_name=os.environ["PINECONE_INDEX_NAME"],
environment=os.environ["PINECONE_ENVIRONMENT"],
)
return store
@@ -0,0 +1,37 @@
import logging
import os
from app.engine.loaders import get_documents
from app.settings import init_settings
from dotenv import load_dotenv
from llama_index.core.indices import VectorStoreIndex
from llama_index.core.storage import StorageContext
from llama_index.vector_stores.qdrant import QdrantVectorStore
load_dotenv()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()
def generate_datasource():
init_settings()
logger.info("Creating new index with Qdrant")
# load the documents and create the index
documents = get_documents()
store = QdrantVectorStore(
collection_name=os.getenv("QDRANT_COLLECTION"),
url=os.getenv("QDRANT_URL"),
api_key=os.getenv("QDRANT_API_KEY"),
)
storage_context = StorageContext.from_defaults(vector_store=store)
VectorStoreIndex.from_documents(
documents,
storage_context=storage_context,
show_progress=True, # this will show you a progress bar as the embeddings are created
)
logger.info(
f"Successfully uploaded documents to the {os.getenv('QDRANT_COLLECTION')} collection."
)
if __name__ == "__main__":
generate_datasource()
@@ -0,0 +1,20 @@
import logging
import os
from llama_index.core.indices import VectorStoreIndex
from llama_index.vector_stores.qdrant import QdrantVectorStore
logger = logging.getLogger("uvicorn")
def get_index():
logger.info("Connecting to Qdrant collection..")
store = QdrantVectorStore(
collection_name=os.getenv("QDRANT_COLLECTION"),
url=os.getenv("QDRANT_URL"),
api_key=os.getenv("QDRANT_API_KEY"),
)
index = VectorStoreIndex.from_vector_store(store)
logger.info("Finished connecting to Qdrant collection.")
return index
@@ -1,11 +0,0 @@
import os
from llama_index.vector_stores.qdrant import QdrantVectorStore
def get_vector_store():
store = QdrantVectorStore(
collection_name=os.getenv("QDRANT_COLLECTION"),
url=os.getenv("QDRANT_URL"),
api_key=os.getenv("QDRANT_API_KEY"),
)
return store
@@ -1,4 +1,5 @@
import { VectorStoreIndex, storageContextFromDefaults } from "llamaindex";
import { VectorStoreIndex } from "llamaindex";
import { storageContextFromDefaults } from "llamaindex/storage/StorageContext";
import * as dotenv from "dotenv";
@@ -1,8 +1,5 @@
import {
SimpleDocumentStore,
storageContextFromDefaults,
VectorStoreIndex,
} from "llamaindex";
import { SimpleDocumentStore, VectorStoreIndex } from "llamaindex";
import { storageContextFromDefaults } from "llamaindex/storage/StorageContext";
import { STORAGE_CACHE_DIR } from "./shared";
export async function getDataSource() {
@@ -14,7 +14,9 @@
"cors": "^2.8.5",
"dotenv": "^16.3.1",
"express": "^4.18.2",
"llamaindex": "0.2.10"
"llamaindex": "0.3.7",
"pdf2json": "3.0.5",
"ajv": "^8.12.0"
},
"devDependencies": {
"@types/cors": "^2.8.16",
@@ -3,7 +3,7 @@ import { Request, Response } from "express";
import { ChatMessage, MessageContent, Settings } from "llamaindex";
import { createChatEngine } from "./engine/chat";
import { LlamaIndexStream } from "./llamaindex-stream";
import { appendEventData } from "./stream-helper";
import { createCallbackManager } from "./stream-helper";
const convertMessageContent = (
textMessage: string,
@@ -45,46 +45,28 @@ export const chat = async (req: Request, res: Response) => {
// Init Vercel AI StreamData
const vercelStreamData = new StreamData();
appendEventData(
vercelStreamData,
`Retrieving context for query: '${userMessage.content}'`,
);
// Setup callback for streaming data before chatting
Settings.callbackManager.on("retrieve", (data) => {
const { nodes } = data.detail;
appendEventData(
vercelStreamData,
`Retrieved ${nodes.length} sources to use as context for the query`,
);
});
// Setup callbacks
const callbackManager = createCallbackManager(vercelStreamData);
// Calling LlamaIndex's ChatEngine to get a streamed response
const response = await chatEngine.chat({
message: userMessageContent,
chatHistory: messages as ChatMessage[],
stream: true,
const response = await Settings.withCallbackManager(callbackManager, () => {
return chatEngine.chat({
message: userMessageContent,
chatHistory: messages as ChatMessage[],
stream: true,
});
});
// Return a stream, which can be consumed by the Vercel/AI client
const { stream } = LlamaIndexStream(response, vercelStreamData, {
const stream = LlamaIndexStream(response, vercelStreamData, {
parserOptions: {
image_url: data?.imageUrl,
},
});
// Pipe LlamaIndexStream to response
const processedStream = stream.pipeThrough(vercelStreamData.stream);
return streamToResponse(processedStream, res, {
headers: {
// response MUST have the `X-Experimental-Stream-Data: 'true'` header
// so that the client uses the correct parsing logic, see
// https://sdk.vercel.ai/docs/api-reference/stream-data#on-the-server
"X-Experimental-Stream-Data": "true",
"Content-Type": "text/plain; charset=utf-8",
"Access-Control-Expose-Headers": "X-Experimental-Stream-Data",
},
});
return streamToResponse(processedStream, res);
} catch (error) {
console.error("[LlamaIndex]", error);
return res.status(500).json({
@@ -1,10 +1,17 @@
import {
Ollama,
OllamaEmbedding,
Anthropic,
GEMINI_EMBEDDING_MODEL,
GEMINI_MODEL,
Gemini,
GeminiEmbedding,
OpenAI,
OpenAIEmbedding,
Settings,
} from "llamaindex";
import { HuggingFaceEmbedding } from "llamaindex/embeddings/HuggingFaceEmbedding";
import { OllamaEmbedding } from "llamaindex/embeddings/OllamaEmbedding";
import { ALL_AVAILABLE_ANTHROPIC_MODELS } from "llamaindex/llm/anthropic";
import { Ollama } from "llamaindex/llm/ollama";
const CHUNK_SIZE = 512;
const CHUNK_OVERLAP = 20;
@@ -12,10 +19,21 @@ const CHUNK_OVERLAP = 20;
export const initSettings = async () => {
// HINT: you can delete the initialization code for unused model providers
console.log(`Using '${process.env.MODEL_PROVIDER}' model provider`);
if (!process.env.MODEL || !process.env.EMBEDDING_MODEL) {
throw new Error("'MODEL' and 'EMBEDDING_MODEL' env variables must be set.");
}
switch (process.env.MODEL_PROVIDER) {
case "ollama":
initOllama();
break;
case "anthropic":
initAnthropic();
break;
case "gemini":
initGemini();
break;
default:
initOpenAI();
break;
@@ -38,11 +56,6 @@ function initOpenAI() {
}
function initOllama() {
if (!process.env.MODEL || !process.env.EMBEDDING_MODEL) {
throw new Error(
"Using Ollama as model provider, 'MODEL' and 'EMBEDDING_MODEL' env variables must be set.",
);
}
Settings.llm = new Ollama({
model: process.env.MODEL ?? "",
});
@@ -50,3 +63,25 @@ function initOllama() {
model: process.env.EMBEDDING_MODEL ?? "",
});
}
function initAnthropic() {
const embedModelMap: Record<string, string> = {
"all-MiniLM-L6-v2": "Xenova/all-MiniLM-L6-v2",
"all-mpnet-base-v2": "Xenova/all-mpnet-base-v2",
};
Settings.llm = new Anthropic({
model: process.env.MODEL as keyof typeof ALL_AVAILABLE_ANTHROPIC_MODELS,
});
Settings.embedModel = new HuggingFaceEmbedding({
modelType: embedModelMap[process.env.EMBEDDING_MODEL!],
});
}
function initGemini() {
Settings.llm = new Gemini({
model: process.env.MODEL as GEMINI_MODEL,
});
Settings.embedModel = new GeminiEmbedding({
model: process.env.EMBEDDING_MODEL as GEMINI_EMBEDDING_MODEL,
});
}
@@ -9,16 +9,22 @@ import {
Metadata,
NodeWithScore,
Response,
StreamingAgentChatResponse,
ToolCallLLMMessageOptions,
} from "llamaindex";
import { AgentStreamChatResponse } from "llamaindex/agent/base";
import { appendImageData, appendSourceData } from "./stream-helper";
type LlamaIndexResponse =
| AgentStreamChatResponse<ToolCallLLMMessageOptions>
| Response;
type ParserOptions = {
image_url?: string;
};
function createParser(
res: AsyncIterable<Response>,
res: AsyncIterable<LlamaIndexResponse>,
data: StreamData,
opts?: ParserOptions,
) {
@@ -33,17 +39,27 @@ function createParser(
async pull(controller): Promise<void> {
const { value, done } = await it.next();
if (done) {
appendSourceData(data, sourceNodes);
if (sourceNodes) {
appendSourceData(data, sourceNodes);
}
controller.close();
data.close();
return;
}
if (!sourceNodes) {
// get source nodes from the first response
sourceNodes = value.sourceNodes;
let delta;
if (value instanceof Response) {
// handle Response type
if (value.sourceNodes) {
// get source nodes from the first response
sourceNodes = value.sourceNodes;
}
delta = value.response ?? "";
} else {
// handle other types
delta = value.response.delta;
}
const text = trimStartOfStream(value.response ?? "");
const text = trimStartOfStream(delta ?? "");
if (text) {
controller.enqueue(text);
}
@@ -52,21 +68,14 @@ function createParser(
}
export function LlamaIndexStream(
response: StreamingAgentChatResponse | AsyncIterable<Response>,
response: AsyncIterable<LlamaIndexResponse>,
data: StreamData,
opts?: {
callbacks?: AIStreamCallbacksAndOptions;
parserOptions?: ParserOptions;
},
): { stream: ReadableStream; data: StreamData } {
const res =
response instanceof StreamingAgentChatResponse
? response.response
: response;
return {
stream: createParser(res, data, opts?.parserOptions)
.pipeThrough(createCallbacksTransformer(opts?.callbacks))
.pipeThrough(createStreamDataTransformer()),
data,
};
): ReadableStream<Uint8Array> {
return createParser(response, data, opts?.parserOptions)
.pipeThrough(createCallbacksTransformer(opts?.callbacks))
.pipeThrough(createStreamDataTransformer());
}
@@ -1,5 +1,11 @@
import { StreamData } from "ai";
import { Metadata, NodeWithScore } from "llamaindex";
import {
CallbackManager,
Metadata,
NodeWithScore,
ToolCall,
ToolOutput,
} from "llamaindex";
export function appendImageData(data: StreamData, imageUrl?: string) {
if (!imageUrl) return;
@@ -37,3 +43,55 @@ export function appendEventData(data: StreamData, title?: string) {
},
});
}
export function appendToolData(
data: StreamData,
toolCall: ToolCall,
toolOutput: ToolOutput,
) {
data.appendMessageAnnotation({
type: "tools",
data: {
toolCall: {
id: toolCall.id,
name: toolCall.name,
input: toolCall.input,
},
toolOutput: {
output: toolOutput.output,
isError: toolOutput.isError,
},
},
});
}
export function createCallbackManager(stream: StreamData) {
const callbackManager = new CallbackManager();
callbackManager.on("retrieve", (data) => {
const { nodes, query } = data.detail;
appendEventData(stream, `Retrieving context for query: '${query}'`);
appendEventData(
stream,
`Retrieved ${nodes.length} sources to use as context for the query`,
);
});
callbackManager.on("llm-tool-call", (event) => {
const { name, input } = event.detail.payload.toolCall;
const inputString = Object.entries(input)
.map(([key, value]) => `${key}: ${value}`)
.join(", ");
appendEventData(
stream,
`Using tool: '${name}' with inputs: '${inputString}'`,
);
});
callbackManager.on("llm-tool-result", (event) => {
const { toolCall, toolResult } = event.detail.payload;
appendToolData(stream, toolCall, toolResult);
});
return callbackManager;
}
@@ -1,10 +1,7 @@
from pydantic import BaseModel
from typing import List, Any, Optional, Dict, Tuple
from fastapi import APIRouter, Depends, HTTPException, Request, status
from llama_index.core.chat_engine.types import (
BaseChatEngine,
StreamingAgentChatResponse,
)
from llama_index.core.chat_engine.types import BaseChatEngine
from llama_index.core.schema import NodeWithScore
from llama_index.core.llms import ChatMessage, MessageRole
from app.engine import get_chat_engine
@@ -109,12 +106,9 @@ async def chat(
# Yield the events from the event handler
async def _event_generator():
async for event in event_handler.async_event_gen():
yield VercelStreamResponse.convert_data(
{
"type": "events",
"data": {"title": event.get_title()},
}
)
event_response = event.to_response()
if event_response is not None:
yield VercelStreamResponse.convert_data(event_response)
combine = stream.merge(_text_generator(), _event_generator())
async with combine.stream() as streamer:
@@ -1,8 +1,9 @@
import json
import asyncio
from typing import AsyncGenerator, Dict, Any, List, Optional
from llama_index.core.callbacks.base import BaseCallbackHandler
from llama_index.core.callbacks.schema import CBEventType
from llama_index.core.tools.types import ToolOutput
from pydantic import BaseModel
@@ -11,19 +12,73 @@ class CallbackEvent(BaseModel):
payload: Optional[Dict[str, Any]] = None
event_id: str = ""
def get_title(self) -> str | None:
# Return as None for the unhandled event types
# to avoid showing them in the UI
def get_retrieval_message(self) -> dict | None:
if self.payload:
nodes = self.payload.get("nodes")
if nodes:
msg = f"Retrieved {len(nodes)} sources to use as context for the query"
else:
msg = f"Retrieving context for query: '{self.payload.get('query_str')}'"
return {
"type": "events",
"data": {"title": msg},
}
else:
return None
def get_tool_message(self) -> dict | None:
func_call_args = self.payload.get("function_call")
if func_call_args is not None and "tool" in self.payload:
tool = self.payload.get("tool")
return {
"type": "events",
"data": {
"title": f"Calling tool: {tool.name} with inputs: {func_call_args}",
},
}
def _is_output_serializable(self, output: Any) -> bool:
try:
json.dumps(output)
return True
except TypeError:
return False
def get_agent_tool_response(self) -> dict | None:
response = self.payload.get("response")
if response is not None:
sources = response.sources
for source in sources:
# Return the tool response here to include the toolCall information
if isinstance(source, ToolOutput):
if self._is_output_serializable(source.raw_output):
output = source.raw_output
else:
output = source.content
return {
"type": "tools",
"data": {
"toolOutput": {
"output": output,
"isError": source.is_error,
},
"toolCall": {
"id": None, # There is no tool id in the ToolOutput
"name": source.tool_name,
"input": source.raw_input,
},
},
}
def to_response(self):
match self.event_type:
case "retrieve":
if self.payload:
nodes = self.payload.get("nodes")
if nodes:
return f"Retrieved {len(nodes)} sources to use as context for the query"
else:
return f"Retrieving context for query: '{self.payload.get('query_str')}'"
else:
return None
return self.get_retrieval_message()
case "function_call":
return self.get_tool_message()
case "agent_step":
return self.get_agent_tool_response()
case _:
return None
@@ -54,7 +109,7 @@ class EventCallbackHandler(BaseCallbackHandler):
**kwargs: Any,
) -> str:
event = CallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
if event.get_title() is not None:
if event.to_response() is not None:
self._aqueue.put_nowait(event)
def on_event_end(
@@ -65,7 +120,7 @@ class EventCallbackHandler(BaseCallbackHandler):
**kwargs: Any,
) -> None:
event = CallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
if event.get_title() is not None:
if event.to_response() is not None:
self._aqueue.put_nowait(event)
def start_trace(self, trace_id: Optional[str] = None) -> None:
@@ -1 +0,0 @@
STORAGE_DIR = "storage" # directory to save the stores to (document store and if used, the `SimpleVectorStore`)
@@ -1,96 +0,0 @@
from dotenv import load_dotenv
load_dotenv()
import os
import logging
from llama_index.core.settings import Settings
from llama_index.core.ingestion import IngestionPipeline
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.vector_stores import SimpleVectorStore
from llama_index.core.storage.docstore import SimpleDocumentStore
from llama_index.core.storage import StorageContext
from llama_index.core import VectorStoreIndex
from app.constants import STORAGE_DIR
from app.settings import init_settings
from app.engine.loaders import get_documents
from app.engine.vectordb import get_vector_store
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()
def get_doc_store():
if not os.path.exists(STORAGE_DIR):
docstore = SimpleDocumentStore()
return docstore
else:
return SimpleDocumentStore.from_persist_dir(STORAGE_DIR)
def run_ingestion_pipeline(docstore, vector_store, documents):
# Create ingestion pipeline
ingestion_pipeline = IngestionPipeline(
transformations=[
SentenceSplitter(
chunk_size=Settings.chunk_size,
chunk_overlap=Settings.chunk_overlap,
),
Settings.embed_model,
],
docstore=docstore,
docstore_strategy="upserts_and_delete",
)
# llama_index having an typing issue when passing vector_store to IngestionPipeline
# so we need to set it manually after initialization
ingestion_pipeline.vector_store = vector_store
# Run the ingestion pipeline and store the results
nodes = ingestion_pipeline.run(show_progress=True, documents=documents)
return nodes
def persist_storage(docstore, vector_store, nodes):
storage_context = StorageContext.from_defaults(
docstore=docstore,
vector_store=vector_store,
)
# SimpleVectorStore does not include index by default
# so we need to create the index manually
# can be removed if using other vector store
if isinstance(vector_store, SimpleVectorStore):
VectorStoreIndex(
nodes=nodes,
storage_context=storage_context,
store_nodes_override=True, # Need enable this to store the nodes and index's id
)
storage_context.persist(STORAGE_DIR)
def generate_datasource():
init_settings()
logger.info("Generate index for the provided data")
# Get the stores and documents or create new ones
documents = get_documents()
docstore = get_doc_store()
vector_store = get_vector_store()
# Run the ingestion pipeline
nodes = run_ingestion_pipeline(
docstore=docstore,
vector_store=vector_store,
documents=documents,
)
# Build the index and persist storage
persist_storage(docstore, vector_store, nodes)
logger.info("Finished generating the index")
if __name__ == "__main__":
generate_datasource()
@@ -1,27 +0,0 @@
import logging
from llama_index.core import load_index_from_storage
from llama_index.core.storage import StorageContext
from llama_index.core.indices.vector_store import VectorStoreIndex
from llama_index.core.vector_stores.simple import SimpleVectorStore
from app.constants import STORAGE_DIR
from app.engine.vectordb import get_vector_store
logger = logging.getLogger("uvicorn")
def get_index():
logger.info("Loading the index...")
store = get_vector_store()
# If the store is a SimpleVectorStore, we need to load the index from the storage
if isinstance(store, SimpleVectorStore):
index = load_index_from_storage(
StorageContext.from_defaults(
vector_store=store,
persist_dir=STORAGE_DIR,
)
)
else:
index = VectorStoreIndex.from_vector_store(store)
logger.info("Loaded index successfully.")
return index
@@ -9,6 +9,10 @@ def init_settings():
init_openai()
elif model_provider == "ollama":
init_ollama()
elif model_provider == "anthropic":
init_anthropic()
elif model_provider == "gemini":
init_gemini()
else:
raise ValueError(f"Invalid model provider: {model_provider}")
Settings.chunk_size = int(os.getenv("CHUNK_SIZE", "1024"))
@@ -42,3 +46,47 @@ def init_openai():
"dimensions": int(dimensions) if dimensions is not None else None,
}
Settings.embed_model = OpenAIEmbedding(**config)
def init_anthropic():
from llama_index.llms.anthropic import Anthropic
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
model_map: Dict[str, str] = {
"claude-3-opus": "claude-3-opus-20240229",
"claude-3-sonnet": "claude-3-sonnet-20240229",
"claude-3-haiku": "claude-3-haiku-20240307",
"claude-2.1": "claude-2.1",
"claude-instant-1.2": "claude-instant-1.2",
}
embed_model_map: Dict[str, str] = {
"all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2",
"all-mpnet-base-v2": "sentence-transformers/all-mpnet-base-v2",
}
Settings.llm = Anthropic(model=model_map[os.getenv("MODEL")])
Settings.embed_model = HuggingFaceEmbedding(
model_name=embed_model_map[os.getenv("EMBEDDING_MODEL")]
)
def init_gemini():
from llama_index.llms.gemini import Gemini
from llama_index.embeddings.gemini import GeminiEmbedding
model_map: Dict[str, str] = {
"gemini-1.5-pro-latest": "models/gemini-1.5-pro-latest",
"gemini-pro": "models/gemini-pro",
"gemini-pro-vision": "models/gemini-pro-vision",
}
embed_model_map: Dict[str, str] = {
"embedding-001": "models/embedding-001",
"text-embedding-004": "models/text-embedding-004",
}
Settings.llm = Gemini(model=model_map[os.getenv("MODEL")])
Settings.embed_model = GeminiEmbedding(
model_name=embed_model_map[os.getenv("EMBEDDING_MODEL")]
)
@@ -1,10 +1,17 @@
import {
Ollama,
OllamaEmbedding,
Anthropic,
GEMINI_EMBEDDING_MODEL,
GEMINI_MODEL,
Gemini,
GeminiEmbedding,
OpenAI,
OpenAIEmbedding,
Settings,
} from "llamaindex";
import { HuggingFaceEmbedding } from "llamaindex/embeddings/HuggingFaceEmbedding";
import { OllamaEmbedding } from "llamaindex/embeddings/OllamaEmbedding";
import { ALL_AVAILABLE_ANTHROPIC_MODELS } from "llamaindex/llm/anthropic";
import { Ollama } from "llamaindex/llm/ollama";
const CHUNK_SIZE = 512;
const CHUNK_OVERLAP = 20;
@@ -12,10 +19,21 @@ const CHUNK_OVERLAP = 20;
export const initSettings = async () => {
// HINT: you can delete the initialization code for unused model providers
console.log(`Using '${process.env.MODEL_PROVIDER}' model provider`);
if (!process.env.MODEL || !process.env.EMBEDDING_MODEL) {
throw new Error("'MODEL' and 'EMBEDDING_MODEL' env variables must be set.");
}
switch (process.env.MODEL_PROVIDER) {
case "ollama":
initOllama();
break;
case "anthropic":
initAnthropic();
break;
case "gemini":
initGemini();
break;
default:
initOpenAI();
break;
@@ -38,11 +56,6 @@ function initOpenAI() {
}
function initOllama() {
if (!process.env.MODEL || !process.env.EMBEDDING_MODEL) {
throw new Error(
"Using Ollama as model provider, 'MODEL' and 'EMBEDDING_MODEL' env variables must be set.",
);
}
Settings.llm = new Ollama({
model: process.env.MODEL ?? "",
});
@@ -50,3 +63,25 @@ function initOllama() {
model: process.env.EMBEDDING_MODEL ?? "",
});
}
function initAnthropic() {
const embedModelMap: Record<string, string> = {
"all-MiniLM-L6-v2": "Xenova/all-MiniLM-L6-v2",
"all-mpnet-base-v2": "Xenova/all-mpnet-base-v2",
};
Settings.llm = new Anthropic({
model: process.env.MODEL as keyof typeof ALL_AVAILABLE_ANTHROPIC_MODELS,
});
Settings.embedModel = new HuggingFaceEmbedding({
modelType: embedModelMap[process.env.EMBEDDING_MODEL!],
});
}
function initGemini() {
Settings.llm = new Gemini({
model: process.env.MODEL as GEMINI_MODEL,
});
Settings.embedModel = new GeminiEmbedding({
model: process.env.EMBEDDING_MODEL as GEMINI_EMBEDDING_MODEL,
});
}
@@ -9,16 +9,22 @@ import {
Metadata,
NodeWithScore,
Response,
StreamingAgentChatResponse,
ToolCallLLMMessageOptions,
} from "llamaindex";
import { AgentStreamChatResponse } from "llamaindex/agent/base";
import { appendImageData, appendSourceData } from "./stream-helper";
type LlamaIndexResponse =
| AgentStreamChatResponse<ToolCallLLMMessageOptions>
| Response;
type ParserOptions = {
image_url?: string;
};
function createParser(
res: AsyncIterable<Response>,
res: AsyncIterable<LlamaIndexResponse>,
data: StreamData,
opts?: ParserOptions,
) {
@@ -33,17 +39,27 @@ function createParser(
async pull(controller): Promise<void> {
const { value, done } = await it.next();
if (done) {
appendSourceData(data, sourceNodes);
if (sourceNodes) {
appendSourceData(data, sourceNodes);
}
controller.close();
data.close();
return;
}
if (!sourceNodes) {
// get source nodes from the first response
sourceNodes = value.sourceNodes;
let delta;
if (value instanceof Response) {
// handle Response type
if (value.sourceNodes) {
// get source nodes from the first response
sourceNodes = value.sourceNodes;
}
delta = value.response ?? "";
} else {
// handle other types
delta = value.response.delta;
}
const text = trimStartOfStream(value.response ?? "");
const text = trimStartOfStream(delta ?? "");
if (text) {
controller.enqueue(text);
}
@@ -52,21 +68,14 @@ function createParser(
}
export function LlamaIndexStream(
response: StreamingAgentChatResponse | AsyncIterable<Response>,
response: AsyncIterable<LlamaIndexResponse>,
data: StreamData,
opts?: {
callbacks?: AIStreamCallbacksAndOptions;
parserOptions?: ParserOptions;
},
): { stream: ReadableStream; data: StreamData } {
const res =
response instanceof StreamingAgentChatResponse
? response.response
: response;
return {
stream: createParser(res, data, opts?.parserOptions)
.pipeThrough(createCallbacksTransformer(opts?.callbacks))
.pipeThrough(createStreamDataTransformer()),
data,
};
): ReadableStream<Uint8Array> {
return createParser(response, data, opts?.parserOptions)
.pipeThrough(createCallbacksTransformer(opts?.callbacks))
.pipeThrough(createStreamDataTransformer());
}
@@ -5,7 +5,7 @@ import { NextRequest, NextResponse } from "next/server";
import { createChatEngine } from "./engine/chat";
import { initSettings } from "./engine/settings";
import { LlamaIndexStream } from "./llamaindex-stream";
import { appendEventData } from "./stream-helper";
import { createCallbackManager } from "./stream-helper";
initObservability();
initSettings();
@@ -57,29 +57,21 @@ export async function POST(request: NextRequest) {
// Init Vercel AI StreamData
const vercelStreamData = new StreamData();
appendEventData(
vercelStreamData,
`Retrieving context for query: '${userMessage.content}'`,
);
// Setup callback for streaming data before chatting
Settings.callbackManager.on("retrieve", (data) => {
const { nodes } = data.detail;
appendEventData(
vercelStreamData,
`Retrieved ${nodes.length} sources to use as context for the query`,
);
});
// Setup callbacks
const callbackManager = createCallbackManager(vercelStreamData);
// Calling LlamaIndex's ChatEngine to get a streamed response
const response = await chatEngine.chat({
message: userMessageContent,
chatHistory: messages as ChatMessage[],
stream: true,
const response = await Settings.withCallbackManager(callbackManager, () => {
return chatEngine.chat({
message: userMessageContent,
chatHistory: messages as ChatMessage[],
stream: true,
});
});
// Transform LlamaIndex stream to Vercel/AI format
const { stream } = LlamaIndexStream(response, vercelStreamData, {
const stream = LlamaIndexStream(response, vercelStreamData, {
parserOptions: {
image_url: data?.imageUrl,
},
@@ -1,5 +1,11 @@
import { StreamData } from "ai";
import { Metadata, NodeWithScore } from "llamaindex";
import {
CallbackManager,
Metadata,
NodeWithScore,
ToolCall,
ToolOutput,
} from "llamaindex";
export function appendImageData(data: StreamData, imageUrl?: string) {
if (!imageUrl) return;
@@ -37,3 +43,55 @@ export function appendEventData(data: StreamData, title?: string) {
},
});
}
export function appendToolData(
data: StreamData,
toolCall: ToolCall,
toolOutput: ToolOutput,
) {
data.appendMessageAnnotation({
type: "tools",
data: {
toolCall: {
id: toolCall.id,
name: toolCall.name,
input: toolCall.input,
},
toolOutput: {
output: toolOutput.output,
isError: toolOutput.isError,
},
},
});
}
export function createCallbackManager(stream: StreamData) {
const callbackManager = new CallbackManager();
callbackManager.on("retrieve", (data) => {
const { nodes, query } = data.detail;
appendEventData(stream, `Retrieving context for query: '${query}'`);
appendEventData(
stream,
`Retrieved ${nodes.length} sources to use as context for the query`,
);
});
callbackManager.on("llm-tool-call", (event) => {
const { name, input } = event.detail.payload.toolCall;
const inputString = Object.entries(input)
.map(([key, value]) => `${key}: ${value}`)
.join(", ");
appendEventData(
stream,
`Using tool: '${name}' with inputs: '${inputString}'`,
);
});
callbackManager.on("llm-tool-result", (event) => {
const { toolCall, toolResult } = event.detail.payload;
appendToolData(stream, toolCall, toolResult);
});
return callbackManager;
}
@@ -17,7 +17,8 @@ export default function ChatSection() {
headers: {
"Content-Type": "application/json", // using JSON because of vercel/ai 2.2.26
},
onError: (error) => {
onError: (error: unknown) => {
if (!(error instanceof Error)) throw error;
const message = JSON.parse(error.message);
alert(message.detail);
},
@@ -7,6 +7,7 @@ import ChatAvatar from "./chat-avatar";
import { ChatEvents } from "./chat-events";
import { ChatImage } from "./chat-image";
import { ChatSources } from "./chat-sources";
import ChatTools from "./chat-tools";
import {
AnnotationData,
EventData,
@@ -14,6 +15,7 @@ import {
MessageAnnotation,
MessageAnnotationType,
SourceData,
ToolData,
} from "./index";
import Markdown from "./markdown";
import { useCopyToClipboard } from "./use-copy-to-clipboard";
@@ -52,19 +54,27 @@ function ChatMessageContent({
annotations,
MessageAnnotationType.SOURCES,
);
const toolData = getAnnotationData<ToolData>(
annotations,
MessageAnnotationType.TOOLS,
);
const contents: ContentDisplayConfig[] = [
{
order: -2,
order: -3,
component: imageData[0] ? <ChatImage data={imageData[0]} /> : null,
},
{
order: -1,
order: -2,
component:
eventData.length > 0 ? (
<ChatEvents isLoading={isLoading} data={eventData} />
) : null,
},
{
order: -1,
component: toolData[0] ? <ChatTools data={toolData[0]} /> : null,
},
{
order: 0,
component: <Markdown content={message.content} />,
@@ -40,9 +40,16 @@ export default function ChatMessages(
className="flex h-[50vh] flex-col gap-5 divide-y overflow-y-auto pb-4"
ref={scrollableChatContainerRef}
>
{props.messages.map((m) => (
<ChatMessage key={m.id} chatMessage={m} isLoading={props.isLoading} />
))}
{props.messages.map((m, i) => {
const isLoadingMessage = i === messageLength - 1 && props.isLoading;
return (
<ChatMessage
key={m.id}
chatMessage={m}
isLoading={isLoadingMessage}
/>
);
})}
{isPending && (
<div className="flex justify-center items-center pt-10">
<Loader2 className="h-4 w-4 animate-spin" />
@@ -0,0 +1,26 @@
import { ToolData } from "./index";
import { WeatherCard, WeatherData } from "./widgets/WeatherCard";
// TODO: If needed, add displaying more tool outputs here
export default function ChatTools({ data }: { data: ToolData }) {
if (!data) return null;
const { toolCall, toolOutput } = data;
if (toolOutput.isError) {
return (
<div className="border-l-2 border-red-400 pl-2">
There was an error when calling the tool {toolCall.name} with input:{" "}
<br />
{JSON.stringify(toolCall.input)}
</div>
);
}
switch (toolCall.name) {
case "get_weather_information":
const weatherData = toolOutput.output as unknown as WeatherData;
return <WeatherCard data={weatherData} />;
default:
return null;
}
}
@@ -1,3 +1,4 @@
import { JSONValue } from "ai";
import ChatInput from "./chat-input";
import ChatMessages from "./chat-messages";
@@ -8,6 +9,7 @@ export enum MessageAnnotationType {
IMAGE = "image",
SOURCES = "sources",
EVENTS = "events",
TOOLS = "tools",
}
export type ImageData = {
@@ -30,7 +32,21 @@ export type EventData = {
isCollapsed: boolean;
};
export type AnnotationData = ImageData | SourceData | EventData;
export type ToolData = {
toolCall: {
id: string;
name: string;
input: {
[key: string]: JSONValue;
};
};
toolOutput: {
output: JSONValue;
isError: boolean;
};
};
export type AnnotationData = ImageData | SourceData | EventData | ToolData;
export type MessageAnnotation = {
type: MessageAnnotationType;
@@ -0,0 +1,213 @@
export interface WeatherData {
latitude: number;
longitude: number;
generationtime_ms: number;
utc_offset_seconds: number;
timezone: string;
timezone_abbreviation: string;
elevation: number;
current_units: {
time: string;
interval: string;
temperature_2m: string;
weather_code: string;
};
current: {
time: string;
interval: number;
temperature_2m: number;
weather_code: number;
};
hourly_units: {
time: string;
temperature_2m: string;
weather_code: string;
};
hourly: {
time: string[];
temperature_2m: number[];
weather_code: number[];
};
daily_units: {
time: string;
weather_code: string;
};
daily: {
time: string[];
weather_code: number[];
};
}
// Follow WMO Weather interpretation codes (WW)
const weatherCodeDisplayMap: Record<
string,
{
icon: JSX.Element;
status: string;
}
> = {
"0": {
icon: <span></span>,
status: "Clear sky",
},
"1": {
icon: <span>🌤</span>,
status: "Mainly clear",
},
"2": {
icon: <span></span>,
status: "Partly cloudy",
},
"3": {
icon: <span></span>,
status: "Overcast",
},
"45": {
icon: <span>🌫</span>,
status: "Fog",
},
"48": {
icon: <span>🌫</span>,
status: "Depositing rime fog",
},
"51": {
icon: <span>🌧</span>,
status: "Drizzle",
},
"53": {
icon: <span>🌧</span>,
status: "Drizzle",
},
"55": {
icon: <span>🌧</span>,
status: "Drizzle",
},
"56": {
icon: <span>🌧</span>,
status: "Freezing Drizzle",
},
"57": {
icon: <span>🌧</span>,
status: "Freezing Drizzle",
},
"61": {
icon: <span>🌧</span>,
status: "Rain",
},
"63": {
icon: <span>🌧</span>,
status: "Rain",
},
"65": {
icon: <span>🌧</span>,
status: "Rain",
},
"66": {
icon: <span>🌧</span>,
status: "Freezing Rain",
},
"67": {
icon: <span>🌧</span>,
status: "Freezing Rain",
},
"71": {
icon: <span></span>,
status: "Snow fall",
},
"73": {
icon: <span></span>,
status: "Snow fall",
},
"75": {
icon: <span></span>,
status: "Snow fall",
},
"77": {
icon: <span></span>,
status: "Snow grains",
},
"80": {
icon: <span>🌧</span>,
status: "Rain showers",
},
"81": {
icon: <span>🌧</span>,
status: "Rain showers",
},
"82": {
icon: <span>🌧</span>,
status: "Rain showers",
},
"85": {
icon: <span></span>,
status: "Snow showers",
},
"86": {
icon: <span></span>,
status: "Snow showers",
},
"95": {
icon: <span></span>,
status: "Thunderstorm",
},
"96": {
icon: <span></span>,
status: "Thunderstorm",
},
"99": {
icon: <span></span>,
status: "Thunderstorm",
},
};
const displayDay = (time: string) => {
return new Date(time).toLocaleDateString("en-US", {
weekday: "long",
});
};
export function WeatherCard({ data }: { data: WeatherData }) {
const currentDayString = new Date(data.current.time).toLocaleDateString(
"en-US",
{
weekday: "long",
month: "long",
day: "numeric",
},
);
return (
<div className="bg-[#61B9F2] rounded-2xl shadow-xl p-5 space-y-4 text-white w-fit">
<div className="flex justify-between">
<div className="space-y-2">
<div className="text-xl">{currentDayString}</div>
<div className="text-5xl font-semibold flex gap-4">
<span>
{data.current.temperature_2m} {data.current_units.temperature_2m}
</span>
{weatherCodeDisplayMap[data.current.weather_code].icon}
</div>
</div>
<span className="text-xl">
{weatherCodeDisplayMap[data.current.weather_code].status}
</span>
</div>
<div className="gap-2 grid grid-cols-6">
{data.daily.time.map((time, index) => {
if (index === 0) return null; // skip the current day
return (
<div key={time} className="flex flex-col items-center gap-4">
<span>{displayDay(time)}</span>
<div className="text-4xl">
{weatherCodeDisplayMap[data.daily.weather_code[index]].icon}
</div>
<span className="text-sm">
{weatherCodeDisplayMap[data.daily.weather_code[index]].status}
</span>
</div>
);
})}
</div>
</div>
);
}
@@ -14,12 +14,14 @@
"@radix-ui/react-hover-card": "^1.0.7",
"@radix-ui/react-slot": "^1.0.2",
"ai": "^3.0.21",
"ajv": "^8.12.0",
"class-variance-authority": "^0.7.0",
"clsx": "^1.2.1",
"clsx": "^2.1.1",
"dotenv": "^16.3.1",
"llamaindex": "0.2.10",
"llamaindex": "0.3.9",
"lucide-react": "^0.294.0",
"next": "^14.0.3",
"pdf2json": "3.0.5",
"react": "^18.2.0",
"react-dom": "^18.2.0",
"react-markdown": "^8.0.7",