mirror of
https://github.com/run-llama/create-llama.git
synced 2026-07-04 00:16:55 -04:00
Compare commits
31 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c3215ccc7b | |||
| 18ca18123f | |||
| 5ecb0c9fb7 | |||
| 7e45f604e6 | |||
| bbacf0f199 | |||
| c0c6df80c7 | |||
| 3b39a12ad6 | |||
| c981eb1423 | |||
| c094b0c6bf | |||
| e2567ffc03 | |||
| 5d8d752b16 | |||
| a0b04be23c | |||
| 94a2809ecd | |||
| e29ef92564 | |||
| 6bdd4ac69d | |||
| 1ad25451a6 | |||
| cfb5257a1e | |||
| 046ff06157 | |||
| 8b81b17984 | |||
| f1c3e8df69 | |||
| 089916a148 | |||
| 3bb94da804 | |||
| 418bf9ba8a | |||
| e5d20b66f6 | |||
| ae7b30106d | |||
| 5fb64b74ca | |||
| e4665b6c0d | |||
| 5463d3bf4b | |||
| 7225e916fd | |||
| 897feb9914 | |||
| 66b5f38eda |
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"$schema": "https://unpkg.com/@changesets/config@3.0.0/schema.json",
|
||||
"changelog": "@changesets/cli/changelog",
|
||||
"commit": true,
|
||||
"commit": false,
|
||||
"fixed": [],
|
||||
"linked": [],
|
||||
"access": "public",
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
---
|
||||
"create-llama": patch
|
||||
---
|
||||
|
||||
Use `gpt-4-turbo` model as default. Upgrade Python llama-index to 0.10.28
|
||||
@@ -1,5 +0,0 @@
|
||||
---
|
||||
"create-llama": patch
|
||||
---
|
||||
|
||||
Remove asking for AI models and use defaults instead (OpenAIs GPT-4 Vision Preview and Embeddings v3). Use `--ask-models` CLI parameter to select models.
|
||||
@@ -1,5 +0,0 @@
|
||||
---
|
||||
"create-llama": patch
|
||||
---
|
||||
|
||||
Add observability for Python
|
||||
@@ -1,5 +0,0 @@
|
||||
---
|
||||
"create-llama": patch
|
||||
---
|
||||
|
||||
Use poetry run generate to generate embeddings for FastAPI
|
||||
@@ -1,5 +0,0 @@
|
||||
---
|
||||
"create-llama": patch
|
||||
---
|
||||
|
||||
Use Settings object for LlamaIndex configuration
|
||||
@@ -1,5 +0,0 @@
|
||||
---
|
||||
"create-llama": patch
|
||||
---
|
||||
|
||||
Add Qdrant support
|
||||
@@ -0,0 +1,5 @@
|
||||
---
|
||||
"create-llama": patch
|
||||
---
|
||||
|
||||
Use ingestion pipeline for Python
|
||||
@@ -0,0 +1,5 @@
|
||||
---
|
||||
"create-llama": patch
|
||||
---
|
||||
|
||||
Display events (e.g. retrieving nodes) per chat message
|
||||
@@ -35,12 +35,13 @@ jobs:
|
||||
with:
|
||||
version: ${{ env.POETRY_VERSION }}
|
||||
|
||||
- uses: pnpm/action-setup@v3
|
||||
|
||||
- name: Setup Node.js ${{ matrix.node-version }}
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: ${{ matrix.node-version }}
|
||||
|
||||
- uses: pnpm/action-setup@v3
|
||||
cache: "pnpm"
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install
|
||||
|
||||
@@ -14,12 +14,13 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: pnpm/action-setup@v3
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version-file: ".nvmrc"
|
||||
|
||||
- uses: pnpm/action-setup@v3
|
||||
cache: "pnpm"
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
name: Publish to GitHub Releases
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "v*"
|
||||
|
||||
jobs:
|
||||
build-and-publish:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout Repo
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- uses: pnpm/action-setup@v3
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version-file: ".nvmrc"
|
||||
cache: "pnpm"
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install
|
||||
|
||||
- name: Build tarball
|
||||
run: |
|
||||
pnpm pack
|
||||
|
||||
- name: Create release
|
||||
uses: ncipollo/release-action@v1
|
||||
with:
|
||||
artifacts: "create-llama-*.tgz"
|
||||
name: Release ${{ github.ref }}
|
||||
bodyFile: "CHANGELOG.md"
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
@@ -15,17 +15,41 @@ jobs:
|
||||
- name: Checkout Repo
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- uses: pnpm/action-setup@v3
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version-file: ".nvmrc"
|
||||
|
||||
- uses: pnpm/action-setup@v3
|
||||
cache: "pnpm"
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install
|
||||
|
||||
- name: Create Release Pull Request
|
||||
- name: Add auth token to .npmrc file
|
||||
run: |
|
||||
cat << EOF >> ".npmrc"
|
||||
//registry.npmjs.org/:_authToken=$NPM_TOKEN
|
||||
EOF
|
||||
env:
|
||||
NPM_TOKEN: ${{ secrets.NPM_TOKEN }}
|
||||
|
||||
- name: Get changeset status
|
||||
id: get-changeset-status
|
||||
run: |
|
||||
pnpm changeset status --output .changeset/status.json
|
||||
new_version=$(jq -r '.releases[0].newVersion' < .changeset/status.json)
|
||||
rm -v .changeset/status.json
|
||||
echo "new-version=${new_version}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Create Release Pull Request or Publish to npm
|
||||
id: changesets
|
||||
uses: changesets/action@v1
|
||||
with:
|
||||
commit: Release ${{ steps.get-changeset-status.outputs.new-version }}
|
||||
title: Release ${{ steps.get-changeset-status.outputs.new-version }}
|
||||
# build package and call changeset publish
|
||||
publish: pnpm release
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
NPM_TOKEN: ${{ secrets.NPM_TOKEN }}
|
||||
|
||||
@@ -1,5 +1,21 @@
|
||||
# create-llama
|
||||
|
||||
## 0.1.0
|
||||
|
||||
### Minor Changes
|
||||
|
||||
- f1c3e8d: Add Llama3 and Phi3 support using Ollama
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- a0dec80: Use `gpt-4-turbo` model as default. Upgrade Python llama-index to 0.10.28
|
||||
- 753229d: Remove asking for AI models and use defaults instead (OpenAIs GPT-4 Vision Preview and Embeddings v3). Use `--ask-models` CLI parameter to select models.
|
||||
- 1d78202: Add observability for Python
|
||||
- 6acccd2: Use poetry run generate to generate embeddings for FastAPI
|
||||
- 9efcffe: Use Settings object for LlamaIndex configuration
|
||||
- 418bf9b: refactor: use tsx instead of ts-node
|
||||
- 1be69a5: Add Qdrant support
|
||||
|
||||
## 0.0.32
|
||||
|
||||
### Patch Changes
|
||||
|
||||
+1
-1
@@ -7,7 +7,7 @@ Install NodeJS. Preferably v18 using nvm or n.
|
||||
Inside the `create-llama` directory:
|
||||
|
||||
```
|
||||
npm i -g pnpm ts-node
|
||||
npm i -g pnpm
|
||||
pnpm install
|
||||
```
|
||||
|
||||
|
||||
+2
-6
@@ -30,10 +30,8 @@ export async function createApp({
|
||||
appPath,
|
||||
packageManager,
|
||||
frontend,
|
||||
openAiKey,
|
||||
modelConfig,
|
||||
llamaCloudKey,
|
||||
model,
|
||||
embeddingModel,
|
||||
communityProjectConfig,
|
||||
llamapack,
|
||||
vectorDb,
|
||||
@@ -77,10 +75,8 @@ export async function createApp({
|
||||
ui,
|
||||
packageManager,
|
||||
isOnline,
|
||||
openAiKey,
|
||||
modelConfig,
|
||||
llamaCloudKey,
|
||||
model,
|
||||
embeddingModel,
|
||||
communityProjectConfig,
|
||||
llamapack,
|
||||
vectorDb,
|
||||
|
||||
@@ -52,6 +52,7 @@ export const copy = async (
|
||||
export const assetRelocator = (name: string) => {
|
||||
switch (name) {
|
||||
case "gitignore":
|
||||
case "npmrc":
|
||||
case "eslintrc.json": {
|
||||
return `.${name}`;
|
||||
}
|
||||
|
||||
+98
-95
@@ -1,6 +1,7 @@
|
||||
import fs from "fs/promises";
|
||||
import path from "path";
|
||||
import {
|
||||
ModelConfig,
|
||||
TemplateDataSource,
|
||||
TemplateFramework,
|
||||
TemplateVectorDB,
|
||||
@@ -28,7 +29,10 @@ const renderEnvVar = (envVars: EnvVar[]): string => {
|
||||
);
|
||||
};
|
||||
|
||||
const getVectorDBEnvs = (vectorDb: TemplateVectorDB) => {
|
||||
const getVectorDBEnvs = (vectorDb?: TemplateVectorDB): EnvVar[] => {
|
||||
if (!vectorDb) {
|
||||
return [];
|
||||
}
|
||||
switch (vectorDb) {
|
||||
case "mongo":
|
||||
return [
|
||||
@@ -130,84 +134,70 @@ const getVectorDBEnvs = (vectorDb: TemplateVectorDB) => {
|
||||
}
|
||||
};
|
||||
|
||||
export const createBackendEnvFile = async (
|
||||
root: string,
|
||||
opts: {
|
||||
openAiKey?: string;
|
||||
llamaCloudKey?: string;
|
||||
vectorDb?: TemplateVectorDB;
|
||||
model?: string;
|
||||
embeddingModel?: string;
|
||||
framework?: TemplateFramework;
|
||||
dataSources?: TemplateDataSource[];
|
||||
port?: number;
|
||||
},
|
||||
) => {
|
||||
// Init env values
|
||||
const envFileName = ".env";
|
||||
const defaultEnvs = [
|
||||
const getModelEnvs = (modelConfig: ModelConfig): EnvVar[] => {
|
||||
return [
|
||||
{
|
||||
name: "MODEL_PROVIDER",
|
||||
description: "The provider for the AI models to use.",
|
||||
value: modelConfig.provider,
|
||||
},
|
||||
{
|
||||
render: true,
|
||||
name: "MODEL",
|
||||
description: "The name of LLM model to use.",
|
||||
value: opts.model,
|
||||
},
|
||||
{
|
||||
render: true,
|
||||
name: "OPENAI_API_KEY",
|
||||
description: "The OpenAI API key to use.",
|
||||
value: opts.openAiKey,
|
||||
},
|
||||
{
|
||||
name: "LLAMA_CLOUD_API_KEY",
|
||||
description: `The Llama Cloud API key.`,
|
||||
value: opts.llamaCloudKey,
|
||||
value: modelConfig.model,
|
||||
},
|
||||
{
|
||||
name: "EMBEDDING_MODEL",
|
||||
description: "Name of the embedding model to use.",
|
||||
value: opts.embeddingModel,
|
||||
value: modelConfig.embeddingModel,
|
||||
},
|
||||
{
|
||||
name: "EMBEDDING_DIM",
|
||||
description: "Dimension of the embedding model to use.",
|
||||
value: 1536,
|
||||
value: modelConfig.dimensions.toString(),
|
||||
},
|
||||
// Add vector database environment variables
|
||||
...(opts.vectorDb ? getVectorDBEnvs(opts.vectorDb) : []),
|
||||
...(modelConfig.provider === "openai"
|
||||
? [
|
||||
{
|
||||
name: "OPENAI_API_KEY",
|
||||
description: "The OpenAI API key to use.",
|
||||
value: modelConfig.apiKey,
|
||||
},
|
||||
{
|
||||
name: "LLM_TEMPERATURE",
|
||||
description: "Temperature for sampling from the model.",
|
||||
},
|
||||
{
|
||||
name: "LLM_MAX_TOKENS",
|
||||
description: "Maximum number of tokens to generate.",
|
||||
},
|
||||
]
|
||||
: []),
|
||||
];
|
||||
let envVars: EnvVar[] = [];
|
||||
if (opts.framework === "fastapi") {
|
||||
envVars = [
|
||||
...defaultEnvs,
|
||||
...[
|
||||
{
|
||||
name: "APP_HOST",
|
||||
description: "The address to start the backend app.",
|
||||
value: "0.0.0.0",
|
||||
},
|
||||
{
|
||||
name: "APP_PORT",
|
||||
description: "The port to start the backend app.",
|
||||
value: opts.port?.toString() || "8000",
|
||||
},
|
||||
{
|
||||
name: "LLM_TEMPERATURE",
|
||||
description: "Temperature for sampling from the model.",
|
||||
},
|
||||
{
|
||||
name: "LLM_MAX_TOKENS",
|
||||
description: "Maximum number of tokens to generate.",
|
||||
},
|
||||
{
|
||||
name: "TOP_K",
|
||||
description:
|
||||
"The number of similar embeddings to return when retrieving documents.",
|
||||
value: "3",
|
||||
},
|
||||
{
|
||||
name: "SYSTEM_PROMPT",
|
||||
description: `Custom system prompt.
|
||||
};
|
||||
|
||||
const getFrameworkEnvs = (
|
||||
framework?: TemplateFramework,
|
||||
port?: number,
|
||||
): EnvVar[] => {
|
||||
if (framework !== "fastapi") {
|
||||
return [];
|
||||
}
|
||||
return [
|
||||
{
|
||||
name: "APP_HOST",
|
||||
description: "The address to start the backend app.",
|
||||
value: "0.0.0.0",
|
||||
},
|
||||
{
|
||||
name: "APP_PORT",
|
||||
description: "The port to start the backend app.",
|
||||
value: port?.toString() || "8000",
|
||||
},
|
||||
// TODO: Once LlamaIndexTS supports string templates, move this to `getEngineEnvs`
|
||||
{
|
||||
name: "SYSTEM_PROMPT",
|
||||
description: `Custom system prompt.
|
||||
Example:
|
||||
SYSTEM_PROMPT="
|
||||
We have provided context information below.
|
||||
@@ -216,24 +206,48 @@ We have provided context information below.
|
||||
---------------------
|
||||
Given this information, please answer the question: {query_str}
|
||||
"`,
|
||||
},
|
||||
],
|
||||
];
|
||||
} else {
|
||||
envVars = [
|
||||
...defaultEnvs,
|
||||
...[
|
||||
opts.framework === "nextjs"
|
||||
? {
|
||||
name: "NEXT_PUBLIC_MODEL",
|
||||
description:
|
||||
"The LLM model to use (hardcode to front-end artifact).",
|
||||
value: opts.model,
|
||||
}
|
||||
: {},
|
||||
],
|
||||
];
|
||||
}
|
||||
},
|
||||
];
|
||||
};
|
||||
|
||||
const getEngineEnvs = (): EnvVar[] => {
|
||||
return [
|
||||
{
|
||||
name: "TOP_K",
|
||||
description:
|
||||
"The number of similar embeddings to return when retrieving documents.",
|
||||
value: "3",
|
||||
},
|
||||
];
|
||||
};
|
||||
|
||||
export const createBackendEnvFile = async (
|
||||
root: string,
|
||||
opts: {
|
||||
llamaCloudKey?: string;
|
||||
vectorDb?: TemplateVectorDB;
|
||||
modelConfig: ModelConfig;
|
||||
framework?: TemplateFramework;
|
||||
dataSources?: TemplateDataSource[];
|
||||
port?: number;
|
||||
},
|
||||
) => {
|
||||
// Init env values
|
||||
const envFileName = ".env";
|
||||
const envVars: EnvVar[] = [
|
||||
{
|
||||
name: "LLAMA_CLOUD_API_KEY",
|
||||
description: `The Llama Cloud API key.`,
|
||||
value: opts.llamaCloudKey,
|
||||
},
|
||||
// Add model environment variables
|
||||
...getModelEnvs(opts.modelConfig),
|
||||
// Add engine environment variables
|
||||
...getEngineEnvs(),
|
||||
// Add vector database environment variables
|
||||
...getVectorDBEnvs(opts.vectorDb),
|
||||
...getFrameworkEnvs(opts.framework, opts.port),
|
||||
];
|
||||
// Render and write env file
|
||||
const content = renderEnvVar(envVars);
|
||||
await fs.writeFile(path.join(root, envFileName), content);
|
||||
@@ -244,20 +258,9 @@ export const createFrontendEnvFile = async (
|
||||
root: string,
|
||||
opts: {
|
||||
customApiPath?: string;
|
||||
model?: string;
|
||||
},
|
||||
) => {
|
||||
const defaultFrontendEnvs = [
|
||||
{
|
||||
name: "MODEL",
|
||||
description: "The OpenAI model to use.",
|
||||
value: opts.model,
|
||||
},
|
||||
{
|
||||
name: "NEXT_PUBLIC_MODEL",
|
||||
description: "The OpenAI model to use (hardcode to front-end artifact).",
|
||||
value: opts.model,
|
||||
},
|
||||
{
|
||||
name: "NEXT_PUBLIC_CHAT_API",
|
||||
description: "The backend API for chat endpoint.",
|
||||
|
||||
+8
-9
@@ -9,12 +9,14 @@ import { createBackendEnvFile, createFrontendEnvFile } from "./env-variables";
|
||||
import { PackageManager } from "./get-pkg-manager";
|
||||
import { installLlamapackProject } from "./llama-pack";
|
||||
import { isHavingPoetryLockFile, tryPoetryRun } from "./poetry";
|
||||
import { isModelConfigured } from "./providers";
|
||||
import { installPythonTemplate } from "./python";
|
||||
import { downloadAndExtractRepo } from "./repo";
|
||||
import { ConfigFileType, writeToolsConfig } from "./tools";
|
||||
import {
|
||||
FileSourceConfig,
|
||||
InstallTemplateArgs,
|
||||
ModelConfig,
|
||||
TemplateDataSource,
|
||||
TemplateFramework,
|
||||
TemplateVectorDB,
|
||||
@@ -24,8 +26,8 @@ import { installTSTemplate } from "./typescript";
|
||||
// eslint-disable-next-line max-params
|
||||
async function generateContextData(
|
||||
framework: TemplateFramework,
|
||||
modelConfig: ModelConfig,
|
||||
packageManager?: PackageManager,
|
||||
openAiKey?: string,
|
||||
vectorDb?: TemplateVectorDB,
|
||||
llamaCloudKey?: string,
|
||||
useLlamaParse?: boolean,
|
||||
@@ -36,12 +38,12 @@ async function generateContextData(
|
||||
? "poetry run generate"
|
||||
: `${packageManager} run generate`,
|
||||
)}`;
|
||||
const openAiKeyConfigured = openAiKey || process.env["OPENAI_API_KEY"];
|
||||
const modelConfigured = isModelConfigured(modelConfig);
|
||||
const llamaCloudKeyConfigured = useLlamaParse
|
||||
? llamaCloudKey || process.env["LLAMA_CLOUD_API_KEY"]
|
||||
: true;
|
||||
const hasVectorDb = vectorDb && vectorDb !== "none";
|
||||
if (openAiKeyConfigured && llamaCloudKeyConfigured && !hasVectorDb) {
|
||||
if (modelConfigured && llamaCloudKeyConfigured && !hasVectorDb) {
|
||||
// If all the required environment variables are set, run the generate script
|
||||
if (framework === "fastapi") {
|
||||
if (isHavingPoetryLockFile()) {
|
||||
@@ -63,7 +65,7 @@ async function generateContextData(
|
||||
|
||||
// generate the message of what to do to run the generate script manually
|
||||
const settings = [];
|
||||
if (!openAiKeyConfigured) settings.push("your OpenAI key");
|
||||
if (!modelConfigured) settings.push("your model provider API key");
|
||||
if (!llamaCloudKeyConfigured) settings.push("your Llama Cloud key");
|
||||
if (hasVectorDb) settings.push("your Vector DB environment variables");
|
||||
const settingsMessage =
|
||||
@@ -141,11 +143,9 @@ export const installTemplate = async (
|
||||
|
||||
// Copy the environment file to the target directory.
|
||||
await createBackendEnvFile(props.root, {
|
||||
openAiKey: props.openAiKey,
|
||||
modelConfig: props.modelConfig,
|
||||
llamaCloudKey: props.llamaCloudKey,
|
||||
vectorDb: props.vectorDb,
|
||||
model: props.model,
|
||||
embeddingModel: props.embeddingModel,
|
||||
framework: props.framework,
|
||||
dataSources: props.dataSources,
|
||||
port: props.externalPort,
|
||||
@@ -163,8 +163,8 @@ export const installTemplate = async (
|
||||
) {
|
||||
await generateContextData(
|
||||
props.framework,
|
||||
props.modelConfig,
|
||||
props.packageManager,
|
||||
props.openAiKey,
|
||||
props.vectorDb,
|
||||
props.llamaCloudKey,
|
||||
props.useLlamaParse,
|
||||
@@ -174,7 +174,6 @@ export const installTemplate = async (
|
||||
} else {
|
||||
// this is a frontend for a full-stack app, create .env file with model information
|
||||
await createFrontendEnvFile(props.root, {
|
||||
model: props.model,
|
||||
customApiPath: props.customApiPath,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
import ciInfo from "ci-info";
|
||||
import prompts from "prompts";
|
||||
import { questionHandlers } from "../../questions";
|
||||
import { ModelConfig, ModelProvider } from "../types";
|
||||
import { askOllamaQuestions } from "./ollama";
|
||||
import { askOpenAIQuestions, isOpenAIConfigured } from "./openai";
|
||||
|
||||
const DEFAULT_MODEL_PROVIDER = "openai";
|
||||
|
||||
export type ModelConfigQuestionsParams = {
|
||||
openAiKey?: string;
|
||||
askModels: boolean;
|
||||
};
|
||||
|
||||
export type ModelConfigParams = Omit<ModelConfig, "provider">;
|
||||
|
||||
export async function askModelConfig({
|
||||
askModels,
|
||||
openAiKey,
|
||||
}: ModelConfigQuestionsParams): Promise<ModelConfig> {
|
||||
let modelProvider: ModelProvider = DEFAULT_MODEL_PROVIDER;
|
||||
if (askModels && !ciInfo.isCI) {
|
||||
const { provider } = await prompts(
|
||||
{
|
||||
type: "select",
|
||||
name: "provider",
|
||||
message: "Which model provider would you like to use",
|
||||
choices: [
|
||||
{
|
||||
title: "OpenAI",
|
||||
value: "openai",
|
||||
},
|
||||
{ title: "Ollama", value: "ollama" },
|
||||
],
|
||||
initial: 0,
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
modelProvider = provider;
|
||||
}
|
||||
|
||||
let modelConfig: ModelConfigParams;
|
||||
switch (modelProvider) {
|
||||
case "ollama":
|
||||
modelConfig = await askOllamaQuestions({ askModels });
|
||||
break;
|
||||
default:
|
||||
modelConfig = await askOpenAIQuestions({
|
||||
openAiKey,
|
||||
askModels,
|
||||
});
|
||||
}
|
||||
return {
|
||||
...modelConfig,
|
||||
provider: modelProvider,
|
||||
};
|
||||
}
|
||||
|
||||
export function isModelConfigured(modelConfig: ModelConfig): boolean {
|
||||
switch (modelConfig.provider) {
|
||||
case "openai":
|
||||
return isOpenAIConfigured(modelConfig);
|
||||
default:
|
||||
return true;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,92 @@
|
||||
import ciInfo from "ci-info";
|
||||
import ollama, { type ModelResponse } from "ollama";
|
||||
import { red } from "picocolors";
|
||||
import prompts from "prompts";
|
||||
import { ModelConfigParams } from ".";
|
||||
import { questionHandlers, toChoice } from "../../questions";
|
||||
|
||||
type ModelData = {
|
||||
dimensions: number;
|
||||
};
|
||||
const MODELS = ["llama3:8b", "wizardlm2:7b", "gemma:7b", "phi3"];
|
||||
const DEFAULT_MODEL = MODELS[0];
|
||||
// TODO: get embedding vector dimensions from the ollama sdk (currently not supported)
|
||||
const EMBEDDING_MODELS: Record<string, ModelData> = {
|
||||
"nomic-embed-text": { dimensions: 768 },
|
||||
"mxbai-embed-large": { dimensions: 1024 },
|
||||
"all-minilm": { dimensions: 384 },
|
||||
};
|
||||
const DEFAULT_EMBEDDING_MODEL: string = Object.keys(EMBEDDING_MODELS)[0];
|
||||
|
||||
type OllamaQuestionsParams = {
|
||||
askModels: boolean;
|
||||
};
|
||||
|
||||
export async function askOllamaQuestions({
|
||||
askModels,
|
||||
}: OllamaQuestionsParams): Promise<ModelConfigParams> {
|
||||
const config: ModelConfigParams = {
|
||||
model: DEFAULT_MODEL,
|
||||
embeddingModel: DEFAULT_EMBEDDING_MODEL,
|
||||
dimensions: EMBEDDING_MODELS[DEFAULT_EMBEDDING_MODEL].dimensions,
|
||||
};
|
||||
|
||||
// use default model values in CI or if user should not be asked
|
||||
const useDefaults = ciInfo.isCI || !askModels;
|
||||
if (!useDefaults) {
|
||||
const { model } = await prompts(
|
||||
{
|
||||
type: "select",
|
||||
name: "model",
|
||||
message: "Which LLM model would you like to use?",
|
||||
choices: MODELS.map(toChoice),
|
||||
initial: 0,
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
await ensureModel(model);
|
||||
config.model = model;
|
||||
|
||||
const { embeddingModel } = await prompts(
|
||||
{
|
||||
type: "select",
|
||||
name: "embeddingModel",
|
||||
message: "Which embedding model would you like to use?",
|
||||
choices: Object.keys(EMBEDDING_MODELS).map(toChoice),
|
||||
initial: 0,
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
await ensureModel(embeddingModel);
|
||||
config.embeddingModel = embeddingModel;
|
||||
config.dimensions = EMBEDDING_MODELS[embeddingModel].dimensions;
|
||||
}
|
||||
|
||||
return config;
|
||||
}
|
||||
|
||||
async function ensureModel(modelName: string) {
|
||||
try {
|
||||
if (modelName.split(":").length === 1) {
|
||||
// model doesn't have a version suffix, use latest
|
||||
modelName = modelName + ":latest";
|
||||
}
|
||||
const { models } = await ollama.list();
|
||||
const found =
|
||||
models.find((model: ModelResponse) => model.name === modelName) !==
|
||||
undefined;
|
||||
if (!found) {
|
||||
console.log(
|
||||
red(
|
||||
`Model ${modelName} was not pulled yet. Call 'ollama pull ${modelName}' and try again.`,
|
||||
),
|
||||
);
|
||||
process.exit(1);
|
||||
}
|
||||
} catch (error) {
|
||||
console.log(
|
||||
red("Listing Ollama models failed. Is 'ollama' running? " + error),
|
||||
);
|
||||
process.exit(1);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
import ciInfo from "ci-info";
|
||||
import got from "got";
|
||||
import ora from "ora";
|
||||
import { red } from "picocolors";
|
||||
import prompts from "prompts";
|
||||
import { ModelConfigParams, ModelConfigQuestionsParams } from ".";
|
||||
import { questionHandlers } from "../../questions";
|
||||
|
||||
const OPENAI_API_URL = "https://api.openai.com/v1";
|
||||
|
||||
const DEFAULT_MODEL = "gpt-4-turbo";
|
||||
const DEFAULT_EMBEDDING_MODEL = "text-embedding-3-large";
|
||||
|
||||
export async function askOpenAIQuestions({
|
||||
openAiKey,
|
||||
askModels,
|
||||
}: ModelConfigQuestionsParams): Promise<ModelConfigParams> {
|
||||
const config: ModelConfigParams = {
|
||||
apiKey: openAiKey,
|
||||
model: DEFAULT_MODEL,
|
||||
embeddingModel: DEFAULT_EMBEDDING_MODEL,
|
||||
dimensions: getDimensions(DEFAULT_EMBEDDING_MODEL),
|
||||
};
|
||||
|
||||
if (!config.apiKey) {
|
||||
const { key } = await prompts(
|
||||
{
|
||||
type: "text",
|
||||
name: "key",
|
||||
message: askModels
|
||||
? "Please provide your OpenAI API key (or leave blank to use OPENAI_API_KEY env variable):"
|
||||
: "Please provide your OpenAI API key (leave blank to skip):",
|
||||
validate: (value: string) => {
|
||||
console.log(value);
|
||||
if (askModels && !value) {
|
||||
if (process.env.OPENAI_API_KEY) {
|
||||
return true;
|
||||
}
|
||||
return "OPENAI_API_KEY env variable is not set - key is required";
|
||||
}
|
||||
return true;
|
||||
},
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
config.apiKey = key || process.env.OPENAI_API_KEY;
|
||||
}
|
||||
|
||||
// use default model values in CI or if user should not be asked
|
||||
const useDefaults = ciInfo.isCI || !askModels;
|
||||
if (!useDefaults) {
|
||||
const { model } = await prompts(
|
||||
{
|
||||
type: "select",
|
||||
name: "model",
|
||||
message: "Which LLM model would you like to use?",
|
||||
choices: await getAvailableModelChoices(false, config.apiKey),
|
||||
initial: 0,
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
config.model = model;
|
||||
|
||||
const { embeddingModel } = await prompts(
|
||||
{
|
||||
type: "select",
|
||||
name: "embeddingModel",
|
||||
message: "Which embedding model would you like to use?",
|
||||
choices: await getAvailableModelChoices(true, config.apiKey),
|
||||
initial: 0,
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
config.embeddingModel = embeddingModel;
|
||||
config.dimensions = getDimensions(embeddingModel);
|
||||
}
|
||||
|
||||
return config;
|
||||
}
|
||||
|
||||
export function isOpenAIConfigured(params: ModelConfigParams): boolean {
|
||||
if (params.apiKey) {
|
||||
return true;
|
||||
}
|
||||
if (process.env["OPENAI_API_KEY"]) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
async function getAvailableModelChoices(
|
||||
selectEmbedding: boolean,
|
||||
apiKey?: string,
|
||||
) {
|
||||
if (!apiKey) {
|
||||
throw new Error("need OpenAI key to retrieve model choices");
|
||||
}
|
||||
const isLLMModel = (modelId: string) => {
|
||||
return modelId.startsWith("gpt");
|
||||
};
|
||||
|
||||
const isEmbeddingModel = (modelId: string) => {
|
||||
return modelId.includes("embedding");
|
||||
};
|
||||
|
||||
const spinner = ora("Fetching available models").start();
|
||||
try {
|
||||
const response = await got(`${OPENAI_API_URL}/models`, {
|
||||
headers: {
|
||||
Authorization: "Bearer " + apiKey,
|
||||
},
|
||||
timeout: 5000,
|
||||
responseType: "json",
|
||||
});
|
||||
const data: any = await response.body;
|
||||
spinner.stop();
|
||||
return data.data
|
||||
.filter((model: any) =>
|
||||
selectEmbedding ? isEmbeddingModel(model.id) : isLLMModel(model.id),
|
||||
)
|
||||
.map((el: any) => {
|
||||
return {
|
||||
title: el.id,
|
||||
value: el.id,
|
||||
};
|
||||
});
|
||||
} catch (error) {
|
||||
spinner.stop();
|
||||
if ((error as any).response?.statusCode === 401) {
|
||||
console.log(
|
||||
red(
|
||||
"Invalid OpenAI API key provided! Please provide a valid key and try again!",
|
||||
),
|
||||
);
|
||||
} else {
|
||||
console.log(red("Request failed: " + error));
|
||||
}
|
||||
process.exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
function getDimensions(modelName: string) {
|
||||
// at 2024-04-24 all OpenAI embedding models support 1536 dimensions except
|
||||
// "text-embedding-3-large", see https://openai.com/blog/new-embedding-models-and-api-updates
|
||||
return modelName === "text-embedding-3-large" ? 1024 : 1536;
|
||||
}
|
||||
+24
-3
@@ -10,6 +10,7 @@ import { isPoetryAvailable, tryPoetryInstall } from "./poetry";
|
||||
import { Tool } from "./tools";
|
||||
import {
|
||||
InstallTemplateArgs,
|
||||
ModelConfig,
|
||||
TemplateDataSource,
|
||||
TemplateVectorDB,
|
||||
} from "./types";
|
||||
@@ -21,6 +22,7 @@ interface Dependency {
|
||||
}
|
||||
|
||||
const getAdditionalDependencies = (
|
||||
modelConfig: ModelConfig,
|
||||
vectorDb?: TemplateVectorDB,
|
||||
dataSource?: TemplateDataSource,
|
||||
tools?: Tool[],
|
||||
@@ -108,6 +110,25 @@ const getAdditionalDependencies = (
|
||||
});
|
||||
});
|
||||
|
||||
switch (modelConfig.provider) {
|
||||
case "ollama":
|
||||
dependencies.push({
|
||||
name: "llama-index-llms-ollama",
|
||||
version: "0.1.2",
|
||||
});
|
||||
dependencies.push({
|
||||
name: "llama-index-embeddings-ollama",
|
||||
version: "0.1.2",
|
||||
});
|
||||
break;
|
||||
case "openai":
|
||||
dependencies.push({
|
||||
name: "llama-index-agent-openai",
|
||||
version: "0.2.2",
|
||||
});
|
||||
break;
|
||||
}
|
||||
|
||||
return dependencies;
|
||||
};
|
||||
|
||||
@@ -205,8 +226,8 @@ export const installPythonTemplate = async ({
|
||||
dataSources,
|
||||
tools,
|
||||
postInstallAction,
|
||||
useLlamaParse,
|
||||
observability,
|
||||
modelConfig,
|
||||
}: Pick<
|
||||
InstallTemplateArgs,
|
||||
| "root"
|
||||
@@ -215,9 +236,9 @@ export const installPythonTemplate = async ({
|
||||
| "vectorDb"
|
||||
| "dataSources"
|
||||
| "tools"
|
||||
| "useLlamaParse"
|
||||
| "postInstallAction"
|
||||
| "observability"
|
||||
| "modelConfig"
|
||||
>) => {
|
||||
console.log("\nInitializing Python project with template:", template, "\n");
|
||||
const templatePath = path.join(templatesDir, "types", template, framework);
|
||||
@@ -258,7 +279,7 @@ export const installPythonTemplate = async ({
|
||||
});
|
||||
|
||||
const addOnDependencies = dataSources
|
||||
.map((ds) => getAdditionalDependencies(vectorDb, ds, tools))
|
||||
.map((ds) => getAdditionalDependencies(modelConfig, vectorDb, ds, tools))
|
||||
.flat();
|
||||
|
||||
if (observability === "opentelemetry") {
|
||||
|
||||
+9
-3
@@ -1,6 +1,14 @@
|
||||
import { PackageManager } from "../helpers/get-pkg-manager";
|
||||
import { Tool } from "./tools";
|
||||
|
||||
export type ModelProvider = "openai" | "ollama";
|
||||
export type ModelConfig = {
|
||||
provider: ModelProvider;
|
||||
apiKey?: string;
|
||||
model: string;
|
||||
embeddingModel: string;
|
||||
dimensions: number;
|
||||
};
|
||||
export type TemplateType = "streaming" | "community" | "llamapack";
|
||||
export type TemplateFramework = "nextjs" | "express" | "fastapi";
|
||||
export type TemplateUI = "html" | "shadcn";
|
||||
@@ -59,11 +67,9 @@ export interface InstallTemplateArgs {
|
||||
ui: TemplateUI;
|
||||
dataSources: TemplateDataSource[];
|
||||
customApiPath?: string;
|
||||
openAiKey?: string;
|
||||
modelConfig: ModelConfig;
|
||||
llamaCloudKey?: string;
|
||||
useLlamaParse?: boolean;
|
||||
model: string;
|
||||
embeddingModel: string;
|
||||
communityProjectConfig?: CommunityProjectConfig;
|
||||
llamapack?: string;
|
||||
vectorDb?: TemplateVectorDB;
|
||||
|
||||
@@ -205,7 +205,7 @@ async function updatePackageJson({
|
||||
// add generate script if using context engine
|
||||
packageJson.scripts = {
|
||||
...packageJson.scripts,
|
||||
generate: `ts-node ${path.join(
|
||||
generate: `tsx ${path.join(
|
||||
relativeEngineDestPath,
|
||||
"engine",
|
||||
"generate.ts",
|
||||
|
||||
@@ -107,20 +107,6 @@ const program = new Commander.Command(packageJson.name)
|
||||
`
|
||||
|
||||
Whether to generate a frontend for your backend.
|
||||
`,
|
||||
)
|
||||
.option(
|
||||
"--model <model>",
|
||||
`
|
||||
|
||||
Select OpenAI model to use. E.g. gpt-3.5-turbo.
|
||||
`,
|
||||
)
|
||||
.option(
|
||||
"--embedding-model <embeddingModel>",
|
||||
`
|
||||
|
||||
Select OpenAI embedding model to use. E.g. text-embedding-ada-002.
|
||||
`,
|
||||
)
|
||||
.option(
|
||||
@@ -201,9 +187,7 @@ if (process.argv.includes("--tools")) {
|
||||
if (process.argv.includes("--no-llama-parse")) {
|
||||
program.useLlamaParse = false;
|
||||
}
|
||||
if (process.argv.includes("--ask-models")) {
|
||||
program.askModels = true;
|
||||
}
|
||||
program.askModels = process.argv.includes("--ask-models");
|
||||
if (process.argv.includes("--no-files")) {
|
||||
program.dataSources = [];
|
||||
} else {
|
||||
@@ -290,7 +274,11 @@ async function run(): Promise<void> {
|
||||
}
|
||||
|
||||
const preferences = (conf.get("preferences") || {}) as QuestionArgs;
|
||||
await askQuestions(program as unknown as QuestionArgs, preferences);
|
||||
await askQuestions(
|
||||
program as unknown as QuestionArgs,
|
||||
preferences,
|
||||
program.openAiKey,
|
||||
);
|
||||
|
||||
await createApp({
|
||||
template: program.template,
|
||||
@@ -299,10 +287,8 @@ async function run(): Promise<void> {
|
||||
appPath: resolvedProjectPath,
|
||||
packageManager,
|
||||
frontend: program.frontend,
|
||||
openAiKey: program.openAiKey,
|
||||
modelConfig: program.modelConfig,
|
||||
llamaCloudKey: program.llamaCloudKey,
|
||||
model: program.model,
|
||||
embeddingModel: program.embeddingModel,
|
||||
communityProjectConfig: program.communityProjectConfig,
|
||||
llamapack: program.llamapack,
|
||||
vectorDb: program.vectorDb,
|
||||
|
||||
+26
-23
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "create-llama",
|
||||
"version": "0.0.32",
|
||||
"version": "0.1.0",
|
||||
"description": "Create LlamaIndex-powered apps with one command",
|
||||
"keywords": [
|
||||
"rag",
|
||||
"llamaindex",
|
||||
"next.js"
|
||||
],
|
||||
"description": "Create LlamaIndex-powered apps with one command",
|
||||
"repository": {
|
||||
"type": "git",
|
||||
"url": "https://github.com/run-llama/LlamaIndexTS",
|
||||
@@ -20,32 +20,30 @@
|
||||
"dist"
|
||||
],
|
||||
"scripts": {
|
||||
"clean": "rimraf --glob ./dist ./templates/**/__pycache__ ./templates/**/node_modules ./templates/**/poetry.lock",
|
||||
"format": "prettier --ignore-unknown --cache --check .",
|
||||
"format:write": "prettier --ignore-unknown --write .",
|
||||
"dev": "ncc build ./index.ts -w -o dist/",
|
||||
"build": "bash ./scripts/build.sh",
|
||||
"build:ncc": "pnpm run clean && ncc build ./index.ts -o ./dist/ --minify --no-cache --no-source-map-register",
|
||||
"lint": "eslint . --ignore-pattern dist --ignore-pattern e2e/cache",
|
||||
"clean": "rimraf --glob ./dist ./templates/**/__pycache__ ./templates/**/node_modules ./templates/**/poetry.lock",
|
||||
"dev": "ncc build ./index.ts -w -o dist/",
|
||||
"e2e": "playwright test",
|
||||
"format": "prettier --ignore-unknown --cache --check .",
|
||||
"format:write": "prettier --ignore-unknown --write .",
|
||||
"lint": "eslint . --ignore-pattern dist --ignore-pattern e2e/cache",
|
||||
"new-snapshot": "pnpm run build && changeset version --snapshot",
|
||||
"new-version": "pnpm run build && changeset version",
|
||||
"pack-install": "bash ./scripts/pack.sh",
|
||||
"prepare": "husky",
|
||||
"release": "pnpm run build && changeset publish",
|
||||
"new-version": "pnpm run build && changeset version",
|
||||
"release-snapshot": "pnpm run build && changeset publish --tag snapshot",
|
||||
"new-snapshot": "pnpm run build && changeset version --snapshot",
|
||||
"pack-install": "bash ./scripts/pack.sh"
|
||||
"release-snapshot": "pnpm run build && changeset publish --tag snapshot"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@playwright/test": "^1.41.1",
|
||||
"dependencies": {
|
||||
"@types/async-retry": "1.4.2",
|
||||
"@types/ci-info": "2.0.0",
|
||||
"@types/cross-spawn": "6.0.0",
|
||||
"@types/fs-extra": "11.0.4",
|
||||
"@types/node": "^20.11.7",
|
||||
"@types/prompts": "2.0.1",
|
||||
"@types/tar": "6.1.5",
|
||||
"@types/validate-npm-package-name": "3.0.0",
|
||||
"@types/fs-extra": "11.0.4",
|
||||
"@vercel/ncc": "0.38.1",
|
||||
"async-retry": "1.3.1",
|
||||
"async-sema": "3.0.1",
|
||||
"ci-info": "github:watson/ci-info#f43f6a1cefff47fb361c88cf4b943fdbcaafe540",
|
||||
@@ -53,29 +51,34 @@
|
||||
"conf": "10.2.0",
|
||||
"cross-spawn": "7.0.3",
|
||||
"fast-glob": "3.3.1",
|
||||
"fs-extra": "11.2.0",
|
||||
"got": "10.7.0",
|
||||
"ollama": "^0.5.0",
|
||||
"ora": "^8.0.1",
|
||||
"picocolors": "1.0.0",
|
||||
"prompts": "2.1.0",
|
||||
"rimraf": "^5.0.5",
|
||||
"smol-toml": "^1.1.4",
|
||||
"tar": "6.1.15",
|
||||
"terminal-link": "^3.0.0",
|
||||
"update-check": "1.5.4",
|
||||
"validate-npm-package-name": "3.0.0",
|
||||
"wait-port": "^1.1.0",
|
||||
"yaml": "2.4.1"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@changesets/cli": "^2.27.1",
|
||||
"@playwright/test": "^1.41.1",
|
||||
"@vercel/ncc": "0.38.1",
|
||||
"eslint": "^8.56.0",
|
||||
"eslint-config-prettier": "^8.10.0",
|
||||
"husky": "^9.0.10",
|
||||
"prettier": "^3.2.5",
|
||||
"prettier-plugin-organize-imports": "^3.2.4",
|
||||
"rimraf": "^5.0.5",
|
||||
"typescript": "^5.3.3",
|
||||
"eslint-config-prettier": "^8.10.0",
|
||||
"ora": "^8.0.1",
|
||||
"fs-extra": "11.2.0",
|
||||
"yaml": "2.4.1"
|
||||
"wait-port": "^1.1.0"
|
||||
},
|
||||
"packageManager": "pnpm@9.0.5",
|
||||
"engines": {
|
||||
"node": ">=16.14.0"
|
||||
},
|
||||
"packageManager": "pnpm@8.15.1"
|
||||
}
|
||||
}
|
||||
|
||||
Generated
+2412
-1915
File diff suppressed because it is too large
Load Diff
+32
-137
@@ -1,8 +1,6 @@
|
||||
import { execSync } from "child_process";
|
||||
import ciInfo from "ci-info";
|
||||
import fs from "fs";
|
||||
import got from "got";
|
||||
import ora from "ora";
|
||||
import path from "path";
|
||||
import { blue, green, red } from "picocolors";
|
||||
import prompts from "prompts";
|
||||
@@ -16,11 +14,10 @@ import { COMMUNITY_OWNER, COMMUNITY_REPO } from "./helpers/constant";
|
||||
import { EXAMPLE_FILE } from "./helpers/datasources";
|
||||
import { templatesDir } from "./helpers/dir";
|
||||
import { getAvailableLlamapackOptions } from "./helpers/llama-pack";
|
||||
import { askModelConfig, isModelConfigured } from "./helpers/providers";
|
||||
import { getProjectOptions } from "./helpers/repo";
|
||||
import { supportedTools, toolsRequireConfig } from "./helpers/tools";
|
||||
|
||||
const OPENAI_API_URL = "https://api.openai.com/v1";
|
||||
|
||||
export type QuestionArgs = Omit<
|
||||
InstallAppArgs,
|
||||
"appPath" | "packageManager"
|
||||
@@ -67,16 +64,13 @@ if ($dialogResult -eq [System.Windows.Forms.DialogResult]::OK)
|
||||
}
|
||||
`;
|
||||
|
||||
const defaults: QuestionArgs = {
|
||||
const defaults: Omit<QuestionArgs, "modelConfig"> = {
|
||||
template: "streaming",
|
||||
framework: "nextjs",
|
||||
ui: "shadcn",
|
||||
frontend: false,
|
||||
openAiKey: "",
|
||||
llamaCloudKey: "",
|
||||
useLlamaParse: false,
|
||||
model: "gpt-4-turbo",
|
||||
embeddingModel: "text-embedding-3-large",
|
||||
communityProjectConfig: undefined,
|
||||
llamapack: "",
|
||||
postInstallAction: "dependencies",
|
||||
@@ -84,7 +78,7 @@ const defaults: QuestionArgs = {
|
||||
tools: [],
|
||||
};
|
||||
|
||||
const handlers = {
|
||||
export const questionHandlers = {
|
||||
onCancel: () => {
|
||||
console.error("Exiting.");
|
||||
process.exit(1);
|
||||
@@ -232,63 +226,15 @@ export const onPromptState = (state: any) => {
|
||||
}
|
||||
};
|
||||
|
||||
const getAvailableModelChoices = async (
|
||||
selectEmbedding: boolean,
|
||||
apiKey?: string,
|
||||
) => {
|
||||
const isLLMModel = (modelId: string) => {
|
||||
return modelId.startsWith("gpt");
|
||||
};
|
||||
|
||||
const isEmbeddingModel = (modelId: string) => {
|
||||
return modelId.includes("embedding");
|
||||
};
|
||||
|
||||
if (apiKey) {
|
||||
const spinner = ora("Fetching available models").start();
|
||||
try {
|
||||
const response = await got(`${OPENAI_API_URL}/models`, {
|
||||
headers: {
|
||||
Authorization: "Bearer " + apiKey,
|
||||
},
|
||||
timeout: 5000,
|
||||
responseType: "json",
|
||||
});
|
||||
const data: any = await response.body;
|
||||
spinner.stop();
|
||||
return data.data
|
||||
.filter((model: any) =>
|
||||
selectEmbedding ? isEmbeddingModel(model.id) : isLLMModel(model.id),
|
||||
)
|
||||
.map((el: any) => {
|
||||
return {
|
||||
title: el.id,
|
||||
value: el.id,
|
||||
};
|
||||
});
|
||||
} catch (error) {
|
||||
spinner.stop();
|
||||
if ((error as any).response?.statusCode === 401) {
|
||||
console.log(
|
||||
red(
|
||||
"Invalid OpenAI API key provided! Please provide a valid key and try again!",
|
||||
),
|
||||
);
|
||||
} else {
|
||||
console.log(red("Request failed: " + error));
|
||||
}
|
||||
process.exit(1);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
export const askQuestions = async (
|
||||
program: QuestionArgs,
|
||||
preferences: QuestionArgs,
|
||||
openAiKey?: string,
|
||||
) => {
|
||||
const getPrefOrDefault = <K extends keyof QuestionArgs>(
|
||||
const getPrefOrDefault = <K extends keyof Omit<QuestionArgs, "modelConfig">>(
|
||||
field: K,
|
||||
): QuestionArgs[K] => preferences[field] ?? defaults[field];
|
||||
): Omit<QuestionArgs, "modelConfig">[K] =>
|
||||
preferences[field] ?? defaults[field];
|
||||
|
||||
// Ask for next action after installation
|
||||
async function askPostInstallAction() {
|
||||
@@ -311,8 +257,7 @@ export const askQuestions = async (
|
||||
},
|
||||
];
|
||||
|
||||
const openAiKeyConfigured =
|
||||
program.openAiKey || process.env["OPENAI_API_KEY"];
|
||||
const modelConfigured = isModelConfigured(program.modelConfig);
|
||||
// If using LlamaParse, require LlamaCloud API key
|
||||
const llamaCloudKeyConfigured = program.useLlamaParse
|
||||
? program.llamaCloudKey || process.env["LLAMA_CLOUD_API_KEY"]
|
||||
@@ -321,7 +266,7 @@ export const askQuestions = async (
|
||||
// Can run the app if all tools do not require configuration
|
||||
if (
|
||||
!hasVectorDb &&
|
||||
openAiKeyConfigured &&
|
||||
modelConfigured &&
|
||||
llamaCloudKeyConfigured &&
|
||||
!toolsRequireConfig(program.tools) &&
|
||||
!program.llamapack
|
||||
@@ -341,7 +286,7 @@ export const askQuestions = async (
|
||||
choices: actionChoices,
|
||||
initial: 1,
|
||||
},
|
||||
handlers,
|
||||
questionHandlers,
|
||||
);
|
||||
|
||||
program.postInstallAction = action;
|
||||
@@ -374,7 +319,7 @@ export const askQuestions = async (
|
||||
],
|
||||
initial: 0,
|
||||
},
|
||||
handlers,
|
||||
questionHandlers,
|
||||
);
|
||||
program.template = template;
|
||||
preferences.template = template;
|
||||
@@ -397,7 +342,7 @@ export const askQuestions = async (
|
||||
})),
|
||||
initial: 0,
|
||||
},
|
||||
handlers,
|
||||
questionHandlers,
|
||||
);
|
||||
const projectConfig = JSON.parse(communityProjectConfig);
|
||||
program.communityProjectConfig = projectConfig;
|
||||
@@ -418,7 +363,7 @@ export const askQuestions = async (
|
||||
})),
|
||||
initial: 0,
|
||||
},
|
||||
handlers,
|
||||
questionHandlers,
|
||||
);
|
||||
program.llamapack = llamapack;
|
||||
preferences.llamapack = llamapack;
|
||||
@@ -444,7 +389,7 @@ export const askQuestions = async (
|
||||
choices,
|
||||
initial: 0,
|
||||
},
|
||||
handlers,
|
||||
questionHandlers,
|
||||
);
|
||||
program.framework = framework;
|
||||
preferences.framework = framework;
|
||||
@@ -504,7 +449,7 @@ export const askQuestions = async (
|
||||
],
|
||||
initial: 0,
|
||||
},
|
||||
handlers,
|
||||
questionHandlers,
|
||||
);
|
||||
|
||||
program.observability = observability;
|
||||
@@ -512,67 +457,13 @@ export const askQuestions = async (
|
||||
}
|
||||
}
|
||||
|
||||
if (!program.openAiKey) {
|
||||
const { key } = await prompts(
|
||||
{
|
||||
type: "text",
|
||||
name: "key",
|
||||
message: program.askModels
|
||||
? "Please provide your OpenAI API key (or leave blank to reuse OPENAI_API_KEY env variable):"
|
||||
: "Please provide your OpenAI API key (leave blank to skip):",
|
||||
validate: (value: string) => {
|
||||
if (program.askModels && !value) {
|
||||
if (process.env.OPENAI_API_KEY) {
|
||||
return true;
|
||||
}
|
||||
return "OpenAI API key is required";
|
||||
}
|
||||
return true;
|
||||
},
|
||||
},
|
||||
handlers,
|
||||
);
|
||||
|
||||
program.openAiKey = key || process.env.OPENAI_API_KEY;
|
||||
preferences.openAiKey = key || process.env.OPENAI_API_KEY;
|
||||
}
|
||||
|
||||
if (!program.model) {
|
||||
if (ciInfo.isCI || !program.askModels) {
|
||||
program.model = defaults.model;
|
||||
} else {
|
||||
const { model } = await prompts(
|
||||
{
|
||||
type: "select",
|
||||
name: "model",
|
||||
message: "Which LLM model would you like to use?",
|
||||
choices: await getAvailableModelChoices(false, program.openAiKey),
|
||||
initial: 0,
|
||||
},
|
||||
handlers,
|
||||
);
|
||||
program.model = model;
|
||||
preferences.model = model;
|
||||
}
|
||||
}
|
||||
|
||||
if (!program.embeddingModel) {
|
||||
if (ciInfo.isCI || !program.askModels) {
|
||||
program.embeddingModel = defaults.embeddingModel;
|
||||
} else {
|
||||
const { embeddingModel } = await prompts(
|
||||
{
|
||||
type: "select",
|
||||
name: "embeddingModel",
|
||||
message: "Which embedding model would you like to use?",
|
||||
choices: await getAvailableModelChoices(true, program.openAiKey),
|
||||
initial: 0,
|
||||
},
|
||||
handlers,
|
||||
);
|
||||
program.embeddingModel = embeddingModel;
|
||||
preferences.embeddingModel = embeddingModel;
|
||||
}
|
||||
if (!program.modelConfig) {
|
||||
const modelConfig = await askModelConfig({
|
||||
openAiKey,
|
||||
askModels: program.askModels ?? false,
|
||||
});
|
||||
program.modelConfig = modelConfig;
|
||||
preferences.modelConfig = modelConfig;
|
||||
}
|
||||
|
||||
if (!program.dataSources) {
|
||||
@@ -596,7 +487,7 @@ export const askQuestions = async (
|
||||
),
|
||||
initial: firstQuestion ? 1 : 0,
|
||||
},
|
||||
handlers,
|
||||
questionHandlers,
|
||||
);
|
||||
|
||||
if (selectedSource === "no" || selectedSource === "none") {
|
||||
@@ -642,7 +533,7 @@ export const askQuestions = async (
|
||||
return true;
|
||||
},
|
||||
},
|
||||
handlers,
|
||||
questionHandlers,
|
||||
);
|
||||
|
||||
program.dataSources.push({
|
||||
@@ -687,7 +578,7 @@ export const askQuestions = async (
|
||||
];
|
||||
program.dataSources.push({
|
||||
type: "db",
|
||||
config: await prompts(dbPrompts, handlers),
|
||||
config: await prompts(dbPrompts, questionHandlers),
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -714,7 +605,7 @@ export const askQuestions = async (
|
||||
active: "yes",
|
||||
inactive: "no",
|
||||
},
|
||||
handlers,
|
||||
questionHandlers,
|
||||
);
|
||||
program.useLlamaParse = useLlamaParse;
|
||||
|
||||
@@ -727,7 +618,7 @@ export const askQuestions = async (
|
||||
message:
|
||||
"Please provide your LlamaIndex Cloud API key (leave blank to skip):",
|
||||
},
|
||||
handlers,
|
||||
questionHandlers,
|
||||
);
|
||||
program.llamaCloudKey = llamaCloudKey;
|
||||
}
|
||||
@@ -746,7 +637,7 @@ export const askQuestions = async (
|
||||
choices: getVectorDbChoices(program.framework),
|
||||
initial: 0,
|
||||
},
|
||||
handlers,
|
||||
questionHandlers,
|
||||
);
|
||||
program.vectorDb = vectorDb;
|
||||
preferences.vectorDb = vectorDb;
|
||||
@@ -781,3 +672,7 @@ export const askQuestions = async (
|
||||
|
||||
await askPostInstallAction();
|
||||
};
|
||||
|
||||
export const toChoice = (value: string) => {
|
||||
return { title: value, value };
|
||||
};
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
from app.engine.index import get_index
|
||||
from fastapi import HTTPException
|
||||
|
||||
|
||||
def get_chat_engine():
|
||||
@@ -8,8 +9,11 @@ def get_chat_engine():
|
||||
|
||||
index = get_index()
|
||||
if index is None:
|
||||
raise Exception(
|
||||
"StorageContext is empty - call 'poetry run generate' to generate the storage first"
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=str(
|
||||
"StorageContext is empty - call 'poetry run generate' to generate the storage first"
|
||||
),
|
||||
)
|
||||
|
||||
return index.as_chat_engine(
|
||||
|
||||
@@ -1,10 +1,5 @@
|
||||
import {
|
||||
BaseTool,
|
||||
OpenAIAgent,
|
||||
QueryEngineTool,
|
||||
Settings,
|
||||
ToolFactory,
|
||||
} from "llamaindex";
|
||||
import { BaseTool, OpenAIAgent, QueryEngineTool } from "llamaindex";
|
||||
import { ToolsFactory } from "llamaindex/tools/ToolsFactory";
|
||||
import fs from "node:fs/promises";
|
||||
import path from "node:path";
|
||||
import { getDataSource } from "./index";
|
||||
@@ -33,12 +28,10 @@ export async function createChatEngine() {
|
||||
const config = JSON.parse(
|
||||
await fs.readFile(path.join("config", "tools.json"), "utf8"),
|
||||
);
|
||||
tools = tools.concat(await ToolFactory.createTools(config));
|
||||
tools = tools.concat(await ToolsFactory.createTools(config));
|
||||
} catch {}
|
||||
|
||||
return new OpenAIAgent({
|
||||
tools,
|
||||
llm: Settings.llm,
|
||||
verbose: true,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -9,7 +9,9 @@ export async function createChatEngine() {
|
||||
);
|
||||
}
|
||||
const retriever = index.asRetriever();
|
||||
retriever.similarityTopK = 3;
|
||||
retriever.similarityTopK = process.env.TOP_K
|
||||
? parseInt(process.env.TOP_K)
|
||||
: 3;
|
||||
|
||||
return new ContextChatEngine({
|
||||
chatModel: Settings.llm,
|
||||
|
||||
@@ -27,10 +27,7 @@ def llama_parse_parser():
|
||||
def get_file_documents(config: FileLoaderConfig):
|
||||
from llama_index.core.readers import SimpleDirectoryReader
|
||||
|
||||
reader = SimpleDirectoryReader(
|
||||
config.data_dir,
|
||||
recursive=True,
|
||||
)
|
||||
reader = SimpleDirectoryReader(config.data_dir, recursive=True, filename_as_id=True)
|
||||
if config.use_llama_parse:
|
||||
parser = llama_parse_parser()
|
||||
reader.file_extractor = {".pdf": parser}
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
import os
|
||||
import logging
|
||||
from llama_index.core.storage import StorageContext
|
||||
from llama_index.core.indices import VectorStoreIndex
|
||||
from llama_index.vector_stores.astra_db import AstraDBVectorStore
|
||||
from app.settings import init_settings
|
||||
from app.engine.loaders import get_documents
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def generate_datasource():
|
||||
init_settings()
|
||||
logger.info("Creating new index")
|
||||
documents = get_documents()
|
||||
store = AstraDBVectorStore(
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_ENDPOINT"],
|
||||
collection_name=os.environ["ASTRA_DB_COLLECTION"],
|
||||
embedding_dimension=int(os.environ["EMBEDDING_DIM"]),
|
||||
)
|
||||
storage_context = StorageContext.from_defaults(vector_store=store)
|
||||
VectorStoreIndex.from_documents(
|
||||
documents,
|
||||
storage_context=storage_context,
|
||||
show_progress=True, # this will show you a progress bar as the embeddings are created
|
||||
)
|
||||
logger.info(f"Successfully created embeddings in the AstraDB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_datasource()
|
||||
+2
-11
@@ -1,21 +1,12 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from llama_index.core.indices import VectorStoreIndex
|
||||
from llama_index.vector_stores.astra_db import AstraDBVectorStore
|
||||
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
def get_index():
|
||||
logger.info("Connecting to index from AstraDB...")
|
||||
def get_vector_store():
|
||||
store = AstraDBVectorStore(
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_ENDPOINT"],
|
||||
collection_name=os.environ["ASTRA_DB_COLLECTION"],
|
||||
embedding_dimension=int(os.environ["EMBEDDING_DIM"]),
|
||||
)
|
||||
index = VectorStoreIndex.from_vector_store(store)
|
||||
logger.info("Finished connecting to index from AstraDB.")
|
||||
return index
|
||||
return store
|
||||
@@ -1,39 +0,0 @@
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
import os
|
||||
import logging
|
||||
from llama_index.core.storage import StorageContext
|
||||
from llama_index.core.indices import VectorStoreIndex
|
||||
from llama_index.vector_stores.milvus import MilvusVectorStore
|
||||
from app.settings import init_settings
|
||||
from app.engine.loaders import get_documents
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def generate_datasource():
|
||||
init_settings()
|
||||
logger.info("Creating new index")
|
||||
# load the documents and create the index
|
||||
documents = get_documents()
|
||||
store = MilvusVectorStore(
|
||||
uri=os.environ["MILVUS_ADDRESS"],
|
||||
user=os.getenv("MILVUS_USERNAME"),
|
||||
password=os.getenv("MILVUS_PASSWORD"),
|
||||
collection_name=os.getenv("MILVUS_COLLECTION"),
|
||||
dim=int(os.getenv("EMBEDDING_DIM")),
|
||||
)
|
||||
storage_context = StorageContext.from_defaults(vector_store=store)
|
||||
VectorStoreIndex.from_documents(
|
||||
documents,
|
||||
storage_context=storage_context,
|
||||
show_progress=True, # this will show you a progress bar as the embeddings are created
|
||||
)
|
||||
logger.info(f"Successfully created embeddings in the Milvus")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_datasource()
|
||||
@@ -1,22 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from llama_index.core.indices import VectorStoreIndex
|
||||
from llama_index.vector_stores.milvus import MilvusVectorStore
|
||||
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
def get_index():
|
||||
logger.info("Connecting to index from Milvus...")
|
||||
store = MilvusVectorStore(
|
||||
uri=os.getenv("MILVUS_ADDRESS"),
|
||||
user=os.getenv("MILVUS_USERNAME"),
|
||||
password=os.getenv("MILVUS_PASSWORD"),
|
||||
collection_name=os.getenv("MILVUS_COLLECTION"),
|
||||
dim=int(os.getenv("EMBEDDING_DIM")),
|
||||
)
|
||||
index = VectorStoreIndex.from_vector_store(store)
|
||||
logger.info("Finished connecting to index from Milvus.")
|
||||
return index
|
||||
@@ -0,0 +1,13 @@
|
||||
import os
|
||||
from llama_index.vector_stores.milvus import MilvusVectorStore
|
||||
|
||||
|
||||
def get_vector_store():
|
||||
store = MilvusVectorStore(
|
||||
uri=os.environ["MILVUS_ADDRESS"],
|
||||
user=os.getenv("MILVUS_USERNAME"),
|
||||
password=os.getenv("MILVUS_PASSWORD"),
|
||||
collection_name=os.getenv("MILVUS_COLLECTION"),
|
||||
dim=int(os.getenv("EMBEDDING_DIM")),
|
||||
)
|
||||
return store
|
||||
@@ -1,43 +0,0 @@
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
import os
|
||||
import logging
|
||||
from llama_index.core.storage import StorageContext
|
||||
from llama_index.core.indices import VectorStoreIndex
|
||||
from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch
|
||||
from app.settings import init_settings
|
||||
from app.engine.loaders import get_documents
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def generate_datasource():
|
||||
init_settings()
|
||||
logger.info("Creating new index")
|
||||
# load the documents and create the index
|
||||
documents = get_documents()
|
||||
store = MongoDBAtlasVectorSearch(
|
||||
db_name=os.environ["MONGODB_DATABASE"],
|
||||
collection_name=os.environ["MONGODB_VECTORS"],
|
||||
index_name=os.environ["MONGODB_VECTOR_INDEX"],
|
||||
)
|
||||
storage_context = StorageContext.from_defaults(vector_store=store)
|
||||
VectorStoreIndex.from_documents(
|
||||
documents,
|
||||
storage_context=storage_context,
|
||||
show_progress=True, # this will show you a progress bar as the embeddings are created
|
||||
)
|
||||
logger.info(
|
||||
f"Successfully created embeddings in the MongoDB collection {os.environ['MONGODB_VECTORS']}"
|
||||
)
|
||||
logger.info(
|
||||
"""IMPORTANT: You can't query your index yet because you need to create a vector search index in MongoDB's UI now.
|
||||
See https://github.com/run-llama/mongodb-demo/tree/main?tab=readme-ov-file#create-a-vector-search-index"""
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_datasource()
|
||||
@@ -1,20 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from llama_index.core.indices import VectorStoreIndex
|
||||
from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch
|
||||
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
def get_index():
|
||||
logger.info("Connecting to index from MongoDB...")
|
||||
store = MongoDBAtlasVectorSearch(
|
||||
db_name=os.environ["MONGODB_DATABASE"],
|
||||
collection_name=os.environ["MONGODB_VECTORS"],
|
||||
index_name=os.environ["MONGODB_VECTOR_INDEX"],
|
||||
)
|
||||
index = VectorStoreIndex.from_vector_store(store)
|
||||
logger.info("Finished connecting to index from MongoDB.")
|
||||
return index
|
||||
@@ -0,0 +1,11 @@
|
||||
import os
|
||||
from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch
|
||||
|
||||
|
||||
def get_vector_store():
|
||||
store = MongoDBAtlasVectorSearch(
|
||||
db_name=os.environ["MONGODB_DATABASE"],
|
||||
collection_name=os.environ["MONGODB_VECTORS"],
|
||||
index_name=os.environ["MONGODB_VECTOR_INDEX"],
|
||||
)
|
||||
return store
|
||||
@@ -1 +0,0 @@
|
||||
STORAGE_DIR = "storage" # directory to cache the generated index
|
||||
@@ -1,32 +0,0 @@
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
import logging
|
||||
from llama_index.core.indices import (
|
||||
VectorStoreIndex,
|
||||
)
|
||||
from app.engine.constants import STORAGE_DIR
|
||||
from app.engine.loaders import get_documents
|
||||
from app.settings import init_settings
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def generate_datasource():
|
||||
init_settings()
|
||||
logger.info("Creating new index")
|
||||
# load the documents and create the index
|
||||
documents = get_documents()
|
||||
index = VectorStoreIndex.from_documents(
|
||||
documents,
|
||||
)
|
||||
# store it for later
|
||||
index.storage_context.persist(STORAGE_DIR)
|
||||
logger.info(f"Finished creating new index. Stored in {STORAGE_DIR}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_datasource()
|
||||
@@ -1,20 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from app.engine.constants import STORAGE_DIR
|
||||
from llama_index.core.storage import StorageContext
|
||||
from llama_index.core.indices import load_index_from_storage
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
def get_index():
|
||||
# check if storage already exists
|
||||
if not os.path.exists(STORAGE_DIR):
|
||||
return None
|
||||
# load the existing index
|
||||
logger.info(f"Loading index from {STORAGE_DIR}...")
|
||||
storage_context = StorageContext.from_defaults(persist_dir=STORAGE_DIR)
|
||||
index = load_index_from_storage(storage_context)
|
||||
logger.info(f"Finished loading index from {STORAGE_DIR}")
|
||||
return index
|
||||
@@ -0,0 +1,16 @@
|
||||
import os
|
||||
|
||||
from llama_index.core.vector_stores import SimpleVectorStore
|
||||
from app.constants import STORAGE_DIR
|
||||
|
||||
|
||||
def get_vector_store():
|
||||
if not os.path.exists(STORAGE_DIR):
|
||||
# Vector store hasn't been persisted before, create a new one
|
||||
vector_store = SimpleVectorStore()
|
||||
else:
|
||||
# Vector store has already been persisted before at STORAGE_DIR - load it
|
||||
vector_store = SimpleVectorStore.from_persist_dir(
|
||||
STORAGE_DIR, namespace="default"
|
||||
)
|
||||
return vector_store
|
||||
@@ -1,2 +0,0 @@
|
||||
PGVECTOR_SCHEMA = "public"
|
||||
PGVECTOR_TABLE = "llamaindex_embedding"
|
||||
@@ -1,35 +0,0 @@
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
import logging
|
||||
from llama_index.core.indices import VectorStoreIndex
|
||||
from llama_index.core.storage import StorageContext
|
||||
|
||||
from app.engine.loaders import get_documents
|
||||
from app.settings import init_settings
|
||||
from app.engine.utils import init_pg_vector_store_from_env
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def generate_datasource():
|
||||
init_settings()
|
||||
logger.info("Creating new index")
|
||||
# load the documents and create the index
|
||||
documents = get_documents()
|
||||
store = init_pg_vector_store_from_env()
|
||||
storage_context = StorageContext.from_defaults(vector_store=store)
|
||||
VectorStoreIndex.from_documents(
|
||||
documents,
|
||||
storage_context=storage_context,
|
||||
show_progress=True, # this will show you a progress bar as the embeddings are created
|
||||
)
|
||||
logger.info(
|
||||
f"Successfully created embeddings in the PG vector store, schema={store.schema_name} table={store.table_name}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_datasource()
|
||||
@@ -1,13 +0,0 @@
|
||||
import logging
|
||||
from llama_index.core.indices.vector_store import VectorStoreIndex
|
||||
from app.engine.utils import init_pg_vector_store_from_env
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
def get_index():
|
||||
logger.info("Connecting to index from PGVector...")
|
||||
store = init_pg_vector_store_from_env()
|
||||
index = VectorStoreIndex.from_vector_store(store)
|
||||
logger.info("Finished connecting to index from PGVector.")
|
||||
return index
|
||||
+6
-2
@@ -1,10 +1,13 @@
|
||||
import os
|
||||
from llama_index.vector_stores.postgres import PGVectorStore
|
||||
from urllib.parse import urlparse
|
||||
from app.engine.constants import PGVECTOR_SCHEMA, PGVECTOR_TABLE
|
||||
|
||||
STORAGE_DIR = "storage"
|
||||
PGVECTOR_SCHEMA = "public"
|
||||
PGVECTOR_TABLE = "llamaindex_embedding"
|
||||
|
||||
|
||||
def init_pg_vector_store_from_env():
|
||||
def get_vector_store():
|
||||
original_conn_string = os.environ.get("PG_CONNECTION_STRING")
|
||||
if original_conn_string is None or original_conn_string == "":
|
||||
raise ValueError("PG_CONNECTION_STRING environment variable is not set.")
|
||||
@@ -24,4 +27,5 @@ def init_pg_vector_store_from_env():
|
||||
async_connection_string=async_conn_string,
|
||||
schema_name=PGVECTOR_SCHEMA,
|
||||
table_name=PGVECTOR_TABLE,
|
||||
embed_dim=int(os.environ.get("EMBEDDING_DIM", 768)),
|
||||
)
|
||||
@@ -1,39 +0,0 @@
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
import os
|
||||
import logging
|
||||
from llama_index.core.storage import StorageContext
|
||||
from llama_index.core.indices import VectorStoreIndex
|
||||
from llama_index.vector_stores.pinecone import PineconeVectorStore
|
||||
from app.settings import init_settings
|
||||
from app.engine.loaders import get_documents
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def generate_datasource():
|
||||
init_settings()
|
||||
logger.info("Creating new index")
|
||||
# load the documents and create the index
|
||||
documents = get_documents()
|
||||
store = PineconeVectorStore(
|
||||
api_key=os.environ["PINECONE_API_KEY"],
|
||||
index_name=os.environ["PINECONE_INDEX_NAME"],
|
||||
environment=os.environ["PINECONE_ENVIRONMENT"],
|
||||
)
|
||||
storage_context = StorageContext.from_defaults(vector_store=store)
|
||||
VectorStoreIndex.from_documents(
|
||||
documents,
|
||||
storage_context=storage_context,
|
||||
show_progress=True, # this will show you a progress bar as the embeddings are created
|
||||
)
|
||||
logger.info(
|
||||
f"Successfully created embeddings and save to your Pinecone index {os.environ['PINECONE_INDEX_NAME']}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_datasource()
|
||||
@@ -1,20 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from llama_index.core.indices import VectorStoreIndex
|
||||
from llama_index.vector_stores.pinecone import PineconeVectorStore
|
||||
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
def get_index():
|
||||
logger.info("Connecting to index from Pinecone...")
|
||||
store = PineconeVectorStore(
|
||||
api_key=os.environ["PINECONE_API_KEY"],
|
||||
index_name=os.environ["PINECONE_INDEX_NAME"],
|
||||
environment=os.environ["PINECONE_ENVIRONMENT"],
|
||||
)
|
||||
index = VectorStoreIndex.from_vector_store(store)
|
||||
logger.info("Finished connecting to index from Pinecone.")
|
||||
return index
|
||||
@@ -0,0 +1,11 @@
|
||||
import os
|
||||
from llama_index.vector_stores.pinecone import PineconeVectorStore
|
||||
|
||||
|
||||
def get_vector_store():
|
||||
store = PineconeVectorStore(
|
||||
api_key=os.environ["PINECONE_API_KEY"],
|
||||
index_name=os.environ["PINECONE_INDEX_NAME"],
|
||||
environment=os.environ["PINECONE_ENVIRONMENT"],
|
||||
)
|
||||
return store
|
||||
@@ -1,37 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
from app.engine.loaders import get_documents
|
||||
from app.settings import init_settings
|
||||
from dotenv import load_dotenv
|
||||
from llama_index.core.indices import VectorStoreIndex
|
||||
from llama_index.core.storage import StorageContext
|
||||
from llama_index.vector_stores.qdrant import QdrantVectorStore
|
||||
load_dotenv()
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def generate_datasource():
|
||||
init_settings()
|
||||
logger.info("Creating new index with Qdrant")
|
||||
# load the documents and create the index
|
||||
documents = get_documents()
|
||||
store = QdrantVectorStore(
|
||||
collection_name=os.getenv("QDRANT_COLLECTION"),
|
||||
url=os.getenv("QDRANT_URL"),
|
||||
api_key=os.getenv("QDRANT_API_KEY"),
|
||||
)
|
||||
storage_context = StorageContext.from_defaults(vector_store=store)
|
||||
VectorStoreIndex.from_documents(
|
||||
documents,
|
||||
storage_context=storage_context,
|
||||
show_progress=True, # this will show you a progress bar as the embeddings are created
|
||||
)
|
||||
logger.info(
|
||||
f"Successfully uploaded documents to the {os.getenv('QDRANT_COLLECTION')} collection."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_datasource()
|
||||
@@ -1,20 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from llama_index.core.indices import VectorStoreIndex
|
||||
from llama_index.vector_stores.qdrant import QdrantVectorStore
|
||||
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
def get_index():
|
||||
logger.info("Connecting to Qdrant collection..")
|
||||
store = QdrantVectorStore(
|
||||
collection_name=os.getenv("QDRANT_COLLECTION"),
|
||||
url=os.getenv("QDRANT_URL"),
|
||||
api_key=os.getenv("QDRANT_API_KEY"),
|
||||
)
|
||||
index = VectorStoreIndex.from_vector_store(store)
|
||||
logger.info("Finished connecting to Qdrant collection.")
|
||||
return index
|
||||
@@ -0,0 +1,11 @@
|
||||
import os
|
||||
from llama_index.vector_stores.qdrant import QdrantVectorStore
|
||||
|
||||
|
||||
def get_vector_store():
|
||||
store = QdrantVectorStore(
|
||||
collection_name=os.getenv("QDRANT_COLLECTION"),
|
||||
url=os.getenv("QDRANT_URL"),
|
||||
api_key=os.getenv("QDRANT_API_KEY"),
|
||||
)
|
||||
return store
|
||||
@@ -2,7 +2,7 @@ export const PGVECTOR_COLLECTION = "data";
|
||||
export const PGVECTOR_SCHEMA = "public";
|
||||
export const PGVECTOR_TABLE = "llamaindex_embedding";
|
||||
|
||||
const REQUIRED_ENV_VARS = ["PG_CONNECTION_STRING", "OPENAI_API_KEY"];
|
||||
const REQUIRED_ENV_VARS = ["PG_CONNECTION_STRING"];
|
||||
|
||||
export function checkRequiredEnvVars() {
|
||||
const missingEnvVars = REQUIRED_ENV_VARS.filter((envVar) => {
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
node-linker=hoisted
|
||||
@@ -10,11 +10,11 @@
|
||||
"dev": "concurrently \"tsup index.ts --format esm --dts --watch\" \"nodemon -q dist/index.mjs\""
|
||||
},
|
||||
"dependencies": {
|
||||
"ai": "^2.2.25",
|
||||
"ai": "^3.0.21",
|
||||
"cors": "^2.8.5",
|
||||
"dotenv": "^16.3.1",
|
||||
"express": "^4.18.2",
|
||||
"llamaindex": "latest"
|
||||
"llamaindex": "0.2.10"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/cors": "^2.8.16",
|
||||
@@ -22,12 +22,12 @@
|
||||
"@types/node": "^20.9.5",
|
||||
"concurrently": "^8.2.2",
|
||||
"eslint": "^8.54.0",
|
||||
"eslint-config-prettier": "^8.10.0",
|
||||
"nodemon": "^3.0.1",
|
||||
"tsup": "^8.0.1",
|
||||
"typescript": "^5.3.2",
|
||||
"prettier": "^3.2.5",
|
||||
"prettier-plugin-organize-imports": "^3.2.4",
|
||||
"eslint-config-prettier": "^8.10.0",
|
||||
"ts-node": "^10.9.2"
|
||||
"tsx": "^4.7.2",
|
||||
"tsup": "^8.0.1",
|
||||
"typescript": "^5.3.2"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,7 +57,7 @@ export const chatRequest = async (req: Request, res: Response) => {
|
||||
} catch (error) {
|
||||
console.error("[LlamaIndex]", error);
|
||||
return res.status(500).json({
|
||||
error: (error as Error).message,
|
||||
detail: (error as Error).message,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import { streamToResponse } from "ai";
|
||||
import { Message, StreamData, streamToResponse } from "ai";
|
||||
import { Request, Response } from "express";
|
||||
import { ChatMessage, MessageContent } from "llamaindex";
|
||||
import { ChatMessage, MessageContent, Settings } from "llamaindex";
|
||||
import { createChatEngine } from "./engine/chat";
|
||||
import { LlamaIndexStream } from "./llamaindex-stream";
|
||||
import { appendEventData } from "./stream-helper";
|
||||
|
||||
const convertMessageContent = (
|
||||
textMessage: string,
|
||||
@@ -25,7 +26,7 @@ const convertMessageContent = (
|
||||
|
||||
export const chat = async (req: Request, res: Response) => {
|
||||
try {
|
||||
const { messages, data }: { messages: ChatMessage[]; data: any } = req.body;
|
||||
const { messages, data }: { messages: Message[]; data: any } = req.body;
|
||||
const userMessage = messages.pop();
|
||||
if (!messages || !userMessage || userMessage.role !== "user") {
|
||||
return res.status(400).json({
|
||||
@@ -42,22 +43,38 @@ export const chat = async (req: Request, res: Response) => {
|
||||
data?.imageUrl,
|
||||
);
|
||||
|
||||
// Init Vercel AI StreamData
|
||||
const vercelStreamData = new StreamData();
|
||||
appendEventData(
|
||||
vercelStreamData,
|
||||
`Retrieving context for query: '${userMessage.content}'`,
|
||||
);
|
||||
|
||||
// Setup callback for streaming data before chatting
|
||||
Settings.callbackManager.on("retrieve", (data) => {
|
||||
const { nodes } = data.detail;
|
||||
appendEventData(
|
||||
vercelStreamData,
|
||||
`Retrieved ${nodes.length} sources to use as context for the query`,
|
||||
);
|
||||
});
|
||||
|
||||
// Calling LlamaIndex's ChatEngine to get a streamed response
|
||||
const response = await chatEngine.chat({
|
||||
message: userMessageContent,
|
||||
chatHistory: messages,
|
||||
chatHistory: messages as ChatMessage[],
|
||||
stream: true,
|
||||
});
|
||||
|
||||
// Return a stream, which can be consumed by the Vercel/AI client
|
||||
const { stream, data: streamData } = LlamaIndexStream(response, {
|
||||
const { stream } = LlamaIndexStream(response, vercelStreamData, {
|
||||
parserOptions: {
|
||||
image_url: data?.imageUrl,
|
||||
},
|
||||
});
|
||||
|
||||
// Pipe LlamaIndexStream to response
|
||||
const processedStream = stream.pipeThrough(streamData.stream);
|
||||
const processedStream = stream.pipeThrough(vercelStreamData.stream);
|
||||
return streamToResponse(processedStream, res, {
|
||||
headers: {
|
||||
// response MUST have the `X-Experimental-Stream-Data: 'true'` header
|
||||
@@ -71,7 +88,7 @@ export const chat = async (req: Request, res: Response) => {
|
||||
} catch (error) {
|
||||
console.error("[LlamaIndex]", error);
|
||||
return res.status(500).json({
|
||||
error: (error as Error).message,
|
||||
detail: (error as Error).message,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,19 +1,52 @@
|
||||
import { OpenAI, OpenAIEmbedding, Settings } from "llamaindex";
|
||||
import {
|
||||
Ollama,
|
||||
OllamaEmbedding,
|
||||
OpenAI,
|
||||
OpenAIEmbedding,
|
||||
Settings,
|
||||
} from "llamaindex";
|
||||
|
||||
const CHUNK_SIZE = 512;
|
||||
const CHUNK_OVERLAP = 20;
|
||||
|
||||
export const initSettings = async () => {
|
||||
// HINT: you can delete the initialization code for unused model providers
|
||||
console.log(`Using '${process.env.MODEL_PROVIDER}' model provider`);
|
||||
switch (process.env.MODEL_PROVIDER) {
|
||||
case "ollama":
|
||||
initOllama();
|
||||
break;
|
||||
default:
|
||||
initOpenAI();
|
||||
break;
|
||||
}
|
||||
Settings.chunkSize = CHUNK_SIZE;
|
||||
Settings.chunkOverlap = CHUNK_OVERLAP;
|
||||
};
|
||||
|
||||
function initOpenAI() {
|
||||
Settings.llm = new OpenAI({
|
||||
model: process.env.MODEL ?? "gpt-3.5-turbo",
|
||||
maxTokens: 512,
|
||||
});
|
||||
Settings.chunkSize = CHUNK_SIZE;
|
||||
Settings.chunkOverlap = CHUNK_OVERLAP;
|
||||
Settings.embedModel = new OpenAIEmbedding({
|
||||
model: process.env.EMBEDDING_MODEL,
|
||||
dimensions: process.env.EMBEDDING_DIM
|
||||
? parseInt(process.env.EMBEDDING_DIM)
|
||||
: undefined,
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
function initOllama() {
|
||||
if (!process.env.MODEL || !process.env.EMBEDDING_MODEL) {
|
||||
throw new Error(
|
||||
"Using Ollama as model provider, 'MODEL' and 'EMBEDDING_MODEL' env variables must be set.",
|
||||
);
|
||||
}
|
||||
Settings.llm = new Ollama({
|
||||
model: process.env.MODEL ?? "",
|
||||
});
|
||||
Settings.embedModel = new OllamaEmbedding({
|
||||
model: process.env.EMBEDDING_MODEL ?? "",
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
import {
|
||||
JSONValue,
|
||||
StreamData,
|
||||
createCallbacksTransformer,
|
||||
createStreamDataTransformer,
|
||||
experimental_StreamData,
|
||||
trimStartOfStreamHelper,
|
||||
type AIStreamCallbacksAndOptions,
|
||||
} from "ai";
|
||||
import { Response, StreamingAgentChatResponse } from "llamaindex";
|
||||
import {
|
||||
Metadata,
|
||||
NodeWithScore,
|
||||
Response,
|
||||
StreamingAgentChatResponse,
|
||||
} from "llamaindex";
|
||||
import { appendImageData, appendSourceData } from "./stream-helper";
|
||||
|
||||
type ParserOptions = {
|
||||
image_url?: string;
|
||||
@@ -14,35 +19,30 @@ type ParserOptions = {
|
||||
|
||||
function createParser(
|
||||
res: AsyncIterable<Response>,
|
||||
data: experimental_StreamData,
|
||||
data: StreamData,
|
||||
opts?: ParserOptions,
|
||||
) {
|
||||
const it = res[Symbol.asyncIterator]();
|
||||
const trimStartOfStream = trimStartOfStreamHelper();
|
||||
|
||||
let sourceNodes: NodeWithScore<Metadata>[] | undefined;
|
||||
return new ReadableStream<string>({
|
||||
start() {
|
||||
// if image_url is provided, send it via the data stream
|
||||
if (opts?.image_url) {
|
||||
const message: JSONValue = {
|
||||
type: "image_url",
|
||||
image_url: {
|
||||
url: opts.image_url,
|
||||
},
|
||||
};
|
||||
data.append(message);
|
||||
} else {
|
||||
data.append({}); // send an empty image response for the user's message
|
||||
}
|
||||
appendImageData(data, opts?.image_url);
|
||||
},
|
||||
async pull(controller): Promise<void> {
|
||||
const { value, done } = await it.next();
|
||||
if (done) {
|
||||
appendSourceData(data, sourceNodes);
|
||||
controller.close();
|
||||
data.append({}); // send an empty image response for the assistant's message
|
||||
data.close();
|
||||
return;
|
||||
}
|
||||
|
||||
if (!sourceNodes) {
|
||||
// get source nodes from the first response
|
||||
sourceNodes = value.sourceNodes;
|
||||
}
|
||||
const text = trimStartOfStream(value.response ?? "");
|
||||
if (text) {
|
||||
controller.enqueue(text);
|
||||
@@ -53,12 +53,12 @@ function createParser(
|
||||
|
||||
export function LlamaIndexStream(
|
||||
response: StreamingAgentChatResponse | AsyncIterable<Response>,
|
||||
data: StreamData,
|
||||
opts?: {
|
||||
callbacks?: AIStreamCallbacksAndOptions;
|
||||
parserOptions?: ParserOptions;
|
||||
},
|
||||
): { stream: ReadableStream; data: experimental_StreamData } {
|
||||
const data = new experimental_StreamData();
|
||||
): { stream: ReadableStream; data: StreamData } {
|
||||
const res =
|
||||
response instanceof StreamingAgentChatResponse
|
||||
? response.response
|
||||
@@ -66,7 +66,7 @@ export function LlamaIndexStream(
|
||||
return {
|
||||
stream: createParser(res, data, opts?.parserOptions)
|
||||
.pipeThrough(createCallbacksTransformer(opts?.callbacks))
|
||||
.pipeThrough(createStreamDataTransformer(true)),
|
||||
.pipeThrough(createStreamDataTransformer()),
|
||||
data,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
import { StreamData } from "ai";
|
||||
import { Metadata, NodeWithScore } from "llamaindex";
|
||||
|
||||
export function appendImageData(data: StreamData, imageUrl?: string) {
|
||||
if (!imageUrl) return;
|
||||
data.appendMessageAnnotation({
|
||||
type: "image",
|
||||
data: {
|
||||
url: imageUrl,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
export function appendSourceData(
|
||||
data: StreamData,
|
||||
sourceNodes?: NodeWithScore<Metadata>[],
|
||||
) {
|
||||
if (!sourceNodes?.length) return;
|
||||
data.appendMessageAnnotation({
|
||||
type: "sources",
|
||||
data: {
|
||||
nodes: sourceNodes.map((node) => ({
|
||||
...node.node.toMutableJSON(),
|
||||
id: node.node.id_,
|
||||
score: node.score ?? null,
|
||||
})),
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
export function appendEventData(data: StreamData, title?: string) {
|
||||
if (!title) return;
|
||||
data.appendMessageAnnotation({
|
||||
type: "events",
|
||||
data: {
|
||||
title,
|
||||
},
|
||||
});
|
||||
}
|
||||
@@ -9,10 +9,5 @@
|
||||
"paths": {
|
||||
"@/*": ["./*"]
|
||||
}
|
||||
},
|
||||
"ts-node": {
|
||||
"compilerOptions": {
|
||||
"module": "commonjs"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,13 +11,7 @@ poetry install
|
||||
poetry shell
|
||||
```
|
||||
|
||||
By default, we use the OpenAI LLM (though you can customize, see `app/settings.py`). As a result, you need to specify an `OPENAI_API_KEY` in an .env file in this directory.
|
||||
|
||||
Example `.env` file:
|
||||
|
||||
```
|
||||
OPENAI_API_KEY=<openai_api_key>
|
||||
```
|
||||
Then check the parameters that have been pre-configured in the `.env` file in this directory. (E.g. you might need to configure an `OPENAI_API_KEY` if you're using OpenAI as model provider).
|
||||
|
||||
If you are using any tools or data sources, you can update their config files in the `config` folder.
|
||||
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Any, Optional, Dict, Tuple
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from llama_index.core.chat_engine.types import (
|
||||
BaseChatEngine,
|
||||
StreamingAgentChatResponse,
|
||||
)
|
||||
from llama_index.core.schema import NodeWithScore
|
||||
from llama_index.core.llms import ChatMessage, MessageRole
|
||||
from app.engine import get_chat_engine
|
||||
from app.api.routers.vercel_response import VercelStreamResponse
|
||||
from app.api.routers.messaging import EventCallbackHandler
|
||||
from aiostream import stream
|
||||
|
||||
chat_router = r = APIRouter()
|
||||
|
||||
@@ -37,6 +40,7 @@ class _SourceNodes(BaseModel):
|
||||
id: str
|
||||
metadata: Dict[str, Any]
|
||||
score: Optional[float]
|
||||
text: str
|
||||
|
||||
@classmethod
|
||||
def from_source_node(cls, source_node: NodeWithScore):
|
||||
@@ -44,6 +48,7 @@ class _SourceNodes(BaseModel):
|
||||
id=source_node.node.node_id,
|
||||
metadata=source_node.node.metadata,
|
||||
score=source_node.score,
|
||||
text=source_node.node.text, # type: ignore
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -89,15 +94,49 @@ async def chat(
|
||||
):
|
||||
last_message_content, messages = await parse_chat_data(data)
|
||||
|
||||
event_handler = EventCallbackHandler()
|
||||
chat_engine.callback_manager.handlers.append(event_handler) # type: ignore
|
||||
response = await chat_engine.astream_chat(last_message_content, messages)
|
||||
|
||||
async def event_generator():
|
||||
async for token in response.async_response_gen():
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
yield token
|
||||
async def content_generator():
|
||||
# Yield the text response
|
||||
async def _text_generator():
|
||||
async for token in response.async_response_gen():
|
||||
yield VercelStreamResponse.convert_text(token)
|
||||
# the text_generator is the leading stream, once it's finished, also finish the event stream
|
||||
event_handler.is_done = True
|
||||
|
||||
return StreamingResponse(event_generator(), media_type="text/plain")
|
||||
# Yield the events from the event handler
|
||||
async def _event_generator():
|
||||
async for event in event_handler.async_event_gen():
|
||||
yield VercelStreamResponse.convert_data(
|
||||
{
|
||||
"type": "events",
|
||||
"data": {"title": event.get_title()},
|
||||
}
|
||||
)
|
||||
|
||||
combine = stream.merge(_text_generator(), _event_generator())
|
||||
async with combine.stream() as streamer:
|
||||
async for item in streamer:
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
yield item
|
||||
|
||||
# Yield the source nodes
|
||||
yield VercelStreamResponse.convert_data(
|
||||
{
|
||||
"type": "sources",
|
||||
"data": {
|
||||
"nodes": [
|
||||
_SourceNodes.from_source_node(node).dict()
|
||||
for node in response.source_nodes
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return VercelStreamResponse(content=content_generator())
|
||||
|
||||
|
||||
# non-streaming endpoint - delete if not needed
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, Dict, Any, List, Optional
|
||||
|
||||
from llama_index.core.callbacks.base import BaseCallbackHandler
|
||||
from llama_index.core.callbacks.schema import CBEventType
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CallbackEvent(BaseModel):
|
||||
event_type: CBEventType
|
||||
payload: Optional[Dict[str, Any]] = None
|
||||
event_id: str = ""
|
||||
|
||||
def get_title(self) -> str | None:
|
||||
# Return as None for the unhandled event types
|
||||
# to avoid showing them in the UI
|
||||
match self.event_type:
|
||||
case "retrieve":
|
||||
if self.payload:
|
||||
nodes = self.payload.get("nodes")
|
||||
if nodes:
|
||||
return f"Retrieved {len(nodes)} sources to use as context for the query"
|
||||
else:
|
||||
return f"Retrieving context for query: '{self.payload.get('query_str')}'"
|
||||
else:
|
||||
return None
|
||||
case _:
|
||||
return None
|
||||
|
||||
|
||||
class EventCallbackHandler(BaseCallbackHandler):
|
||||
_aqueue: asyncio.Queue
|
||||
is_done: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
"""Initialize the base callback handler."""
|
||||
ignored_events = [
|
||||
CBEventType.CHUNKING,
|
||||
CBEventType.NODE_PARSING,
|
||||
CBEventType.EMBEDDING,
|
||||
CBEventType.LLM,
|
||||
CBEventType.TEMPLATING,
|
||||
]
|
||||
super().__init__(ignored_events, ignored_events)
|
||||
self._aqueue = asyncio.Queue()
|
||||
|
||||
def on_event_start(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
event_id: str = "",
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
event = CallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
|
||||
if event.get_title() is not None:
|
||||
self._aqueue.put_nowait(event)
|
||||
|
||||
def on_event_end(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
event_id: str = "",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
event = CallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
|
||||
if event.get_title() is not None:
|
||||
self._aqueue.put_nowait(event)
|
||||
|
||||
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
||||
"""No-op."""
|
||||
|
||||
def end_trace(
|
||||
self,
|
||||
trace_id: Optional[str] = None,
|
||||
trace_map: Optional[Dict[str, List[str]]] = None,
|
||||
) -> None:
|
||||
"""No-op."""
|
||||
|
||||
async def async_event_gen(self) -> AsyncGenerator[CallbackEvent, None]:
|
||||
while not self._aqueue.empty() or not self.is_done:
|
||||
try:
|
||||
yield await asyncio.wait_for(self._aqueue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
@@ -0,0 +1,29 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
|
||||
class VercelStreamResponse(StreamingResponse):
|
||||
"""
|
||||
Class to convert the response from the chat engine to the streaming format expected by Vercel
|
||||
"""
|
||||
|
||||
TEXT_PREFIX = "0:"
|
||||
DATA_PREFIX = "8:"
|
||||
|
||||
@classmethod
|
||||
def convert_text(cls, token: str):
|
||||
# Escape newlines and double quotes to avoid breaking the stream
|
||||
token = json.dumps(token)
|
||||
return f"{cls.TEXT_PREFIX}{token}\n"
|
||||
|
||||
@classmethod
|
||||
def convert_data(cls, data: dict):
|
||||
data_str = json.dumps(data)
|
||||
return f"{cls.DATA_PREFIX}[{data_str}]\n"
|
||||
|
||||
def __init__(self, content: Any, **kwargs):
|
||||
super().__init__(
|
||||
content=content,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
STORAGE_DIR = "storage" # directory to save the stores to (document store and if used, the `SimpleVectorStore`)
|
||||
@@ -0,0 +1,96 @@
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
import os
|
||||
import logging
|
||||
from llama_index.core.settings import Settings
|
||||
from llama_index.core.ingestion import IngestionPipeline
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core.vector_stores import SimpleVectorStore
|
||||
from llama_index.core.storage.docstore import SimpleDocumentStore
|
||||
from llama_index.core.storage import StorageContext
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from app.constants import STORAGE_DIR
|
||||
from app.settings import init_settings
|
||||
from app.engine.loaders import get_documents
|
||||
from app.engine.vectordb import get_vector_store
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def get_doc_store():
|
||||
if not os.path.exists(STORAGE_DIR):
|
||||
docstore = SimpleDocumentStore()
|
||||
return docstore
|
||||
else:
|
||||
return SimpleDocumentStore.from_persist_dir(STORAGE_DIR)
|
||||
|
||||
|
||||
def run_ingestion_pipeline(docstore, vector_store, documents):
|
||||
# Create ingestion pipeline
|
||||
ingestion_pipeline = IngestionPipeline(
|
||||
transformations=[
|
||||
SentenceSplitter(
|
||||
chunk_size=Settings.chunk_size,
|
||||
chunk_overlap=Settings.chunk_overlap,
|
||||
),
|
||||
Settings.embed_model,
|
||||
],
|
||||
docstore=docstore,
|
||||
docstore_strategy="upserts_and_delete",
|
||||
)
|
||||
|
||||
# llama_index having an typing issue when passing vector_store to IngestionPipeline
|
||||
# so we need to set it manually after initialization
|
||||
ingestion_pipeline.vector_store = vector_store
|
||||
|
||||
# Run the ingestion pipeline and store the results
|
||||
nodes = ingestion_pipeline.run(show_progress=True, documents=documents)
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
def persist_storage(docstore, vector_store, nodes):
|
||||
storage_context = StorageContext.from_defaults(
|
||||
docstore=docstore,
|
||||
vector_store=vector_store,
|
||||
)
|
||||
# SimpleVectorStore does not include index by default
|
||||
# so we need to create the index manually
|
||||
# can be removed if using other vector store
|
||||
if isinstance(vector_store, SimpleVectorStore):
|
||||
VectorStoreIndex(
|
||||
nodes=nodes,
|
||||
storage_context=storage_context,
|
||||
store_nodes_override=True, # Need enable this to store the nodes and index's id
|
||||
)
|
||||
storage_context.persist(STORAGE_DIR)
|
||||
|
||||
|
||||
def generate_datasource():
|
||||
init_settings()
|
||||
logger.info("Generate index for the provided data")
|
||||
|
||||
# Get the stores and documents or create new ones
|
||||
documents = get_documents()
|
||||
docstore = get_doc_store()
|
||||
vector_store = get_vector_store()
|
||||
|
||||
# Run the ingestion pipeline
|
||||
nodes = run_ingestion_pipeline(
|
||||
docstore=docstore,
|
||||
vector_store=vector_store,
|
||||
documents=documents,
|
||||
)
|
||||
|
||||
# Build the index and persist storage
|
||||
persist_storage(docstore, vector_store, nodes)
|
||||
|
||||
logger.info("Finished generating the index")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_datasource()
|
||||
@@ -0,0 +1,27 @@
|
||||
import logging
|
||||
from llama_index.core import load_index_from_storage
|
||||
from llama_index.core.storage import StorageContext
|
||||
from llama_index.core.indices.vector_store import VectorStoreIndex
|
||||
from llama_index.core.vector_stores.simple import SimpleVectorStore
|
||||
from app.constants import STORAGE_DIR
|
||||
from app.engine.vectordb import get_vector_store
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
def get_index():
|
||||
logger.info("Loading the index...")
|
||||
store = get_vector_store()
|
||||
# If the store is a SimpleVectorStore, we need to load the index from the storage
|
||||
if isinstance(store, SimpleVectorStore):
|
||||
index = load_index_from_storage(
|
||||
StorageContext.from_defaults(
|
||||
vector_store=store,
|
||||
persist_dir=STORAGE_DIR,
|
||||
)
|
||||
)
|
||||
else:
|
||||
index = VectorStoreIndex.from_vector_store(store)
|
||||
|
||||
logger.info("Loaded index successfully.")
|
||||
return index
|
||||
@@ -1,41 +1,44 @@
|
||||
import os
|
||||
from typing import Dict
|
||||
from llama_index.core.settings import Settings
|
||||
from llama_index.llms.openai import OpenAI
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
|
||||
|
||||
def llm_config_from_env() -> Dict:
|
||||
from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||
|
||||
model = os.getenv("MODEL")
|
||||
temperature = os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)
|
||||
max_tokens = os.getenv("LLM_MAX_TOKENS")
|
||||
|
||||
config = {
|
||||
"model": model,
|
||||
"temperature": float(temperature),
|
||||
"max_tokens": int(max_tokens) if max_tokens is not None else None,
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def embedding_config_from_env() -> Dict:
|
||||
model = os.getenv("EMBEDDING_MODEL")
|
||||
dimension = os.getenv("EMBEDDING_DIM")
|
||||
|
||||
config = {
|
||||
"model": model,
|
||||
"dimension": int(dimension) if dimension is not None else None,
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def init_settings():
|
||||
llm_configs = llm_config_from_env()
|
||||
embedding_configs = embedding_config_from_env()
|
||||
|
||||
Settings.llm = OpenAI(**llm_configs)
|
||||
Settings.embed_model = OpenAIEmbedding(**embedding_configs)
|
||||
model_provider = os.getenv("MODEL_PROVIDER")
|
||||
if model_provider == "openai":
|
||||
init_openai()
|
||||
elif model_provider == "ollama":
|
||||
init_ollama()
|
||||
else:
|
||||
raise ValueError(f"Invalid model provider: {model_provider}")
|
||||
Settings.chunk_size = int(os.getenv("CHUNK_SIZE", "1024"))
|
||||
Settings.chunk_overlap = int(os.getenv("CHUNK_OVERLAP", "20"))
|
||||
|
||||
|
||||
def init_ollama():
|
||||
from llama_index.llms.ollama import Ollama
|
||||
from llama_index.embeddings.ollama import OllamaEmbedding
|
||||
|
||||
Settings.embed_model = OllamaEmbedding(model_name=os.getenv("EMBEDDING_MODEL"))
|
||||
Settings.llm = Ollama(model=os.getenv("MODEL"))
|
||||
|
||||
|
||||
def init_openai():
|
||||
from llama_index.llms.openai import OpenAI
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||
|
||||
max_tokens = os.getenv("LLM_MAX_TOKENS")
|
||||
config = {
|
||||
"model": os.getenv("MODEL"),
|
||||
"temperature": float(os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)),
|
||||
"max_tokens": int(max_tokens) if max_tokens is not None else None,
|
||||
}
|
||||
Settings.llm = OpenAI(**config)
|
||||
|
||||
dimensions = os.getenv("EMBEDDING_DIM")
|
||||
config = {
|
||||
"model": os.getenv("EMBEDDING_MODEL"),
|
||||
"dimensions": int(dimensions) if dimensions is not None else None,
|
||||
}
|
||||
Settings.embed_model = OpenAIEmbedding(**config)
|
||||
|
||||
@@ -13,9 +13,9 @@ python = "^3.11,<3.12"
|
||||
fastapi = "^0.109.1"
|
||||
uvicorn = { extras = ["standard"], version = "^0.23.2" }
|
||||
python-dotenv = "^1.0.0"
|
||||
aiostream = "^0.5.2"
|
||||
llama-index = "0.10.28"
|
||||
llama-index-core = "0.10.28"
|
||||
llama-index-agent-openai = "0.2.2"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
@@ -1,19 +1,52 @@
|
||||
import { OpenAI, OpenAIEmbedding, Settings } from "llamaindex";
|
||||
import {
|
||||
Ollama,
|
||||
OllamaEmbedding,
|
||||
OpenAI,
|
||||
OpenAIEmbedding,
|
||||
Settings,
|
||||
} from "llamaindex";
|
||||
|
||||
const CHUNK_SIZE = 512;
|
||||
const CHUNK_OVERLAP = 20;
|
||||
|
||||
export const initSettings = async () => {
|
||||
// HINT: you can delete the initialization code for unused model providers
|
||||
console.log(`Using '${process.env.MODEL_PROVIDER}' model provider`);
|
||||
switch (process.env.MODEL_PROVIDER) {
|
||||
case "ollama":
|
||||
initOllama();
|
||||
break;
|
||||
default:
|
||||
initOpenAI();
|
||||
break;
|
||||
}
|
||||
Settings.chunkSize = CHUNK_SIZE;
|
||||
Settings.chunkOverlap = CHUNK_OVERLAP;
|
||||
};
|
||||
|
||||
function initOpenAI() {
|
||||
Settings.llm = new OpenAI({
|
||||
model: process.env.MODEL ?? "gpt-3.5-turbo",
|
||||
maxTokens: 512,
|
||||
});
|
||||
Settings.chunkSize = CHUNK_SIZE;
|
||||
Settings.chunkOverlap = CHUNK_OVERLAP;
|
||||
Settings.embedModel = new OpenAIEmbedding({
|
||||
model: process.env.EMBEDDING_MODEL,
|
||||
dimensions: process.env.EMBEDDING_DIM
|
||||
? parseInt(process.env.EMBEDDING_DIM)
|
||||
: undefined,
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
function initOllama() {
|
||||
if (!process.env.MODEL || !process.env.EMBEDDING_MODEL) {
|
||||
throw new Error(
|
||||
"Using Ollama as model provider, 'MODEL' and 'EMBEDDING_MODEL' env variables must be set.",
|
||||
);
|
||||
}
|
||||
Settings.llm = new Ollama({
|
||||
model: process.env.MODEL ?? "",
|
||||
});
|
||||
Settings.embedModel = new OllamaEmbedding({
|
||||
model: process.env.EMBEDDING_MODEL ?? "",
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
import {
|
||||
JSONValue,
|
||||
StreamData,
|
||||
createCallbacksTransformer,
|
||||
createStreamDataTransformer,
|
||||
experimental_StreamData,
|
||||
trimStartOfStreamHelper,
|
||||
type AIStreamCallbacksAndOptions,
|
||||
} from "ai";
|
||||
import { Response, StreamingAgentChatResponse } from "llamaindex";
|
||||
import {
|
||||
Metadata,
|
||||
NodeWithScore,
|
||||
Response,
|
||||
StreamingAgentChatResponse,
|
||||
} from "llamaindex";
|
||||
import { appendImageData, appendSourceData } from "./stream-helper";
|
||||
|
||||
type ParserOptions = {
|
||||
image_url?: string;
|
||||
@@ -14,35 +19,30 @@ type ParserOptions = {
|
||||
|
||||
function createParser(
|
||||
res: AsyncIterable<Response>,
|
||||
data: experimental_StreamData,
|
||||
data: StreamData,
|
||||
opts?: ParserOptions,
|
||||
) {
|
||||
const it = res[Symbol.asyncIterator]();
|
||||
const trimStartOfStream = trimStartOfStreamHelper();
|
||||
|
||||
let sourceNodes: NodeWithScore<Metadata>[] | undefined;
|
||||
return new ReadableStream<string>({
|
||||
start() {
|
||||
// if image_url is provided, send it via the data stream
|
||||
if (opts?.image_url) {
|
||||
const message: JSONValue = {
|
||||
type: "image_url",
|
||||
image_url: {
|
||||
url: opts.image_url,
|
||||
},
|
||||
};
|
||||
data.append(message);
|
||||
} else {
|
||||
data.append({}); // send an empty image response for the user's message
|
||||
}
|
||||
appendImageData(data, opts?.image_url);
|
||||
},
|
||||
async pull(controller): Promise<void> {
|
||||
const { value, done } = await it.next();
|
||||
if (done) {
|
||||
appendSourceData(data, sourceNodes);
|
||||
controller.close();
|
||||
data.append({}); // send an empty image response for the assistant's message
|
||||
data.close();
|
||||
return;
|
||||
}
|
||||
|
||||
if (!sourceNodes) {
|
||||
// get source nodes from the first response
|
||||
sourceNodes = value.sourceNodes;
|
||||
}
|
||||
const text = trimStartOfStream(value.response ?? "");
|
||||
if (text) {
|
||||
controller.enqueue(text);
|
||||
@@ -53,12 +53,12 @@ function createParser(
|
||||
|
||||
export function LlamaIndexStream(
|
||||
response: StreamingAgentChatResponse | AsyncIterable<Response>,
|
||||
data: StreamData,
|
||||
opts?: {
|
||||
callbacks?: AIStreamCallbacksAndOptions;
|
||||
parserOptions?: ParserOptions;
|
||||
},
|
||||
): { stream: ReadableStream; data: experimental_StreamData } {
|
||||
const data = new experimental_StreamData();
|
||||
): { stream: ReadableStream; data: StreamData } {
|
||||
const res =
|
||||
response instanceof StreamingAgentChatResponse
|
||||
? response.response
|
||||
@@ -66,7 +66,7 @@ export function LlamaIndexStream(
|
||||
return {
|
||||
stream: createParser(res, data, opts?.parserOptions)
|
||||
.pipeThrough(createCallbacksTransformer(opts?.callbacks))
|
||||
.pipeThrough(createStreamDataTransformer(true)),
|
||||
.pipeThrough(createStreamDataTransformer()),
|
||||
data,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import { initObservability } from "@/app/observability";
|
||||
import { StreamingTextResponse } from "ai";
|
||||
import { ChatMessage, MessageContent } from "llamaindex";
|
||||
import { Message, StreamData, StreamingTextResponse } from "ai";
|
||||
import { ChatMessage, MessageContent, Settings } from "llamaindex";
|
||||
import { NextRequest, NextResponse } from "next/server";
|
||||
import { createChatEngine } from "./engine/chat";
|
||||
import { initSettings } from "./engine/settings";
|
||||
import { LlamaIndexStream } from "./llamaindex-stream";
|
||||
import { appendEventData } from "./stream-helper";
|
||||
|
||||
initObservability();
|
||||
initSettings();
|
||||
@@ -34,7 +35,7 @@ const convertMessageContent = (
|
||||
export async function POST(request: NextRequest) {
|
||||
try {
|
||||
const body = await request.json();
|
||||
const { messages, data }: { messages: ChatMessage[]; data: any } = body;
|
||||
const { messages, data }: { messages: Message[]; data: any } = body;
|
||||
const userMessage = messages.pop();
|
||||
if (!messages || !userMessage || userMessage.role !== "user") {
|
||||
return NextResponse.json(
|
||||
@@ -54,27 +55,43 @@ export async function POST(request: NextRequest) {
|
||||
data?.imageUrl,
|
||||
);
|
||||
|
||||
// Init Vercel AI StreamData
|
||||
const vercelStreamData = new StreamData();
|
||||
appendEventData(
|
||||
vercelStreamData,
|
||||
`Retrieving context for query: '${userMessage.content}'`,
|
||||
);
|
||||
|
||||
// Setup callback for streaming data before chatting
|
||||
Settings.callbackManager.on("retrieve", (data) => {
|
||||
const { nodes } = data.detail;
|
||||
appendEventData(
|
||||
vercelStreamData,
|
||||
`Retrieved ${nodes.length} sources to use as context for the query`,
|
||||
);
|
||||
});
|
||||
|
||||
// Calling LlamaIndex's ChatEngine to get a streamed response
|
||||
const response = await chatEngine.chat({
|
||||
message: userMessageContent,
|
||||
chatHistory: messages,
|
||||
chatHistory: messages as ChatMessage[],
|
||||
stream: true,
|
||||
});
|
||||
|
||||
// Transform LlamaIndex stream to Vercel/AI format
|
||||
const { stream, data: streamData } = LlamaIndexStream(response, {
|
||||
const { stream } = LlamaIndexStream(response, vercelStreamData, {
|
||||
parserOptions: {
|
||||
image_url: data?.imageUrl,
|
||||
},
|
||||
});
|
||||
|
||||
// Return a StreamingTextResponse, which can be consumed by the Vercel/AI client
|
||||
return new StreamingTextResponse(stream, {}, streamData);
|
||||
return new StreamingTextResponse(stream, {}, vercelStreamData);
|
||||
} catch (error) {
|
||||
console.error("[LlamaIndex]", error);
|
||||
return NextResponse.json(
|
||||
{
|
||||
error: (error as Error).message,
|
||||
detail: (error as Error).message,
|
||||
},
|
||||
{
|
||||
status: 500,
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
import { StreamData } from "ai";
|
||||
import { Metadata, NodeWithScore } from "llamaindex";
|
||||
|
||||
export function appendImageData(data: StreamData, imageUrl?: string) {
|
||||
if (!imageUrl) return;
|
||||
data.appendMessageAnnotation({
|
||||
type: "image",
|
||||
data: {
|
||||
url: imageUrl,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
export function appendSourceData(
|
||||
data: StreamData,
|
||||
sourceNodes?: NodeWithScore<Metadata>[],
|
||||
) {
|
||||
if (!sourceNodes?.length) return;
|
||||
data.appendMessageAnnotation({
|
||||
type: "sources",
|
||||
data: {
|
||||
nodes: sourceNodes.map((node) => ({
|
||||
...node.node.toMutableJSON(),
|
||||
id: node.node.id_,
|
||||
score: node.score ?? null,
|
||||
})),
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
export function appendEventData(data: StreamData, title?: string) {
|
||||
if (!title) return;
|
||||
data.appendMessageAnnotation({
|
||||
type: "events",
|
||||
data: {
|
||||
title,
|
||||
},
|
||||
});
|
||||
}
|
||||
@@ -1,8 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import { useChat } from "ai/react";
|
||||
import { useMemo } from "react";
|
||||
import { insertDataIntoMessages } from "./transform";
|
||||
import { ChatInput, ChatMessages } from "./ui/chat";
|
||||
|
||||
export default function ChatSection() {
|
||||
@@ -14,22 +12,21 @@ export default function ChatSection() {
|
||||
handleInputChange,
|
||||
reload,
|
||||
stop,
|
||||
data,
|
||||
} = useChat({
|
||||
api: process.env.NEXT_PUBLIC_CHAT_API,
|
||||
headers: {
|
||||
"Content-Type": "application/json", // using JSON because of vercel/ai 2.2.26
|
||||
},
|
||||
onError: (error) => {
|
||||
const message = JSON.parse(error.message);
|
||||
alert(message.detail);
|
||||
},
|
||||
});
|
||||
|
||||
const transformedMessages = useMemo(() => {
|
||||
return insertDataIntoMessages(messages, data);
|
||||
}, [messages, data]);
|
||||
|
||||
return (
|
||||
<div className="space-y-4 max-w-5xl w-full">
|
||||
<ChatMessages
|
||||
messages={transformedMessages}
|
||||
messages={messages}
|
||||
isLoading={isLoading}
|
||||
reload={reload}
|
||||
stop={stop}
|
||||
@@ -39,7 +36,7 @@ export default function ChatSection() {
|
||||
handleSubmit={handleSubmit}
|
||||
handleInputChange={handleInputChange}
|
||||
isLoading={isLoading}
|
||||
multiModal={process.env.NEXT_PUBLIC_MODEL === "gpt-4-turbo"}
|
||||
multiModal={true}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
import { JSONValue, Message } from "ai";
|
||||
|
||||
export const isValidMessageData = (rawData: JSONValue | undefined) => {
|
||||
if (!rawData || typeof rawData !== "object") return false;
|
||||
if (Object.keys(rawData).length === 0) return false;
|
||||
return true;
|
||||
};
|
||||
|
||||
export const insertDataIntoMessages = (
|
||||
messages: Message[],
|
||||
data: JSONValue[] | undefined,
|
||||
) => {
|
||||
if (!data) return messages;
|
||||
messages.forEach((message, i) => {
|
||||
const rawData = data[i];
|
||||
if (isValidMessageData(rawData)) message.data = rawData;
|
||||
});
|
||||
return messages;
|
||||
};
|
||||
@@ -0,0 +1,48 @@
|
||||
import { ChevronDown, ChevronRight, Loader2 } from "lucide-react";
|
||||
import { useState } from "react";
|
||||
import { Button } from "../button";
|
||||
import {
|
||||
Collapsible,
|
||||
CollapsibleContent,
|
||||
CollapsibleTrigger,
|
||||
} from "../collapsible";
|
||||
import { EventData } from "./index";
|
||||
|
||||
export function ChatEvents({
|
||||
data,
|
||||
isLoading,
|
||||
}: {
|
||||
data: EventData[];
|
||||
isLoading: boolean;
|
||||
}) {
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
|
||||
const buttonLabel = isOpen ? "Hide events" : "Show events";
|
||||
|
||||
const EventIcon = isOpen ? (
|
||||
<ChevronDown className="h-4 w-4" />
|
||||
) : (
|
||||
<ChevronRight className="h-4 w-4" />
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="border-l-2 border-indigo-400 pl-2">
|
||||
<Collapsible open={isOpen} onOpenChange={setIsOpen}>
|
||||
<CollapsibleTrigger asChild>
|
||||
<Button variant="secondary" className="space-x-2">
|
||||
{isLoading ? <Loader2 className="h-4 w-4 animate-spin" /> : null}
|
||||
<span>{buttonLabel}</span>
|
||||
{EventIcon}
|
||||
</Button>
|
||||
</CollapsibleTrigger>
|
||||
<CollapsibleContent asChild>
|
||||
<div className="mt-4 text-sm space-y-2">
|
||||
{data.map((eventItem, index) => (
|
||||
<div key={index}>{eventItem.title}</div>
|
||||
))}
|
||||
</div>
|
||||
</CollapsibleContent>
|
||||
</Collapsible>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
import Image from "next/image";
|
||||
import { type ImageData } from "./index";
|
||||
|
||||
export function ChatImage({ data }: { data: ImageData }) {
|
||||
return (
|
||||
<div className="rounded-md max-w-[200px] shadow-md">
|
||||
<Image
|
||||
src={data.url}
|
||||
width={0}
|
||||
height={0}
|
||||
sizes="100vw"
|
||||
style={{ width: "100%", height: "auto" }}
|
||||
alt=""
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,51 +1,104 @@
|
||||
import { Check, Copy } from "lucide-react";
|
||||
|
||||
import { JSONValue, Message } from "ai";
|
||||
import Image from "next/image";
|
||||
import { Message } from "ai";
|
||||
import { Fragment } from "react";
|
||||
import { Button } from "../button";
|
||||
import ChatAvatar from "./chat-avatar";
|
||||
import { ChatEvents } from "./chat-events";
|
||||
import { ChatImage } from "./chat-image";
|
||||
import { ChatSources } from "./chat-sources";
|
||||
import {
|
||||
AnnotationData,
|
||||
EventData,
|
||||
ImageData,
|
||||
MessageAnnotation,
|
||||
MessageAnnotationType,
|
||||
SourceData,
|
||||
} from "./index";
|
||||
import Markdown from "./markdown";
|
||||
import { useCopyToClipboard } from "./use-copy-to-clipboard";
|
||||
|
||||
interface ChatMessageImageData {
|
||||
type: "image_url";
|
||||
image_url: {
|
||||
url: string;
|
||||
};
|
||||
type ContentDisplayConfig = {
|
||||
order: number;
|
||||
component: JSX.Element | null;
|
||||
};
|
||||
|
||||
function getAnnotationData<T extends AnnotationData>(
|
||||
annotations: MessageAnnotation[],
|
||||
type: MessageAnnotationType,
|
||||
): T[] {
|
||||
return annotations.filter((a) => a.type === type).map((a) => a.data as T);
|
||||
}
|
||||
|
||||
// This component will parse message data and render the appropriate UI.
|
||||
function ChatMessageData({ messageData }: { messageData: JSONValue }) {
|
||||
const { image_url, type } = messageData as unknown as ChatMessageImageData;
|
||||
if (type === "image_url") {
|
||||
return (
|
||||
<div className="rounded-md max-w-[200px] shadow-md">
|
||||
<Image
|
||||
src={image_url.url}
|
||||
width={0}
|
||||
height={0}
|
||||
sizes="100vw"
|
||||
style={{ width: "100%", height: "auto" }}
|
||||
alt=""
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
return null;
|
||||
function ChatMessageContent({
|
||||
message,
|
||||
isLoading,
|
||||
}: {
|
||||
message: Message;
|
||||
isLoading: boolean;
|
||||
}) {
|
||||
const annotations = message.annotations as MessageAnnotation[] | undefined;
|
||||
if (!annotations?.length) return <Markdown content={message.content} />;
|
||||
|
||||
const imageData = getAnnotationData<ImageData>(
|
||||
annotations,
|
||||
MessageAnnotationType.IMAGE,
|
||||
);
|
||||
const eventData = getAnnotationData<EventData>(
|
||||
annotations,
|
||||
MessageAnnotationType.EVENTS,
|
||||
);
|
||||
const sourceData = getAnnotationData<SourceData>(
|
||||
annotations,
|
||||
MessageAnnotationType.SOURCES,
|
||||
);
|
||||
|
||||
const contents: ContentDisplayConfig[] = [
|
||||
{
|
||||
order: -2,
|
||||
component: imageData[0] ? <ChatImage data={imageData[0]} /> : null,
|
||||
},
|
||||
{
|
||||
order: -1,
|
||||
component:
|
||||
eventData.length > 0 ? (
|
||||
<ChatEvents isLoading={isLoading} data={eventData} />
|
||||
) : null,
|
||||
},
|
||||
{
|
||||
order: 0,
|
||||
component: <Markdown content={message.content} />,
|
||||
},
|
||||
{
|
||||
order: 1,
|
||||
component: sourceData[0] ? <ChatSources data={sourceData[0]} /> : null,
|
||||
},
|
||||
];
|
||||
|
||||
return (
|
||||
<div className="flex-1 gap-4 flex flex-col">
|
||||
{contents
|
||||
.sort((a, b) => a.order - b.order)
|
||||
.map((content, index) => (
|
||||
<Fragment key={index}>{content.component}</Fragment>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default function ChatMessage(chatMessage: Message) {
|
||||
export default function ChatMessage({
|
||||
chatMessage,
|
||||
isLoading,
|
||||
}: {
|
||||
chatMessage: Message;
|
||||
isLoading: boolean;
|
||||
}) {
|
||||
const { isCopied, copyToClipboard } = useCopyToClipboard({ timeout: 2000 });
|
||||
return (
|
||||
<div className="flex items-start gap-4 pr-5 pt-5">
|
||||
<ChatAvatar role={chatMessage.role} />
|
||||
<div className="group flex flex-1 justify-between gap-2">
|
||||
<div className="flex-1 space-y-4">
|
||||
{chatMessage.data && (
|
||||
<ChatMessageData messageData={chatMessage.data} />
|
||||
)}
|
||||
<Markdown content={chatMessage.content} />
|
||||
</div>
|
||||
<ChatMessageContent message={chatMessage} isLoading={isLoading} />
|
||||
<Button
|
||||
onClick={() => copyToClipboard(chatMessage.content)}
|
||||
size="icon"
|
||||
|
||||
@@ -41,7 +41,7 @@ export default function ChatMessages(
|
||||
ref={scrollableChatContainerRef}
|
||||
>
|
||||
{props.messages.map((m) => (
|
||||
<ChatMessage key={m.id} {...m} />
|
||||
<ChatMessage key={m.id} chatMessage={m} isLoading={props.isLoading} />
|
||||
))}
|
||||
{isPending && (
|
||||
<div className="flex justify-center items-center pt-10">
|
||||
|
||||
@@ -0,0 +1,94 @@
|
||||
import { ArrowUpRightSquare, Check, Copy } from "lucide-react";
|
||||
import { useMemo } from "react";
|
||||
import { Button } from "../button";
|
||||
import { HoverCard, HoverCardContent, HoverCardTrigger } from "../hover-card";
|
||||
import { SourceData, SourceNode } from "./index";
|
||||
import { useCopyToClipboard } from "./use-copy-to-clipboard";
|
||||
|
||||
const SCORE_THRESHOLD = 0.5;
|
||||
|
||||
export function ChatSources({ data }: { data: SourceData }) {
|
||||
const sources = useMemo(() => {
|
||||
return (
|
||||
data.nodes
|
||||
?.filter((node) => Object.keys(node.metadata).length > 0)
|
||||
?.filter((node) => (node.score ?? 1) > SCORE_THRESHOLD)
|
||||
.sort((a, b) => (b.score ?? 1) - (a.score ?? 1)) || []
|
||||
);
|
||||
}, [data.nodes]);
|
||||
|
||||
if (sources.length === 0) return null;
|
||||
|
||||
return (
|
||||
<div className="space-x-2 text-sm">
|
||||
<span className="font-semibold">Sources:</span>
|
||||
<div className="inline-flex gap-1 items-center">
|
||||
{sources.map((node: SourceNode, index: number) => (
|
||||
<div key={node.id}>
|
||||
<HoverCard>
|
||||
<HoverCardTrigger>
|
||||
<div className="text-xs w-5 h-5 rounded-full bg-gray-100 mb-2 flex items-center justify-center hover:text-white hover:bg-primary hover:cursor-pointer">
|
||||
{index + 1}
|
||||
</div>
|
||||
</HoverCardTrigger>
|
||||
<HoverCardContent>
|
||||
<NodeInfo node={node} />
|
||||
</HoverCardContent>
|
||||
</HoverCard>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function NodeInfo({ node }: { node: SourceNode }) {
|
||||
const { isCopied, copyToClipboard } = useCopyToClipboard({ timeout: 1000 });
|
||||
|
||||
if (typeof node.metadata["URL"] === "string") {
|
||||
// this is a node generated by the web loader, it contains an external URL
|
||||
// add a link to view this URL
|
||||
return (
|
||||
<a
|
||||
className="space-x-2 flex items-center my-2 hover:text-blue-900"
|
||||
href={node.metadata["URL"]}
|
||||
target="_blank"
|
||||
>
|
||||
<span>{node.metadata["URL"]}</span>
|
||||
<ArrowUpRightSquare className="w-4 h-4" />
|
||||
</a>
|
||||
);
|
||||
}
|
||||
|
||||
if (typeof node.metadata["file_path"] === "string") {
|
||||
// this is a node generated by the file loader, it contains file path
|
||||
// add a button to copy the path to the clipboard
|
||||
const filePath = node.metadata["file_path"];
|
||||
return (
|
||||
<div className="flex items-center px-2 py-1 justify-between my-2">
|
||||
<span>{filePath}</span>
|
||||
<Button
|
||||
onClick={() => copyToClipboard(filePath)}
|
||||
size="icon"
|
||||
variant="ghost"
|
||||
className="h-12 w-12"
|
||||
>
|
||||
{isCopied ? (
|
||||
<Check className="h-4 w-4" />
|
||||
) : (
|
||||
<Copy className="h-4 w-4" />
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// node generated by unknown loader, implement renderer by analyzing logged out metadata
|
||||
console.log("Node metadata", node.metadata);
|
||||
return (
|
||||
<p>
|
||||
Sorry, unknown node type. Please add a new renderer in the NodeInfo
|
||||
component.
|
||||
</p>
|
||||
);
|
||||
}
|
||||
@@ -3,3 +3,36 @@ import ChatMessages from "./chat-messages";
|
||||
|
||||
export { type ChatHandler } from "./chat.interface";
|
||||
export { ChatInput, ChatMessages };
|
||||
|
||||
export enum MessageAnnotationType {
|
||||
IMAGE = "image",
|
||||
SOURCES = "sources",
|
||||
EVENTS = "events",
|
||||
}
|
||||
|
||||
export type ImageData = {
|
||||
url: string;
|
||||
};
|
||||
|
||||
export type SourceNode = {
|
||||
id: string;
|
||||
metadata: Record<string, unknown>;
|
||||
score?: number;
|
||||
text: string;
|
||||
};
|
||||
|
||||
export type SourceData = {
|
||||
nodes: SourceNode[];
|
||||
};
|
||||
|
||||
export type EventData = {
|
||||
title: string;
|
||||
isCollapsed: boolean;
|
||||
};
|
||||
|
||||
export type AnnotationData = ImageData | SourceData | EventData;
|
||||
|
||||
export type MessageAnnotation = {
|
||||
type: MessageAnnotationType;
|
||||
data: AnnotationData;
|
||||
};
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
"use client";
|
||||
|
||||
import * as CollapsiblePrimitive from "@radix-ui/react-collapsible";
|
||||
|
||||
const Collapsible = CollapsiblePrimitive.Root;
|
||||
|
||||
const CollapsibleTrigger = CollapsiblePrimitive.CollapsibleTrigger;
|
||||
|
||||
const CollapsibleContent = CollapsiblePrimitive.CollapsibleContent;
|
||||
|
||||
export { Collapsible, CollapsibleContent, CollapsibleTrigger };
|
||||
@@ -0,0 +1,29 @@
|
||||
"use client";
|
||||
|
||||
import * as HoverCardPrimitive from "@radix-ui/react-hover-card";
|
||||
import * as React from "react";
|
||||
|
||||
import { cn } from "./lib/utils";
|
||||
|
||||
const HoverCard = HoverCardPrimitive.Root;
|
||||
|
||||
const HoverCardTrigger = HoverCardPrimitive.Trigger;
|
||||
|
||||
const HoverCardContent = React.forwardRef<
|
||||
React.ElementRef<typeof HoverCardPrimitive.Content>,
|
||||
React.ComponentPropsWithoutRef<typeof HoverCardPrimitive.Content>
|
||||
>(({ className, align = "center", sideOffset = 4, ...props }, ref) => (
|
||||
<HoverCardPrimitive.Content
|
||||
ref={ref}
|
||||
align={align}
|
||||
sideOffset={sideOffset}
|
||||
className={cn(
|
||||
"z-50 w-64 rounded-md border bg-popover p-4 text-popover-foreground shadow-md outline-none data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2",
|
||||
className,
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
));
|
||||
HoverCardContent.displayName = HoverCardPrimitive.Content.displayName;
|
||||
|
||||
export { HoverCard, HoverCardContent, HoverCardTrigger };
|
||||
@@ -2,6 +2,7 @@
|
||||
"experimental": {
|
||||
"outputFileTracingIncludes": {
|
||||
"/*": ["./cache/**/*"]
|
||||
}
|
||||
},
|
||||
"serverComponentsExternalPackages": ["sharp", "onnxruntime-node"]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,12 +10,14 @@
|
||||
"lint": "next lint"
|
||||
},
|
||||
"dependencies": {
|
||||
"@radix-ui/react-collapsible": "^1.0.3",
|
||||
"@radix-ui/react-hover-card": "^1.0.7",
|
||||
"@radix-ui/react-slot": "^1.0.2",
|
||||
"ai": "^2.2.27",
|
||||
"ai": "^3.0.21",
|
||||
"class-variance-authority": "^0.7.0",
|
||||
"clsx": "^1.2.1",
|
||||
"dotenv": "^16.3.1",
|
||||
"llamaindex": "latest",
|
||||
"llamaindex": "0.2.10",
|
||||
"lucide-react": "^0.294.0",
|
||||
"next": "^14.0.3",
|
||||
"react": "^18.2.0",
|
||||
@@ -33,17 +35,17 @@
|
||||
"@types/node": "^20.10.3",
|
||||
"@types/react": "^18.2.42",
|
||||
"@types/react-dom": "^18.2.17",
|
||||
"@types/react-syntax-highlighter": "^15.5.11",
|
||||
"autoprefixer": "^10.4.16",
|
||||
"cross-env": "^7.0.3",
|
||||
"eslint": "^8.55.0",
|
||||
"eslint-config-next": "^14.0.3",
|
||||
"eslint-config-prettier": "^8.10.0",
|
||||
"postcss": "^8.4.32",
|
||||
"tailwindcss": "^3.3.6",
|
||||
"typescript": "^5.3.2",
|
||||
"@types/react-syntax-highlighter": "^15.5.11",
|
||||
"cross-env": "^7.0.3",
|
||||
"prettier": "^3.2.5",
|
||||
"prettier-plugin-organize-imports": "^3.2.4",
|
||||
"eslint-config-prettier": "^8.10.0",
|
||||
"ts-node": "^10.9.2"
|
||||
"tailwindcss": "^3.3.6",
|
||||
"tsx": "^4.7.2",
|
||||
"typescript": "^5.3.2"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,10 +24,5 @@
|
||||
"forceConsistentCasingInFileNames": true
|
||||
},
|
||||
"include": ["next-env.d.ts", "**/*.ts", "**/*.tsx", ".next/types/**/*.ts"],
|
||||
"exclude": ["node_modules"],
|
||||
"ts-node": {
|
||||
"compilerOptions": {
|
||||
"module": "commonjs"
|
||||
}
|
||||
}
|
||||
"exclude": ["node_modules"]
|
||||
}
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
// webpack config must be a function in NextJS that is used to patch the default webpack config provided by NextJS, see https://nextjs.org/docs/pages/api-reference/next-config-js/webpack
|
||||
export default function webpack(config) {
|
||||
// See https://webpack.js.org/configuration/resolve/#resolvealias
|
||||
config.resolve.alias = {
|
||||
...config.resolve.alias,
|
||||
sharp$: false,
|
||||
"onnxruntime-node$": false,
|
||||
config.resolve.fallback = {
|
||||
aws4: false,
|
||||
};
|
||||
|
||||
// Following lines will fix issues with onnxruntime-node when using pnpm
|
||||
// See: https://github.com/vercel/next.js/issues/43433
|
||||
config.externals.push({
|
||||
"onnxruntime-node": "commonjs onnxruntime-node",
|
||||
sharp: "commonjs sharp",
|
||||
});
|
||||
|
||||
return config;
|
||||
}
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
export default function webpack(config, isServer) {
|
||||
// See https://webpack.js.org/configuration/resolve/#resolvealias
|
||||
config.resolve.alias = {
|
||||
...config.resolve.alias,
|
||||
sharp$: false,
|
||||
"onnxruntime-node$": false,
|
||||
config.resolve.fallback = {
|
||||
aws4: false,
|
||||
};
|
||||
config.module.rules.push({
|
||||
test: /\.node$/,
|
||||
|
||||
Reference in New Issue
Block a user