mirror of
https://github.com/run-llama/create-llama.git
synced 2026-07-02 19:14:28 -04:00
Compare commits
99 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| af6ac9a444 | |||
| 22245ca9fd | |||
| 81b67794ef | |||
| 5c13646e55 | |||
| 43474a51ff | |||
| cf11b233c6 | |||
| fd9fb42ace | |||
| 92798f73dd | |||
| e71d8bd6e2 | |||
| e25e112873 | |||
| 048187cce3 | |||
| 6bd76fbfb1 | |||
| a553d5051e | |||
| b0becaa8dc | |||
| 6a42542642 | |||
| f936a470f3 | |||
| df9cca5a52 | |||
| dc9ee895a7 | |||
| 98ff3c2e77 | |||
| 0900413689 | |||
| 8dc6a2bf5a | |||
| 23b735717d | |||
| bd4714ca8d | |||
| 455ab6862e | |||
| 58e6c150c0 | |||
| e57e9813dd | |||
| 7302880c5f | |||
| 624c721ac4 | |||
| d2c66cf550 | |||
| df96159e88 | |||
| 32fb32ab18 | |||
| 3b57bdcf12 | |||
| a221cfc11f | |||
| d3f92f8a69 | |||
| d1026ea784 | |||
| 791ca7c945 | |||
| 07fcefde5d | |||
| 9ecd061262 | |||
| 344d832d3d | |||
| a0aab03226 | |||
| a8073063c5 | |||
| aeb6fef4da | |||
| 64732f05aa | |||
| 588e0d607b | |||
| f2c3389168 | |||
| 5093b37c05 | |||
| f383f0cbe9 | |||
| b3c969dae5 | |||
| 628e16df7c | |||
| aa69014d04 | |||
| 293557cbb4 | |||
| b46d050fc3 | |||
| 02ed277dd0 | |||
| 48b96ff188 | |||
| 9c9decbb88 | |||
| 0748f2e8d7 | |||
| 3079162806 | |||
| 48c19c6e62 | |||
| d75c08e7d8 | |||
| 8f03f8d4bc | |||
| 19c57d945a | |||
| 9112d0801e | |||
| 93b797c162 | |||
| d53b760fd0 | |||
| a880c7c016 | |||
| 7b116ce7f7 | |||
| d1232fb1d5 | |||
| bedf199236 | |||
| c1510bd3fa | |||
| 69b9ce76bf | |||
| 9ced116e1a | |||
| fae9bcd65a | |||
| 2091fea2b4 | |||
| 563b51d76d | |||
| 88c88bf16d | |||
| cd6ebf7295 | |||
| 50b2ddbbf5 | |||
| 5fe2d519d2 | |||
| 09f1db3b5e | |||
| cb3be7d1d4 | |||
| 5474a1f182 | |||
| 1148ddba53 | |||
| 9e945ed355 | |||
| 6342163df2 | |||
| a42fa53a6b | |||
| 099f626586 | |||
| 956538eeb0 | |||
| 555f6b2905 | |||
| d8bc271a21 | |||
| f29561cde2 | |||
| 442abae8ac | |||
| 0ad2207684 | |||
| bfde30deed | |||
| 96fdb83abf | |||
| b7e0072c9c | |||
| 81bc340dda | |||
| ddf3aef7dc | |||
| 1f5a26f3a8 | |||
| 48188ca3f9 |
@@ -1,5 +0,0 @@
|
||||
---
|
||||
"create-llama": patch
|
||||
---
|
||||
|
||||
Add support E2B code interpreter tool for FastAPI
|
||||
@@ -17,7 +17,7 @@ jobs:
|
||||
matrix:
|
||||
node-version: [18, 20]
|
||||
python-version: ["3.11"]
|
||||
os: [macos-latest, windows-latest]
|
||||
os: [macos-latest, windows-latest, ubuntu-22.04]
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
@@ -62,6 +62,7 @@ jobs:
|
||||
run: pnpm run e2e
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLAMA_CLOUD_API_KEY: ${{ secrets.LLAMA_CLOUD_API_KEY }}
|
||||
working-directory: .
|
||||
|
||||
- uses: actions/upload-artifact@v3
|
||||
|
||||
@@ -46,5 +46,8 @@ e2e/cache
|
||||
# intellij
|
||||
**/.idea
|
||||
|
||||
# Python
|
||||
.mypy_cache/
|
||||
|
||||
# build artifacts
|
||||
create-llama-*.tgz
|
||||
|
||||
+116
@@ -1,5 +1,121 @@
|
||||
# create-llama
|
||||
|
||||
## 0.1.24
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 43474a5: Configure LlamaCloud organization ID for Python
|
||||
- cf11b23: Add Azure code interpreter for Python and TS
|
||||
- fd9fb42: Add Azure OpenAI as model provider
|
||||
- 5c13646: Fix starter questions not working in python backend
|
||||
|
||||
## 0.1.23
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 6bd76fb: Add template for structured extraction
|
||||
|
||||
## 0.1.22
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- b0becaa: Add e2e testing for llamacloud datasource
|
||||
- df9cca5: Upgrade pdf viewer
|
||||
|
||||
## 0.1.21
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- bd4714c: Filter private documents for Typescript (Using MetadataFilters) and update to LlamaIndexTS 0.5.7
|
||||
- 58e6c15: Add using LlamaParse for private file uploader
|
||||
- 455ab68: Display files in sources using LlamaCloud indexes.
|
||||
- 23b7357: Use gpt-4o-mini as default model
|
||||
- 0900413: Add suggestions for next questions.
|
||||
|
||||
## 0.1.20
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 624c721: Update to LlamaIndex 0.10.55
|
||||
|
||||
## 0.1.19
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- df96159: Use Qdrant FastEmbed as local embedding provider
|
||||
- 32fb32a: Support upload document files: pdf, docx, txt
|
||||
|
||||
## 0.1.18
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- d1026ea: support Mistral as llm and embedding
|
||||
- a221cfc: Use LlamaParse for all the file types that it supports (if activated)
|
||||
|
||||
## 0.1.17
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 9ecd061: Add new template for a multi-agents app
|
||||
|
||||
## 0.1.16
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- a0aab03: Add T-System's LLMHUB as a model provider
|
||||
|
||||
## 0.1.15
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 64732f0: Fix the issue of images not showing with the sandbox URL from OpenAI's models
|
||||
- aeb6fef: use llamacloud for chat
|
||||
|
||||
## 0.1.14
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- f2c3389: chore: update to llamaindex 0.4.3
|
||||
- 5093b37: Remove non-working file selectors for Linux
|
||||
|
||||
## 0.1.13
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- b3c969d: Add image generator tool
|
||||
|
||||
## 0.1.12
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- aa69014: Fix NextJS for TS 5.2
|
||||
|
||||
## 0.1.11
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 48b96ff: Add DuckDuckGo search tool
|
||||
- 9c9decb: Reuse function tool instances and improve e2b interpreter tool for Python
|
||||
- 02ed277: Add Groq as a model provider
|
||||
- 0748f2e: Remove hard-coded Gemini supported models
|
||||
|
||||
## 0.1.10
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 9112d08: Add OpenAPI tool for Typescript
|
||||
- 8f03f8d: Add OLLAMA_REQUEST_TIMEOUT variable to config Ollama timeout (Python)
|
||||
- 8f03f8d: Apply nest_asyncio for llama parse
|
||||
|
||||
## 0.1.9
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- a42fa53: Add CSV upload
|
||||
- 563b51d: Fix Vercel streaming (python) to stream data events instantly
|
||||
- d60b3c5: Add E2B code interpreter tool for FastAPI
|
||||
- 956538e: Add OpenAPI action tool for FastAPI
|
||||
|
||||
## 0.1.8
|
||||
|
||||
### Patch Changes
|
||||
|
||||
@@ -1,14 +1,20 @@
|
||||
# Create LlamaIndex App
|
||||
# Create Llama
|
||||
|
||||
The easiest way to get started with [LlamaIndex](https://www.llamaindex.ai/) is by using `create-llama`. This CLI tool enables you to quickly start building a new LlamaIndex application, with everything set up for you.
|
||||
|
||||
## Get started
|
||||
|
||||
Just run
|
||||
|
||||
```bash
|
||||
npx create-llama@latest
|
||||
```
|
||||
|
||||
to get started, or see below for more options. Once your app is generated, run
|
||||
to get started, or watch this video for a demo session:
|
||||
|
||||
https://github.com/user-attachments/assets/dd3edc36-4453-4416-91c2-d24326c6c167
|
||||
|
||||
Once your app is generated, run
|
||||
|
||||
```bash
|
||||
npm run dev
|
||||
@@ -18,16 +24,20 @@ to start the development server. You can then visit [http://localhost:3000](http
|
||||
|
||||
## What you'll get
|
||||
|
||||
- A Next.js-powered front-end using components from [shadcn/ui](https://ui.shadcn.com/). The app is set up as a chat interface that can answer questions about your data (see below)
|
||||
- A Next.js-powered front-end using components from [shadcn/ui](https://ui.shadcn.com/). The app is set up as a chat interface that can answer questions about your data or interact with your agent
|
||||
- Your choice of 3 back-ends:
|
||||
- **Next.js**: if you select this option, you’ll have a full-stack Next.js application that you can deploy to a host like [Vercel](https://vercel.com/) in just a few clicks. This uses [LlamaIndex.TS](https://www.npmjs.com/package/llamaindex), our TypeScript library.
|
||||
- **Express**: if you want a more traditional Node.js application you can generate an Express backend. This also uses LlamaIndex.TS.
|
||||
- **Python FastAPI**: if you select this option, you’ll get a backend powered by the [llama-index python package](https://pypi.org/project/llama-index/), which you can deploy to a service like Render or fly.io.
|
||||
- **Python FastAPI**: if you select this option, you’ll get a backend powered by the [llama-index Python package](https://pypi.org/project/llama-index/), which you can deploy to a service like Render or fly.io.
|
||||
- The back-end has two endpoints (one streaming, the other one non-streaming) that allow you to send the state of your chat and receive additional responses
|
||||
- You add arbitrary data sources to your chat, like local files, websites, or data retrieved from a database.
|
||||
- Turn your chat into an AI agent by adding tools (functions called by the LLM).
|
||||
- The app uses OpenAI by default, so you'll need an OpenAI API key, or you can customize it to use any of the dozens of LLMs we support.
|
||||
|
||||
Here's how it looks like:
|
||||
|
||||
https://github.com/user-attachments/assets/d57af1a1-d99b-4e9c-98d9-4cbd1327eff8
|
||||
|
||||
## Using your data
|
||||
|
||||
You can supply your own data; the app will index it and answer questions. Your generated app will have a folder called `data` (If you're using Express or Python and generate a frontend, it will be `./backend/data`).
|
||||
@@ -54,7 +64,7 @@ Optionally generate a frontend if you've selected the Python or Express back-end
|
||||
|
||||
## Customizing the AI models
|
||||
|
||||
The app will default to OpenAI's `gpt-4-turbo` LLM and `text-embedding-3-large` embedding model.
|
||||
The app will default to OpenAI's `gpt-4o-mini` LLM and `text-embedding-3-large` embedding model.
|
||||
|
||||
If you want to use different OpenAI models, add the `--ask-models` CLI parameter.
|
||||
|
||||
@@ -84,7 +94,7 @@ Need to install the following packages:
|
||||
create-llama@latest
|
||||
Ok to proceed? (y) y
|
||||
✔ What is your project named? … my-app
|
||||
✔ Which template would you like to use? › Chat
|
||||
✔ Which template would you like to use? › Agentic RAG (single agent)
|
||||
✔ Which framework would you like to use? › NextJS
|
||||
✔ Would you like to set up observability? › No
|
||||
✔ Please provide your OpenAI API key (leave blank to skip): …
|
||||
@@ -92,6 +102,7 @@ Ok to proceed? (y) y
|
||||
✔ Would you like to add another data source? › No
|
||||
✔ Would you like to use LlamaParse (improved parser for RAG - requires API key)? … no / yes
|
||||
✔ Would you like to use a vector database? › No, just store the data in the file system
|
||||
✔ Would you like to build an agent using tools? If so, select the tools here, otherwise just press enter › Weather
|
||||
? How would you like to proceed? › - Use arrow-keys. Return to submit.
|
||||
Just generate code (~1 sec)
|
||||
❯ Start in VSCode (~1 sec)
|
||||
|
||||
@@ -151,5 +151,19 @@ export async function createApp({
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
dataSources.some((dataSource) => dataSource.type === "file") &&
|
||||
process.platform === "linux"
|
||||
) {
|
||||
console.log(
|
||||
yellow(
|
||||
`You can add your own data files to ${terminalLink(
|
||||
"data",
|
||||
`file://${root}/data`,
|
||||
)} folder manually.`,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
console.log();
|
||||
}
|
||||
|
||||
+12
-3
@@ -17,13 +17,16 @@ const templateFrameworks: TemplateFramework[] = [
|
||||
"express",
|
||||
"fastapi",
|
||||
];
|
||||
const dataSources: string[] = ["--no-files", "--example-file"];
|
||||
const dataSources: string[] = ["--no-files", "--llamacloud"];
|
||||
const templateUIs: TemplateUI[] = ["shadcn", "html"];
|
||||
const templatePostInstallActions: TemplatePostInstallAction[] = [
|
||||
"none",
|
||||
"runApp",
|
||||
];
|
||||
|
||||
const llamaCloudProjectName = "create-llama";
|
||||
const llamaCloudIndexName = "e2e-test";
|
||||
|
||||
for (const templateType of templateTypes) {
|
||||
for (const templateFramework of templateFrameworks) {
|
||||
for (const dataSource of dataSources) {
|
||||
@@ -31,6 +34,10 @@ for (const templateType of templateTypes) {
|
||||
for (const templatePostInstallAction of templatePostInstallActions) {
|
||||
const appType: AppType =
|
||||
templateFramework === "nextjs" ? "" : "--frontend";
|
||||
const userMessage =
|
||||
dataSource !== "--no-files"
|
||||
? "Physical standard for letters"
|
||||
: "Hello";
|
||||
test.describe(`try create-llama ${templateType} ${templateFramework} ${dataSource} ${templateUI} ${appType} ${templatePostInstallAction}`, async () => {
|
||||
let port: number;
|
||||
let externalPort: number;
|
||||
@@ -55,6 +62,8 @@ for (const templateType of templateTypes) {
|
||||
port,
|
||||
externalPort,
|
||||
templatePostInstallAction,
|
||||
llamaCloudProjectName,
|
||||
llamaCloudIndexName,
|
||||
);
|
||||
name = result.projectName;
|
||||
appProcess = result.appProcess;
|
||||
@@ -75,7 +84,7 @@ for (const templateType of templateTypes) {
|
||||
}) => {
|
||||
test.skip(templatePostInstallAction !== "runApp");
|
||||
await page.goto(`http://localhost:${port}`);
|
||||
await page.fill("form input", "hello");
|
||||
await page.fill("form input", userMessage);
|
||||
const [response] = await Promise.all([
|
||||
page.waitForResponse(
|
||||
(res) => {
|
||||
@@ -106,7 +115,7 @@ for (const templateType of templateTypes) {
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: "Hello",
|
||||
content: userMessage,
|
||||
},
|
||||
],
|
||||
},
|
||||
|
||||
+10
-2
@@ -72,9 +72,13 @@ export async function runCreateLlama(
|
||||
port: number,
|
||||
externalPort: number,
|
||||
postInstallAction: TemplatePostInstallAction,
|
||||
llamaCloudProjectName: string,
|
||||
llamaCloudIndexName: string,
|
||||
): Promise<CreateLlamaResult> {
|
||||
if (!process.env.OPENAI_API_KEY) {
|
||||
throw new Error("Setting OPENAI_API_KEY is mandatory to run tests");
|
||||
if (!process.env.OPENAI_API_KEY || !process.env.LLAMA_CLOUD_API_KEY) {
|
||||
throw new Error(
|
||||
"Setting the OPENAI_API_KEY and LLAMA_CLOUD_API_KEY is mandatory to run tests",
|
||||
);
|
||||
}
|
||||
const name = [
|
||||
templateType,
|
||||
@@ -110,12 +114,16 @@ export async function runCreateLlama(
|
||||
"--no-llama-parse",
|
||||
"--observability",
|
||||
"none",
|
||||
"--llama-cloud-key",
|
||||
process.env.LLAMA_CLOUD_API_KEY,
|
||||
].join(" ");
|
||||
console.log(`running command '${command}' in ${cwd}`);
|
||||
const appProcess = exec(command, {
|
||||
cwd,
|
||||
env: {
|
||||
...process.env,
|
||||
LLAMA_CLOUD_PROJECT_NAME: llamaCloudProjectName,
|
||||
LLAMA_CLOUD_INDEX_NAME: llamaCloudIndexName,
|
||||
},
|
||||
});
|
||||
appProcess.stderr?.on("data", (data) => {
|
||||
|
||||
@@ -5,9 +5,12 @@ import {
|
||||
ModelConfig,
|
||||
TemplateDataSource,
|
||||
TemplateFramework,
|
||||
TemplateType,
|
||||
TemplateVectorDB,
|
||||
} from "./types";
|
||||
|
||||
import { TSYSTEMS_LLMHUB_API_URL } from "./providers/llmhub";
|
||||
|
||||
export type EnvVar = {
|
||||
name?: string;
|
||||
description?: string;
|
||||
@@ -133,6 +136,31 @@ const getVectorDBEnvs = (
|
||||
"Optional API key for authenticating requests to Qdrant.",
|
||||
},
|
||||
];
|
||||
case "llamacloud":
|
||||
return [
|
||||
{
|
||||
name: "LLAMA_CLOUD_INDEX_NAME",
|
||||
description:
|
||||
"The name of the LlamaCloud index to use (part of the LlamaCloud project).",
|
||||
value: "test",
|
||||
},
|
||||
{
|
||||
name: "LLAMA_CLOUD_PROJECT_NAME",
|
||||
description: "The name of the LlamaCloud project.",
|
||||
value: "Default",
|
||||
},
|
||||
{
|
||||
name: "LLAMA_CLOUD_BASE_URL",
|
||||
description:
|
||||
"The base URL for the LlamaCloud API. Only change this for non-production environments",
|
||||
value: "https://api.cloud.llamaindex.ai",
|
||||
},
|
||||
{
|
||||
name: "LLAMA_CLOUD_ORGANIZATION_ID",
|
||||
description:
|
||||
"The organization ID for the LlamaCloud project (uses default organization if not specified - Python only)",
|
||||
},
|
||||
];
|
||||
case "chroma":
|
||||
const envs = [
|
||||
{
|
||||
@@ -185,6 +213,10 @@ const getModelEnvs = (modelConfig: ModelConfig): EnvVar[] => {
|
||||
description: "Dimension of the embedding model to use.",
|
||||
value: modelConfig.dimensions.toString(),
|
||||
},
|
||||
{
|
||||
name: "CONVERSATION_STARTERS",
|
||||
description: "The questions to help users get started (multi-line).",
|
||||
},
|
||||
...(modelConfig.provider === "openai"
|
||||
? [
|
||||
{
|
||||
@@ -211,6 +243,15 @@ const getModelEnvs = (modelConfig: ModelConfig): EnvVar[] => {
|
||||
},
|
||||
]
|
||||
: []),
|
||||
...(modelConfig.provider === "groq"
|
||||
? [
|
||||
{
|
||||
name: "GROQ_API_KEY",
|
||||
description: "The Groq API key to use.",
|
||||
value: modelConfig.apiKey,
|
||||
},
|
||||
]
|
||||
: []),
|
||||
...(modelConfig.provider === "gemini"
|
||||
? [
|
||||
{
|
||||
@@ -229,6 +270,57 @@ const getModelEnvs = (modelConfig: ModelConfig): EnvVar[] => {
|
||||
},
|
||||
]
|
||||
: []),
|
||||
...(modelConfig.provider === "mistral"
|
||||
? [
|
||||
{
|
||||
name: "MISTRAL_API_KEY",
|
||||
description: "The Mistral API key to use.",
|
||||
value: modelConfig.apiKey,
|
||||
},
|
||||
]
|
||||
: []),
|
||||
...(modelConfig.provider === "azure-openai"
|
||||
? [
|
||||
{
|
||||
name: "AZURE_OPENAI_KEY",
|
||||
description: "The Azure OpenAI key to use.",
|
||||
value: modelConfig.apiKey,
|
||||
},
|
||||
{
|
||||
name: "AZURE_OPENAI_ENDPOINT",
|
||||
description: "The Azure OpenAI endpoint to use.",
|
||||
},
|
||||
{
|
||||
name: "AZURE_OPENAI_API_VERSION",
|
||||
description: "The Azure OpenAI API version to use.",
|
||||
},
|
||||
{
|
||||
name: "AZURE_OPENAI_LLM_DEPLOYMENT",
|
||||
description:
|
||||
"The Azure OpenAI deployment to use for LLM deployment.",
|
||||
},
|
||||
{
|
||||
name: "AZURE_OPENAI_EMBEDDING_DEPLOYMENT",
|
||||
description:
|
||||
"The Azure OpenAI deployment to use for embedding deployment.",
|
||||
},
|
||||
]
|
||||
: []),
|
||||
...(modelConfig.provider === "t-systems"
|
||||
? [
|
||||
{
|
||||
name: "T_SYSTEMS_LLMHUB_BASE_URL",
|
||||
description:
|
||||
"The base URL for the T-Systems AI Foundation Model API. Eg: http://localhost:11434",
|
||||
value: TSYSTEMS_LLMHUB_API_URL,
|
||||
},
|
||||
{
|
||||
name: "T_SYSTEMS_LLMHUB_API_KEY",
|
||||
description: "API Key for T-System's AI Foundation Model.",
|
||||
value: modelConfig.apiKey,
|
||||
},
|
||||
]
|
||||
: []),
|
||||
];
|
||||
};
|
||||
|
||||
@@ -276,6 +368,12 @@ const getEngineEnvs = (): EnvVar[] => {
|
||||
"The number of similar embeddings to return when retrieving documents.",
|
||||
value: "3",
|
||||
},
|
||||
{
|
||||
name: "STREAM_TIMEOUT",
|
||||
description:
|
||||
"The time in milliseconds to wait for the stream to return a response.",
|
||||
value: "60000",
|
||||
},
|
||||
];
|
||||
};
|
||||
|
||||
@@ -322,6 +420,36 @@ const getSystemPromptEnv = (tools?: Tool[]): EnvVar => {
|
||||
};
|
||||
};
|
||||
|
||||
const getTemplateEnvs = (template?: TemplateType): EnvVar[] => {
|
||||
if (template === "multiagent") {
|
||||
return [
|
||||
{
|
||||
name: "MESSAGE_QUEUE_PORT",
|
||||
},
|
||||
{
|
||||
name: "CONTROL_PLANE_PORT",
|
||||
},
|
||||
{
|
||||
name: "HUMAN_CONSUMER_PORT",
|
||||
},
|
||||
{
|
||||
name: "AGENT_QUERY_ENGINE_PORT",
|
||||
value: "8003",
|
||||
},
|
||||
{
|
||||
name: "AGENT_QUERY_ENGINE_DESCRIPTION",
|
||||
value: "Query information from the provided data",
|
||||
},
|
||||
{
|
||||
name: "AGENT_DUMMY_PORT",
|
||||
value: "8004",
|
||||
},
|
||||
];
|
||||
} else {
|
||||
return [];
|
||||
}
|
||||
};
|
||||
|
||||
export const createBackendEnvFile = async (
|
||||
root: string,
|
||||
opts: {
|
||||
@@ -330,6 +458,7 @@ export const createBackendEnvFile = async (
|
||||
modelConfig: ModelConfig;
|
||||
framework: TemplateFramework;
|
||||
dataSources?: TemplateDataSource[];
|
||||
template?: TemplateType;
|
||||
port?: number;
|
||||
tools?: Tool[];
|
||||
},
|
||||
@@ -350,6 +479,8 @@ export const createBackendEnvFile = async (
|
||||
...getVectorDBEnvs(opts.vectorDb, opts.framework),
|
||||
...getFrameworkEnvs(opts.framework, opts.port),
|
||||
...getToolEnvs(opts.tools),
|
||||
// Add template environment variables
|
||||
...getTemplateEnvs(opts.template),
|
||||
getSystemPromptEnv(opts.tools),
|
||||
];
|
||||
// Render and write env file
|
||||
|
||||
+57
-28
@@ -8,6 +8,7 @@ import { writeLoadersConfig } from "./datasources";
|
||||
import { createBackendEnvFile, createFrontendEnvFile } from "./env-variables";
|
||||
import { PackageManager } from "./get-pkg-manager";
|
||||
import { installLlamapackProject } from "./llama-pack";
|
||||
import { makeDir } from "./make-dir";
|
||||
import { isHavingPoetryLockFile, tryPoetryRun } from "./poetry";
|
||||
import { installPythonTemplate } from "./python";
|
||||
import { downloadAndExtractRepo } from "./repo";
|
||||
@@ -22,6 +23,31 @@ import {
|
||||
} from "./types";
|
||||
import { installTSTemplate } from "./typescript";
|
||||
|
||||
const checkForGenerateScript = (
|
||||
modelConfig: ModelConfig,
|
||||
vectorDb?: TemplateVectorDB,
|
||||
llamaCloudKey?: string,
|
||||
useLlamaParse?: boolean,
|
||||
) => {
|
||||
const missingSettings = [];
|
||||
|
||||
if (!modelConfig.isConfigured()) {
|
||||
missingSettings.push("your model provider API key");
|
||||
}
|
||||
|
||||
const llamaCloudApiKey = llamaCloudKey ?? process.env["LLAMA_CLOUD_API_KEY"];
|
||||
const isRequiredLlamaCloudKey = useLlamaParse || vectorDb === "llamacloud";
|
||||
if (isRequiredLlamaCloudKey && !llamaCloudApiKey) {
|
||||
missingSettings.push("your LLAMA_CLOUD_API_KEY");
|
||||
}
|
||||
|
||||
if (vectorDb !== "none" && vectorDb !== "llamacloud") {
|
||||
missingSettings.push("your Vector DB environment variables");
|
||||
}
|
||||
|
||||
return missingSettings;
|
||||
};
|
||||
|
||||
// eslint-disable-next-line max-params
|
||||
async function generateContextData(
|
||||
framework: TemplateFramework,
|
||||
@@ -37,12 +63,15 @@ async function generateContextData(
|
||||
? "poetry run generate"
|
||||
: `${packageManager} run generate`,
|
||||
)}`;
|
||||
const modelConfigured = modelConfig.isConfigured();
|
||||
const llamaCloudKeyConfigured = useLlamaParse
|
||||
? llamaCloudKey || process.env["LLAMA_CLOUD_API_KEY"]
|
||||
: true;
|
||||
const hasVectorDb = vectorDb && vectorDb !== "none";
|
||||
if (modelConfigured && llamaCloudKeyConfigured && !hasVectorDb) {
|
||||
|
||||
const missingSettings = checkForGenerateScript(
|
||||
modelConfig,
|
||||
vectorDb,
|
||||
llamaCloudKey,
|
||||
useLlamaParse,
|
||||
);
|
||||
|
||||
if (!missingSettings.length) {
|
||||
// If all the required environment variables are set, run the generate script
|
||||
if (framework === "fastapi") {
|
||||
if (isHavingPoetryLockFile()) {
|
||||
@@ -62,15 +91,8 @@ async function generateContextData(
|
||||
}
|
||||
}
|
||||
|
||||
// generate the message of what to do to run the generate script manually
|
||||
const settings = [];
|
||||
if (!modelConfigured) settings.push("your model provider API key");
|
||||
if (!llamaCloudKeyConfigured) settings.push("your Llama Cloud key");
|
||||
if (hasVectorDb) settings.push("your Vector DB environment variables");
|
||||
const settingsMessage =
|
||||
settings.length > 0 ? `After setting ${settings.join(" and ")}, ` : "";
|
||||
const generateMessage = `run ${runGenerate} to generate the context data.`;
|
||||
console.log(`\n${settingsMessage}${generateMessage}\n\n`);
|
||||
const settingsMessage = `After setting ${missingSettings.join(" and ")}, run ${runGenerate} to generate the context data.`;
|
||||
console.log(`\n${settingsMessage}\n\n`);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,15 +163,22 @@ export const installTemplate = async (
|
||||
// This is a backend, so we need to copy the test data and create the env file.
|
||||
|
||||
// Copy the environment file to the target directory.
|
||||
await createBackendEnvFile(props.root, {
|
||||
modelConfig: props.modelConfig,
|
||||
llamaCloudKey: props.llamaCloudKey,
|
||||
vectorDb: props.vectorDb,
|
||||
framework: props.framework,
|
||||
dataSources: props.dataSources,
|
||||
port: props.externalPort,
|
||||
tools: props.tools,
|
||||
});
|
||||
if (
|
||||
props.template === "streaming" ||
|
||||
props.template === "multiagent" ||
|
||||
props.template === "extractor"
|
||||
) {
|
||||
await createBackendEnvFile(props.root, {
|
||||
modelConfig: props.modelConfig,
|
||||
llamaCloudKey: props.llamaCloudKey,
|
||||
vectorDb: props.vectorDb,
|
||||
framework: props.framework,
|
||||
dataSources: props.dataSources,
|
||||
port: props.externalPort,
|
||||
tools: props.tools,
|
||||
template: props.template,
|
||||
});
|
||||
}
|
||||
|
||||
if (props.dataSources.length > 0) {
|
||||
console.log("\nGenerating context data...\n");
|
||||
@@ -172,10 +201,10 @@ export const installTemplate = async (
|
||||
}
|
||||
}
|
||||
|
||||
// Create tool-output directory
|
||||
if (props.tools && props.tools.length > 0) {
|
||||
await fsExtra.mkdir(path.join(props.root, "tool-output"));
|
||||
}
|
||||
// Create outputs directory
|
||||
await makeDir(path.join(props.root, "output/tools"));
|
||||
await makeDir(path.join(props.root, "output/uploaded"));
|
||||
await makeDir(path.join(props.root, "output/llamacloud"));
|
||||
} else {
|
||||
// this is a frontend for a full-stack app, create .env file with model information
|
||||
await createFrontendEnvFile(props.root, {
|
||||
|
||||
@@ -0,0 +1,118 @@
|
||||
import ciInfo from "ci-info";
|
||||
import prompts from "prompts";
|
||||
import { ModelConfigParams, ModelConfigQuestionsParams } from ".";
|
||||
import { questionHandlers } from "../../questions";
|
||||
|
||||
const ALL_AZURE_OPENAI_CHAT_MODELS: Record<string, { openAIModel: string }> = {
|
||||
"gpt-35-turbo": { openAIModel: "gpt-3.5-turbo" },
|
||||
"gpt-35-turbo-16k": {
|
||||
openAIModel: "gpt-3.5-turbo-16k",
|
||||
},
|
||||
"gpt-4o": { openAIModel: "gpt-4o" },
|
||||
"gpt-4": { openAIModel: "gpt-4" },
|
||||
"gpt-4-32k": { openAIModel: "gpt-4-32k" },
|
||||
"gpt-4-turbo": {
|
||||
openAIModel: "gpt-4-turbo",
|
||||
},
|
||||
"gpt-4-turbo-2024-04-09": {
|
||||
openAIModel: "gpt-4-turbo",
|
||||
},
|
||||
"gpt-4-vision-preview": {
|
||||
openAIModel: "gpt-4-vision-preview",
|
||||
},
|
||||
"gpt-4-1106-preview": {
|
||||
openAIModel: "gpt-4-1106-preview",
|
||||
},
|
||||
"gpt-4o-2024-05-13": {
|
||||
openAIModel: "gpt-4o-2024-05-13",
|
||||
},
|
||||
};
|
||||
|
||||
const ALL_AZURE_OPENAI_EMBEDDING_MODELS: Record<
|
||||
string,
|
||||
{
|
||||
dimensions: number;
|
||||
openAIModel: string;
|
||||
}
|
||||
> = {
|
||||
"text-embedding-ada-002": {
|
||||
dimensions: 1536,
|
||||
openAIModel: "text-embedding-ada-002",
|
||||
},
|
||||
"text-embedding-3-small": {
|
||||
dimensions: 1536,
|
||||
openAIModel: "text-embedding-3-small",
|
||||
},
|
||||
"text-embedding-3-large": {
|
||||
dimensions: 3072,
|
||||
openAIModel: "text-embedding-3-large",
|
||||
},
|
||||
};
|
||||
|
||||
const DEFAULT_MODEL = "gpt-4o";
|
||||
const DEFAULT_EMBEDDING_MODEL = "text-embedding-3-large";
|
||||
|
||||
export async function askAzureQuestions({
|
||||
openAiKey,
|
||||
askModels,
|
||||
}: ModelConfigQuestionsParams): Promise<ModelConfigParams> {
|
||||
const config: ModelConfigParams = {
|
||||
apiKey: openAiKey || process.env.AZURE_OPENAI_KEY,
|
||||
model: DEFAULT_MODEL,
|
||||
embeddingModel: DEFAULT_EMBEDDING_MODEL,
|
||||
dimensions: getDimensions(DEFAULT_EMBEDDING_MODEL),
|
||||
isConfigured(): boolean {
|
||||
// the Azure model provider can't be fully configured as endpoint and deployment names have to be configured with env variables
|
||||
return false;
|
||||
},
|
||||
};
|
||||
|
||||
// use default model values in CI or if user should not be asked
|
||||
const useDefaults = ciInfo.isCI || !askModels;
|
||||
if (!useDefaults) {
|
||||
const { model } = await prompts(
|
||||
{
|
||||
type: "select",
|
||||
name: "model",
|
||||
message: "Which LLM model would you like to use?",
|
||||
choices: getAvailableModelChoices(),
|
||||
initial: 0,
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
config.model = model;
|
||||
|
||||
const { embeddingModel } = await prompts(
|
||||
{
|
||||
type: "select",
|
||||
name: "embeddingModel",
|
||||
message: "Which embedding model would you like to use?",
|
||||
choices: getAvailableEmbeddingModelChoices(),
|
||||
initial: 0,
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
config.embeddingModel = embeddingModel;
|
||||
config.dimensions = getDimensions(embeddingModel);
|
||||
}
|
||||
|
||||
return config;
|
||||
}
|
||||
|
||||
function getAvailableModelChoices() {
|
||||
return Object.keys(ALL_AZURE_OPENAI_CHAT_MODELS).map((key) => ({
|
||||
title: key,
|
||||
value: key,
|
||||
}));
|
||||
}
|
||||
|
||||
function getAvailableEmbeddingModelChoices() {
|
||||
return Object.keys(ALL_AZURE_OPENAI_EMBEDDING_MODELS).map((key) => ({
|
||||
title: key,
|
||||
value: key,
|
||||
}));
|
||||
}
|
||||
|
||||
function getDimensions(modelName: string) {
|
||||
return ALL_AZURE_OPENAI_EMBEDDING_MODELS[modelName].dimensions;
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
import ciInfo from "ci-info";
|
||||
import prompts from "prompts";
|
||||
import { ModelConfigParams } from ".";
|
||||
import { questionHandlers, toChoice } from "../../questions";
|
||||
|
||||
const MODELS = ["llama3-8b", "llama3-70b", "mixtral-8x7b"];
|
||||
const DEFAULT_MODEL = MODELS[0];
|
||||
|
||||
// Use huggingface embedding models for now as Groq doesn't support embedding models
|
||||
enum HuggingFaceEmbeddingModelType {
|
||||
XENOVA_ALL_MINILM_L6_V2 = "all-MiniLM-L6-v2",
|
||||
XENOVA_ALL_MPNET_BASE_V2 = "all-mpnet-base-v2",
|
||||
}
|
||||
type ModelData = {
|
||||
dimensions: number;
|
||||
};
|
||||
const EMBEDDING_MODELS: Record<HuggingFaceEmbeddingModelType, ModelData> = {
|
||||
[HuggingFaceEmbeddingModelType.XENOVA_ALL_MINILM_L6_V2]: {
|
||||
dimensions: 384,
|
||||
},
|
||||
[HuggingFaceEmbeddingModelType.XENOVA_ALL_MPNET_BASE_V2]: {
|
||||
dimensions: 768,
|
||||
},
|
||||
};
|
||||
const DEFAULT_EMBEDDING_MODEL = Object.keys(EMBEDDING_MODELS)[0];
|
||||
const DEFAULT_DIMENSIONS = Object.values(EMBEDDING_MODELS)[0].dimensions;
|
||||
|
||||
type GroqQuestionsParams = {
|
||||
apiKey?: string;
|
||||
askModels: boolean;
|
||||
};
|
||||
|
||||
export async function askGroqQuestions({
|
||||
askModels,
|
||||
apiKey,
|
||||
}: GroqQuestionsParams): Promise<ModelConfigParams> {
|
||||
const config: ModelConfigParams = {
|
||||
apiKey,
|
||||
model: DEFAULT_MODEL,
|
||||
embeddingModel: DEFAULT_EMBEDDING_MODEL,
|
||||
dimensions: DEFAULT_DIMENSIONS,
|
||||
isConfigured(): boolean {
|
||||
if (config.apiKey) {
|
||||
return true;
|
||||
}
|
||||
if (process.env["GROQ_API_KEY"]) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
},
|
||||
};
|
||||
|
||||
if (!config.apiKey) {
|
||||
const { key } = await prompts(
|
||||
{
|
||||
type: "text",
|
||||
name: "key",
|
||||
message:
|
||||
"Please provide your Groq API key (or leave blank to use GROQ_API_KEY env variable):",
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
config.apiKey = key || process.env.GROQ_API_KEY;
|
||||
}
|
||||
|
||||
// use default model values in CI or if user should not be asked
|
||||
const useDefaults = ciInfo.isCI || !askModels;
|
||||
if (!useDefaults) {
|
||||
const { model } = await prompts(
|
||||
{
|
||||
type: "select",
|
||||
name: "model",
|
||||
message: "Which LLM model would you like to use?",
|
||||
choices: MODELS.map(toChoice),
|
||||
initial: 0,
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
config.model = model;
|
||||
|
||||
const { embeddingModel } = await prompts(
|
||||
{
|
||||
type: "select",
|
||||
name: "embeddingModel",
|
||||
message: "Which embedding model would you like to use?",
|
||||
choices: Object.keys(EMBEDDING_MODELS).map(toChoice),
|
||||
initial: 0,
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
config.embeddingModel = embeddingModel;
|
||||
config.dimensions =
|
||||
EMBEDDING_MODELS[
|
||||
embeddingModel as HuggingFaceEmbeddingModelType
|
||||
].dimensions;
|
||||
}
|
||||
|
||||
return config;
|
||||
}
|
||||
+33
-10
@@ -1,9 +1,13 @@
|
||||
import ciInfo from "ci-info";
|
||||
import prompts from "prompts";
|
||||
import { questionHandlers } from "../../questions";
|
||||
import { ModelConfig, ModelProvider } from "../types";
|
||||
import { ModelConfig, ModelProvider, TemplateFramework } from "../types";
|
||||
import { askAnthropicQuestions } from "./anthropic";
|
||||
import { askAzureQuestions } from "./azure";
|
||||
import { askGeminiQuestions } from "./gemini";
|
||||
import { askGroqQuestions } from "./groq";
|
||||
import { askLLMHubQuestions } from "./llmhub";
|
||||
import { askMistralQuestions } from "./mistral";
|
||||
import { askOllamaQuestions } from "./ollama";
|
||||
import { askOpenAIQuestions } from "./openai";
|
||||
|
||||
@@ -12,6 +16,7 @@ const DEFAULT_MODEL_PROVIDER = "openai";
|
||||
export type ModelConfigQuestionsParams = {
|
||||
openAiKey?: string;
|
||||
askModels: boolean;
|
||||
framework?: TemplateFramework;
|
||||
};
|
||||
|
||||
export type ModelConfigParams = Omit<ModelConfig, "provider">;
|
||||
@@ -19,23 +24,29 @@ export type ModelConfigParams = Omit<ModelConfig, "provider">;
|
||||
export async function askModelConfig({
|
||||
askModels,
|
||||
openAiKey,
|
||||
framework,
|
||||
}: ModelConfigQuestionsParams): Promise<ModelConfig> {
|
||||
let modelProvider: ModelProvider = DEFAULT_MODEL_PROVIDER;
|
||||
if (askModels && !ciInfo.isCI) {
|
||||
let choices = [
|
||||
{ title: "OpenAI", value: "openai" },
|
||||
{ title: "Groq", value: "groq" },
|
||||
{ title: "Ollama", value: "ollama" },
|
||||
{ title: "Anthropic", value: "anthropic" },
|
||||
{ title: "Gemini", value: "gemini" },
|
||||
{ title: "Mistral", value: "mistral" },
|
||||
{ title: "AzureOpenAI", value: "azure-openai" },
|
||||
];
|
||||
|
||||
if (framework === "fastapi") {
|
||||
choices.push({ title: "T-Systems", value: "t-systems" });
|
||||
}
|
||||
const { provider } = await prompts(
|
||||
{
|
||||
type: "select",
|
||||
name: "provider",
|
||||
message: "Which model provider would you like to use",
|
||||
choices: [
|
||||
{
|
||||
title: "OpenAI",
|
||||
value: "openai",
|
||||
},
|
||||
{ title: "Ollama", value: "ollama" },
|
||||
{ title: "Anthropic", value: "anthropic" },
|
||||
{ title: "Gemini", value: "gemini" },
|
||||
],
|
||||
choices: choices,
|
||||
initial: 0,
|
||||
},
|
||||
questionHandlers,
|
||||
@@ -48,12 +59,24 @@ export async function askModelConfig({
|
||||
case "ollama":
|
||||
modelConfig = await askOllamaQuestions({ askModels });
|
||||
break;
|
||||
case "groq":
|
||||
modelConfig = await askGroqQuestions({ askModels });
|
||||
break;
|
||||
case "anthropic":
|
||||
modelConfig = await askAnthropicQuestions({ askModels });
|
||||
break;
|
||||
case "gemini":
|
||||
modelConfig = await askGeminiQuestions({ askModels });
|
||||
break;
|
||||
case "mistral":
|
||||
modelConfig = await askMistralQuestions({ askModels });
|
||||
break;
|
||||
case "azure-openai":
|
||||
modelConfig = await askAzureQuestions({ askModels });
|
||||
break;
|
||||
case "t-systems":
|
||||
modelConfig = await askLLMHubQuestions({ askModels });
|
||||
break;
|
||||
default:
|
||||
modelConfig = await askOpenAIQuestions({
|
||||
openAiKey,
|
||||
|
||||
@@ -0,0 +1,169 @@
|
||||
import ciInfo from "ci-info";
|
||||
import got from "got";
|
||||
import ora from "ora";
|
||||
import { red } from "picocolors";
|
||||
import prompts from "prompts";
|
||||
import { ModelConfigParams } from ".";
|
||||
import { questionHandlers } from "../../questions";
|
||||
|
||||
export const TSYSTEMS_LLMHUB_API_URL =
|
||||
"https://llm-server.llmhub.t-systems.net/v2";
|
||||
|
||||
const DEFAULT_MODEL = "gpt-3.5-turbo";
|
||||
const DEFAULT_EMBEDDING_MODEL = "text-embedding-3-large";
|
||||
|
||||
const LLMHUB_MODELS = [
|
||||
"gpt-35-turbo",
|
||||
"gpt-4-32k-1",
|
||||
"gpt-4-32k-canada",
|
||||
"gpt-4-32k-france",
|
||||
"gpt-4-turbo-128k-france",
|
||||
"Llama2-70b-Instruct",
|
||||
"Llama-3-70B-Instruct",
|
||||
"Mixtral-8x7B-Instruct-v0.1",
|
||||
"mistral-large-32k-france",
|
||||
"CodeLlama-2",
|
||||
];
|
||||
const LLMHUB_EMBEDDING_MODELS = [
|
||||
"text-embedding-ada-002",
|
||||
"text-embedding-ada-002-france",
|
||||
"jina-embeddings-v2-base-de",
|
||||
"jina-embeddings-v2-base-code",
|
||||
"text-embedding-bge-m3",
|
||||
];
|
||||
|
||||
type LLMHubQuestionsParams = {
|
||||
apiKey?: string;
|
||||
askModels: boolean;
|
||||
};
|
||||
|
||||
export async function askLLMHubQuestions({
|
||||
askModels,
|
||||
apiKey,
|
||||
}: LLMHubQuestionsParams): Promise<ModelConfigParams> {
|
||||
const config: ModelConfigParams = {
|
||||
apiKey,
|
||||
model: DEFAULT_MODEL,
|
||||
embeddingModel: DEFAULT_EMBEDDING_MODEL,
|
||||
dimensions: getDimensions(DEFAULT_EMBEDDING_MODEL),
|
||||
isConfigured(): boolean {
|
||||
if (config.apiKey) {
|
||||
return true;
|
||||
}
|
||||
if (process.env["T_SYSTEMS_LLMHUB_API_KEY"]) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
},
|
||||
};
|
||||
|
||||
if (!config.apiKey) {
|
||||
const { key } = await prompts(
|
||||
{
|
||||
type: "text",
|
||||
name: "key",
|
||||
message: askModels
|
||||
? "Please provide your LLMHub API key (or leave blank to use T_SYSTEMS_LLMHUB_API_KEY env variable):"
|
||||
: "Please provide your LLMHub API key (leave blank to skip):",
|
||||
validate: (value: string) => {
|
||||
if (askModels && !value) {
|
||||
if (process.env.T_SYSTEMS_LLMHUB_API_KEY) {
|
||||
return true;
|
||||
}
|
||||
return "T_SYSTEMS_LLMHUB_API_KEY env variable is not set - key is required";
|
||||
}
|
||||
return true;
|
||||
},
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
config.apiKey = key || process.env.T_SYSTEMS_LLMHUB_API_KEY;
|
||||
}
|
||||
|
||||
// use default model values in CI or if user should not be asked
|
||||
const useDefaults = ciInfo.isCI || !askModels;
|
||||
if (!useDefaults) {
|
||||
const { model } = await prompts(
|
||||
{
|
||||
type: "select",
|
||||
name: "model",
|
||||
message: "Which LLM model would you like to use?",
|
||||
choices: await getAvailableModelChoices(false, config.apiKey),
|
||||
initial: 0,
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
config.model = model;
|
||||
|
||||
const { embeddingModel } = await prompts(
|
||||
{
|
||||
type: "select",
|
||||
name: "embeddingModel",
|
||||
message: "Which embedding model would you like to use?",
|
||||
choices: await getAvailableModelChoices(true, config.apiKey),
|
||||
initial: 0,
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
config.embeddingModel = embeddingModel;
|
||||
config.dimensions = getDimensions(embeddingModel);
|
||||
}
|
||||
|
||||
return config;
|
||||
}
|
||||
|
||||
async function getAvailableModelChoices(
|
||||
selectEmbedding: boolean,
|
||||
apiKey?: string,
|
||||
) {
|
||||
if (!apiKey) {
|
||||
throw new Error("Need LLMHub key to retrieve model choices");
|
||||
}
|
||||
const isLLMModel = (modelId: string) => {
|
||||
return LLMHUB_MODELS.includes(modelId);
|
||||
};
|
||||
|
||||
const isEmbeddingModel = (modelId: string) => {
|
||||
return LLMHUB_EMBEDDING_MODELS.includes(modelId);
|
||||
};
|
||||
|
||||
const spinner = ora("Fetching available models").start();
|
||||
try {
|
||||
const response = await got(`${TSYSTEMS_LLMHUB_API_URL}/models`, {
|
||||
headers: {
|
||||
Authorization: "Bearer " + apiKey,
|
||||
},
|
||||
timeout: 5000,
|
||||
responseType: "json",
|
||||
});
|
||||
const data: any = await response.body;
|
||||
spinner.stop();
|
||||
return data.data
|
||||
.filter((model: any) =>
|
||||
selectEmbedding ? isEmbeddingModel(model.id) : isLLMModel(model.id),
|
||||
)
|
||||
.map((el: any) => {
|
||||
return {
|
||||
title: el.id,
|
||||
value: el.id,
|
||||
};
|
||||
});
|
||||
} catch (error) {
|
||||
spinner.stop();
|
||||
if ((error as any).response?.statusCode === 401) {
|
||||
console.log(
|
||||
red(
|
||||
"Invalid LLMHub API key provided! Please provide a valid key and try again!",
|
||||
),
|
||||
);
|
||||
} else {
|
||||
console.log(red("Request failed: " + error));
|
||||
}
|
||||
process.exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
function getDimensions(modelName: string) {
|
||||
// Assuming dimensions similar to OpenAI for simplicity. Update if different.
|
||||
return modelName === "text-embedding-004" ? 768 : 1536;
|
||||
}
|
||||
@@ -0,0 +1,86 @@
|
||||
import ciInfo from "ci-info";
|
||||
import prompts from "prompts";
|
||||
import { ModelConfigParams } from ".";
|
||||
import { questionHandlers, toChoice } from "../../questions";
|
||||
|
||||
const MODELS = ["mistral-tiny", "mistral-small", "mistral-medium"];
|
||||
type ModelData = {
|
||||
dimensions: number;
|
||||
};
|
||||
const EMBEDDING_MODELS: Record<string, ModelData> = {
|
||||
"mistral-embed": { dimensions: 1024 },
|
||||
};
|
||||
|
||||
const DEFAULT_MODEL = MODELS[0];
|
||||
const DEFAULT_EMBEDDING_MODEL = Object.keys(EMBEDDING_MODELS)[0];
|
||||
const DEFAULT_DIMENSIONS = Object.values(EMBEDDING_MODELS)[0].dimensions;
|
||||
|
||||
type MistralQuestionsParams = {
|
||||
apiKey?: string;
|
||||
askModels: boolean;
|
||||
};
|
||||
|
||||
export async function askMistralQuestions({
|
||||
askModels,
|
||||
apiKey,
|
||||
}: MistralQuestionsParams): Promise<ModelConfigParams> {
|
||||
const config: ModelConfigParams = {
|
||||
apiKey,
|
||||
model: DEFAULT_MODEL,
|
||||
embeddingModel: DEFAULT_EMBEDDING_MODEL,
|
||||
dimensions: DEFAULT_DIMENSIONS,
|
||||
isConfigured(): boolean {
|
||||
if (config.apiKey) {
|
||||
return true;
|
||||
}
|
||||
if (process.env["MISTRAL_API_KEY"]) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
},
|
||||
};
|
||||
|
||||
if (!config.apiKey) {
|
||||
const { key } = await prompts(
|
||||
{
|
||||
type: "text",
|
||||
name: "key",
|
||||
message:
|
||||
"Please provide your Mistral API key (or leave blank to use MISTRAL_API_KEY env variable):",
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
config.apiKey = key || process.env.MISTRAL_API_KEY;
|
||||
}
|
||||
|
||||
// use default model values in CI or if user should not be asked
|
||||
const useDefaults = ciInfo.isCI || !askModels;
|
||||
if (!useDefaults) {
|
||||
const { model } = await prompts(
|
||||
{
|
||||
type: "select",
|
||||
name: "model",
|
||||
message: "Which LLM model would you like to use?",
|
||||
choices: MODELS.map(toChoice),
|
||||
initial: 0,
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
config.model = model;
|
||||
|
||||
const { embeddingModel } = await prompts(
|
||||
{
|
||||
type: "select",
|
||||
name: "embeddingModel",
|
||||
message: "Which embedding model would you like to use?",
|
||||
choices: Object.keys(EMBEDDING_MODELS).map(toChoice),
|
||||
initial: 0,
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
config.embeddingModel = embeddingModel;
|
||||
config.dimensions = EMBEDDING_MODELS[embeddingModel].dimensions;
|
||||
}
|
||||
|
||||
return config;
|
||||
}
|
||||
@@ -8,7 +8,7 @@ import { questionHandlers } from "../../questions";
|
||||
|
||||
const OPENAI_API_URL = "https://api.openai.com/v1";
|
||||
|
||||
const DEFAULT_MODEL = "gpt-3.5-turbo";
|
||||
const DEFAULT_MODEL = "gpt-4o-mini";
|
||||
const DEFAULT_EMBEDDING_MODEL = "text-embedding-3-large";
|
||||
|
||||
export async function askOpenAIQuestions({
|
||||
|
||||
+71
-18
@@ -55,11 +55,11 @@ const getAdditionalDependencies = (
|
||||
case "milvus": {
|
||||
dependencies.push({
|
||||
name: "llama-index-vector-stores-milvus",
|
||||
version: "^0.1.6",
|
||||
version: "^0.1.20",
|
||||
});
|
||||
dependencies.push({
|
||||
name: "pymilvus",
|
||||
version: "2.3.7",
|
||||
version: "2.4.4",
|
||||
});
|
||||
break;
|
||||
}
|
||||
@@ -118,6 +118,12 @@ const getAdditionalDependencies = (
|
||||
version: "^2.9.9",
|
||||
});
|
||||
break;
|
||||
case "llamacloud":
|
||||
dependencies.push({
|
||||
name: "llama-index-indices-managed-llama-cloud",
|
||||
version: "^0.2.7",
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -144,7 +150,17 @@ const getAdditionalDependencies = (
|
||||
case "openai":
|
||||
dependencies.push({
|
||||
name: "llama-index-agent-openai",
|
||||
version: "0.2.2",
|
||||
version: "0.2.6",
|
||||
});
|
||||
break;
|
||||
case "groq":
|
||||
dependencies.push({
|
||||
name: "llama-index-llms-groq",
|
||||
version: "0.1.4",
|
||||
});
|
||||
dependencies.push({
|
||||
name: "llama-index-embeddings-fastembed",
|
||||
version: "^0.1.4",
|
||||
});
|
||||
break;
|
||||
case "anthropic":
|
||||
@@ -153,20 +169,50 @@ const getAdditionalDependencies = (
|
||||
version: "0.1.10",
|
||||
});
|
||||
dependencies.push({
|
||||
name: "llama-index-embeddings-huggingface",
|
||||
version: "0.2.0",
|
||||
name: "llama-index-embeddings-fastembed",
|
||||
version: "^0.1.4",
|
||||
});
|
||||
break;
|
||||
case "gemini":
|
||||
dependencies.push({
|
||||
name: "llama-index-llms-gemini",
|
||||
version: "0.1.7",
|
||||
version: "0.1.10",
|
||||
});
|
||||
dependencies.push({
|
||||
name: "llama-index-embeddings-gemini",
|
||||
version: "0.1.6",
|
||||
});
|
||||
break;
|
||||
case "mistral":
|
||||
dependencies.push({
|
||||
name: "llama-index-llms-mistralai",
|
||||
version: "0.1.17",
|
||||
});
|
||||
dependencies.push({
|
||||
name: "llama-index-embeddings-mistralai",
|
||||
version: "0.1.4",
|
||||
});
|
||||
break;
|
||||
case "azure-openai":
|
||||
dependencies.push({
|
||||
name: "llama-index-llms-azure-openai",
|
||||
version: "0.1.10",
|
||||
});
|
||||
dependencies.push({
|
||||
name: "llama-index-embeddings-azure-openai",
|
||||
version: "0.1.11",
|
||||
});
|
||||
break;
|
||||
case "t-systems":
|
||||
dependencies.push({
|
||||
name: "llama-index-agent-openai",
|
||||
version: "0.2.2",
|
||||
});
|
||||
dependencies.push({
|
||||
name: "llama-index-llms-openai-like",
|
||||
version: "0.1.3",
|
||||
});
|
||||
break;
|
||||
}
|
||||
|
||||
return dependencies;
|
||||
@@ -304,20 +350,27 @@ export const installPythonTemplate = async ({
|
||||
cwd: path.join(compPath, "loaders", "python"),
|
||||
});
|
||||
|
||||
// Select and copy engine code based on data sources and tools
|
||||
let engine;
|
||||
tools = tools ?? [];
|
||||
if (dataSources.length > 0 && tools.length === 0) {
|
||||
console.log("\nNo tools selected - use optimized context chat engine\n");
|
||||
engine = "chat";
|
||||
} else {
|
||||
engine = "agent";
|
||||
}
|
||||
await copy("**", enginePath, {
|
||||
parents: true,
|
||||
cwd: path.join(compPath, "engines", "python", engine),
|
||||
// Copy settings.py to app
|
||||
await copy("**", path.join(root, "app"), {
|
||||
cwd: path.join(compPath, "settings", "python"),
|
||||
});
|
||||
|
||||
if (template === "streaming") {
|
||||
// For the streaming template only:
|
||||
// Select and copy engine code based on data sources and tools
|
||||
let engine;
|
||||
if (dataSources.length > 0 && (!tools || tools.length === 0)) {
|
||||
console.log("\nNo tools selected - use optimized context chat engine\n");
|
||||
engine = "chat";
|
||||
} else {
|
||||
engine = "agent";
|
||||
}
|
||||
await copy("**", enginePath, {
|
||||
parents: true,
|
||||
cwd: path.join(compPath, "engines", "python", engine),
|
||||
});
|
||||
}
|
||||
|
||||
console.log("Adding additional dependencies");
|
||||
|
||||
const addOnDependencies = getAdditionalDependencies(
|
||||
|
||||
+115
-10
@@ -30,7 +30,7 @@ export type ToolDependencies = {
|
||||
|
||||
export const supportedTools: Tool[] = [
|
||||
{
|
||||
display: "Google Search (configuration required after installation)",
|
||||
display: "Google Search",
|
||||
name: "google.GoogleSearchToolSpec",
|
||||
config: {
|
||||
engine:
|
||||
@@ -54,6 +54,29 @@ export const supportedTools: Tool[] = [
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
// For python app, we will use a local DuckDuckGo search tool (instead of DuckDuckGo search tool in LlamaHub)
|
||||
// to get the same results as the TS app.
|
||||
display: "DuckDuckGo Search",
|
||||
name: "duckduckgo",
|
||||
dependencies: [
|
||||
{
|
||||
name: "duckduckgo-search",
|
||||
version: "6.1.7",
|
||||
},
|
||||
],
|
||||
supportedFrameworks: ["fastapi", "nextjs", "express"],
|
||||
type: ToolType.LOCAL,
|
||||
envVars: [
|
||||
{
|
||||
name: TOOL_SYSTEM_PROMPT_ENV_VAR,
|
||||
description: "System prompt for DuckDuckGo search tool.",
|
||||
value: `You are a DuckDuckGo search agent.
|
||||
You can use the duckduckgo search tool to get information from the web to answer user questions.
|
||||
For better results, you can specify the region parameter to get results from a specific region but it's optional.`,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
display: "Wikipedia",
|
||||
name: "wikipedia.WikipediaToolSpec",
|
||||
@@ -107,13 +130,90 @@ export const supportedTools: Tool[] = [
|
||||
{
|
||||
name: TOOL_SYSTEM_PROMPT_ENV_VAR,
|
||||
description: "System prompt for code interpreter tool.",
|
||||
value: `You are a Python interpreter.
|
||||
- You are given tasks to complete and you run python code to solve them.
|
||||
- The python code runs in a Jupyter notebook. Every time you call \`interpreter\` tool, the python code is executed in a separate cell. It's okay to make multiple calls to \`interpreter\`.
|
||||
- Display visualizations using matplotlib or any other visualization library directly in the notebook. Shouldn't save the visualizations to a file, just return the base64 encoded data.
|
||||
- You can install any pip package (if it exists) if you need to but the usual packages for data analysis are already preinstalled.
|
||||
- You can run any python code you want in a secure environment.
|
||||
- Use absolute url from result to display images or any other media.`,
|
||||
value: `-You are a Python interpreter that can run any python code in a secure environment.
|
||||
- The python code runs in a Jupyter notebook. Every time you call the 'interpreter' tool, the python code is executed in a separate cell.
|
||||
- You are given tasks to complete and you run python code to solve them.
|
||||
- It's okay to make multiple calls to interpreter tool. If you get an error or the result is not what you expected, you can call the tool again. Don't give up too soon!
|
||||
- Plot visualizations using matplotlib or any other visualization library directly in the notebook.
|
||||
- You can install any pip package (if it exists) by running a cell with pip install.`,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
display: "OpenAPI action",
|
||||
name: "openapi_action.OpenAPIActionToolSpec",
|
||||
dependencies: [
|
||||
{
|
||||
name: "llama-index-tools-openapi",
|
||||
version: "0.1.3",
|
||||
},
|
||||
{
|
||||
name: "jsonschema",
|
||||
version: "^4.22.0",
|
||||
},
|
||||
{
|
||||
name: "llama-index-tools-requests",
|
||||
version: "0.1.3",
|
||||
},
|
||||
],
|
||||
config: {
|
||||
openapi_uri: "The URL or file path of the OpenAPI schema",
|
||||
},
|
||||
supportedFrameworks: ["fastapi", "express", "nextjs"],
|
||||
type: ToolType.LOCAL,
|
||||
envVars: [
|
||||
{
|
||||
name: TOOL_SYSTEM_PROMPT_ENV_VAR,
|
||||
description: "System prompt for openapi action tool.",
|
||||
value:
|
||||
"You are an OpenAPI action agent. You help users to make requests to the provided OpenAPI schema.",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
display: "Image Generator",
|
||||
name: "img_gen",
|
||||
supportedFrameworks: ["fastapi", "express", "nextjs"],
|
||||
type: ToolType.LOCAL,
|
||||
envVars: [
|
||||
{
|
||||
name: "STABILITY_API_KEY",
|
||||
description:
|
||||
"STABILITY_API_KEY key is required to run image generator. Get it here: https://platform.stability.ai/account/keys",
|
||||
},
|
||||
{
|
||||
name: TOOL_SYSTEM_PROMPT_ENV_VAR,
|
||||
description: "System prompt for image generator tool.",
|
||||
value: `You are an image generator agent. You help users to generate images using the Stability API.`,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
display: "Azure Code Interpreter",
|
||||
name: "azure_code_interpreter.AzureCodeInterpreterToolSpec",
|
||||
supportedFrameworks: ["fastapi", "nextjs", "express"],
|
||||
type: ToolType.LLAMAHUB,
|
||||
dependencies: [
|
||||
{
|
||||
name: "llama-index-tools-azure-code-interpreter",
|
||||
version: "0.2.0",
|
||||
},
|
||||
],
|
||||
envVars: [
|
||||
{
|
||||
name: "AZURE_POOL_MANAGEMENT_ENDPOINT",
|
||||
description:
|
||||
"Please follow this guideline to create and get the pool management endpoint: https://learn.microsoft.com/azure/container-apps/sessions?tabs=azure-cli",
|
||||
},
|
||||
{
|
||||
name: TOOL_SYSTEM_PROMPT_ENV_VAR,
|
||||
description: "System prompt for Azure code interpreter tool.",
|
||||
value: `-You are a Python interpreter that can run any python code in a secure environment.
|
||||
- The python code runs in a Jupyter notebook. Every time you call the 'interpreter' tool, the python code is executed in a separate cell.
|
||||
- You are given tasks to complete and you run python code to solve them.
|
||||
- It's okay to make multiple calls to interpreter tool. If you get an error or the result is not what you expected, you can call the tool again. Don't give up too soon!
|
||||
- Plot visualizations using matplotlib or any other visualization library directly in the notebook.
|
||||
- You can install any pip package (if it exists) by running a cell with pip install.`,
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -142,9 +242,15 @@ export const getTools = (toolsName: string[]): Tool[] => {
|
||||
return tools;
|
||||
};
|
||||
|
||||
export const toolRequiresConfig = (tool: Tool): boolean => {
|
||||
const hasConfig = Object.keys(tool.config || {}).length > 0;
|
||||
const hasEmptyEnvVar = tool.envVars?.some((envVar) => !envVar.value) ?? false;
|
||||
return hasConfig || hasEmptyEnvVar;
|
||||
};
|
||||
|
||||
export const toolsRequireConfig = (tools?: Tool[]): boolean => {
|
||||
if (tools) {
|
||||
return tools?.some((tool) => Object.keys(tool.config || {}).length > 0);
|
||||
return tools?.some(toolRequiresConfig);
|
||||
}
|
||||
return false;
|
||||
};
|
||||
@@ -159,7 +265,6 @@ export const writeToolsConfig = async (
|
||||
tools: Tool[] = [],
|
||||
type: ConfigFileType = ConfigFileType.YAML,
|
||||
) => {
|
||||
if (tools.length === 0) return; // no tools selected, no config need
|
||||
const configContent: {
|
||||
[key in ToolType]: Record<string, any>;
|
||||
} = {
|
||||
|
||||
+18
-4
@@ -1,7 +1,15 @@
|
||||
import { PackageManager } from "../helpers/get-pkg-manager";
|
||||
import { Tool } from "./tools";
|
||||
|
||||
export type ModelProvider = "openai" | "ollama" | "anthropic" | "gemini";
|
||||
export type ModelProvider =
|
||||
| "openai"
|
||||
| "groq"
|
||||
| "ollama"
|
||||
| "anthropic"
|
||||
| "gemini"
|
||||
| "mistral"
|
||||
| "azure-openai"
|
||||
| "t-systems";
|
||||
export type ModelConfig = {
|
||||
provider: ModelProvider;
|
||||
apiKey?: string;
|
||||
@@ -10,7 +18,12 @@ export type ModelConfig = {
|
||||
dimensions: number;
|
||||
isConfigured(): boolean;
|
||||
};
|
||||
export type TemplateType = "streaming" | "community" | "llamapack";
|
||||
export type TemplateType =
|
||||
| "extractor"
|
||||
| "streaming"
|
||||
| "community"
|
||||
| "llamapack"
|
||||
| "multiagent";
|
||||
export type TemplateFramework = "nextjs" | "express" | "fastapi";
|
||||
export type TemplateUI = "html" | "shadcn";
|
||||
export type TemplateVectorDB =
|
||||
@@ -21,7 +34,8 @@ export type TemplateVectorDB =
|
||||
| "milvus"
|
||||
| "astra"
|
||||
| "qdrant"
|
||||
| "chroma";
|
||||
| "chroma"
|
||||
| "llamacloud";
|
||||
export type TemplatePostInstallAction =
|
||||
| "none"
|
||||
| "VSCode"
|
||||
@@ -31,7 +45,7 @@ export type TemplateDataSource = {
|
||||
type: TemplateDataSourceType;
|
||||
config: TemplateDataSourceConfig;
|
||||
};
|
||||
export type TemplateDataSourceType = "file" | "web" | "db";
|
||||
export type TemplateDataSourceType = "file" | "web" | "db" | "llamacloud";
|
||||
export type TemplateObservability = "none" | "opentelemetry";
|
||||
// Config for both file and folder
|
||||
export type FileSourceConfig = {
|
||||
|
||||
+14
-2
@@ -1,7 +1,7 @@
|
||||
import fs from "fs/promises";
|
||||
import os from "os";
|
||||
import path from "path";
|
||||
import { bold, cyan } from "picocolors";
|
||||
import { bold, cyan, yellow } from "picocolors";
|
||||
import { assetRelocator, copy } from "../helpers/copy";
|
||||
import { callPackageManager } from "../helpers/install";
|
||||
import { templatesDir } from "./dir";
|
||||
@@ -104,8 +104,20 @@ export const installTSTemplate = async ({
|
||||
: path.join("src", "controllers");
|
||||
const enginePath = path.join(root, relativeEngineDestPath, "engine");
|
||||
|
||||
// copy llamaindex code for TS templates
|
||||
await copy("**", path.join(root, relativeEngineDestPath, "llamaindex"), {
|
||||
parents: true,
|
||||
cwd: path.join(compPath, "llamaindex", "typescript"),
|
||||
});
|
||||
|
||||
// copy vector db component
|
||||
console.log("\nUsing vector DB:", vectorDb ?? "none", "\n");
|
||||
if (vectorDb === "llamacloud") {
|
||||
console.log(
|
||||
`\nUsing managed index from LlamaCloud. Ensure the ${yellow("LLAMA_CLOUD_* environment variables are set correctly.")}`,
|
||||
);
|
||||
} else {
|
||||
console.log("\nUsing vector DB:", vectorDb ?? "none");
|
||||
}
|
||||
await copy("**", enginePath, {
|
||||
parents: true,
|
||||
cwd: path.join(compPath, "vectordbs", "typescript", vectorDb ?? "none"),
|
||||
|
||||
@@ -9,7 +9,7 @@ import prompts from "prompts";
|
||||
import terminalLink from "terminal-link";
|
||||
import checkForUpdate from "update-check";
|
||||
import { createApp } from "./create-app";
|
||||
import { getDataSources } from "./helpers/datasources";
|
||||
import { EXAMPLE_FILE, getDataSources } from "./helpers/datasources";
|
||||
import { getPkgManager } from "./helpers/get-pkg-manager";
|
||||
import { isFolderEmpty } from "./helpers/is-folder-empty";
|
||||
import { initializeGlobalAgent } from "./helpers/proxy";
|
||||
@@ -194,8 +194,16 @@ if (process.argv.includes("--no-llama-parse")) {
|
||||
program.askModels = process.argv.includes("--ask-models");
|
||||
if (process.argv.includes("--no-files")) {
|
||||
program.dataSources = [];
|
||||
} else {
|
||||
} else if (process.argv.includes("--example-file")) {
|
||||
program.dataSources = getDataSources(program.files, program.exampleFile);
|
||||
} else if (process.argv.includes("--llamacloud")) {
|
||||
program.dataSources = [
|
||||
{
|
||||
type: "llamacloud",
|
||||
config: {},
|
||||
},
|
||||
EXAMPLE_FILE,
|
||||
];
|
||||
}
|
||||
|
||||
const packageManager = !!program.useNpm
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "create-llama",
|
||||
"version": "0.1.8",
|
||||
"version": "0.1.24",
|
||||
"description": "Create LlamaIndex-powered apps with one command",
|
||||
"keywords": [
|
||||
"rag",
|
||||
|
||||
+151
-74
@@ -9,6 +9,7 @@ import {
|
||||
TemplateDataSource,
|
||||
TemplateDataSourceType,
|
||||
TemplateFramework,
|
||||
TemplateType,
|
||||
} from "./helpers";
|
||||
import { COMMUNITY_OWNER, COMMUNITY_REPO } from "./helpers/constant";
|
||||
import { EXAMPLE_FILE } from "./helpers/datasources";
|
||||
@@ -16,7 +17,11 @@ import { templatesDir } from "./helpers/dir";
|
||||
import { getAvailableLlamapackOptions } from "./helpers/llama-pack";
|
||||
import { askModelConfig } from "./helpers/providers";
|
||||
import { getProjectOptions } from "./helpers/repo";
|
||||
import { supportedTools, toolsRequireConfig } from "./helpers/tools";
|
||||
import {
|
||||
supportedTools,
|
||||
toolRequiresConfig,
|
||||
toolsRequireConfig,
|
||||
} from "./helpers/tools";
|
||||
|
||||
export type QuestionArgs = Omit<
|
||||
InstallAppArgs,
|
||||
@@ -118,8 +123,15 @@ const getVectorDbChoices = (framework: TemplateFramework) => {
|
||||
export const getDataSourceChoices = (
|
||||
framework: TemplateFramework,
|
||||
selectedDataSource: TemplateDataSource[],
|
||||
template?: TemplateType,
|
||||
) => {
|
||||
// If LlamaCloud is already selected, don't show any other options
|
||||
if (selectedDataSource.find((s) => s.type === "llamacloud")) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const choices = [];
|
||||
|
||||
if (selectedDataSource.length > 0) {
|
||||
choices.push({
|
||||
title: "No",
|
||||
@@ -127,29 +139,37 @@ export const getDataSourceChoices = (
|
||||
});
|
||||
}
|
||||
if (selectedDataSource === undefined || selectedDataSource.length === 0) {
|
||||
if (template !== "multiagent") {
|
||||
choices.push({
|
||||
title: "No datasource",
|
||||
value: "none",
|
||||
});
|
||||
}
|
||||
choices.push({
|
||||
title: "No data, just a simple chat or agent",
|
||||
value: "none",
|
||||
});
|
||||
choices.push({
|
||||
title: "Use an example PDF",
|
||||
title:
|
||||
process.platform !== "linux"
|
||||
? "Use an example PDF"
|
||||
: "Use an example PDF (you can add your own data files later)",
|
||||
value: "exampleFile",
|
||||
});
|
||||
}
|
||||
|
||||
choices.push(
|
||||
{
|
||||
title: `Use local files (${supportedContextFileTypes.join(", ")})`,
|
||||
value: "file",
|
||||
},
|
||||
{
|
||||
title:
|
||||
process.platform === "win32"
|
||||
? "Use a local folder"
|
||||
: "Use local folders",
|
||||
value: "folder",
|
||||
},
|
||||
);
|
||||
// Linux has many distros so we won't support file/folder picker for now
|
||||
if (process.platform !== "linux") {
|
||||
choices.push(
|
||||
{
|
||||
title: `Use local files (${supportedContextFileTypes.join(", ")})`,
|
||||
value: "file",
|
||||
},
|
||||
{
|
||||
title:
|
||||
process.platform === "win32"
|
||||
? "Use a local folder"
|
||||
: "Use local folders",
|
||||
value: "folder",
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
if (framework === "fastapi") {
|
||||
choices.push({
|
||||
@@ -161,6 +181,13 @@ export const getDataSourceChoices = (
|
||||
value: "db",
|
||||
});
|
||||
}
|
||||
|
||||
if (!selectedDataSource.length) {
|
||||
choices.push({
|
||||
title: "Use managed index from LlamaCloud",
|
||||
value: "llamacloud",
|
||||
});
|
||||
}
|
||||
return choices;
|
||||
};
|
||||
|
||||
@@ -258,25 +285,27 @@ export const askQuestions = async (
|
||||
},
|
||||
];
|
||||
|
||||
const modelConfigured =
|
||||
!program.llamapack && program.modelConfig.isConfigured();
|
||||
// If using LlamaParse, require LlamaCloud API key
|
||||
const llamaCloudKeyConfigured = program.useLlamaParse
|
||||
? program.llamaCloudKey || process.env["LLAMA_CLOUD_API_KEY"]
|
||||
: true;
|
||||
const hasVectorDb = program.vectorDb && program.vectorDb !== "none";
|
||||
// Can run the app if all tools do not require configuration
|
||||
if (
|
||||
!hasVectorDb &&
|
||||
modelConfigured &&
|
||||
llamaCloudKeyConfigured &&
|
||||
!toolsRequireConfig(program.tools)
|
||||
) {
|
||||
actionChoices.push({
|
||||
title:
|
||||
"Generate code, install dependencies, and run the app (~2 min)",
|
||||
value: "runApp",
|
||||
});
|
||||
if (program.template !== "multiagent") {
|
||||
const modelConfigured =
|
||||
!program.llamapack && program.modelConfig.isConfigured();
|
||||
// If using LlamaParse, require LlamaCloud API key
|
||||
const llamaCloudKeyConfigured = program.useLlamaParse
|
||||
? program.llamaCloudKey || process.env["LLAMA_CLOUD_API_KEY"]
|
||||
: true;
|
||||
const hasVectorDb = program.vectorDb && program.vectorDb !== "none";
|
||||
// Can run the app if all tools do not require configuration
|
||||
if (
|
||||
!hasVectorDb &&
|
||||
modelConfigured &&
|
||||
llamaCloudKeyConfigured &&
|
||||
!toolsRequireConfig(program.tools)
|
||||
) {
|
||||
actionChoices.push({
|
||||
title:
|
||||
"Generate code, install dependencies, and run the app (~2 min)",
|
||||
value: "runApp",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const { action } = await prompts(
|
||||
@@ -308,7 +337,12 @@ export const askQuestions = async (
|
||||
name: "template",
|
||||
message: "Which template would you like to use?",
|
||||
choices: [
|
||||
{ title: "Chat", value: "streaming" },
|
||||
{ title: "Agentic RAG (single agent)", value: "streaming" },
|
||||
{
|
||||
title: "Multi-agent app (using llama-agents)",
|
||||
value: "multiagent",
|
||||
},
|
||||
{ title: "Structured Extractor", value: "extractor" },
|
||||
{
|
||||
title: `Community template from ${styledRepo}`,
|
||||
value: "community",
|
||||
@@ -372,6 +406,10 @@ export const askQuestions = async (
|
||||
return; // early return - no further questions needed for llamapack projects
|
||||
}
|
||||
|
||||
if (program.template === "multiagent" || program.template === "extractor") {
|
||||
// TODO: multi-agents currently only supports FastAPI
|
||||
program.framework = preferences.framework = "fastapi";
|
||||
}
|
||||
if (!program.framework) {
|
||||
if (ciInfo.isCI) {
|
||||
program.framework = getPrefOrDefault("framework");
|
||||
@@ -397,7 +435,10 @@ export const askQuestions = async (
|
||||
}
|
||||
}
|
||||
|
||||
if (program.framework === "express" || program.framework === "fastapi") {
|
||||
if (
|
||||
(program.framework === "express" || program.framework === "fastapi") &&
|
||||
program.template === "streaming"
|
||||
) {
|
||||
// if a backend-only framework is selected, ask whether we should create a frontend
|
||||
if (program.frontend === undefined) {
|
||||
if (ciInfo.isCI) {
|
||||
@@ -434,7 +475,7 @@ export const askQuestions = async (
|
||||
}
|
||||
}
|
||||
|
||||
if (!program.observability) {
|
||||
if (!program.observability && program.template === "streaming") {
|
||||
if (ciInfo.isCI) {
|
||||
program.observability = getPrefOrDefault("observability");
|
||||
} else {
|
||||
@@ -461,6 +502,7 @@ export const askQuestions = async (
|
||||
const modelConfig = await askModelConfig({
|
||||
openAiKey,
|
||||
askModels: program.askModels ?? false,
|
||||
framework: program.framework,
|
||||
});
|
||||
program.modelConfig = modelConfig;
|
||||
preferences.modelConfig = modelConfig;
|
||||
@@ -474,6 +516,12 @@ export const askQuestions = async (
|
||||
// continue asking user for data sources if none are initially provided
|
||||
while (true) {
|
||||
const firstQuestion = program.dataSources.length === 0;
|
||||
const choices = getDataSourceChoices(
|
||||
program.framework,
|
||||
program.dataSources,
|
||||
program.template,
|
||||
);
|
||||
if (choices.length === 0) break;
|
||||
const { selectedSource } = await prompts(
|
||||
{
|
||||
type: "select",
|
||||
@@ -481,10 +529,7 @@ export const askQuestions = async (
|
||||
message: firstQuestion
|
||||
? "Which data source would you like to use?"
|
||||
: "Would you like to add another data source?",
|
||||
choices: getDataSourceChoices(
|
||||
program.framework,
|
||||
program.dataSources,
|
||||
),
|
||||
choices,
|
||||
initial: firstQuestion ? 1 : 0,
|
||||
},
|
||||
questionHandlers,
|
||||
@@ -581,51 +626,82 @@ export const askQuestions = async (
|
||||
config: await prompts(dbPrompts, questionHandlers),
|
||||
});
|
||||
}
|
||||
case "llamacloud": {
|
||||
program.dataSources.push({
|
||||
type: "llamacloud",
|
||||
config: {},
|
||||
});
|
||||
program.dataSources.push(EXAMPLE_FILE);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Asking for LlamaParse if user selected file or folder data source
|
||||
if (
|
||||
program.dataSources.some((ds) => ds.type === "file") &&
|
||||
program.useLlamaParse === undefined
|
||||
) {
|
||||
if (ciInfo.isCI) {
|
||||
program.useLlamaParse = getPrefOrDefault("useLlamaParse");
|
||||
program.llamaCloudKey = getPrefOrDefault("llamaCloudKey");
|
||||
} else {
|
||||
const { useLlamaParse } = await prompts(
|
||||
{
|
||||
type: "toggle",
|
||||
name: "useLlamaParse",
|
||||
message:
|
||||
"Would you like to use LlamaParse (improved parser for RAG - requires API key)?",
|
||||
initial: false,
|
||||
active: "yes",
|
||||
inactive: "no",
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
program.useLlamaParse = useLlamaParse;
|
||||
const isUsingLlamaCloud = program.dataSources.some(
|
||||
(ds) => ds.type === "llamacloud",
|
||||
);
|
||||
|
||||
// Ask for LlamaCloud API key
|
||||
if (useLlamaParse && program.llamaCloudKey === undefined) {
|
||||
// Asking for LlamaParse if user selected file data source
|
||||
if (isUsingLlamaCloud) {
|
||||
// default to use LlamaParse if using LlamaCloud
|
||||
program.useLlamaParse = preferences.useLlamaParse = true;
|
||||
} else {
|
||||
if (program.useLlamaParse === undefined) {
|
||||
// if already set useLlamaParse, don't ask again
|
||||
if (program.dataSources.some((ds) => ds.type === "file")) {
|
||||
if (ciInfo.isCI) {
|
||||
program.useLlamaParse = getPrefOrDefault("useLlamaParse");
|
||||
} else {
|
||||
const { useLlamaParse } = await prompts(
|
||||
{
|
||||
type: "toggle",
|
||||
name: "useLlamaParse",
|
||||
message:
|
||||
"Would you like to use LlamaParse (improved parser for RAG - requires API key)?",
|
||||
initial: false,
|
||||
active: "yes",
|
||||
inactive: "no",
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
program.useLlamaParse = useLlamaParse;
|
||||
preferences.useLlamaParse = useLlamaParse;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Ask for LlamaCloud API key when using a LlamaCloud index or LlamaParse
|
||||
if (isUsingLlamaCloud || program.useLlamaParse) {
|
||||
if (!program.llamaCloudKey) {
|
||||
// if already set, don't ask again
|
||||
if (ciInfo.isCI) {
|
||||
program.llamaCloudKey = getPrefOrDefault("llamaCloudKey");
|
||||
} else {
|
||||
// Ask for LlamaCloud API key
|
||||
const { llamaCloudKey } = await prompts(
|
||||
{
|
||||
type: "text",
|
||||
name: "llamaCloudKey",
|
||||
message:
|
||||
"Please provide your LlamaIndex Cloud API key (leave blank to skip):",
|
||||
"Please provide your LlamaCloud API key (leave blank to skip):",
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
program.llamaCloudKey = llamaCloudKey;
|
||||
program.llamaCloudKey = preferences.llamaCloudKey =
|
||||
llamaCloudKey || process.env.LLAMA_CLOUD_API_KEY;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (program.dataSources.length > 0 && !program.vectorDb) {
|
||||
if (isUsingLlamaCloud) {
|
||||
// When using a LlamaCloud index, don't ask for vector database and use code in `llamacloud` folder for vector database
|
||||
const vectorDb = "llamacloud";
|
||||
program.vectorDb = vectorDb;
|
||||
preferences.vectorDb = vectorDb;
|
||||
} else if (program.dataSources.length > 0 && !program.vectorDb) {
|
||||
if (ciInfo.isCI) {
|
||||
program.vectorDb = getPrefOrDefault("vectorDb");
|
||||
} else {
|
||||
@@ -644,7 +720,8 @@ export const askQuestions = async (
|
||||
}
|
||||
}
|
||||
|
||||
if (!program.tools) {
|
||||
if (!program.tools && program.template === "streaming") {
|
||||
// TODO: allow to select tools also for multi-agent framework
|
||||
if (ciInfo.isCI) {
|
||||
program.tools = getPrefOrDefault("tools");
|
||||
} else {
|
||||
@@ -652,7 +729,7 @@ export const askQuestions = async (
|
||||
t.supportedFrameworks?.includes(program.framework),
|
||||
);
|
||||
const toolChoices = options.map((tool) => ({
|
||||
title: tool.display,
|
||||
title: `${tool.display}${toolRequiresConfig(tool) ? " (needs configuration)" : ""}`,
|
||||
value: tool.name,
|
||||
}));
|
||||
const { toolsName } = await prompts({
|
||||
|
||||
@@ -6,7 +6,7 @@ from app.engine.tools import ToolFactory
|
||||
from app.engine.index import get_index
|
||||
|
||||
|
||||
def get_chat_engine():
|
||||
def get_chat_engine(filters=None):
|
||||
system_prompt = os.getenv("SYSTEM_PROMPT")
|
||||
top_k = os.getenv("TOP_K", "3")
|
||||
tools = []
|
||||
@@ -14,7 +14,9 @@ def get_chat_engine():
|
||||
# Add query tool if index exists
|
||||
index = get_index()
|
||||
if index is not None:
|
||||
query_engine = index.as_query_engine(similarity_top_k=int(top_k))
|
||||
query_engine = index.as_query_engine(
|
||||
similarity_top_k=int(top_k), filters=filters
|
||||
)
|
||||
query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine)
|
||||
tools.append(query_engine_tool)
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import os
|
||||
import yaml
|
||||
import json
|
||||
import importlib
|
||||
|
||||
from cachetools import cached, LRUCache
|
||||
from llama_index.core.tools.tool_spec.base import BaseToolSpec
|
||||
from llama_index.core.tools.function_tool import FunctionTool
|
||||
|
||||
@@ -18,7 +19,6 @@ class ToolFactory:
|
||||
ToolType.LOCAL: "app.engine.tools",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def load_tools(tool_type: str, tool_name: str, config: dict) -> list[FunctionTool]:
|
||||
source_package = ToolFactory.TOOL_SOURCE_PACKAGE_MAP[tool_type]
|
||||
try:
|
||||
@@ -31,7 +31,7 @@ class ToolFactory:
|
||||
return tool_spec.to_tool_list()
|
||||
else:
|
||||
module = importlib.import_module(f"{source_package}.{tool_name}")
|
||||
tools = getattr(module, "tools")
|
||||
tools = module.get_tools(**config)
|
||||
if not all(isinstance(tool, FunctionTool) for tool in tools):
|
||||
raise ValueError(
|
||||
f"The module {module} does not contain valid tools"
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
from llama_index.core.tools.function_tool import FunctionTool
|
||||
|
||||
|
||||
def duckduckgo_search(
|
||||
query: str,
|
||||
region: str = "wt-wt",
|
||||
max_results: int = 10,
|
||||
):
|
||||
"""
|
||||
Use this function to search for any query in DuckDuckGo.
|
||||
Args:
|
||||
query (str): The query to search in DuckDuckGo.
|
||||
region Optional(str): The region to be used for the search in [country-language] convention, ex us-en, uk-en, ru-ru, etc...
|
||||
max_results Optional(int): The maximum number of results to be returned. Default is 10.
|
||||
"""
|
||||
try:
|
||||
from duckduckgo_search import DDGS
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"duckduckgo_search package is required to use this function."
|
||||
"Please install it by running: `poetry add duckduckgo_search` or `pip install duckduckgo_search`"
|
||||
)
|
||||
|
||||
params = {
|
||||
"keywords": query,
|
||||
"region": region,
|
||||
"max_results": max_results,
|
||||
}
|
||||
results = []
|
||||
with DDGS() as ddg:
|
||||
results = list(ddg.text(**params))
|
||||
return results
|
||||
|
||||
|
||||
def get_tools(**kwargs):
|
||||
return [FunctionTool.from_defaults(duckduckgo_search)]
|
||||
@@ -0,0 +1,108 @@
|
||||
import os
|
||||
import uuid
|
||||
import logging
|
||||
import requests
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from llama_index.core.tools import FunctionTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ImageGeneratorToolOutput(BaseModel):
|
||||
is_success: bool = Field(
|
||||
...,
|
||||
description="Whether the image generation was successful.",
|
||||
)
|
||||
image_url: Optional[str] = Field(
|
||||
None,
|
||||
description="The URL of the generated image.",
|
||||
)
|
||||
error_message: Optional[str] = Field(
|
||||
None,
|
||||
description="The error message if the image generation failed.",
|
||||
)
|
||||
|
||||
|
||||
class ImageGeneratorTool:
|
||||
_IMG_OUTPUT_FORMAT = "webp"
|
||||
_IMG_OUTPUT_DIR = "output/tool"
|
||||
_IMG_GEN_API = "https://api.stability.ai/v2beta/stable-image/generate/core"
|
||||
|
||||
def __init__(self, api_key: str = None):
|
||||
if not api_key:
|
||||
api_key = os.getenv("STABILITY_API_KEY")
|
||||
self._api_key = api_key
|
||||
self.fileserver_url_prefix = os.getenv("FILESERVER_URL_PREFIX")
|
||||
if self._api_key is None:
|
||||
raise ValueError(
|
||||
"STABILITY_API_KEY key is required to run image generator. Get it here: https://platform.stability.ai/account/keys"
|
||||
)
|
||||
if self.fileserver_url_prefix is None:
|
||||
raise ValueError("FILESERVER_URL_PREFIX is required.")
|
||||
|
||||
def _prepare_output_dir(self):
|
||||
"""
|
||||
Create the output directory if it doesn't exist
|
||||
"""
|
||||
if not os.path.exists(self._IMG_OUTPUT_DIR):
|
||||
os.makedirs(self._IMG_OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
def _save_image(self, image_data: bytes):
|
||||
self._prepare_output_dir()
|
||||
filename = f"{uuid.uuid4()}.{self._IMG_OUTPUT_FORMAT}"
|
||||
output_path = os.path.join(self._IMG_OUTPUT_DIR, filename)
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(image_data)
|
||||
url = f"{os.getenv('FILESERVER_URL_PREFIX')}/{self._IMG_OUTPUT_DIR}/{filename}"
|
||||
logger.info(f"Saved image to {output_path}.\nURL: {url}")
|
||||
return url
|
||||
|
||||
def _call_stability_api(self, prompt: str):
|
||||
headers = {
|
||||
"authorization": f"Bearer {self._api_key}",
|
||||
"accept": "image/*",
|
||||
}
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"output_format": self._IMG_OUTPUT_FORMAT,
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
self._IMG_GEN_API,
|
||||
headers=headers,
|
||||
files={"none": ""},
|
||||
data=data,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return response
|
||||
|
||||
def generate_image(self, prompt: str) -> ImageGeneratorToolOutput:
|
||||
"""
|
||||
Use this tool to generate an image based on the prompt.
|
||||
Args:
|
||||
prompt (str): The prompt to generate the image from.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Call the Stability API
|
||||
response = self._call_stability_api(prompt)
|
||||
|
||||
# Save the image and get the URL
|
||||
image_url = self._save_image(response.content)
|
||||
|
||||
return ImageGeneratorToolOutput(
|
||||
is_success=True,
|
||||
image_url=image_url,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(e, exc_info=True)
|
||||
return ImageGeneratorToolOutput(
|
||||
is_success=False,
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
|
||||
def get_tools(**kwargs):
|
||||
return [FunctionTool.from_defaults(ImageGeneratorTool(**kwargs).generate_image)]
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
import base64
|
||||
import uuid
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Tuple, Dict
|
||||
from typing import List, Tuple, Dict, Optional
|
||||
from llama_index.core.tools import FunctionTool
|
||||
from e2b_code_interpreter import CodeInterpreter
|
||||
from e2b_code_interpreter.models import Logs
|
||||
@@ -14,8 +14,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class InterpreterExtraResult(BaseModel):
|
||||
type: str
|
||||
filename: str
|
||||
url: str
|
||||
content: Optional[str] = None
|
||||
filename: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
|
||||
|
||||
class E2BToolOutput(BaseModel):
|
||||
@@ -26,11 +27,26 @@ class E2BToolOutput(BaseModel):
|
||||
|
||||
class E2BCodeInterpreter:
|
||||
|
||||
output_dir = "tool-output"
|
||||
output_dir = "output/tool"
|
||||
|
||||
def __init__(self, api_key: str = None):
|
||||
if api_key is None:
|
||||
api_key = os.getenv("E2B_API_KEY")
|
||||
filesever_url_prefix = os.getenv("FILESERVER_URL_PREFIX")
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"E2B_API_KEY key is required to run code interpreter. Get it here: https://e2b.dev/docs/getting-started/api-key"
|
||||
)
|
||||
if not filesever_url_prefix:
|
||||
raise ValueError(
|
||||
"FILESERVER_URL_PREFIX is required to display file output from sandbox"
|
||||
)
|
||||
|
||||
def __init__(self, api_key: str, filesever_url_prefix: str):
|
||||
self.api_key = api_key
|
||||
self.filesever_url_prefix = filesever_url_prefix
|
||||
self.interpreter = CodeInterpreter(api_key=api_key)
|
||||
|
||||
def __del__(self):
|
||||
self.interpreter.close()
|
||||
|
||||
def get_output_path(self, filename: str) -> str:
|
||||
# if output directory doesn't exist, create it
|
||||
@@ -72,63 +88,56 @@ class E2BCodeInterpreter:
|
||||
|
||||
try:
|
||||
formats = result.formats()
|
||||
base64_data_arr = [result[format] for format in formats]
|
||||
results = [result[format] for format in formats]
|
||||
|
||||
for ext, base64_data in zip(formats, base64_data_arr):
|
||||
if ext and base64_data:
|
||||
result = self.save_to_disk(base64_data, ext)
|
||||
filename = result["filename"]
|
||||
output.append(
|
||||
InterpreterExtraResult(
|
||||
type=ext, filename=filename, url=self.get_file_url(filename)
|
||||
for ext, data in zip(formats, results):
|
||||
match ext:
|
||||
case "png" | "svg" | "jpeg" | "pdf":
|
||||
result = self.save_to_disk(data, ext)
|
||||
filename = result["filename"]
|
||||
output.append(
|
||||
InterpreterExtraResult(
|
||||
type=ext,
|
||||
filename=filename,
|
||||
url=self.get_file_url(filename),
|
||||
)
|
||||
)
|
||||
case _:
|
||||
output.append(
|
||||
InterpreterExtraResult(
|
||||
type=ext,
|
||||
content=data,
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception as error:
|
||||
logger.error("Error when saving data to disk", error)
|
||||
logger.exception(error, exc_info=True)
|
||||
logger.error("Error when parsing output from E2b interpreter tool", error)
|
||||
|
||||
return output
|
||||
|
||||
def interpret(self, code: str) -> E2BToolOutput:
|
||||
with CodeInterpreter(api_key=self.api_key) as interpreter:
|
||||
logger.info(
|
||||
f"\n{'='*50}\n> Running following AI-generated code:\n{code}\n{'='*50}"
|
||||
)
|
||||
exec = interpreter.notebook.exec_cell(code)
|
||||
"""
|
||||
Execute python code in a Jupyter notebook cell, the toll will return result, stdout, stderr, display_data, and error.
|
||||
|
||||
if exec.error:
|
||||
output = E2BToolOutput(is_error=True, logs=[exec.error])
|
||||
Parameters:
|
||||
code (str): The python code to be executed in a single cell.
|
||||
"""
|
||||
logger.info(
|
||||
f"\n{'='*50}\n> Running following AI-generated code:\n{code}\n{'='*50}"
|
||||
)
|
||||
exec = self.interpreter.notebook.exec_cell(code)
|
||||
|
||||
if exec.error:
|
||||
logger.error("Error when executing code", exec.error)
|
||||
output = E2BToolOutput(is_error=True, logs=exec.logs, results=[])
|
||||
else:
|
||||
if len(exec.results) == 0:
|
||||
output = E2BToolOutput(is_error=False, logs=exec.logs, results=[])
|
||||
else:
|
||||
if len(exec.results) == 0:
|
||||
output = E2BToolOutput(is_error=False, logs=exec.logs, results=[])
|
||||
else:
|
||||
results = self.parse_result(exec.results[0])
|
||||
output = E2BToolOutput(
|
||||
is_error=False, logs=exec.logs, results=results
|
||||
)
|
||||
return output
|
||||
results = self.parse_result(exec.results[0])
|
||||
output = E2BToolOutput(is_error=False, logs=exec.logs, results=results)
|
||||
return output
|
||||
|
||||
|
||||
def code_interpret(code: str) -> Dict:
|
||||
"""
|
||||
Execute python code in a Jupyter notebook cell and return any result, stdout, stderr, display_data, and error.
|
||||
"""
|
||||
api_key = os.getenv("E2B_API_KEY")
|
||||
filesever_url_prefix = os.getenv("FILESERVER_URL_PREFIX")
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"E2B_API_KEY key is required to run code interpreter. Get it here: https://e2b.dev/docs/getting-started/api-key"
|
||||
)
|
||||
if not filesever_url_prefix:
|
||||
raise ValueError(
|
||||
"FILESERVER_URL_PREFIX is required to display file output from sandbox"
|
||||
)
|
||||
|
||||
interpreter = E2BCodeInterpreter(
|
||||
api_key=api_key, filesever_url_prefix=filesever_url_prefix
|
||||
)
|
||||
output = interpreter.interpret(code)
|
||||
return output.dict()
|
||||
|
||||
|
||||
# Specify as functions tools to be loaded by the ToolFactory
|
||||
tools = [FunctionTool.from_defaults(code_interpret)]
|
||||
def get_tools(**kwargs):
|
||||
return [FunctionTool.from_defaults(E2BCodeInterpreter(**kwargs).interpret)]
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
from typing import Dict, List, Tuple
|
||||
from llama_index.tools.openapi import OpenAPIToolSpec
|
||||
from llama_index.tools.requests import RequestsToolSpec
|
||||
|
||||
|
||||
class OpenAPIActionToolSpec(OpenAPIToolSpec, RequestsToolSpec):
|
||||
"""
|
||||
A combination of OpenAPI and Requests tool specs that can parse OpenAPI specs and make requests.
|
||||
|
||||
openapi_uri: str: The file path or URL to the OpenAPI spec.
|
||||
domain_headers: dict: Whitelist domains and the headers to use.
|
||||
"""
|
||||
|
||||
spec_functions = OpenAPIToolSpec.spec_functions + RequestsToolSpec.spec_functions
|
||||
# Cached parsed specs by URI
|
||||
_specs: Dict[str, Tuple[Dict, List[str]]] = {}
|
||||
|
||||
def __init__(self, openapi_uri: str, domain_headers: dict = None, **kwargs):
|
||||
if domain_headers is None:
|
||||
domain_headers = {}
|
||||
if openapi_uri not in self._specs:
|
||||
openapi_spec, servers = self._load_openapi_spec(openapi_uri)
|
||||
self._specs[openapi_uri] = (openapi_spec, servers)
|
||||
else:
|
||||
openapi_spec, servers = self._specs[openapi_uri]
|
||||
|
||||
# Add the servers to the domain headers if they are not already present
|
||||
for server in servers:
|
||||
if server not in domain_headers:
|
||||
domain_headers[server] = {}
|
||||
|
||||
OpenAPIToolSpec.__init__(self, spec=openapi_spec)
|
||||
RequestsToolSpec.__init__(self, domain_headers)
|
||||
|
||||
@staticmethod
|
||||
def _load_openapi_spec(uri: str) -> Tuple[Dict, List[str]]:
|
||||
"""
|
||||
Load an OpenAPI spec from a URI.
|
||||
|
||||
Args:
|
||||
uri (str): A file path or URL to the OpenAPI spec.
|
||||
|
||||
Returns:
|
||||
List[Document]: A list of Document objects.
|
||||
"""
|
||||
import yaml
|
||||
from urllib.parse import urlparse
|
||||
|
||||
if uri.startswith("http"):
|
||||
import requests
|
||||
|
||||
response = requests.get(uri)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(
|
||||
"Could not initialize OpenAPIActionToolSpec: "
|
||||
f"Failed to load OpenAPI spec from {uri}, status code: {response.status_code}"
|
||||
)
|
||||
spec = yaml.safe_load(response.text)
|
||||
elif uri.startswith("file"):
|
||||
filepath = urlparse(uri).path
|
||||
with open(filepath, "r") as file:
|
||||
spec = yaml.safe_load(file)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Could not initialize OpenAPIActionToolSpec: Invalid OpenAPI URI provided. "
|
||||
"Only HTTP and file path are supported."
|
||||
)
|
||||
# Add the servers to the whitelist
|
||||
try:
|
||||
servers = [
|
||||
urlparse(server["url"]).netloc for server in spec.get("servers", [])
|
||||
]
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
"Could not initialize OpenAPIActionToolSpec: Invalid OpenAPI spec provided. "
|
||||
"Could not get `servers` from the spec."
|
||||
) from e
|
||||
return spec, servers
|
||||
@@ -69,4 +69,5 @@ class OpenMeteoWeather:
|
||||
return response.json()
|
||||
|
||||
|
||||
tools = [FunctionTool.from_defaults(OpenMeteoWeather.get_weather_information)]
|
||||
def get_tools(**kwargs):
|
||||
return [FunctionTool.from_defaults(OpenMeteoWeather.get_weather_information)]
|
||||
|
||||
@@ -3,7 +3,7 @@ from app.engine.index import get_index
|
||||
from fastapi import HTTPException
|
||||
|
||||
|
||||
def get_chat_engine():
|
||||
def get_chat_engine(filters=None):
|
||||
system_prompt = os.getenv("SYSTEM_PROMPT")
|
||||
top_k = os.getenv("TOP_K", 3)
|
||||
|
||||
@@ -20,4 +20,5 @@ def get_chat_engine():
|
||||
similarity_top_k=int(top_k),
|
||||
system_prompt=system_prompt,
|
||||
chat_mode="condense_plus_context",
|
||||
filters=filters,
|
||||
)
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
import { BaseToolWithCall, OpenAIAgent, QueryEngineTool } from "llamaindex";
|
||||
import {
|
||||
BaseToolWithCall,
|
||||
MetadataFilter,
|
||||
MetadataFilters,
|
||||
OpenAIAgent,
|
||||
QueryEngineTool,
|
||||
} from "llamaindex";
|
||||
import fs from "node:fs/promises";
|
||||
import path from "node:path";
|
||||
import { getDataSource } from "./index";
|
||||
import { STORAGE_CACHE_DIR } from "./shared";
|
||||
import { createTools } from "./tools";
|
||||
|
||||
export async function createChatEngine() {
|
||||
export async function createChatEngine(documentIds?: string[]) {
|
||||
const tools: BaseToolWithCall[] = [];
|
||||
|
||||
// Add a query engine tool if we have a data source
|
||||
@@ -14,10 +19,12 @@ export async function createChatEngine() {
|
||||
if (index) {
|
||||
tools.push(
|
||||
new QueryEngineTool({
|
||||
queryEngine: index.asQueryEngine(),
|
||||
queryEngine: index.asQueryEngine({
|
||||
preFilters: generateFilters(documentIds || []),
|
||||
}),
|
||||
metadata: {
|
||||
name: "data_query_engine",
|
||||
description: `A query engine for documents in storage folder: ${STORAGE_CACHE_DIR}`,
|
||||
description: `A query engine for documents from your data source.`,
|
||||
},
|
||||
}),
|
||||
);
|
||||
@@ -40,3 +47,27 @@ export async function createChatEngine() {
|
||||
systemPrompt: process.env.SYSTEM_PROMPT,
|
||||
});
|
||||
}
|
||||
|
||||
function generateFilters(documentIds: string[]): MetadataFilters | undefined {
|
||||
// public documents don't have the "private" field or it's set to "false"
|
||||
const publicDocumentsFilter: MetadataFilter = {
|
||||
key: "private",
|
||||
value: ["true"],
|
||||
operator: "nin",
|
||||
};
|
||||
|
||||
// if no documentIds are provided, only retrieve information from public documents
|
||||
if (!documentIds.length) return { filters: [publicDocumentsFilter] };
|
||||
|
||||
const privateDocumentsFilter: MetadataFilter = {
|
||||
key: "doc_id",
|
||||
value: documentIds,
|
||||
operator: "in",
|
||||
};
|
||||
|
||||
// if documentIds are provided, retrieve information from public and private documents
|
||||
return {
|
||||
filters: [publicDocumentsFilter, privateDocumentsFilter],
|
||||
condition: "or",
|
||||
};
|
||||
}
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
import { JSONSchemaType } from "ajv";
|
||||
import { search } from "duck-duck-scrape";
|
||||
import { BaseTool, ToolMetadata } from "llamaindex";
|
||||
|
||||
export type DuckDuckGoParameter = {
|
||||
query: string;
|
||||
region?: string;
|
||||
};
|
||||
|
||||
export type DuckDuckGoToolParams = {
|
||||
metadata?: ToolMetadata<JSONSchemaType<DuckDuckGoParameter>>;
|
||||
};
|
||||
|
||||
const DEFAULT_META_DATA: ToolMetadata<JSONSchemaType<DuckDuckGoParameter>> = {
|
||||
name: "duckduckgo",
|
||||
description: "Use this function to search for any query in DuckDuckGo.",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
query: {
|
||||
type: "string",
|
||||
description: "The query to search in DuckDuckGo.",
|
||||
},
|
||||
region: {
|
||||
type: "string",
|
||||
description:
|
||||
"Optional, The region to be used for the search in [country-language] convention, ex us-en, uk-en, ru-ru, etc...",
|
||||
nullable: true,
|
||||
},
|
||||
},
|
||||
required: ["query"],
|
||||
},
|
||||
};
|
||||
|
||||
type DuckDuckGoSearchResult = {
|
||||
title: string;
|
||||
description: string;
|
||||
url: string;
|
||||
};
|
||||
|
||||
export class DuckDuckGoSearchTool implements BaseTool<DuckDuckGoParameter> {
|
||||
metadata: ToolMetadata<JSONSchemaType<DuckDuckGoParameter>>;
|
||||
|
||||
constructor(params: DuckDuckGoToolParams) {
|
||||
this.metadata = params.metadata ?? DEFAULT_META_DATA;
|
||||
}
|
||||
|
||||
async call(input: DuckDuckGoParameter) {
|
||||
const { query, region } = input;
|
||||
const options = region ? { region } : {};
|
||||
const searchResults = await search(query, options);
|
||||
|
||||
return searchResults.results.map((result) => {
|
||||
return {
|
||||
title: result.title,
|
||||
description: result.description,
|
||||
url: result.url,
|
||||
} as DuckDuckGoSearchResult;
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,112 @@
|
||||
import type { JSONSchemaType } from "ajv";
|
||||
import { FormData } from "formdata-node";
|
||||
import fs from "fs";
|
||||
import got from "got";
|
||||
import { BaseTool, ToolMetadata } from "llamaindex";
|
||||
import path from "node:path";
|
||||
import { Readable } from "stream";
|
||||
|
||||
export type ImgGeneratorParameter = {
|
||||
prompt: string;
|
||||
};
|
||||
|
||||
export type ImgGeneratorToolParams = {
|
||||
metadata?: ToolMetadata<JSONSchemaType<ImgGeneratorParameter>>;
|
||||
};
|
||||
|
||||
export type ImgGeneratorToolOutput = {
|
||||
isSuccess: boolean;
|
||||
imageUrl?: string;
|
||||
errorMessage?: string;
|
||||
};
|
||||
|
||||
const DEFAULT_META_DATA: ToolMetadata<JSONSchemaType<ImgGeneratorParameter>> = {
|
||||
name: "image_generator",
|
||||
description: `Use this function to generate an image based on the prompt.`,
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
prompt: {
|
||||
type: "string",
|
||||
description: "The prompt to generate the image",
|
||||
},
|
||||
},
|
||||
required: ["prompt"],
|
||||
},
|
||||
};
|
||||
|
||||
export class ImgGeneratorTool implements BaseTool<ImgGeneratorParameter> {
|
||||
readonly IMG_OUTPUT_FORMAT = "webp";
|
||||
readonly IMG_OUTPUT_DIR = "output/tool";
|
||||
readonly IMG_GEN_API =
|
||||
"https://api.stability.ai/v2beta/stable-image/generate/core";
|
||||
|
||||
metadata: ToolMetadata<JSONSchemaType<ImgGeneratorParameter>>;
|
||||
|
||||
constructor(params?: ImgGeneratorToolParams) {
|
||||
this.checkRequiredEnvVars();
|
||||
this.metadata = params?.metadata || DEFAULT_META_DATA;
|
||||
}
|
||||
|
||||
async call(input: ImgGeneratorParameter): Promise<ImgGeneratorToolOutput> {
|
||||
return await this.generateImage(input.prompt);
|
||||
}
|
||||
|
||||
private generateImage = async (
|
||||
prompt: string,
|
||||
): Promise<ImgGeneratorToolOutput> => {
|
||||
try {
|
||||
const buffer = await this.promptToImgBuffer(prompt);
|
||||
const imageUrl = this.saveImage(buffer);
|
||||
return { isSuccess: true, imageUrl };
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
return {
|
||||
isSuccess: false,
|
||||
errorMessage: "Failed to generate image. Please try again.",
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
private promptToImgBuffer = async (prompt: string) => {
|
||||
const form = new FormData();
|
||||
form.append("prompt", prompt);
|
||||
form.append("output_format", this.IMG_OUTPUT_FORMAT);
|
||||
const buffer = await got
|
||||
.post(this.IMG_GEN_API, {
|
||||
// Not sure why it shows an type error when passing form to body
|
||||
// Although I follow document: https://github.com/sindresorhus/got/blob/main/documentation/2-options.md#body
|
||||
// Tt still works fine, so I make casting to unknown to avoid the typescript warning
|
||||
// Found a similar issue: https://github.com/sindresorhus/got/discussions/1877
|
||||
body: form as unknown as Buffer | Readable | string,
|
||||
headers: {
|
||||
Authorization: `Bearer ${process.env.STABILITY_API_KEY}`,
|
||||
Accept: "image/*",
|
||||
},
|
||||
})
|
||||
.buffer();
|
||||
return buffer;
|
||||
};
|
||||
|
||||
private saveImage = (buffer: Buffer) => {
|
||||
const filename = `${crypto.randomUUID()}.${this.IMG_OUTPUT_FORMAT}`;
|
||||
const outputPath = path.join(this.IMG_OUTPUT_DIR, filename);
|
||||
fs.writeFileSync(outputPath, buffer);
|
||||
const url = `${process.env.FILESERVER_URL_PREFIX}/${this.IMG_OUTPUT_DIR}/${filename}`;
|
||||
console.log(`Saved image to ${outputPath}.\nURL: ${url}`);
|
||||
return url;
|
||||
};
|
||||
|
||||
private checkRequiredEnvVars = () => {
|
||||
if (!process.env.STABILITY_API_KEY) {
|
||||
throw new Error(
|
||||
"STABILITY_API_KEY key is required to run image generator. Get it here: https://platform.stability.ai/account/keys",
|
||||
);
|
||||
}
|
||||
if (!process.env.FILESERVER_URL_PREFIX) {
|
||||
throw new Error(
|
||||
"FILESERVER_URL_PREFIX is required to display file output after generation",
|
||||
);
|
||||
}
|
||||
};
|
||||
}
|
||||
@@ -1,42 +1,61 @@
|
||||
import { BaseToolWithCall } from "llamaindex";
|
||||
import { ToolsFactory } from "llamaindex/tools/ToolsFactory";
|
||||
import { DuckDuckGoSearchTool, DuckDuckGoToolParams } from "./duckduckgo";
|
||||
import { ImgGeneratorTool, ImgGeneratorToolParams } from "./img-gen";
|
||||
import { InterpreterTool, InterpreterToolParams } from "./interpreter";
|
||||
import { OpenAPIActionTool } from "./openapi-action";
|
||||
import { WeatherTool, WeatherToolParams } from "./weather";
|
||||
|
||||
type ToolCreator = (config: unknown) => BaseToolWithCall;
|
||||
type ToolCreator = (config: unknown) => Promise<BaseToolWithCall[]>;
|
||||
|
||||
export async function createTools(toolConfig: {
|
||||
local: Record<string, unknown>;
|
||||
llamahub: any;
|
||||
}): Promise<BaseToolWithCall[]> {
|
||||
// add local tools from the 'tools' folder (if configured)
|
||||
const tools = createLocalTools(toolConfig.local);
|
||||
const tools = await createLocalTools(toolConfig.local);
|
||||
// add tools from LlamaIndexTS (if configured)
|
||||
tools.push(...(await ToolsFactory.createTools(toolConfig.llamahub)));
|
||||
return tools;
|
||||
}
|
||||
|
||||
const toolFactory: Record<string, ToolCreator> = {
|
||||
weather: (config: unknown) => {
|
||||
return new WeatherTool(config as WeatherToolParams);
|
||||
weather: async (config: unknown) => {
|
||||
return [new WeatherTool(config as WeatherToolParams)];
|
||||
},
|
||||
interpreter: (config: unknown) => {
|
||||
return new InterpreterTool(config as InterpreterToolParams);
|
||||
interpreter: async (config: unknown) => {
|
||||
return [new InterpreterTool(config as InterpreterToolParams)];
|
||||
},
|
||||
"openapi_action.OpenAPIActionToolSpec": async (config: unknown) => {
|
||||
const { openapi_uri, domain_headers } = config as {
|
||||
openapi_uri: string;
|
||||
domain_headers: Record<string, Record<string, string>>;
|
||||
};
|
||||
const openAPIActionTool = new OpenAPIActionTool(
|
||||
openapi_uri,
|
||||
domain_headers,
|
||||
);
|
||||
return await openAPIActionTool.toToolFunctions();
|
||||
},
|
||||
duckduckgo: async (config: unknown) => {
|
||||
return [new DuckDuckGoSearchTool(config as DuckDuckGoToolParams)];
|
||||
},
|
||||
img_gen: async (config: unknown) => {
|
||||
return [new ImgGeneratorTool(config as ImgGeneratorToolParams)];
|
||||
},
|
||||
};
|
||||
|
||||
function createLocalTools(
|
||||
async function createLocalTools(
|
||||
localConfig: Record<string, unknown>,
|
||||
): BaseToolWithCall[] {
|
||||
): Promise<BaseToolWithCall[]> {
|
||||
const tools: BaseToolWithCall[] = [];
|
||||
|
||||
Object.keys(localConfig).forEach((key) => {
|
||||
for (const [key, toolConfig] of Object.entries(localConfig)) {
|
||||
if (key in toolFactory) {
|
||||
const toolConfig = localConfig[key];
|
||||
const tool = toolFactory[key](toolConfig);
|
||||
tools.push(tool);
|
||||
const newTools = await toolFactory[key](toolConfig);
|
||||
tools.push(...newTools);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return tools;
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ export type InterpreterToolParams = {
|
||||
fileServerURLPrefix?: string;
|
||||
};
|
||||
|
||||
export type InterpreterToolOuput = {
|
||||
export type InterpreterToolOutput = {
|
||||
isError: boolean;
|
||||
logs: Logs;
|
||||
extraResult: InterpreterExtraResult[];
|
||||
@@ -34,8 +34,9 @@ type InterpreterExtraType =
|
||||
|
||||
export type InterpreterExtraResult = {
|
||||
type: InterpreterExtraType;
|
||||
filename: string;
|
||||
url: string;
|
||||
content?: string;
|
||||
filename?: string;
|
||||
url?: string;
|
||||
};
|
||||
|
||||
const DEFAULT_META_DATA: ToolMetadata<JSONSchemaType<InterpreterParameter>> = {
|
||||
@@ -55,7 +56,7 @@ const DEFAULT_META_DATA: ToolMetadata<JSONSchemaType<InterpreterParameter>> = {
|
||||
};
|
||||
|
||||
export class InterpreterTool implements BaseTool<InterpreterParameter> {
|
||||
private readonly outputDir = "tool-output";
|
||||
private readonly outputDir = "output/tool";
|
||||
private apiKey?: string;
|
||||
private fileServerURLPrefix?: string;
|
||||
metadata: ToolMetadata<JSONSchemaType<InterpreterParameter>>;
|
||||
@@ -88,7 +89,7 @@ export class InterpreterTool implements BaseTool<InterpreterParameter> {
|
||||
return this.codeInterpreter;
|
||||
}
|
||||
|
||||
public async codeInterpret(code: string): Promise<InterpreterToolOuput> {
|
||||
public async codeInterpret(code: string): Promise<InterpreterToolOutput> {
|
||||
console.log(
|
||||
`\n${"=".repeat(50)}\n> Running following AI-generated code:\n${code}\n${"=".repeat(50)}`,
|
||||
);
|
||||
@@ -96,7 +97,7 @@ export class InterpreterTool implements BaseTool<InterpreterParameter> {
|
||||
const exec = await interpreter.notebook.execCell(code);
|
||||
if (exec.error) console.error("[Code Interpreter error]", exec.error);
|
||||
const extraResult = await this.getExtraResult(exec.results[0]);
|
||||
const result: InterpreterToolOuput = {
|
||||
const result: InterpreterToolOutput = {
|
||||
isError: !!exec.error,
|
||||
logs: exec.logs,
|
||||
extraResult,
|
||||
@@ -104,12 +105,15 @@ export class InterpreterTool implements BaseTool<InterpreterParameter> {
|
||||
return result;
|
||||
}
|
||||
|
||||
async call(input: InterpreterParameter): Promise<InterpreterToolOuput> {
|
||||
async call(input: InterpreterParameter): Promise<InterpreterToolOutput> {
|
||||
const result = await this.codeInterpret(input.code);
|
||||
await this.codeInterpreter?.close();
|
||||
return result;
|
||||
}
|
||||
|
||||
async close() {
|
||||
await this.codeInterpreter?.close();
|
||||
}
|
||||
|
||||
private async getExtraResult(
|
||||
res?: Result,
|
||||
): Promise<InterpreterExtraResult[]> {
|
||||
@@ -118,23 +122,34 @@ export class InterpreterTool implements BaseTool<InterpreterParameter> {
|
||||
|
||||
try {
|
||||
const formats = res.formats(); // formats available for the result. Eg: ['png', ...]
|
||||
const base64DataArr = formats.map((f) => res[f as keyof Result]); // get base64 data for each format
|
||||
const results = formats.map((f) => res[f as keyof Result]); // get base64 data for each format
|
||||
|
||||
// save base64 data to file and return the url
|
||||
for (let i = 0; i < formats.length; i++) {
|
||||
const ext = formats[i];
|
||||
const base64Data = base64DataArr[i];
|
||||
if (ext && base64Data) {
|
||||
const { filename } = this.saveToDisk(base64Data, ext);
|
||||
output.push({
|
||||
type: ext as InterpreterExtraType,
|
||||
filename,
|
||||
url: this.getFileUrl(filename),
|
||||
});
|
||||
const data = results[i];
|
||||
switch (ext) {
|
||||
case "png":
|
||||
case "jpeg":
|
||||
case "svg":
|
||||
case "pdf":
|
||||
const { filename } = this.saveToDisk(data, ext);
|
||||
output.push({
|
||||
type: ext as InterpreterExtraType,
|
||||
filename,
|
||||
url: this.getFileUrl(filename),
|
||||
});
|
||||
break;
|
||||
default:
|
||||
output.push({
|
||||
type: ext as InterpreterExtraType,
|
||||
content: data,
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error when saving data to disk", error);
|
||||
console.error("Error when parsing e2b response", error);
|
||||
}
|
||||
|
||||
return output;
|
||||
|
||||
@@ -0,0 +1,164 @@
|
||||
import SwaggerParser from "@apidevtools/swagger-parser";
|
||||
import { JSONSchemaType } from "ajv";
|
||||
import got from "got";
|
||||
import { FunctionTool, JSONValue, ToolMetadata } from "llamaindex";
|
||||
|
||||
interface DomainHeaders {
|
||||
[key: string]: { [header: string]: string };
|
||||
}
|
||||
|
||||
type Input = {
|
||||
url: string;
|
||||
params: object;
|
||||
};
|
||||
|
||||
type APIInfo = {
|
||||
description: string;
|
||||
title: string;
|
||||
};
|
||||
|
||||
export class OpenAPIActionTool {
|
||||
// cache the loaded specs by URL
|
||||
private static specs: Record<string, any> = {};
|
||||
|
||||
private readonly INVALID_URL_PROMPT =
|
||||
"This url did not include a hostname or scheme. Please determine the complete URL and try again.";
|
||||
|
||||
private createLoadSpecMetaData = (info: APIInfo) => {
|
||||
return {
|
||||
name: "load_openapi_spec",
|
||||
description: `Use this to retrieve the OpenAPI spec for the API named ${info.title} with the following description: ${info.description}. Call it before making any requests to the API.`,
|
||||
};
|
||||
};
|
||||
|
||||
private readonly createMethodCallMetaData = (
|
||||
method: "POST" | "PATCH" | "GET",
|
||||
info: APIInfo,
|
||||
) => {
|
||||
return {
|
||||
name: `${method.toLowerCase()}_request`,
|
||||
description: `Use this to call the ${method} method on the API named ${info.title}`,
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
url: {
|
||||
type: "string",
|
||||
description: `The url to make the ${method} request against`,
|
||||
},
|
||||
params: {
|
||||
type: "object",
|
||||
description:
|
||||
method === "GET"
|
||||
? "the URL parameters to provide with the get request"
|
||||
: `the key-value pairs to provide with the ${method} request`,
|
||||
},
|
||||
},
|
||||
required: ["url"],
|
||||
},
|
||||
} as ToolMetadata<JSONSchemaType<Input>>;
|
||||
};
|
||||
|
||||
constructor(
|
||||
public openapi_uri: string,
|
||||
public domainHeaders: DomainHeaders = {},
|
||||
) {}
|
||||
|
||||
async loadOpenapiSpec(url: string): Promise<any> {
|
||||
const api = await SwaggerParser.validate(url);
|
||||
return {
|
||||
servers: "servers" in api ? api.servers : "",
|
||||
info: { description: api.info.description, title: api.info.title },
|
||||
endpoints: api.paths,
|
||||
};
|
||||
}
|
||||
|
||||
async getRequest(input: Input): Promise<JSONValue> {
|
||||
if (!this.validUrl(input.url)) {
|
||||
return this.INVALID_URL_PROMPT;
|
||||
}
|
||||
try {
|
||||
const data = await got
|
||||
.get(input.url, {
|
||||
headers: this.getHeadersForUrl(input.url),
|
||||
searchParams: input.params as URLSearchParams,
|
||||
})
|
||||
.json();
|
||||
return data as JSONValue;
|
||||
} catch (error) {
|
||||
return error as JSONValue;
|
||||
}
|
||||
}
|
||||
|
||||
async postRequest(input: Input): Promise<JSONValue> {
|
||||
if (!this.validUrl(input.url)) {
|
||||
return this.INVALID_URL_PROMPT;
|
||||
}
|
||||
try {
|
||||
const res = await got.post(input.url, {
|
||||
headers: this.getHeadersForUrl(input.url),
|
||||
json: input.params,
|
||||
});
|
||||
return res.body as JSONValue;
|
||||
} catch (error) {
|
||||
return error as JSONValue;
|
||||
}
|
||||
}
|
||||
|
||||
async patchRequest(input: Input): Promise<JSONValue> {
|
||||
if (!this.validUrl(input.url)) {
|
||||
return this.INVALID_URL_PROMPT;
|
||||
}
|
||||
try {
|
||||
const res = await got.patch(input.url, {
|
||||
headers: this.getHeadersForUrl(input.url),
|
||||
json: input.params,
|
||||
});
|
||||
return res.body as JSONValue;
|
||||
} catch (error) {
|
||||
return error as JSONValue;
|
||||
}
|
||||
}
|
||||
|
||||
public async toToolFunctions() {
|
||||
if (!OpenAPIActionTool.specs[this.openapi_uri]) {
|
||||
console.log(`Loading spec for URL: ${this.openapi_uri}`);
|
||||
const spec = await this.loadOpenapiSpec(this.openapi_uri);
|
||||
OpenAPIActionTool.specs[this.openapi_uri] = spec;
|
||||
}
|
||||
const spec = OpenAPIActionTool.specs[this.openapi_uri];
|
||||
// TODO: read endpoints with parameters from spec and create one tool for each endpoint
|
||||
// For now, we just create a tool for each HTTP method which does not work well for passing parameters
|
||||
return [
|
||||
FunctionTool.from(() => {
|
||||
return spec;
|
||||
}, this.createLoadSpecMetaData(spec.info)),
|
||||
FunctionTool.from(
|
||||
this.getRequest.bind(this),
|
||||
this.createMethodCallMetaData("GET", spec.info),
|
||||
),
|
||||
FunctionTool.from(
|
||||
this.postRequest.bind(this),
|
||||
this.createMethodCallMetaData("POST", spec.info),
|
||||
),
|
||||
FunctionTool.from(
|
||||
this.patchRequest.bind(this),
|
||||
this.createMethodCallMetaData("PATCH", spec.info),
|
||||
),
|
||||
];
|
||||
}
|
||||
|
||||
private validUrl(url: string): boolean {
|
||||
const parsed = new URL(url);
|
||||
return !!parsed.protocol && !!parsed.hostname;
|
||||
}
|
||||
|
||||
private getDomain(url: string): string {
|
||||
const parsed = new URL(url);
|
||||
return parsed.hostname;
|
||||
}
|
||||
|
||||
private getHeadersForUrl(url: string): { [header: string]: string } {
|
||||
const domain = this.getDomain(url);
|
||||
return this.domainHeaders[domain] || {};
|
||||
}
|
||||
}
|
||||
@@ -1,21 +1,21 @@
|
||||
import { ContextChatEngine, Settings } from "llamaindex";
|
||||
import { getDataSource } from "./index";
|
||||
|
||||
export async function createChatEngine() {
|
||||
export async function createChatEngine(documentIds?: string[]) {
|
||||
const index = await getDataSource();
|
||||
if (!index) {
|
||||
throw new Error(
|
||||
`StorageContext is empty - call 'npm run generate' to generate the storage first`,
|
||||
);
|
||||
}
|
||||
const retriever = index.asRetriever();
|
||||
retriever.similarityTopK = process.env.TOP_K
|
||||
? parseInt(process.env.TOP_K)
|
||||
: 3;
|
||||
const retriever = index.asRetriever({
|
||||
similarityTopK: process.env.TOP_K ? parseInt(process.env.TOP_K) : 3,
|
||||
});
|
||||
|
||||
return new ContextChatEngine({
|
||||
chatModel: Settings.llm,
|
||||
retriever,
|
||||
systemPrompt: process.env.SYSTEM_PROMPT,
|
||||
// disable as a custom system prompt disables the generated context
|
||||
// systemPrompt: process.env.SYSTEM_PROMPT,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
import fs from "fs";
|
||||
import crypto from "node:crypto";
|
||||
import { getExtractors } from "../../engine/loader";
|
||||
|
||||
const MIME_TYPE_TO_EXT: Record<string, string> = {
|
||||
"application/pdf": "pdf",
|
||||
"text/plain": "txt",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document":
|
||||
"docx",
|
||||
};
|
||||
|
||||
const UPLOADED_FOLDER = "output/uploaded";
|
||||
|
||||
export async function loadDocuments(fileBuffer: Buffer, mimeType: string) {
|
||||
const extractors = getExtractors();
|
||||
const reader = extractors[MIME_TYPE_TO_EXT[mimeType]];
|
||||
|
||||
if (!reader) {
|
||||
throw new Error(`Unsupported document type: ${mimeType}`);
|
||||
}
|
||||
console.log(`Processing uploaded document of type: ${mimeType}`);
|
||||
return await reader.loadDataAsContent(fileBuffer);
|
||||
}
|
||||
|
||||
export async function saveDocument(fileBuffer: Buffer, mimeType: string) {
|
||||
const fileExt = MIME_TYPE_TO_EXT[mimeType];
|
||||
if (!fileExt) throw new Error(`Unsupported document type: ${mimeType}`);
|
||||
|
||||
const filename = `${crypto.randomUUID()}.${fileExt}`;
|
||||
const filepath = `${UPLOADED_FOLDER}/${filename}`;
|
||||
const fileurl = `${process.env.FILESERVER_URL_PREFIX}/${filepath}`;
|
||||
|
||||
if (!fs.existsSync(UPLOADED_FOLDER)) {
|
||||
fs.mkdirSync(UPLOADED_FOLDER, { recursive: true });
|
||||
}
|
||||
await fs.promises.writeFile(filepath, fileBuffer);
|
||||
|
||||
console.log(`Saved document file to ${filepath}.\nURL: ${fileurl}`);
|
||||
return {
|
||||
filename,
|
||||
filepath,
|
||||
fileurl,
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
import {
|
||||
BaseNode,
|
||||
Document,
|
||||
IngestionPipeline,
|
||||
Metadata,
|
||||
Settings,
|
||||
SimpleNodeParser,
|
||||
storageContextFromDefaults,
|
||||
VectorStoreIndex,
|
||||
} from "llamaindex";
|
||||
import { LlamaCloudIndex } from "llamaindex/cloud/LlamaCloudIndex";
|
||||
import { getDataSource } from "../../engine";
|
||||
|
||||
export async function runPipeline(documents: Document[], filename: string) {
|
||||
const currentIndex = await getDataSource();
|
||||
|
||||
// Update documents with metadata
|
||||
for (const document of documents) {
|
||||
document.metadata = {
|
||||
...document.metadata,
|
||||
file_name: filename,
|
||||
private: "true", // to separate from other public documents
|
||||
};
|
||||
}
|
||||
|
||||
if (currentIndex instanceof LlamaCloudIndex) {
|
||||
// LlamaCloudIndex processes the documents automatically
|
||||
// so we don't need ingestion pipeline, just insert the documents directly
|
||||
for (const document of documents) {
|
||||
await currentIndex.insert(document);
|
||||
}
|
||||
} else {
|
||||
// Use ingestion pipeline to process the documents into nodes and add them to the vector store
|
||||
const pipeline = new IngestionPipeline({
|
||||
transformations: [
|
||||
new SimpleNodeParser({
|
||||
chunkSize: Settings.chunkSize,
|
||||
chunkOverlap: Settings.chunkOverlap,
|
||||
}),
|
||||
Settings.embedModel,
|
||||
],
|
||||
});
|
||||
const nodes = await pipeline.run({ documents });
|
||||
await addNodesToVectorStore(nodes, currentIndex);
|
||||
}
|
||||
|
||||
return documents.map((document) => document.id_);
|
||||
}
|
||||
|
||||
async function addNodesToVectorStore(
|
||||
nodes: BaseNode<Metadata>[],
|
||||
currentIndex: VectorStoreIndex | null,
|
||||
) {
|
||||
if (currentIndex) {
|
||||
await currentIndex.insertNodes(nodes);
|
||||
} else {
|
||||
// Not using vectordb and haven't generated local index yet
|
||||
const storageContext = await storageContextFromDefaults({
|
||||
persistDir: "./cache",
|
||||
});
|
||||
currentIndex = await VectorStoreIndex.init({ nodes, storageContext });
|
||||
}
|
||||
currentIndex.storageContext.docStore.persist();
|
||||
console.log("Added nodes to the vector store.");
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
import { loadDocuments, saveDocument } from "./helper";
|
||||
import { runPipeline } from "./pipeline";
|
||||
|
||||
export async function uploadDocument(raw: string): Promise<string[]> {
|
||||
const [header, content] = raw.split(",");
|
||||
const mimeType = header.replace("data:", "").replace(";base64", "");
|
||||
const fileBuffer = Buffer.from(content, "base64");
|
||||
const documents = await loadDocuments(fileBuffer, mimeType);
|
||||
const { filename } = await saveDocument(fileBuffer, mimeType);
|
||||
return await runPipeline(documents, filename);
|
||||
}
|
||||
@@ -0,0 +1,124 @@
|
||||
import { JSONValue } from "ai";
|
||||
import { MessageContent, MessageContentDetail } from "llamaindex";
|
||||
|
||||
export type DocumentFileType = "csv" | "pdf" | "txt" | "docx";
|
||||
|
||||
export type DocumentFileContent = {
|
||||
type: "ref" | "text";
|
||||
value: string[] | string;
|
||||
};
|
||||
|
||||
export type DocumentFile = {
|
||||
id: string;
|
||||
filename: string;
|
||||
filesize: number;
|
||||
filetype: DocumentFileType;
|
||||
content: DocumentFileContent;
|
||||
};
|
||||
|
||||
type Annotation = {
|
||||
type: string;
|
||||
data: object;
|
||||
};
|
||||
|
||||
export function retrieveDocumentIds(annotations?: JSONValue[]): string[] {
|
||||
if (!annotations) return [];
|
||||
|
||||
const ids: string[] = [];
|
||||
|
||||
for (const annotation of annotations) {
|
||||
const { type, data } = getValidAnnotation(annotation);
|
||||
if (
|
||||
type === "document_file" &&
|
||||
"files" in data &&
|
||||
Array.isArray(data.files)
|
||||
) {
|
||||
const files = data.files as DocumentFile[];
|
||||
for (const file of files) {
|
||||
if (Array.isArray(file.content.value)) {
|
||||
// it's an array, so it's an array of doc IDs
|
||||
for (const id of file.content.value) {
|
||||
ids.push(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ids;
|
||||
}
|
||||
|
||||
export function convertMessageContent(
|
||||
content: string,
|
||||
annotations?: JSONValue[],
|
||||
): MessageContent {
|
||||
if (!annotations) return content;
|
||||
return [
|
||||
{
|
||||
type: "text",
|
||||
text: content,
|
||||
},
|
||||
...convertAnnotations(annotations),
|
||||
];
|
||||
}
|
||||
|
||||
function convertAnnotations(annotations: JSONValue[]): MessageContentDetail[] {
|
||||
const content: MessageContentDetail[] = [];
|
||||
annotations.forEach((annotation: JSONValue) => {
|
||||
const { type, data } = getValidAnnotation(annotation);
|
||||
// convert image
|
||||
if (type === "image" && "url" in data && typeof data.url === "string") {
|
||||
content.push({
|
||||
type: "image_url",
|
||||
image_url: {
|
||||
url: data.url,
|
||||
},
|
||||
});
|
||||
}
|
||||
// convert the content of files to a text message
|
||||
if (
|
||||
type === "document_file" &&
|
||||
"files" in data &&
|
||||
Array.isArray(data.files)
|
||||
) {
|
||||
// get all CSV files and convert their whole content to one text message
|
||||
// currently CSV files are the only files where we send the whole content - we don't use an index
|
||||
const csvFiles: DocumentFile[] = data.files.filter(
|
||||
(file: DocumentFile) => file.filetype === "csv",
|
||||
);
|
||||
if (csvFiles && csvFiles.length > 0) {
|
||||
const csvContents = csvFiles.map((file: DocumentFile) => {
|
||||
const fileContent = Array.isArray(file.content.value)
|
||||
? file.content.value.join("\n")
|
||||
: file.content.value;
|
||||
return "```csv\n" + fileContent + "\n```";
|
||||
});
|
||||
const text =
|
||||
"Use the following CSV content:\n" + csvContents.join("\n\n");
|
||||
content.push({
|
||||
type: "text",
|
||||
text,
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
return content;
|
||||
}
|
||||
|
||||
function getValidAnnotation(annotation: JSONValue): Annotation {
|
||||
if (
|
||||
!(
|
||||
annotation &&
|
||||
typeof annotation === "object" &&
|
||||
"type" in annotation &&
|
||||
typeof annotation.type === "string" &&
|
||||
"data" in annotation &&
|
||||
annotation.data &&
|
||||
typeof annotation.data === "object"
|
||||
)
|
||||
) {
|
||||
throw new Error("Client sent invalid annotation. Missing data and type");
|
||||
}
|
||||
return { type: annotation.type, data: annotation.data };
|
||||
}
|
||||
@@ -0,0 +1,128 @@
|
||||
import { StreamData } from "ai";
|
||||
import {
|
||||
CallbackManager,
|
||||
Metadata,
|
||||
NodeWithScore,
|
||||
ToolCall,
|
||||
ToolOutput,
|
||||
} from "llamaindex";
|
||||
import { LLamaCloudFileService } from "./service";
|
||||
|
||||
export function appendSourceData(
|
||||
data: StreamData,
|
||||
sourceNodes?: NodeWithScore<Metadata>[],
|
||||
) {
|
||||
if (!sourceNodes?.length) return;
|
||||
try {
|
||||
const nodes = sourceNodes.map((node) => ({
|
||||
...node.node.toMutableJSON(),
|
||||
id: node.node.id_,
|
||||
score: node.score ?? null,
|
||||
url: getNodeUrl(node.node.metadata),
|
||||
}));
|
||||
data.appendMessageAnnotation({
|
||||
type: "sources",
|
||||
data: {
|
||||
nodes,
|
||||
},
|
||||
});
|
||||
} catch (error) {
|
||||
console.error("Error appending source data:", error);
|
||||
}
|
||||
}
|
||||
|
||||
export function appendEventData(data: StreamData, title?: string) {
|
||||
if (!title) return;
|
||||
data.appendMessageAnnotation({
|
||||
type: "events",
|
||||
data: {
|
||||
title,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
export function appendToolData(
|
||||
data: StreamData,
|
||||
toolCall: ToolCall,
|
||||
toolOutput: ToolOutput,
|
||||
) {
|
||||
data.appendMessageAnnotation({
|
||||
type: "tools",
|
||||
data: {
|
||||
toolCall: {
|
||||
id: toolCall.id,
|
||||
name: toolCall.name,
|
||||
input: toolCall.input,
|
||||
},
|
||||
toolOutput: {
|
||||
output: toolOutput.output,
|
||||
isError: toolOutput.isError,
|
||||
},
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
export function createStreamTimeout(stream: StreamData) {
|
||||
const timeout = Number(process.env.STREAM_TIMEOUT ?? 1000 * 60 * 5); // default to 5 minutes
|
||||
const t = setTimeout(() => {
|
||||
appendEventData(stream, `Stream timed out after ${timeout / 1000} seconds`);
|
||||
stream.close();
|
||||
}, timeout);
|
||||
return t;
|
||||
}
|
||||
|
||||
export function createCallbackManager(stream: StreamData) {
|
||||
const callbackManager = new CallbackManager();
|
||||
|
||||
callbackManager.on("retrieve-end", (data) => {
|
||||
const { nodes, query } = data.detail;
|
||||
appendSourceData(stream, nodes);
|
||||
appendEventData(stream, `Retrieving context for query: '${query}'`);
|
||||
appendEventData(
|
||||
stream,
|
||||
`Retrieved ${nodes.length} sources to use as context for the query`,
|
||||
);
|
||||
LLamaCloudFileService.downloadFiles(nodes); // don't await to avoid blocking chat streaming
|
||||
});
|
||||
|
||||
callbackManager.on("llm-tool-call", (event) => {
|
||||
const { name, input } = event.detail.toolCall;
|
||||
const inputString = Object.entries(input)
|
||||
.map(([key, value]) => `${key}: ${value}`)
|
||||
.join(", ");
|
||||
appendEventData(
|
||||
stream,
|
||||
`Using tool: '${name}' with inputs: '${inputString}'`,
|
||||
);
|
||||
});
|
||||
|
||||
callbackManager.on("llm-tool-result", (event) => {
|
||||
const { toolCall, toolResult } = event.detail;
|
||||
appendToolData(stream, toolCall, toolResult);
|
||||
});
|
||||
|
||||
return callbackManager;
|
||||
}
|
||||
|
||||
function getNodeUrl(metadata: Metadata) {
|
||||
if (!process.env.FILESERVER_URL_PREFIX) {
|
||||
console.warn(
|
||||
"FILESERVER_URL_PREFIX is not set. File URLs will not be generated.",
|
||||
);
|
||||
}
|
||||
const fileName = metadata["file_name"];
|
||||
if (fileName && process.env.FILESERVER_URL_PREFIX) {
|
||||
// file_name exists and file server is configured
|
||||
const pipelineId = metadata["pipeline_id"];
|
||||
if (pipelineId && metadata["private"] == null) {
|
||||
// file is from LlamaCloud and was not ingested locally
|
||||
const name = LLamaCloudFileService.toDownloadedName(pipelineId, fileName);
|
||||
return `${process.env.FILESERVER_URL_PREFIX}/output/llamacloud/${name}`;
|
||||
}
|
||||
const isPrivate = metadata["private"] === "true";
|
||||
const folder = isPrivate ? "output/uploaded" : "data";
|
||||
return `${process.env.FILESERVER_URL_PREFIX}/${folder}/${fileName}`;
|
||||
}
|
||||
// fallback to URL in metadata (e.g. for websites)
|
||||
return metadata["URL"];
|
||||
}
|
||||
@@ -0,0 +1,128 @@
|
||||
import { Metadata, NodeWithScore } from "llamaindex";
|
||||
import fs from "node:fs";
|
||||
import https from "node:https";
|
||||
import path from "node:path";
|
||||
|
||||
const LLAMA_CLOUD_OUTPUT_DIR = "output/llamacloud";
|
||||
const LLAMA_CLOUD_BASE_URL = "https://cloud.llamaindex.ai/api/v1";
|
||||
const FILE_DELIMITER = "$"; // delimiter between pipelineId and filename
|
||||
|
||||
interface LlamaCloudFile {
|
||||
name: string;
|
||||
file_id: string;
|
||||
project_id: string;
|
||||
}
|
||||
|
||||
export class LLamaCloudFileService {
|
||||
public static async downloadFiles(nodes: NodeWithScore<Metadata>[]) {
|
||||
const files = this.nodesToDownloadFiles(nodes);
|
||||
if (!files.length) return;
|
||||
console.log("Downloading files from LlamaCloud...");
|
||||
for (const file of files) {
|
||||
await this.downloadFile(file.pipelineId, file.fileName);
|
||||
}
|
||||
}
|
||||
|
||||
public static toDownloadedName(pipelineId: string, fileName: string) {
|
||||
return `${pipelineId}${FILE_DELIMITER}${fileName}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* This function will return an array of unique files to download from LlamaCloud
|
||||
* We only download files that are uploaded directly in LlamaCloud datasources (don't have `private` in metadata)
|
||||
* Files are uploaded directly in LlamaCloud datasources don't have `private` in metadata (public docs)
|
||||
* Files are uploaded from local via `generate` command will have `private=false` (public docs)
|
||||
* Files are uploaded from local via `/chat/upload` endpoint will have `private=true` (private docs)
|
||||
*
|
||||
* @param nodes
|
||||
* @returns list of unique files to download
|
||||
*/
|
||||
private static nodesToDownloadFiles(nodes: NodeWithScore<Metadata>[]) {
|
||||
const downloadFiles: Array<{
|
||||
pipelineId: string;
|
||||
fileName: string;
|
||||
}> = [];
|
||||
for (const node of nodes) {
|
||||
const isLocalFile = node.node.metadata["private"] != null;
|
||||
const pipelineId = node.node.metadata["pipeline_id"];
|
||||
const fileName = node.node.metadata["file_name"];
|
||||
if (isLocalFile || !pipelineId || !fileName) continue;
|
||||
const isDuplicate = downloadFiles.some(
|
||||
(f) => f.pipelineId === pipelineId && f.fileName === fileName,
|
||||
);
|
||||
if (!isDuplicate) {
|
||||
downloadFiles.push({ pipelineId, fileName });
|
||||
}
|
||||
}
|
||||
return downloadFiles;
|
||||
}
|
||||
|
||||
private static async downloadFile(pipelineId: string, fileName: string) {
|
||||
try {
|
||||
const downloadedName = this.toDownloadedName(pipelineId, fileName);
|
||||
const downloadedPath = path.join(LLAMA_CLOUD_OUTPUT_DIR, downloadedName);
|
||||
|
||||
// Check if file already exists
|
||||
if (fs.existsSync(downloadedPath)) return;
|
||||
|
||||
const urlToDownload = await this.getFileUrlByName(pipelineId, fileName);
|
||||
if (!urlToDownload) throw new Error("File not found in LlamaCloud");
|
||||
|
||||
const file = fs.createWriteStream(downloadedPath);
|
||||
https
|
||||
.get(urlToDownload, (response) => {
|
||||
response.pipe(file);
|
||||
file.on("finish", () => {
|
||||
file.close(() => {
|
||||
console.log("File downloaded successfully");
|
||||
});
|
||||
});
|
||||
})
|
||||
.on("error", (err) => {
|
||||
fs.unlink(downloadedPath, () => {
|
||||
console.error("Error downloading file:", err);
|
||||
throw err;
|
||||
});
|
||||
});
|
||||
} catch (error) {
|
||||
throw new Error(`Error downloading file from LlamaCloud: ${error}`);
|
||||
}
|
||||
}
|
||||
|
||||
private static async getFileUrlByName(
|
||||
pipelineId: string,
|
||||
name: string,
|
||||
): Promise<string | null> {
|
||||
const files = await this.getAllFiles(pipelineId);
|
||||
const file = files.find((file) => file.name === name);
|
||||
if (!file) return null;
|
||||
return await this.getFileUrlById(file.project_id, file.file_id);
|
||||
}
|
||||
|
||||
private static async getFileUrlById(
|
||||
projectId: string,
|
||||
fileId: string,
|
||||
): Promise<string> {
|
||||
const url = `${LLAMA_CLOUD_BASE_URL}/files/${fileId}/content?project_id=${projectId}`;
|
||||
const headers = {
|
||||
Accept: "application/json",
|
||||
Authorization: `Bearer ${process.env.LLAMA_CLOUD_API_KEY}`,
|
||||
};
|
||||
const response = await fetch(url, { method: "GET", headers });
|
||||
const data = (await response.json()) as { url: string };
|
||||
return data.url;
|
||||
}
|
||||
|
||||
private static async getAllFiles(
|
||||
pipelineId: string,
|
||||
): Promise<LlamaCloudFile[]> {
|
||||
const url = `${LLAMA_CLOUD_BASE_URL}/pipelines/${pipelineId}/files`;
|
||||
const headers = {
|
||||
Accept: "application/json",
|
||||
Authorization: `Bearer ${process.env.LLAMA_CLOUD_API_KEY}`,
|
||||
};
|
||||
const response = await fetch(url, { method: "GET", headers });
|
||||
const data = await response.json();
|
||||
return data;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
import {
|
||||
StreamData,
|
||||
createCallbacksTransformer,
|
||||
createStreamDataTransformer,
|
||||
trimStartOfStreamHelper,
|
||||
type AIStreamCallbacksAndOptions,
|
||||
} from "ai";
|
||||
import { ChatMessage, EngineResponse } from "llamaindex";
|
||||
import { generateNextQuestions } from "./suggestion";
|
||||
|
||||
export function LlamaIndexStream(
|
||||
response: AsyncIterable<EngineResponse>,
|
||||
data: StreamData,
|
||||
chatHistory: ChatMessage[],
|
||||
opts?: {
|
||||
callbacks?: AIStreamCallbacksAndOptions;
|
||||
},
|
||||
): ReadableStream<Uint8Array> {
|
||||
return createParser(response, data, chatHistory)
|
||||
.pipeThrough(createCallbacksTransformer(opts?.callbacks))
|
||||
.pipeThrough(createStreamDataTransformer());
|
||||
}
|
||||
|
||||
function createParser(
|
||||
res: AsyncIterable<EngineResponse>,
|
||||
data: StreamData,
|
||||
chatHistory: ChatMessage[],
|
||||
) {
|
||||
const it = res[Symbol.asyncIterator]();
|
||||
const trimStartOfStream = trimStartOfStreamHelper();
|
||||
let llmTextResponse = "";
|
||||
|
||||
return new ReadableStream<string>({
|
||||
async pull(controller): Promise<void> {
|
||||
const { value, done } = await it.next();
|
||||
if (done) {
|
||||
controller.close();
|
||||
// LLM stream is done, generate the next questions with a new LLM call
|
||||
chatHistory.push({ role: "assistant", content: llmTextResponse });
|
||||
const questions: string[] = await generateNextQuestions(chatHistory);
|
||||
if (questions.length > 0) {
|
||||
data.appendMessageAnnotation({
|
||||
type: "suggested_questions",
|
||||
data: questions,
|
||||
});
|
||||
}
|
||||
data.close();
|
||||
return;
|
||||
}
|
||||
const text = trimStartOfStream(value.delta ?? "");
|
||||
if (text) {
|
||||
llmTextResponse += text;
|
||||
controller.enqueue(text);
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
import { ChatMessage, Settings } from "llamaindex";
|
||||
|
||||
const NEXT_QUESTION_PROMPT_TEMPLATE = `You're a helpful assistant! Your task is to suggest the next question that user might ask.
|
||||
Here is the conversation history
|
||||
---------------------
|
||||
$conversation
|
||||
---------------------
|
||||
Given the conversation history, please give me $number_of_questions questions that you might ask next!
|
||||
Your answer should be wrapped in three sticks which follows the following format:
|
||||
\`\`\`
|
||||
<question 1>
|
||||
<question 2>\`\`\`
|
||||
`;
|
||||
const N_QUESTIONS_TO_GENERATE = 3;
|
||||
|
||||
export async function generateNextQuestions(
|
||||
conversation: ChatMessage[],
|
||||
numberOfQuestions: number = N_QUESTIONS_TO_GENERATE,
|
||||
) {
|
||||
const llm = Settings.llm;
|
||||
|
||||
// Format conversation
|
||||
const conversationText = conversation
|
||||
.map((message) => `${message.role}: ${message.content}`)
|
||||
.join("\n");
|
||||
const message = NEXT_QUESTION_PROMPT_TEMPLATE.replace(
|
||||
"$conversation",
|
||||
conversationText,
|
||||
).replace("$number_of_questions", numberOfQuestions.toString());
|
||||
|
||||
try {
|
||||
const response = await llm.complete({ prompt: message });
|
||||
const questions = extractQuestions(response.text);
|
||||
return questions;
|
||||
} catch (error) {
|
||||
console.error("Error: ", error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: instead of parsing the LLM's result we can use structured predict, once LITS supports it
|
||||
function extractQuestions(text: string): string[] {
|
||||
// Extract the text inside the triple backticks
|
||||
// @ts-ignore
|
||||
const contentMatch = text.match(/```(.*?)```/s);
|
||||
const content = contentMatch ? contentMatch[1] : "";
|
||||
|
||||
// Split the content by newlines to get each question
|
||||
const questions = content
|
||||
.split("\n")
|
||||
.map((question) => question.trim())
|
||||
.filter((question) => question !== "");
|
||||
|
||||
return questions;
|
||||
}
|
||||
@@ -1,11 +1,9 @@
|
||||
import os
|
||||
import yaml
|
||||
import importlib
|
||||
import logging
|
||||
from typing import Dict
|
||||
|
||||
import yaml
|
||||
from app.engine.loaders.db import DBLoaderConfig, get_db_documents
|
||||
from app.engine.loaders.file import FileLoaderConfig, get_file_documents
|
||||
from app.engine.loaders.web import WebLoaderConfig, get_web_documents
|
||||
from app.engine.loaders.db import DBLoaderConfig, get_db_documents
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict
|
||||
from llama_parse import LlamaParse
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
@@ -23,25 +24,46 @@ def llama_parse_parser():
|
||||
"LLAMA_CLOUD_API_KEY environment variable is not set. "
|
||||
"Please set it in .env file or in your shell environment then run again!"
|
||||
)
|
||||
parser = LlamaParse(result_type="markdown", verbose=True, language="en")
|
||||
parser = LlamaParse(
|
||||
result_type="markdown",
|
||||
verbose=True,
|
||||
language="en",
|
||||
ignore_errors=False,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def llama_parse_extractor() -> Dict[str, LlamaParse]:
|
||||
from llama_parse.utils import SUPPORTED_FILE_TYPES
|
||||
|
||||
parser = llama_parse_parser()
|
||||
return {file_type: parser for file_type in SUPPORTED_FILE_TYPES}
|
||||
|
||||
|
||||
def get_file_documents(config: FileLoaderConfig):
|
||||
from llama_index.core.readers import SimpleDirectoryReader
|
||||
|
||||
try:
|
||||
file_extractor = None
|
||||
if config.use_llama_parse:
|
||||
# LlamaParse is async first,
|
||||
# so we need to use nest_asyncio to run it in sync mode
|
||||
import nest_asyncio
|
||||
|
||||
nest_asyncio.apply()
|
||||
|
||||
file_extractor = llama_parse_extractor()
|
||||
reader = SimpleDirectoryReader(
|
||||
config.data_dir,
|
||||
recursive=True,
|
||||
filename_as_id=True,
|
||||
raise_on_error=True,
|
||||
file_extractor=file_extractor,
|
||||
)
|
||||
if config.use_llama_parse:
|
||||
parser = llama_parse_parser()
|
||||
reader.file_extractor = {".pdf": parser}
|
||||
return reader.load_data()
|
||||
except ValueError as e:
|
||||
import sys, traceback
|
||||
except Exception as e:
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
# Catch the error if the data dir is empty
|
||||
# and return as empty document list
|
||||
|
||||
@@ -1,7 +1,14 @@
|
||||
import { SimpleDirectoryReader } from "llamaindex";
|
||||
import {
|
||||
FILE_EXT_TO_READER,
|
||||
SimpleDirectoryReader,
|
||||
} from "llamaindex/readers/SimpleDirectoryReader";
|
||||
|
||||
export const DATA_DIR = "./data";
|
||||
|
||||
export function getExtractors() {
|
||||
return FILE_EXT_TO_READER;
|
||||
}
|
||||
|
||||
export async function getDocuments() {
|
||||
return await new SimpleDirectoryReader().loadData({
|
||||
directoryPath: DATA_DIR,
|
||||
|
||||
@@ -1,19 +1,30 @@
|
||||
import { LlamaParseReader } from "llamaindex/readers/LlamaParseReader";
|
||||
import {
|
||||
FILE_EXT_TO_READER,
|
||||
LlamaParseReader,
|
||||
SimpleDirectoryReader,
|
||||
} from "llamaindex";
|
||||
} from "llamaindex/readers/SimpleDirectoryReader";
|
||||
|
||||
export const DATA_DIR = "./data";
|
||||
|
||||
export function getExtractors() {
|
||||
const llamaParseParser = new LlamaParseReader({ resultType: "markdown" });
|
||||
const extractors = FILE_EXT_TO_READER;
|
||||
// Change all the supported extractors to LlamaParse
|
||||
// except for .txt, it doesn't need to be parsed
|
||||
for (const key in extractors) {
|
||||
if (key === "txt") {
|
||||
continue;
|
||||
}
|
||||
extractors[key] = llamaParseParser;
|
||||
}
|
||||
return extractors;
|
||||
}
|
||||
|
||||
export async function getDocuments() {
|
||||
const reader = new SimpleDirectoryReader();
|
||||
// Load PDFs using LlamaParseReader
|
||||
const extractors = getExtractors();
|
||||
return await reader.loadData({
|
||||
directoryPath: DATA_DIR,
|
||||
fileExtToReader: {
|
||||
...FILE_EXT_TO_READER,
|
||||
pdf: new LlamaParseReader({ resultType: "markdown" }),
|
||||
},
|
||||
fileExtToReader: extractors,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
from llama_index.core.settings import Settings
|
||||
from typing import Dict
|
||||
import os
|
||||
|
||||
DEFAULT_MODEL = "gpt-3.5-turbo"
|
||||
DEFAULT_EMBEDDING_MODEL = "text-embedding-3-large"
|
||||
|
||||
class TSIEmbedding(OpenAIEmbedding):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._query_engine = self._text_engine = self.model_name
|
||||
|
||||
def llm_config_from_env() -> Dict:
|
||||
from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||
|
||||
model = os.getenv("MODEL", DEFAULT_MODEL)
|
||||
temperature = os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)
|
||||
max_tokens = os.getenv("LLM_MAX_TOKENS")
|
||||
api_key = os.getenv("T_SYSTEMS_LLMHUB_API_KEY")
|
||||
api_base = os.getenv("T_SYSTEMS_LLMHUB_BASE_URL")
|
||||
|
||||
config = {
|
||||
"model": model,
|
||||
"api_key": api_key,
|
||||
"api_base": api_base,
|
||||
"temperature": float(temperature),
|
||||
"max_tokens": int(max_tokens) if max_tokens is not None else None,
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def embedding_config_from_env() -> Dict:
|
||||
from llama_index.core.constants import DEFAULT_EMBEDDING_DIM
|
||||
|
||||
model = os.getenv("EMBEDDING_MODEL", DEFAULT_EMBEDDING_MODEL)
|
||||
dimension = os.getenv("EMBEDDING_DIM", DEFAULT_EMBEDDING_DIM)
|
||||
api_key = os.getenv("T_SYSTEMS_LLMHUB_API_KEY")
|
||||
api_base = os.getenv("T_SYSTEMS_LLMHUB_BASE_URL")
|
||||
|
||||
config = {
|
||||
"model_name": model,
|
||||
"dimension": int(dimension) if dimension is not None else None,
|
||||
"api_key": api_key,
|
||||
"api_base": api_base,
|
||||
}
|
||||
return config
|
||||
|
||||
def init_llmhub():
|
||||
from llama_index.llms.openai_like import OpenAILike
|
||||
|
||||
llm_configs = llm_config_from_env()
|
||||
embedding_configs = embedding_config_from_env()
|
||||
|
||||
Settings.embed_model = TSIEmbedding(**embedding_configs)
|
||||
Settings.llm = OpenAILike(
|
||||
**llm_configs,
|
||||
is_chat_model=True,
|
||||
is_function_calling_model=False,
|
||||
context_window=4096,
|
||||
)
|
||||
@@ -0,0 +1,172 @@
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from llama_index.core.settings import Settings
|
||||
|
||||
|
||||
def init_settings():
|
||||
model_provider = os.getenv("MODEL_PROVIDER")
|
||||
match model_provider:
|
||||
case "openai":
|
||||
init_openai()
|
||||
case "groq":
|
||||
init_groq()
|
||||
case "ollama":
|
||||
init_ollama()
|
||||
case "anthropic":
|
||||
init_anthropic()
|
||||
case "gemini":
|
||||
init_gemini()
|
||||
case "mistral":
|
||||
init_mistral()
|
||||
case "azure-openai":
|
||||
init_azure_openai()
|
||||
case "t-systems":
|
||||
from .llmhub import init_llmhub
|
||||
|
||||
init_llmhub()
|
||||
case _:
|
||||
raise ValueError(f"Invalid model provider: {model_provider}")
|
||||
|
||||
Settings.chunk_size = int(os.getenv("CHUNK_SIZE", "1024"))
|
||||
Settings.chunk_overlap = int(os.getenv("CHUNK_OVERLAP", "20"))
|
||||
|
||||
|
||||
def init_ollama():
|
||||
from llama_index.embeddings.ollama import OllamaEmbedding
|
||||
from llama_index.llms.ollama.base import DEFAULT_REQUEST_TIMEOUT, Ollama
|
||||
|
||||
base_url = os.getenv("OLLAMA_BASE_URL") or "http://127.0.0.1:11434"
|
||||
request_timeout = float(
|
||||
os.getenv("OLLAMA_REQUEST_TIMEOUT", DEFAULT_REQUEST_TIMEOUT)
|
||||
)
|
||||
Settings.embed_model = OllamaEmbedding(
|
||||
base_url=base_url,
|
||||
model_name=os.getenv("EMBEDDING_MODEL"),
|
||||
)
|
||||
Settings.llm = Ollama(
|
||||
base_url=base_url, model=os.getenv("MODEL"), request_timeout=request_timeout
|
||||
)
|
||||
|
||||
|
||||
def init_openai():
|
||||
from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
from llama_index.llms.openai import OpenAI
|
||||
|
||||
max_tokens = os.getenv("LLM_MAX_TOKENS")
|
||||
config = {
|
||||
"model": os.getenv("MODEL"),
|
||||
"temperature": float(os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)),
|
||||
"max_tokens": int(max_tokens) if max_tokens is not None else None,
|
||||
}
|
||||
Settings.llm = OpenAI(**config)
|
||||
|
||||
dimensions = os.getenv("EMBEDDING_DIM")
|
||||
config = {
|
||||
"model": os.getenv("EMBEDDING_MODEL"),
|
||||
"dimensions": int(dimensions) if dimensions is not None else None,
|
||||
}
|
||||
Settings.embed_model = OpenAIEmbedding(**config)
|
||||
|
||||
|
||||
def init_azure_openai():
|
||||
from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
|
||||
from llama_index.llms.azure_openai import AzureOpenAI
|
||||
|
||||
llm_deployment = os.environ["AZURE_OPENAI_LLM_DEPLOYMENT"]
|
||||
embedding_deployment = os.environ["AZURE_OPENAI_EMBEDDING_DEPLOYMENT"]
|
||||
max_tokens = os.getenv("LLM_MAX_TOKENS")
|
||||
temperature = os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)
|
||||
dimensions = os.getenv("EMBEDDING_DIM")
|
||||
|
||||
azure_config = {
|
||||
"api_key": os.environ["AZURE_OPENAI_KEY"],
|
||||
"azure_endpoint": os.environ["AZURE_OPENAI_ENDPOINT"],
|
||||
"api_version": os.getenv("AZURE_OPENAI_API_VERSION")
|
||||
or os.getenv("OPENAI_API_VERSION"),
|
||||
}
|
||||
|
||||
Settings.llm = AzureOpenAI(
|
||||
model=os.getenv("MODEL"),
|
||||
max_tokens=int(max_tokens) if max_tokens is not None else None,
|
||||
temperature=float(temperature),
|
||||
deployment_name=llm_deployment,
|
||||
**azure_config,
|
||||
)
|
||||
|
||||
Settings.embed_model = AzureOpenAIEmbedding(
|
||||
model=os.getenv("EMBEDDING_MODEL"),
|
||||
dimensions=int(dimensions) if dimensions is not None else None,
|
||||
deployment_name=embedding_deployment,
|
||||
**azure_config,
|
||||
)
|
||||
|
||||
|
||||
def init_fastembed():
|
||||
"""
|
||||
Use Qdrant Fastembed as the local embedding provider.
|
||||
"""
|
||||
from llama_index.embeddings.fastembed import FastEmbedEmbedding
|
||||
|
||||
embed_model_map: Dict[str, str] = {
|
||||
# Small and multilingual
|
||||
"all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
# Large and multilingual
|
||||
"paraphrase-multilingual-mpnet-base-v2": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2", # noqa: E501
|
||||
}
|
||||
|
||||
# This will download the model automatically if it is not already downloaded
|
||||
Settings.embed_model = FastEmbedEmbedding(
|
||||
model_name=embed_model_map[os.getenv("EMBEDDING_MODEL")]
|
||||
)
|
||||
|
||||
|
||||
def init_groq():
|
||||
from llama_index.llms.groq import Groq
|
||||
|
||||
model_map: Dict[str, str] = {
|
||||
"llama3-8b": "llama3-8b-8192",
|
||||
"llama3-70b": "llama3-70b-8192",
|
||||
"mixtral-8x7b": "mixtral-8x7b-32768",
|
||||
}
|
||||
|
||||
Settings.llm = Groq(model=model_map[os.getenv("MODEL")])
|
||||
# Groq does not provide embeddings, so we use FastEmbed instead
|
||||
init_fastembed()
|
||||
|
||||
|
||||
def init_anthropic():
|
||||
from llama_index.llms.anthropic import Anthropic
|
||||
|
||||
model_map: Dict[str, str] = {
|
||||
"claude-3-opus": "claude-3-opus-20240229",
|
||||
"claude-3-sonnet": "claude-3-sonnet-20240229",
|
||||
"claude-3-haiku": "claude-3-haiku-20240307",
|
||||
"claude-2.1": "claude-2.1",
|
||||
"claude-instant-1.2": "claude-instant-1.2",
|
||||
}
|
||||
|
||||
Settings.llm = Anthropic(model=model_map[os.getenv("MODEL")])
|
||||
# Anthropic does not provide embeddings, so we use FastEmbed instead
|
||||
init_fastembed()
|
||||
|
||||
|
||||
def init_gemini():
|
||||
from llama_index.embeddings.gemini import GeminiEmbedding
|
||||
from llama_index.llms.gemini import Gemini
|
||||
|
||||
model_name = f"models/{os.getenv('MODEL')}"
|
||||
embed_model_name = f"models/{os.getenv('EMBEDDING_MODEL')}"
|
||||
|
||||
Settings.llm = Gemini(model=model_name)
|
||||
Settings.embed_model = GeminiEmbedding(model_name=embed_model_name)
|
||||
|
||||
|
||||
def init_mistral():
|
||||
from llama_index.embeddings.mistralai import MistralAIEmbedding
|
||||
from llama_index.llms.mistralai import MistralAI
|
||||
|
||||
Settings.llm = MistralAI(model=os.getenv("MODEL"))
|
||||
Settings.embed_model = MistralAIEmbedding(model_name=os.getenv("EMBEDDING_MODEL"))
|
||||
@@ -1,5 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import { Message } from "./chat-messages";
|
||||
|
||||
export interface ChatInputProps {
|
||||
/** The current value of the input */
|
||||
input?: string;
|
||||
@@ -12,7 +14,8 @@ export interface ChatInputProps {
|
||||
/** Form submission handler to automatically reset input and append a user message */
|
||||
handleSubmit: (e: React.FormEvent<HTMLFormElement>) => void;
|
||||
isLoading: boolean;
|
||||
multiModal?: boolean;
|
||||
messages: Message[];
|
||||
setInput?: (input: string) => void;
|
||||
}
|
||||
|
||||
export default function ChatInput(props: ChatInputProps) {
|
||||
|
||||
@@ -19,8 +19,12 @@ export default function ChatMessages({
|
||||
isLoading?: boolean;
|
||||
stop?: () => void;
|
||||
reload?: () => void;
|
||||
append?: (
|
||||
message: Message | Omit<Message, "id">,
|
||||
) => Promise<string | null | undefined>;
|
||||
}) {
|
||||
const scrollableChatContainerRef = useRef<HTMLDivElement>(null);
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
|
||||
const scrollToBottom = () => {
|
||||
if (scrollableChatContainerRef.current) {
|
||||
@@ -31,14 +35,14 @@ export default function ChatMessages({
|
||||
|
||||
useEffect(() => {
|
||||
scrollToBottom();
|
||||
}, [messages.length]);
|
||||
}, [messages.length, lastMessage]);
|
||||
|
||||
return (
|
||||
<div className="w-full max-w-5xl p-4 bg-white rounded-xl shadow-xl">
|
||||
<div
|
||||
className="flex flex-col gap-5 divide-y h-[50vh] overflow-auto"
|
||||
ref={scrollableChatContainerRef}
|
||||
>
|
||||
<div
|
||||
className="flex-1 w-full max-w-5xl p-4 bg-white rounded-xl shadow-xl overflow-auto"
|
||||
ref={scrollableChatContainerRef}
|
||||
>
|
||||
<div className="flex flex-col gap-5 divide-y">
|
||||
{messages.map((m: Message) => (
|
||||
<ChatItem key={m.id} {...m} />
|
||||
))}
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
"use client";
|
||||
|
||||
import { useEffect, useMemo, useState } from "react";
|
||||
|
||||
export interface ChatConfig {
|
||||
backend?: string;
|
||||
starterQuestions?: string[];
|
||||
}
|
||||
|
||||
export function useClientConfig(): ChatConfig {
|
||||
const chatAPI = process.env.NEXT_PUBLIC_CHAT_API;
|
||||
const [config, setConfig] = useState<ChatConfig>();
|
||||
|
||||
const backendOrigin = useMemo(() => {
|
||||
return chatAPI ? new URL(chatAPI).origin : "";
|
||||
}, [chatAPI]);
|
||||
|
||||
const configAPI = `${backendOrigin}/api/chat/config`;
|
||||
|
||||
useEffect(() => {
|
||||
fetch(configAPI)
|
||||
.then((response) => response.json())
|
||||
.then((data) => setConfig({ ...data, chatAPI }))
|
||||
.catch((error) => console.error("Error fetching config", error));
|
||||
}, [chatAPI, configAPI]);
|
||||
|
||||
return {
|
||||
backend: backendOrigin,
|
||||
starterQuestions: config?.starterQuestions,
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
import os
|
||||
import logging
|
||||
from app.settings import init_settings
|
||||
from app.engine.loaders import get_documents
|
||||
from llama_index.indices.managed.llama_cloud import LlamaCloudIndex
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def generate_datasource():
|
||||
init_settings()
|
||||
logger.info("Generate index for the provided data")
|
||||
|
||||
name = os.getenv("LLAMA_CLOUD_INDEX_NAME")
|
||||
project_name = os.getenv("LLAMA_CLOUD_PROJECT_NAME")
|
||||
api_key = os.getenv("LLAMA_CLOUD_API_KEY")
|
||||
base_url = os.getenv("LLAMA_CLOUD_BASE_URL")
|
||||
organization_id = os.getenv("LLAMA_CLOUD_ORGANIZATION_ID")
|
||||
|
||||
if name is None or project_name is None or api_key is None:
|
||||
raise ValueError(
|
||||
"Please set LLAMA_CLOUD_INDEX_NAME, LLAMA_CLOUD_PROJECT_NAME and LLAMA_CLOUD_API_KEY"
|
||||
" to your environment variables or config them in .env file"
|
||||
)
|
||||
|
||||
documents = get_documents()
|
||||
|
||||
# Set private=false to mark the document as public (required for filtering)
|
||||
for doc in documents:
|
||||
doc.metadata["private"] = "false"
|
||||
|
||||
LlamaCloudIndex.from_documents(
|
||||
documents=documents,
|
||||
name=name,
|
||||
project_name=project_name,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
organization_id=organization_id
|
||||
)
|
||||
|
||||
logger.info("Finished generating the index")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_datasource()
|
||||
@@ -0,0 +1,30 @@
|
||||
import logging
|
||||
import os
|
||||
from llama_index.indices.managed.llama_cloud import LlamaCloudIndex
|
||||
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
def get_index():
|
||||
name = os.getenv("LLAMA_CLOUD_INDEX_NAME")
|
||||
project_name = os.getenv("LLAMA_CLOUD_PROJECT_NAME")
|
||||
api_key = os.getenv("LLAMA_CLOUD_API_KEY")
|
||||
base_url = os.getenv("LLAMA_CLOUD_BASE_URL")
|
||||
organization_id = os.getenv("LLAMA_CLOUD_ORGANIZATION_ID")
|
||||
|
||||
if name is None or project_name is None or api_key is None:
|
||||
raise ValueError(
|
||||
"Please set LLAMA_CLOUD_INDEX_NAME, LLAMA_CLOUD_PROJECT_NAME and LLAMA_CLOUD_API_KEY"
|
||||
" to your environment variables or config them in .env file"
|
||||
)
|
||||
|
||||
index = LlamaCloudIndex(
|
||||
name=name,
|
||||
project_name=project_name,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
organization_id=organization_id
|
||||
)
|
||||
|
||||
return index
|
||||
@@ -2,14 +2,14 @@ from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
import os
|
||||
import logging
|
||||
import os
|
||||
|
||||
from app.engine.loaders import get_documents
|
||||
from app.settings import init_settings
|
||||
from llama_index.core.indices import (
|
||||
VectorStoreIndex,
|
||||
)
|
||||
from app.engine.loaders import get_documents
|
||||
from app.settings import init_settings
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger()
|
||||
@@ -21,6 +21,9 @@ def generate_datasource():
|
||||
storage_dir = os.environ.get("STORAGE_DIR", "storage")
|
||||
# load the documents and create the index
|
||||
documents = get_documents()
|
||||
# Set private=false to mark the document as public (required for filtering)
|
||||
for doc in documents:
|
||||
doc.metadata["private"] = "false"
|
||||
index = VectorStoreIndex.from_documents(
|
||||
documents,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
import * as dotenv from "dotenv";
|
||||
import { LlamaCloudIndex } from "llamaindex";
|
||||
import { getDataSource } from "./index";
|
||||
import { getDocuments } from "./loader";
|
||||
import { initSettings } from "./settings";
|
||||
import { checkRequiredEnvVars } from "./shared";
|
||||
|
||||
dotenv.config();
|
||||
|
||||
async function loadAndIndex() {
|
||||
const documents = await getDocuments();
|
||||
// Set private=false to mark the document as public (required for filtering)
|
||||
for (const document of documents) {
|
||||
document.metadata = {
|
||||
...document.metadata,
|
||||
private: "false",
|
||||
};
|
||||
}
|
||||
await getDataSource();
|
||||
await LlamaCloudIndex.fromDocuments({
|
||||
documents,
|
||||
name: process.env.LLAMA_CLOUD_INDEX_NAME!,
|
||||
projectName: process.env.LLAMA_CLOUD_PROJECT_NAME!,
|
||||
apiKey: process.env.LLAMA_CLOUD_API_KEY,
|
||||
baseUrl: process.env.LLAMA_CLOUD_BASE_URL,
|
||||
});
|
||||
console.log(`Successfully created embeddings!`);
|
||||
}
|
||||
|
||||
(async () => {
|
||||
checkRequiredEnvVars();
|
||||
initSettings();
|
||||
await loadAndIndex();
|
||||
console.log("Finished generating storage.");
|
||||
})();
|
||||
@@ -0,0 +1,13 @@
|
||||
import { LlamaCloudIndex } from "llamaindex/cloud/LlamaCloudIndex";
|
||||
import { checkRequiredEnvVars } from "./shared";
|
||||
|
||||
export async function getDataSource() {
|
||||
checkRequiredEnvVars();
|
||||
const index = new LlamaCloudIndex({
|
||||
name: process.env.LLAMA_CLOUD_INDEX_NAME!,
|
||||
projectName: process.env.LLAMA_CLOUD_PROJECT_NAME!,
|
||||
apiKey: process.env.LLAMA_CLOUD_API_KEY,
|
||||
baseUrl: process.env.LLAMA_CLOUD_BASE_URL,
|
||||
});
|
||||
return index;
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
const REQUIRED_ENV_VARS = [
|
||||
"LLAMA_CLOUD_INDEX_NAME",
|
||||
"LLAMA_CLOUD_PROJECT_NAME",
|
||||
"LLAMA_CLOUD_API_KEY",
|
||||
];
|
||||
|
||||
export function checkRequiredEnvVars() {
|
||||
const missingEnvVars = REQUIRED_ENV_VARS.filter((envVar) => {
|
||||
return !process.env[envVar];
|
||||
});
|
||||
|
||||
if (missingEnvVars.length > 0) {
|
||||
console.log(
|
||||
`The following environment variables are required but missing: ${missingEnvVars.join(
|
||||
", ",
|
||||
)}`,
|
||||
);
|
||||
throw new Error(
|
||||
`Missing environment variables: ${missingEnvVars.join(", ")}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -25,6 +25,11 @@ async function generateDatasource() {
|
||||
persistDir: STORAGE_CACHE_DIR,
|
||||
});
|
||||
const documents = await getDocuments();
|
||||
// Set private=false to mark the document as public (required for filtering)
|
||||
documents.forEach((doc) => {
|
||||
doc.metadata["private"] = "false";
|
||||
});
|
||||
|
||||
await VectorStoreIndex.fromDocuments(documents, {
|
||||
storageContext,
|
||||
});
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
This is a [LlamaIndex](https://www.llamaindex.ai/) project using [FastAPI](https://fastapi.tiangolo.com/) bootstrapped with [`create-llama`](https://github.com/run-llama/LlamaIndexTS/tree/main/packages/create-llama) featuring [structured extraction](https://docs.llamaindex.ai/en/stable/examples/structured_outputs/structured_outputs/?h=structured+output).
|
||||
|
||||
## Getting Started
|
||||
|
||||
First, setup the environment with poetry:
|
||||
|
||||
> **_Note:_** This step is not needed if you are using the dev-container.
|
||||
|
||||
```shell
|
||||
poetry install
|
||||
poetry shell
|
||||
```
|
||||
|
||||
Then check the parameters that have been pre-configured in the `.env` file in this directory. (E.g. you might need to configure an `OPENAI_API_KEY` if you're using OpenAI as model provider).
|
||||
|
||||
Second, generate the embeddings of the documents in the `./data` directory (if this folder exists - otherwise, skip this step):
|
||||
|
||||
```shell
|
||||
poetry run generate
|
||||
```
|
||||
|
||||
Third, run the API in one command:
|
||||
|
||||
```shell
|
||||
poetry run python main.py
|
||||
```
|
||||
|
||||
The example provides the `/api/extractor/query` API endpoint.
|
||||
|
||||
This query endpoint returns structured data in the format of the [Output](./app/api/routers/output.py) class. Modify this class to change the output format.
|
||||
|
||||
You can test the endpoint with the following curl request:
|
||||
|
||||
```shell
|
||||
curl --location 'localhost:8000/api/extractor/query' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{ "query": "What is the maximum weight for a parcel?" }'
|
||||
```
|
||||
|
||||
Which will return a response that the RAG pipeline is confident about the answer.
|
||||
|
||||
Try
|
||||
|
||||
```shell
|
||||
curl --location 'localhost:8000/api/extractor/query' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{ "query": "What is the weather today?" }'
|
||||
```
|
||||
|
||||
To retrieve a response with low confidence since the question is not related to the provided document in the `./data` directory.
|
||||
|
||||
You can start editing the API endpoint by modifying [`extractor.py`](./app/api/routers/extractor.py). The endpoints auto-update as you save the file.
|
||||
|
||||
Open [http://localhost:8000/docs](http://localhost:8000/docs) with your browser to see the Swagger UI of the API.
|
||||
|
||||
The API allows CORS for all origins to simplify development. You can change this behavior by setting the `ENVIRONMENT` environment variable to `prod`:
|
||||
|
||||
```
|
||||
ENVIRONMENT=prod python main.py
|
||||
```
|
||||
|
||||
## Learn More
|
||||
|
||||
To learn more about LlamaIndex, take a look at the following resources:
|
||||
|
||||
- [LlamaIndex Documentation](https://docs.llamaindex.ai) - learn about LlamaIndex.
|
||||
|
||||
You can check out [the LlamaIndex GitHub repository](https://github.com/run-llama/llama_index) - your feedback and contributions are welcome!
|
||||
@@ -0,0 +1,58 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from llama_index.core.settings import Settings
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.api.routers.output import Output
|
||||
from app.engine.index import get_index
|
||||
|
||||
extractor_router = r = APIRouter()
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
class RequestData(BaseModel):
|
||||
query: str
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"examples": [
|
||||
{"query": "What's the maximum weight for a parcel?"},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@r.post("/query")
|
||||
async def query_request(
|
||||
data: RequestData,
|
||||
):
|
||||
# Create a query engine using that returns responses in the format of the Output class
|
||||
query_engine = get_query_engine(Output)
|
||||
|
||||
response = await query_engine.aquery(data.query)
|
||||
|
||||
output_data = response.response.dict()
|
||||
return Output(**output_data)
|
||||
|
||||
|
||||
def get_query_engine(output_cls: BaseModel):
|
||||
top_k = os.getenv("TOP_K", 3)
|
||||
|
||||
index = get_index()
|
||||
if index is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=str(
|
||||
"StorageContext is empty - call 'poetry run generate' to generate the storage first"
|
||||
),
|
||||
)
|
||||
|
||||
sllm = Settings.llm.as_structured_llm(output_cls)
|
||||
|
||||
return index.as_query_engine(
|
||||
similarity_top_k=int(top_k),
|
||||
llm=sllm,
|
||||
response_mode="tree_summarize",
|
||||
)
|
||||
@@ -0,0 +1,32 @@
|
||||
import logging
|
||||
from llama_index.core.schema import BaseModel, Field
|
||||
from typing import List
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
class Output(BaseModel):
|
||||
response: str = Field(..., description="The answer to the question.")
|
||||
page_numbers: List[int] = Field(
|
||||
...,
|
||||
description="The page numbers of the sources used to answer this question. Do not include a page number if the context is irrelevant.",
|
||||
)
|
||||
confidence: float = Field(
|
||||
...,
|
||||
ge=0,
|
||||
le=1,
|
||||
description="Confidence value between 0-1 of the correctness of the result.",
|
||||
)
|
||||
confidence_explanation: str = Field(
|
||||
..., description="Explanation for the confidence score"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"response": "This is an example answer.",
|
||||
"page_numbers": [1, 2, 3],
|
||||
"confidence": 0.85,
|
||||
"confidence_explanation": "This is an explanation for the confidence score.",
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
import logging
|
||||
import os
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import RedirectResponse
|
||||
from app.api.routers.extractor import extractor_router
|
||||
from app.settings import init_settings
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
init_settings()
|
||||
|
||||
environment = os.getenv("ENVIRONMENT", "dev") # Default to 'development' if not set
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
if environment == "dev":
|
||||
logger.warning("Running in development mode - allowing CORS for all origins")
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Redirect to documentation page when accessing base URL
|
||||
@app.get("/")
|
||||
async def redirect_to_docs():
|
||||
return RedirectResponse(url="/docs")
|
||||
|
||||
|
||||
app.include_router(extractor_router, prefix="/api/extractor")
|
||||
|
||||
if __name__ == "__main__":
|
||||
app_host = os.getenv("APP_HOST", "0.0.0.0")
|
||||
app_port = int(os.getenv("APP_PORT", "8000"))
|
||||
reload = True if environment == "dev" else False
|
||||
|
||||
uvicorn.run(app="main:app", host=app_host, port=app_port, reload=reload)
|
||||
@@ -0,0 +1,21 @@
|
||||
[tool.poetry]
|
||||
name = "app"
|
||||
version = "0.1.0"
|
||||
description = ""
|
||||
authors = ["Marcus Schiesser <mail@marcusschiesser.de>"]
|
||||
readme = "README.md"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
generate = "app.engine.generate:generate_datasource"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.11,<3.12"
|
||||
fastapi = "^0.109.1"
|
||||
uvicorn = { extras = ["standard"], version = "^0.23.2" }
|
||||
python-dotenv = "^1.0.0"
|
||||
llama-index = "^0.10.58"
|
||||
cachetools = "^5.3.3"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
@@ -0,0 +1,50 @@
|
||||
This is a [LlamaIndex](https://www.llamaindex.ai/) project using [FastAPI](https://fastapi.tiangolo.com/) bootstrapped with [`create-llama`](https://github.com/run-llama/LlamaIndexTS/tree/main/packages/create-llama).
|
||||
|
||||
## Getting Started
|
||||
|
||||
First, setup the environment with poetry:
|
||||
|
||||
> **_Note:_** This step is not needed if you are using the dev-container.
|
||||
|
||||
```shell
|
||||
poetry install
|
||||
poetry shell
|
||||
```
|
||||
|
||||
Then check the parameters that have been pre-configured in the `.env` file in this directory. (E.g. you might need to configure an `OPENAI_API_KEY` if you're using OpenAI as model provider).
|
||||
|
||||
Second, generate the embeddings of the documents in the `./data` directory (if this folder exists - otherwise, skip this step):
|
||||
|
||||
```shell
|
||||
poetry run generate
|
||||
```
|
||||
|
||||
Third, run all the services in one command:
|
||||
|
||||
```shell
|
||||
poetry run python main.py
|
||||
```
|
||||
|
||||
You can monitor and test the agent services with `llama-agents` monitor TUI:
|
||||
|
||||
```shell
|
||||
poetry run llama-agents monitor --control-plane-url http://127.0.0.1:8001
|
||||
```
|
||||
|
||||
## Services:
|
||||
|
||||
- Message queue (port 8000): To exchange the message between services
|
||||
- Control plane (port 8001): A gateway to manage the tasks and services.
|
||||
- Human consumer (port 8002): To handle result when the task is completed.
|
||||
- Agent service `query_engine` (port 8003): Agent that can query information from the configured LlamaIndex index.
|
||||
- Agent service `dummy_agent` (port 8004): A dummy agent that does nothing. Good starting point to add more agents.
|
||||
|
||||
The ports listed above are set by default, but you can change them in the `.env` file.
|
||||
|
||||
## Learn More
|
||||
|
||||
To learn more about LlamaIndex, take a look at the following resources:
|
||||
|
||||
- [LlamaIndex Documentation](https://docs.llamaindex.ai) - learn about LlamaIndex.
|
||||
|
||||
You can check out [the LlamaIndex GitHub repository](https://github.com/run-llama/llama_index) - your feedback and contributions are welcome!
|
||||
@@ -0,0 +1,33 @@
|
||||
from llama_agents import AgentService, SimpleMessageQueue
|
||||
from llama_index.core.agent import FunctionCallingAgentWorker
|
||||
from llama_index.core.tools import FunctionTool
|
||||
from llama_index.core.settings import Settings
|
||||
from app.utils import load_from_env
|
||||
|
||||
|
||||
DEFAULT_DUMMY_AGENT_DESCRIPTION = "I'm a dummy agent which does nothing."
|
||||
|
||||
|
||||
def dummy_function():
|
||||
"""
|
||||
This function does nothing.
|
||||
"""
|
||||
return ""
|
||||
|
||||
|
||||
def init_dummy_agent(message_queue: SimpleMessageQueue) -> AgentService:
|
||||
agent = FunctionCallingAgentWorker(
|
||||
tools=[FunctionTool.from_defaults(fn=dummy_function)],
|
||||
llm=Settings.llm,
|
||||
prefix_messages=[],
|
||||
).as_agent()
|
||||
|
||||
return AgentService(
|
||||
service_name="dummy_agent",
|
||||
agent=agent,
|
||||
message_queue=message_queue.client,
|
||||
description=load_from_env("AGENT_DUMMY_DESCRIPTION", throw_error=False)
|
||||
or DEFAULT_DUMMY_AGENT_DESCRIPTION,
|
||||
host=load_from_env("AGENT_DUMMY_HOST", throw_error=False) or "127.0.0.1",
|
||||
port=int(load_from_env("AGENT_DUMMY_PORT")),
|
||||
)
|
||||
@@ -0,0 +1,52 @@
|
||||
import os
|
||||
from llama_agents import AgentService, SimpleMessageQueue
|
||||
from llama_index.core.agent import FunctionCallingAgentWorker
|
||||
from llama_index.core.tools import QueryEngineTool, ToolMetadata
|
||||
from llama_index.core.settings import Settings
|
||||
from app.engine.index import get_index
|
||||
from app.utils import load_from_env
|
||||
|
||||
|
||||
DEFAULT_QUERY_ENGINE_AGENT_DESCRIPTION = (
|
||||
"Used to answer the questions using the provided context data."
|
||||
)
|
||||
|
||||
|
||||
def get_query_engine_tool() -> QueryEngineTool:
|
||||
"""
|
||||
Provide an agent worker that can be used to query the index.
|
||||
"""
|
||||
index = get_index()
|
||||
if index is None:
|
||||
raise ValueError("Index not found. Please create an index first.")
|
||||
query_engine = index.as_query_engine(similarity_top_k=int(os.getenv("TOP_K", 3)))
|
||||
return QueryEngineTool(
|
||||
query_engine=query_engine,
|
||||
metadata=ToolMetadata(
|
||||
name="context_data",
|
||||
description="""
|
||||
Provide the provided context information.
|
||||
Use a detailed plain text question as input to the tool.
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def init_query_engine_agent(
|
||||
message_queue: SimpleMessageQueue,
|
||||
) -> AgentService:
|
||||
"""
|
||||
Initialize the agent service.
|
||||
"""
|
||||
agent = FunctionCallingAgentWorker(
|
||||
tools=[get_query_engine_tool()], llm=Settings.llm, prefix_messages=[]
|
||||
).as_agent()
|
||||
return AgentService(
|
||||
service_name="context_query_agent",
|
||||
agent=agent,
|
||||
message_queue=message_queue.client,
|
||||
description=load_from_env("AGENT_QUERY_ENGINE_DESCRIPTION", throw_error=False)
|
||||
or DEFAULT_QUERY_ENGINE_AGENT_DESCRIPTION,
|
||||
host=load_from_env("AGENT_QUERY_ENGINE_HOST", throw_error=False) or "127.0.0.1",
|
||||
port=int(load_from_env("AGENT_QUERY_ENGINE_PORT")),
|
||||
)
|
||||
@@ -0,0 +1,19 @@
|
||||
from llama_index.llms.openai import OpenAI
|
||||
from llama_agents import AgentOrchestrator, ControlPlaneServer
|
||||
from app.core.message_queue import message_queue
|
||||
from app.utils import load_from_env
|
||||
|
||||
|
||||
control_plane_host = (
|
||||
load_from_env("CONTROL_PLANE_HOST", throw_error=False) or "127.0.0.1"
|
||||
)
|
||||
control_plane_port = load_from_env("CONTROL_PLANE_PORT", throw_error=False) or "8001"
|
||||
|
||||
|
||||
# setup control plane
|
||||
control_plane = ControlPlaneServer(
|
||||
message_queue=message_queue,
|
||||
orchestrator=AgentOrchestrator(llm=OpenAI()),
|
||||
host=control_plane_host,
|
||||
port=int(control_plane_port) if control_plane_port else None,
|
||||
)
|
||||
@@ -0,0 +1,12 @@
|
||||
from llama_agents import SimpleMessageQueue
|
||||
from app.utils import load_from_env
|
||||
|
||||
message_queue_host = (
|
||||
load_from_env("MESSAGE_QUEUE_HOST", throw_error=False) or "127.0.0.1"
|
||||
)
|
||||
message_queue_port = load_from_env("MESSAGE_QUEUE_PORT", throw_error=False) or "8000"
|
||||
|
||||
message_queue = SimpleMessageQueue(
|
||||
host=message_queue_host,
|
||||
port=int(message_queue_port) if message_queue_port else None,
|
||||
)
|
||||
@@ -0,0 +1,88 @@
|
||||
import json
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from fastapi import FastAPI
|
||||
from typing import Dict, Optional
|
||||
from llama_agents import CallableMessageConsumer, QueueMessage
|
||||
from llama_agents.message_queues.base import BaseMessageQueue
|
||||
from llama_agents.message_consumers.base import BaseMessageQueueConsumer
|
||||
from llama_agents.message_consumers.remote import RemoteMessageConsumer
|
||||
from app.utils import load_from_env
|
||||
from app.core.message_queue import message_queue
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class TaskResultService:
|
||||
def __init__(
|
||||
self,
|
||||
message_queue: BaseMessageQueue,
|
||||
name: str = "human",
|
||||
host: str = "127.0.0.1",
|
||||
port: Optional[int] = 8002,
|
||||
) -> None:
|
||||
self.name = name
|
||||
self.host = host
|
||||
self.port = port
|
||||
|
||||
self._message_queue = message_queue
|
||||
|
||||
# app
|
||||
self._app = FastAPI()
|
||||
self._app.add_api_route(
|
||||
"/", self.home, methods=["GET"], tags=["Human Consumer"]
|
||||
)
|
||||
self._app.add_api_route(
|
||||
"/process_message",
|
||||
self.process_message,
|
||||
methods=["POST"],
|
||||
tags=["Human Consumer"],
|
||||
)
|
||||
|
||||
@property
|
||||
def message_queue(self) -> BaseMessageQueue:
|
||||
return self._message_queue
|
||||
|
||||
def as_consumer(self, remote: bool = False) -> BaseMessageQueueConsumer:
|
||||
if remote:
|
||||
return RemoteMessageConsumer(
|
||||
url=(
|
||||
f"http://{self.host}:{self.port}/process_message"
|
||||
if self.port
|
||||
else f"http://{self.host}/process_message"
|
||||
),
|
||||
message_type=self.name,
|
||||
)
|
||||
|
||||
return CallableMessageConsumer(
|
||||
message_type=self.name,
|
||||
handler=self.process_message,
|
||||
)
|
||||
|
||||
async def process_message(self, message: QueueMessage) -> None:
|
||||
Path("task_results").mkdir(exist_ok=True)
|
||||
with open("task_results/task_results.json", "+a") as f:
|
||||
json.dump(message.model_dump(), f)
|
||||
f.write("\n")
|
||||
|
||||
async def home(self) -> Dict[str, str]:
|
||||
return {"message": "hello, human."}
|
||||
|
||||
async def register_to_message_queue(self) -> None:
|
||||
"""Register to the message queue."""
|
||||
await self.message_queue.register_consumer(self.as_consumer(remote=True))
|
||||
|
||||
|
||||
human_consumer_host = (
|
||||
load_from_env("HUMAN_CONSUMER_HOST", throw_error=False) or "127.0.0.1"
|
||||
)
|
||||
human_consumer_port = load_from_env("HUMAN_CONSUMER_PORT", throw_error=False) or "8002"
|
||||
|
||||
|
||||
human_consumer_server = TaskResultService(
|
||||
message_queue=message_queue,
|
||||
host=human_consumer_host,
|
||||
port=int(human_consumer_port) if human_consumer_port else None,
|
||||
name="human",
|
||||
)
|
||||
@@ -0,0 +1,8 @@
|
||||
import os
|
||||
|
||||
|
||||
def load_from_env(var: str, throw_error: bool = True) -> str:
|
||||
res = os.getenv(var)
|
||||
if res is None and throw_error:
|
||||
raise ValueError(f"Missing environment variable: {var}")
|
||||
return res
|
||||
@@ -0,0 +1,27 @@
|
||||
from dotenv import load_dotenv
|
||||
from app.settings import init_settings
|
||||
|
||||
load_dotenv()
|
||||
init_settings()
|
||||
|
||||
from llama_agents import ServerLauncher
|
||||
from app.core.message_queue import message_queue
|
||||
from app.core.control_plane import control_plane
|
||||
from app.core.task_result import human_consumer_server
|
||||
from app.agents.query_engine.agent import init_query_engine_agent
|
||||
from app.agents.dummy.agent import init_dummy_agent
|
||||
|
||||
agents = [
|
||||
init_query_engine_agent(message_queue),
|
||||
init_dummy_agent(message_queue),
|
||||
]
|
||||
|
||||
launcher = ServerLauncher(
|
||||
agents,
|
||||
control_plane,
|
||||
message_queue,
|
||||
additional_consumers=[human_consumer_server.as_consumer()],
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
launcher.launch_servers()
|
||||
@@ -0,0 +1,20 @@
|
||||
[tool.poetry]
|
||||
name = "app"
|
||||
version = "0.1.0"
|
||||
description = ""
|
||||
authors = ["Marcus Schiesser <mail@marcusschiesser.de>"]
|
||||
readme = "README.md"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
generate = "app.engine.generate:generate_datasource"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.11"
|
||||
llama-agents = "^0.0.3"
|
||||
llama-index-agent-openai = "^0.2.7"
|
||||
llama-index-embeddings-openai = "^0.1.10"
|
||||
llama-index-llms-openai = "^0.1.23"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
@@ -3,5 +3,8 @@
|
||||
"rules": {
|
||||
"max-params": ["error", 4],
|
||||
"prefer-const": "error"
|
||||
},
|
||||
"parserOptions": {
|
||||
"sourceType": "module"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,4 +2,4 @@
|
||||
.env
|
||||
node_modules/
|
||||
|
||||
tool-output/
|
||||
output/
|
||||
|
||||
@@ -14,7 +14,7 @@ const prodCorsOrigin = process.env["PROD_CORS_ORIGIN"];
|
||||
|
||||
initObservability();
|
||||
|
||||
app.use(express.json());
|
||||
app.use(express.json({ limit: "50mb" }));
|
||||
|
||||
if (isDevelopment) {
|
||||
console.warn("Running in development mode - allowing CORS for all origins");
|
||||
@@ -32,7 +32,7 @@ if (isDevelopment) {
|
||||
}
|
||||
|
||||
app.use("/api/files/data", express.static("data"));
|
||||
app.use("/api/files/tool-output", express.static("tool-output"));
|
||||
app.use("/api/files/output", express.static("output"));
|
||||
app.use(express.text());
|
||||
|
||||
app.get("/", (req: Request, res: Response) => {
|
||||
|
||||
@@ -1,23 +1,32 @@
|
||||
{
|
||||
"name": "llama-index-express-streaming",
|
||||
"version": "1.0.0",
|
||||
"main": "dist/index.js",
|
||||
"exports": "./index.js",
|
||||
"types": "./index.d.ts",
|
||||
"type": "module",
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
},
|
||||
"scripts": {
|
||||
"format": "prettier --ignore-unknown --cache --check .",
|
||||
"format:write": "prettier --ignore-unknown --write .",
|
||||
"build": "tsup index.ts --format cjs --dts",
|
||||
"build": "tsup index.ts --format esm --dts",
|
||||
"start": "node dist/index.js",
|
||||
"dev": "concurrently \"tsup index.ts --format cjs --dts --watch\" \"nodemon -q dist/index.js\""
|
||||
"dev": "concurrently \"tsup index.ts --format esm --dts --watch\" \"nodemon --watch dist/index.js\""
|
||||
},
|
||||
"dependencies": {
|
||||
"ai": "^3.0.21",
|
||||
"cors": "^2.8.5",
|
||||
"dotenv": "^16.3.1",
|
||||
"duck-duck-scrape": "^2.2.5",
|
||||
"express": "^4.18.2",
|
||||
"llamaindex": "0.3.13",
|
||||
"llamaindex": "0.5.8",
|
||||
"pdf2json": "3.0.5",
|
||||
"ajv": "^8.12.0",
|
||||
"@e2b/code-interpreter": "^0.0.5"
|
||||
"@e2b/code-interpreter": "^0.0.5",
|
||||
"got": "^14.4.1",
|
||||
"@apidevtools/swagger-parser": "^10.1.0",
|
||||
"formdata-node": "^6.0.3"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/cors": "^2.8.16",
|
||||
@@ -30,7 +39,7 @@
|
||||
"prettier": "^3.2.5",
|
||||
"prettier-plugin-organize-imports": "^3.2.4",
|
||||
"tsx": "^4.7.2",
|
||||
"tsup": "^8.0.1",
|
||||
"tsup": "8.1.0",
|
||||
"typescript": "^5.3.2"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
import { Request, Response } from "express";
|
||||
|
||||
export const chatConfig = async (_req: Request, res: Response) => {
|
||||
let starterQuestions = undefined;
|
||||
if (
|
||||
process.env.CONVERSATION_STARTERS &&
|
||||
process.env.CONVERSATION_STARTERS.trim()
|
||||
) {
|
||||
starterQuestions = process.env.CONVERSATION_STARTERS.trim().split("\n");
|
||||
}
|
||||
return res.status(200).json({
|
||||
starterQuestions,
|
||||
});
|
||||
};
|
||||
@@ -0,0 +1,12 @@
|
||||
import { Request, Response } from "express";
|
||||
import { uploadDocument } from "./llamaindex/documents/upload";
|
||||
|
||||
export const chatUpload = async (req: Request, res: Response) => {
|
||||
const { base64 }: { base64: string } = req.body;
|
||||
if (!base64) {
|
||||
return res.status(400).json({
|
||||
error: "base64 is required in the request body",
|
||||
});
|
||||
}
|
||||
return res.status(200).json(await uploadDocument(base64));
|
||||
};
|
||||
@@ -1,32 +1,23 @@
|
||||
import { Message, StreamData, streamToResponse } from "ai";
|
||||
import { JSONValue, Message, StreamData, streamToResponse } from "ai";
|
||||
import { Request, Response } from "express";
|
||||
import { ChatMessage, MessageContent, Settings } from "llamaindex";
|
||||
import { ChatMessage, Settings } from "llamaindex";
|
||||
import { createChatEngine } from "./engine/chat";
|
||||
import { LlamaIndexStream } from "./llamaindex-stream";
|
||||
import { createCallbackManager } from "./stream-helper";
|
||||
|
||||
const convertMessageContent = (
|
||||
textMessage: string,
|
||||
imageUrl: string | undefined,
|
||||
): MessageContent => {
|
||||
if (!imageUrl) return textMessage;
|
||||
return [
|
||||
{
|
||||
type: "text",
|
||||
text: textMessage,
|
||||
},
|
||||
{
|
||||
type: "image_url",
|
||||
image_url: {
|
||||
url: imageUrl,
|
||||
},
|
||||
},
|
||||
];
|
||||
};
|
||||
import {
|
||||
convertMessageContent,
|
||||
retrieveDocumentIds,
|
||||
} from "./llamaindex/streaming/annotations";
|
||||
import {
|
||||
createCallbackManager,
|
||||
createStreamTimeout,
|
||||
} from "./llamaindex/streaming/events";
|
||||
import { LlamaIndexStream } from "./llamaindex/streaming/stream";
|
||||
|
||||
export const chat = async (req: Request, res: Response) => {
|
||||
// Init Vercel AI StreamData and timeout
|
||||
const vercelStreamData = new StreamData();
|
||||
const streamTimeout = createStreamTimeout(vercelStreamData);
|
||||
try {
|
||||
const { messages, data }: { messages: Message[]; data: any } = req.body;
|
||||
const { messages }: { messages: Message[] } = req.body;
|
||||
const userMessage = messages.pop();
|
||||
if (!messages || !userMessage || userMessage.role !== "user") {
|
||||
return res.status(400).json({
|
||||
@@ -35,17 +26,34 @@ export const chat = async (req: Request, res: Response) => {
|
||||
});
|
||||
}
|
||||
|
||||
const chatEngine = await createChatEngine();
|
||||
let annotations = userMessage.annotations;
|
||||
if (!annotations) {
|
||||
// the user didn't send any new annotations with the last message
|
||||
// so use the annotations from the last user message that has annotations
|
||||
// REASON: GPT4 doesn't consider MessageContentDetail from previous messages, only strings
|
||||
annotations = messages
|
||||
.slice()
|
||||
.reverse()
|
||||
.find(
|
||||
(message) => message.role === "user" && message.annotations,
|
||||
)?.annotations;
|
||||
}
|
||||
|
||||
// retrieve document Ids from the annotations of all messages (if any) and create chat engine with index
|
||||
const allAnnotations: JSONValue[] = [...messages, userMessage].flatMap(
|
||||
(message) => {
|
||||
return message.annotations ?? [];
|
||||
},
|
||||
);
|
||||
const ids = retrieveDocumentIds(allAnnotations);
|
||||
const chatEngine = await createChatEngine(ids);
|
||||
|
||||
// Convert message content from Vercel/AI format to LlamaIndex/OpenAI format
|
||||
const userMessageContent = convertMessageContent(
|
||||
userMessage.content,
|
||||
data?.imageUrl,
|
||||
annotations,
|
||||
);
|
||||
|
||||
// Init Vercel AI StreamData
|
||||
const vercelStreamData = new StreamData();
|
||||
|
||||
// Setup callbacks
|
||||
const callbackManager = createCallbackManager(vercelStreamData);
|
||||
|
||||
@@ -59,11 +67,11 @@ export const chat = async (req: Request, res: Response) => {
|
||||
});
|
||||
|
||||
// Return a stream, which can be consumed by the Vercel/AI client
|
||||
const stream = LlamaIndexStream(response, vercelStreamData, {
|
||||
parserOptions: {
|
||||
image_url: data?.imageUrl,
|
||||
},
|
||||
});
|
||||
const stream = LlamaIndexStream(
|
||||
response,
|
||||
vercelStreamData,
|
||||
messages as ChatMessage[],
|
||||
);
|
||||
|
||||
return streamToResponse(stream, res, {}, vercelStreamData);
|
||||
} catch (error) {
|
||||
@@ -71,5 +79,7 @@ export const chat = async (req: Request, res: Response) => {
|
||||
return res.status(500).json({
|
||||
detail: (error as Error).message,
|
||||
});
|
||||
} finally {
|
||||
clearTimeout(streamTimeout);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { Settings, SimpleChatEngine } from "llamaindex";
|
||||
|
||||
export async function createChatEngine() {
|
||||
export async function createChatEngine(documentIds?: string[]) {
|
||||
return new SimpleChatEngine({
|
||||
llm: Settings.llm,
|
||||
});
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
import {
|
||||
ALL_AVAILABLE_MISTRAL_MODELS,
|
||||
Anthropic,
|
||||
GEMINI_EMBEDDING_MODEL,
|
||||
GEMINI_MODEL,
|
||||
Gemini,
|
||||
GeminiEmbedding,
|
||||
Groq,
|
||||
MistralAI,
|
||||
MistralAIEmbedding,
|
||||
MistralAIEmbeddingModelType,
|
||||
OpenAI,
|
||||
OpenAIEmbedding,
|
||||
Settings,
|
||||
@@ -28,12 +33,21 @@ export const initSettings = async () => {
|
||||
case "ollama":
|
||||
initOllama();
|
||||
break;
|
||||
case "groq":
|
||||
initGroq();
|
||||
break;
|
||||
case "anthropic":
|
||||
initAnthropic();
|
||||
break;
|
||||
case "gemini":
|
||||
initGemini();
|
||||
break;
|
||||
case "mistral":
|
||||
initMistralAI();
|
||||
break;
|
||||
case "azure-openai":
|
||||
initAzureOpenAI();
|
||||
break;
|
||||
default:
|
||||
initOpenAI();
|
||||
break;
|
||||
@@ -44,8 +58,10 @@ export const initSettings = async () => {
|
||||
|
||||
function initOpenAI() {
|
||||
Settings.llm = new OpenAI({
|
||||
model: process.env.MODEL ?? "gpt-3.5-turbo",
|
||||
maxTokens: 512,
|
||||
model: process.env.MODEL ?? "gpt-4o-mini",
|
||||
maxTokens: process.env.LLM_MAX_TOKENS
|
||||
? Number(process.env.LLM_MAX_TOKENS)
|
||||
: undefined,
|
||||
});
|
||||
Settings.embedModel = new OpenAIEmbedding({
|
||||
model: process.env.EMBEDDING_MODEL,
|
||||
@@ -55,11 +71,57 @@ function initOpenAI() {
|
||||
});
|
||||
}
|
||||
|
||||
function initAzureOpenAI() {
|
||||
// Map Azure OpenAI model names to OpenAI model names (only for TS)
|
||||
const AZURE_OPENAI_MODEL_MAP: Record<string, string> = {
|
||||
"gpt-35-turbo": "gpt-3.5-turbo",
|
||||
"gpt-35-turbo-16k": "gpt-3.5-turbo-16k",
|
||||
"gpt-4o": "gpt-4o",
|
||||
"gpt-4": "gpt-4",
|
||||
"gpt-4-32k": "gpt-4-32k",
|
||||
"gpt-4-turbo": "gpt-4-turbo",
|
||||
"gpt-4-turbo-2024-04-09": "gpt-4-turbo",
|
||||
"gpt-4-vision-preview": "gpt-4-vision-preview",
|
||||
"gpt-4-1106-preview": "gpt-4-1106-preview",
|
||||
"gpt-4o-2024-05-13": "gpt-4o-2024-05-13",
|
||||
};
|
||||
|
||||
const azureConfig = {
|
||||
apiKey: process.env.AZURE_OPENAI_KEY,
|
||||
endpoint: process.env.AZURE_OPENAI_ENDPOINT,
|
||||
apiVersion:
|
||||
process.env.AZURE_OPENAI_API_VERSION || process.env.OPENAI_API_VERSION,
|
||||
};
|
||||
|
||||
Settings.llm = new OpenAI({
|
||||
model:
|
||||
AZURE_OPENAI_MODEL_MAP[process.env.MODEL ?? "gpt-35-turbo"] ??
|
||||
"gpt-3.5-turbo",
|
||||
maxTokens: process.env.LLM_MAX_TOKENS
|
||||
? Number(process.env.LLM_MAX_TOKENS)
|
||||
: undefined,
|
||||
azure: {
|
||||
...azureConfig,
|
||||
deployment: process.env.AZURE_OPENAI_LLM_DEPLOYMENT,
|
||||
},
|
||||
});
|
||||
|
||||
Settings.embedModel = new OpenAIEmbedding({
|
||||
model: process.env.EMBEDDING_MODEL,
|
||||
dimensions: process.env.EMBEDDING_DIM
|
||||
? parseInt(process.env.EMBEDDING_DIM)
|
||||
: undefined,
|
||||
azure: {
|
||||
...azureConfig,
|
||||
deployment: process.env.AZURE_OPENAI_EMBEDDING_DEPLOYMENT,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
function initOllama() {
|
||||
const config = {
|
||||
host: process.env.OLLAMA_BASE_URL ?? "http://127.0.0.1:11434",
|
||||
};
|
||||
|
||||
Settings.llm = new Ollama({
|
||||
model: process.env.MODEL ?? "",
|
||||
config,
|
||||
@@ -70,6 +132,27 @@ function initOllama() {
|
||||
});
|
||||
}
|
||||
|
||||
function initGroq() {
|
||||
const embedModelMap: Record<string, string> = {
|
||||
"all-MiniLM-L6-v2": "Xenova/all-MiniLM-L6-v2",
|
||||
"all-mpnet-base-v2": "Xenova/all-mpnet-base-v2",
|
||||
};
|
||||
|
||||
const modelMap: Record<string, string> = {
|
||||
"llama3-8b": "llama3-8b-8192",
|
||||
"llama3-70b": "llama3-70b-8192",
|
||||
"mixtral-8x7b": "mixtral-8x7b-32768",
|
||||
};
|
||||
|
||||
Settings.llm = new Groq({
|
||||
model: modelMap[process.env.MODEL!],
|
||||
});
|
||||
|
||||
Settings.embedModel = new HuggingFaceEmbedding({
|
||||
modelType: embedModelMap[process.env.EMBEDDING_MODEL!],
|
||||
});
|
||||
}
|
||||
|
||||
function initAnthropic() {
|
||||
const embedModelMap: Record<string, string> = {
|
||||
"all-MiniLM-L6-v2": "Xenova/all-MiniLM-L6-v2",
|
||||
@@ -91,3 +174,12 @@ function initGemini() {
|
||||
model: process.env.EMBEDDING_MODEL as GEMINI_EMBEDDING_MODEL,
|
||||
});
|
||||
}
|
||||
|
||||
function initMistralAI() {
|
||||
Settings.llm = new MistralAI({
|
||||
model: process.env.MODEL as keyof typeof ALL_AVAILABLE_MISTRAL_MODELS,
|
||||
});
|
||||
Settings.embedModel = new MistralAIEmbedding({
|
||||
model: process.env.EMBEDDING_MODEL as MistralAIEmbeddingModelType,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1,81 +0,0 @@
|
||||
import {
|
||||
StreamData,
|
||||
createCallbacksTransformer,
|
||||
createStreamDataTransformer,
|
||||
trimStartOfStreamHelper,
|
||||
type AIStreamCallbacksAndOptions,
|
||||
} from "ai";
|
||||
import {
|
||||
Metadata,
|
||||
NodeWithScore,
|
||||
Response,
|
||||
ToolCallLLMMessageOptions,
|
||||
} from "llamaindex";
|
||||
|
||||
import { AgentStreamChatResponse } from "llamaindex/agent/base";
|
||||
import { appendImageData, appendSourceData } from "./stream-helper";
|
||||
|
||||
type LlamaIndexResponse =
|
||||
| AgentStreamChatResponse<ToolCallLLMMessageOptions>
|
||||
| Response;
|
||||
|
||||
type ParserOptions = {
|
||||
image_url?: string;
|
||||
};
|
||||
|
||||
function createParser(
|
||||
res: AsyncIterable<LlamaIndexResponse>,
|
||||
data: StreamData,
|
||||
opts?: ParserOptions,
|
||||
) {
|
||||
const it = res[Symbol.asyncIterator]();
|
||||
const trimStartOfStream = trimStartOfStreamHelper();
|
||||
|
||||
let sourceNodes: NodeWithScore<Metadata>[] | undefined;
|
||||
return new ReadableStream<string>({
|
||||
start() {
|
||||
appendImageData(data, opts?.image_url);
|
||||
},
|
||||
async pull(controller): Promise<void> {
|
||||
const { value, done } = await it.next();
|
||||
if (done) {
|
||||
if (sourceNodes) {
|
||||
appendSourceData(data, sourceNodes);
|
||||
}
|
||||
controller.close();
|
||||
data.close();
|
||||
return;
|
||||
}
|
||||
|
||||
let delta;
|
||||
if (value instanceof Response) {
|
||||
// handle Response type
|
||||
if (value.sourceNodes) {
|
||||
// get source nodes from the first response
|
||||
sourceNodes = value.sourceNodes;
|
||||
}
|
||||
delta = value.response ?? "";
|
||||
} else {
|
||||
// handle other types
|
||||
delta = value.response.delta;
|
||||
}
|
||||
const text = trimStartOfStream(delta ?? "");
|
||||
if (text) {
|
||||
controller.enqueue(text);
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
export function LlamaIndexStream(
|
||||
response: AsyncIterable<LlamaIndexResponse>,
|
||||
data: StreamData,
|
||||
opts?: {
|
||||
callbacks?: AIStreamCallbacksAndOptions;
|
||||
parserOptions?: ParserOptions;
|
||||
},
|
||||
): ReadableStream<Uint8Array> {
|
||||
return createParser(response, data, opts?.parserOptions)
|
||||
.pipeThrough(createCallbacksTransformer(opts?.callbacks))
|
||||
.pipeThrough(createStreamDataTransformer());
|
||||
}
|
||||
@@ -1,97 +0,0 @@
|
||||
import { StreamData } from "ai";
|
||||
import {
|
||||
CallbackManager,
|
||||
Metadata,
|
||||
NodeWithScore,
|
||||
ToolCall,
|
||||
ToolOutput,
|
||||
} from "llamaindex";
|
||||
|
||||
export function appendImageData(data: StreamData, imageUrl?: string) {
|
||||
if (!imageUrl) return;
|
||||
data.appendMessageAnnotation({
|
||||
type: "image",
|
||||
data: {
|
||||
url: imageUrl,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
export function appendSourceData(
|
||||
data: StreamData,
|
||||
sourceNodes?: NodeWithScore<Metadata>[],
|
||||
) {
|
||||
if (!sourceNodes?.length) return;
|
||||
data.appendMessageAnnotation({
|
||||
type: "sources",
|
||||
data: {
|
||||
nodes: sourceNodes.map((node) => ({
|
||||
...node.node.toMutableJSON(),
|
||||
id: node.node.id_,
|
||||
score: node.score ?? null,
|
||||
})),
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
export function appendEventData(data: StreamData, title?: string) {
|
||||
if (!title) return;
|
||||
data.appendMessageAnnotation({
|
||||
type: "events",
|
||||
data: {
|
||||
title,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
export function appendToolData(
|
||||
data: StreamData,
|
||||
toolCall: ToolCall,
|
||||
toolOutput: ToolOutput,
|
||||
) {
|
||||
data.appendMessageAnnotation({
|
||||
type: "tools",
|
||||
data: {
|
||||
toolCall: {
|
||||
id: toolCall.id,
|
||||
name: toolCall.name,
|
||||
input: toolCall.input,
|
||||
},
|
||||
toolOutput: {
|
||||
output: toolOutput.output,
|
||||
isError: toolOutput.isError,
|
||||
},
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
export function createCallbackManager(stream: StreamData) {
|
||||
const callbackManager = new CallbackManager();
|
||||
|
||||
callbackManager.on("retrieve", (data) => {
|
||||
const { nodes, query } = data.detail;
|
||||
appendEventData(stream, `Retrieving context for query: '${query}'`);
|
||||
appendEventData(
|
||||
stream,
|
||||
`Retrieved ${nodes.length} sources to use as context for the query`,
|
||||
);
|
||||
});
|
||||
|
||||
callbackManager.on("llm-tool-call", (event) => {
|
||||
const { name, input } = event.detail.payload.toolCall;
|
||||
const inputString = Object.entries(input)
|
||||
.map(([key, value]) => `${key}: ${value}`)
|
||||
.join(", ");
|
||||
appendEventData(
|
||||
stream,
|
||||
`Using tool: '${name}' with inputs: '${inputString}'`,
|
||||
);
|
||||
});
|
||||
|
||||
callbackManager.on("llm-tool-result", (event) => {
|
||||
const { toolCall, toolResult } = event.detail.payload;
|
||||
appendToolData(stream, toolCall, toolResult);
|
||||
});
|
||||
|
||||
return callbackManager;
|
||||
}
|
||||
@@ -1,5 +1,7 @@
|
||||
import express, { Router } from "express";
|
||||
import { chatConfig } from "../controllers/chat-config.controller";
|
||||
import { chatRequest } from "../controllers/chat-request.controller";
|
||||
import { chatUpload } from "../controllers/chat-upload.controller";
|
||||
import { chat } from "../controllers/chat.controller";
|
||||
import { initSettings } from "../controllers/engine/settings";
|
||||
|
||||
@@ -8,5 +10,7 @@ const llmRouter: Router = express.Router();
|
||||
initSettings();
|
||||
llmRouter.route("/").post(chat);
|
||||
llmRouter.route("/request").post(chatRequest);
|
||||
llmRouter.route("/config").get(chatConfig);
|
||||
llmRouter.route("/upload").post(chatUpload);
|
||||
|
||||
export default llmRouter;
|
||||
|
||||
@@ -1,154 +1,127 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Any, Optional, Dict, Tuple
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from llama_index.core.chat_engine.types import BaseChatEngine
|
||||
from llama_index.core.schema import NodeWithScore
|
||||
from llama_index.core.llms import ChatMessage, MessageRole
|
||||
from app.engine import get_chat_engine
|
||||
import logging
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request, status
|
||||
from llama_index.core.chat_engine.types import BaseChatEngine, NodeWithScore
|
||||
from llama_index.core.llms import MessageRole
|
||||
from llama_index.core.vector_stores.types import MetadataFilter, MetadataFilters
|
||||
|
||||
from app.api.routers.events import EventCallbackHandler
|
||||
from app.api.routers.models import (
|
||||
ChatConfig,
|
||||
ChatData,
|
||||
Message,
|
||||
Result,
|
||||
SourceNodes,
|
||||
)
|
||||
from app.api.routers.vercel_response import VercelStreamResponse
|
||||
from app.api.routers.messaging import EventCallbackHandler
|
||||
from aiostream import stream
|
||||
from app.api.services.llama_cloud import LLamaCloudFileService
|
||||
from app.engine import get_chat_engine
|
||||
|
||||
chat_router = r = APIRouter()
|
||||
|
||||
|
||||
class _Message(BaseModel):
|
||||
role: MessageRole
|
||||
content: str
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
class _ChatData(BaseModel):
|
||||
messages: List[_Message]
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What standards for letters exist?",
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class _SourceNodes(BaseModel):
|
||||
id: str
|
||||
metadata: Dict[str, Any]
|
||||
score: Optional[float]
|
||||
text: str
|
||||
|
||||
@classmethod
|
||||
def from_source_node(cls, source_node: NodeWithScore):
|
||||
return cls(
|
||||
id=source_node.node.node_id,
|
||||
metadata=source_node.node.metadata,
|
||||
score=source_node.score,
|
||||
text=source_node.node.text, # type: ignore
|
||||
def process_response_nodes(
|
||||
nodes: List[NodeWithScore],
|
||||
background_tasks: BackgroundTasks,
|
||||
):
|
||||
"""
|
||||
Start background tasks on the source nodes if needed.
|
||||
"""
|
||||
files_to_download = SourceNodes.get_download_files(nodes)
|
||||
for file in files_to_download:
|
||||
background_tasks.add_task(
|
||||
LLamaCloudFileService.download_llamacloud_pipeline_file, file
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_source_nodes(cls, source_nodes: List[NodeWithScore]):
|
||||
return [cls.from_source_node(node) for node in source_nodes]
|
||||
|
||||
|
||||
class _Result(BaseModel):
|
||||
result: _Message
|
||||
nodes: List[_SourceNodes]
|
||||
|
||||
|
||||
async def parse_chat_data(data: _ChatData) -> Tuple[str, List[ChatMessage]]:
|
||||
# check preconditions and get last message
|
||||
if len(data.messages) == 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="No messages provided",
|
||||
)
|
||||
last_message = data.messages.pop()
|
||||
if last_message.role != MessageRole.USER:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Last message must be from user",
|
||||
)
|
||||
# convert messages coming from the request to type ChatMessage
|
||||
messages = [
|
||||
ChatMessage(
|
||||
role=m.role,
|
||||
content=m.content,
|
||||
)
|
||||
for m in data.messages
|
||||
]
|
||||
return last_message.content, messages
|
||||
|
||||
|
||||
# streaming endpoint - delete if not needed
|
||||
@r.post("")
|
||||
async def chat(
|
||||
request: Request,
|
||||
data: _ChatData,
|
||||
data: ChatData,
|
||||
background_tasks: BackgroundTasks,
|
||||
chat_engine: BaseChatEngine = Depends(get_chat_engine),
|
||||
):
|
||||
last_message_content, messages = await parse_chat_data(data)
|
||||
|
||||
event_handler = EventCallbackHandler()
|
||||
chat_engine.callback_manager.handlers.append(event_handler) # type: ignore
|
||||
try:
|
||||
last_message_content = data.get_last_message_content()
|
||||
messages = data.get_history_messages()
|
||||
|
||||
doc_ids = data.get_chat_document_ids()
|
||||
filters = generate_filters(doc_ids)
|
||||
logger.info("Creating chat engine with filters", filters.dict())
|
||||
chat_engine = get_chat_engine(filters=filters)
|
||||
|
||||
event_handler = EventCallbackHandler()
|
||||
chat_engine.callback_manager.handlers.append(event_handler) # type: ignore
|
||||
|
||||
response = await chat_engine.astream_chat(last_message_content, messages)
|
||||
process_response_nodes(response.source_nodes, background_tasks)
|
||||
|
||||
return VercelStreamResponse(request, event_handler, response, data)
|
||||
except Exception as e:
|
||||
logger.exception("Error in chat engine", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error in chat engine: {e}",
|
||||
) from e
|
||||
|
||||
|
||||
def generate_filters(doc_ids):
|
||||
if len(doc_ids) > 0:
|
||||
filters = MetadataFilters(
|
||||
filters=[
|
||||
MetadataFilter(
|
||||
key="private",
|
||||
value=["true"],
|
||||
operator="nin", # type: ignore
|
||||
),
|
||||
MetadataFilter(
|
||||
key="doc_id",
|
||||
value=doc_ids,
|
||||
operator="in", # type: ignore
|
||||
),
|
||||
],
|
||||
condition="or", # type: ignore
|
||||
)
|
||||
else:
|
||||
filters = MetadataFilters(
|
||||
# Use the "NIN" - "not in" operator to include all public documents (don't have the private key set)
|
||||
filters=[
|
||||
MetadataFilter(
|
||||
key="private",
|
||||
value=["true"],
|
||||
operator="nin", # type: ignore
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
async def content_generator():
|
||||
# Yield the text response
|
||||
async def _text_generator():
|
||||
async for token in response.async_response_gen():
|
||||
yield VercelStreamResponse.convert_text(token)
|
||||
# the text_generator is the leading stream, once it's finished, also finish the event stream
|
||||
event_handler.is_done = True
|
||||
|
||||
# Yield the events from the event handler
|
||||
async def _event_generator():
|
||||
async for event in event_handler.async_event_gen():
|
||||
event_response = event.to_response()
|
||||
if event_response is not None:
|
||||
yield VercelStreamResponse.convert_data(event_response)
|
||||
|
||||
combine = stream.merge(_text_generator(), _event_generator())
|
||||
async with combine.stream() as streamer:
|
||||
async for item in streamer:
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
yield item
|
||||
|
||||
# Yield the source nodes
|
||||
yield VercelStreamResponse.convert_data(
|
||||
{
|
||||
"type": "sources",
|
||||
"data": {
|
||||
"nodes": [
|
||||
_SourceNodes.from_source_node(node).dict()
|
||||
for node in response.source_nodes
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return VercelStreamResponse(content=content_generator())
|
||||
return filters
|
||||
|
||||
|
||||
# non-streaming endpoint - delete if not needed
|
||||
@r.post("/request")
|
||||
async def chat_request(
|
||||
data: _ChatData,
|
||||
data: ChatData,
|
||||
chat_engine: BaseChatEngine = Depends(get_chat_engine),
|
||||
) -> _Result:
|
||||
last_message_content, messages = await parse_chat_data(data)
|
||||
) -> Result:
|
||||
last_message_content = data.get_last_message_content()
|
||||
messages = data.get_history_messages()
|
||||
|
||||
response = await chat_engine.achat(last_message_content, messages)
|
||||
return _Result(
|
||||
result=_Message(role=MessageRole.ASSISTANT, content=response.response),
|
||||
nodes=_SourceNodes.from_source_nodes(response.source_nodes),
|
||||
return Result(
|
||||
result=Message(role=MessageRole.ASSISTANT, content=response.response),
|
||||
nodes=SourceNodes.from_source_nodes(response.source_nodes),
|
||||
)
|
||||
|
||||
|
||||
@r.get("/config")
|
||||
async def chat_config() -> ChatConfig:
|
||||
starter_questions = None
|
||||
conversation_starters = os.getenv("CONVERSATION_STARTERS")
|
||||
if conversation_starters and conversation_starters.strip():
|
||||
starter_questions = conversation_starters.strip().split("\n")
|
||||
return ChatConfig(starter_questions=starter_questions)
|
||||
|
||||
@@ -0,0 +1,251 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Literal, Optional, Set
|
||||
|
||||
from llama_index.core.llms import ChatMessage, MessageRole
|
||||
from llama_index.core.schema import NodeWithScore
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from pydantic.alias_generators import to_camel
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
class FileContent(BaseModel):
|
||||
type: Literal["text", "ref"]
|
||||
# If the file is pure text then the value is be a string
|
||||
# otherwise, it's a list of document IDs
|
||||
value: str | List[str]
|
||||
|
||||
|
||||
class File(BaseModel):
|
||||
id: str
|
||||
content: FileContent
|
||||
filename: str
|
||||
filesize: int
|
||||
filetype: str
|
||||
|
||||
|
||||
class AnnotationFileData(BaseModel):
|
||||
files: List[File] = Field(
|
||||
default=[],
|
||||
description="List of files",
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"csvFiles": [
|
||||
{
|
||||
"content": "Name, Age\nAlice, 25\nBob, 30",
|
||||
"filename": "example.csv",
|
||||
"filesize": 123,
|
||||
"id": "123",
|
||||
"type": "text/csv",
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
alias_generator = to_camel
|
||||
|
||||
|
||||
class Annotation(BaseModel):
|
||||
type: str
|
||||
data: AnnotationFileData | List[str]
|
||||
|
||||
def to_content(self) -> str | None:
|
||||
if self.type == "document_file":
|
||||
# We only support generating context content for CSV files for now
|
||||
csv_files = [file for file in self.data.files if file.filetype == "csv"]
|
||||
if len(csv_files) > 0:
|
||||
return "Use data from following CSV raw content\n" + "\n".join(
|
||||
[f"```csv\n{csv_file.content.value}\n```" for csv_file in csv_files]
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"The annotation {self.type} is not supported for generating context content"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
role: MessageRole
|
||||
content: str
|
||||
annotations: List[Annotation] | None = None
|
||||
|
||||
|
||||
class ChatData(BaseModel):
|
||||
messages: List[Message]
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What standards for letters exist?",
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@validator("messages")
|
||||
def messages_must_not_be_empty(cls, v):
|
||||
if len(v) == 0:
|
||||
raise ValueError("Messages must not be empty")
|
||||
return v
|
||||
|
||||
def get_last_message_content(self) -> str:
|
||||
"""
|
||||
Get the content of the last message along with the data content if available.
|
||||
Fallback to use data content from previous messages
|
||||
"""
|
||||
if len(self.messages) == 0:
|
||||
raise ValueError("There is not any message in the chat")
|
||||
last_message = self.messages[-1]
|
||||
message_content = last_message.content
|
||||
for message in reversed(self.messages):
|
||||
if message.role == MessageRole.USER and message.annotations is not None:
|
||||
annotation_contents = filter(
|
||||
None,
|
||||
[annotation.to_content() for annotation in message.annotations],
|
||||
)
|
||||
if not annotation_contents:
|
||||
continue
|
||||
annotation_text = "\n".join(annotation_contents)
|
||||
message_content = f"{message_content}\n{annotation_text}"
|
||||
break
|
||||
return message_content
|
||||
|
||||
def get_history_messages(self) -> List[ChatMessage]:
|
||||
"""
|
||||
Get the history messages
|
||||
"""
|
||||
return [
|
||||
ChatMessage(role=message.role, content=message.content)
|
||||
for message in self.messages[:-1]
|
||||
]
|
||||
|
||||
def is_last_message_from_user(self) -> bool:
|
||||
return self.messages[-1].role == MessageRole.USER
|
||||
|
||||
def get_chat_document_ids(self) -> List[str]:
|
||||
"""
|
||||
Get the document IDs from the chat messages
|
||||
"""
|
||||
document_ids: List[str] = []
|
||||
for message in self.messages:
|
||||
if message.role == MessageRole.USER and message.annotations is not None:
|
||||
for annotation in message.annotations:
|
||||
if (
|
||||
annotation.type == "document_file"
|
||||
and annotation.data.files is not None
|
||||
):
|
||||
for fi in annotation.data.files:
|
||||
if fi.content.type == "ref":
|
||||
document_ids += fi.content.value
|
||||
return list(set(document_ids))
|
||||
|
||||
|
||||
class LlamaCloudFile(BaseModel):
|
||||
file_name: str
|
||||
pipeline_id: str
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, LlamaCloudFile):
|
||||
return NotImplemented
|
||||
return (
|
||||
self.file_name == other.file_name and self.pipeline_id == other.pipeline_id
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.file_name, self.pipeline_id))
|
||||
|
||||
|
||||
class SourceNodes(BaseModel):
|
||||
id: str
|
||||
metadata: Dict[str, Any]
|
||||
score: Optional[float]
|
||||
text: str
|
||||
url: Optional[str]
|
||||
|
||||
@classmethod
|
||||
def from_source_node(cls, source_node: NodeWithScore):
|
||||
metadata = source_node.node.metadata
|
||||
url = cls.get_url_from_metadata(metadata)
|
||||
|
||||
return cls(
|
||||
id=source_node.node.node_id,
|
||||
metadata=metadata,
|
||||
score=source_node.score,
|
||||
text=source_node.node.text, # type: ignore
|
||||
url=url,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_url_from_metadata(cls, metadata: Dict[str, Any]) -> str:
|
||||
url_prefix = os.getenv("FILESERVER_URL_PREFIX")
|
||||
if not url_prefix:
|
||||
logger.warning(
|
||||
"Warning: FILESERVER_URL_PREFIX not set in environment variables. Can't use file server"
|
||||
)
|
||||
file_name = metadata.get("file_name")
|
||||
if file_name and url_prefix:
|
||||
# file_name exists and file server is configured
|
||||
pipeline_id = metadata.get("pipeline_id")
|
||||
if pipeline_id and metadata.get("private") is None:
|
||||
# file is from LlamaCloud and was not ingested locally
|
||||
file_name = f"{pipeline_id}${file_name}"
|
||||
return f"{url_prefix}/output/llamacloud/{file_name}"
|
||||
is_private = metadata.get("private", "false") == "true"
|
||||
if is_private:
|
||||
return f"{url_prefix}/output/uploaded/{file_name}"
|
||||
return f"{url_prefix}/data/{file_name}"
|
||||
else:
|
||||
# fallback to URL in metadata (e.g. for websites)
|
||||
return metadata.get("URL")
|
||||
|
||||
@classmethod
|
||||
def from_source_nodes(cls, source_nodes: List[NodeWithScore]):
|
||||
return [cls.from_source_node(node) for node in source_nodes]
|
||||
|
||||
@staticmethod
|
||||
def get_download_files(nodes: List[NodeWithScore]) -> Set[LlamaCloudFile]:
|
||||
source_nodes = SourceNodes.from_source_nodes(nodes)
|
||||
llama_cloud_files = [
|
||||
LlamaCloudFile(
|
||||
file_name=node.metadata.get("file_name"),
|
||||
pipeline_id=node.metadata.get("pipeline_id"),
|
||||
)
|
||||
for node in source_nodes
|
||||
if (
|
||||
node.metadata.get("private")
|
||||
is None # Only download files are from LlamaCloud and were not ingested locally
|
||||
and node.metadata.get("pipeline_id") is not None
|
||||
and node.metadata.get("file_name") is not None
|
||||
)
|
||||
]
|
||||
# Remove duplicates and return
|
||||
return set(llama_cloud_files)
|
||||
|
||||
|
||||
class Result(BaseModel):
|
||||
result: Message
|
||||
nodes: List[SourceNodes]
|
||||
|
||||
|
||||
class ChatConfig(BaseModel):
|
||||
starter_questions: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="List of starter questions",
|
||||
serialization_alias="starterQuestions"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"starterQuestions": [
|
||||
"What standards for letters exist?",
|
||||
"What are the requirements for a letter to be considered a letter?",
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.api.services.file import PrivateFileService
|
||||
|
||||
file_upload_router = r = APIRouter()
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
class FileUploadRequest(BaseModel):
|
||||
base64: str
|
||||
|
||||
|
||||
@r.post("")
|
||||
def upload_file(request: FileUploadRequest) -> List[str]:
|
||||
try:
|
||||
logger.info("Processing file")
|
||||
return PrivateFileService.process_file(request.base64)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing file: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Error processing file")
|
||||
@@ -1,6 +1,13 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from aiostream import stream
|
||||
from fastapi import Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from llama_index.core.chat_engine.types import StreamingAgentChatResponse
|
||||
|
||||
from app.api.routers.events import EventCallbackHandler
|
||||
from app.api.routers.models import ChatData, Message, SourceNodes
|
||||
from app.api.services.suggestion import NextQuestionSuggestion
|
||||
|
||||
|
||||
class VercelStreamResponse(StreamingResponse):
|
||||
@@ -22,8 +29,81 @@ class VercelStreamResponse(StreamingResponse):
|
||||
data_str = json.dumps(data)
|
||||
return f"{cls.DATA_PREFIX}[{data_str}]\n"
|
||||
|
||||
def __init__(self, content: Any, **kwargs):
|
||||
super().__init__(
|
||||
content=content,
|
||||
**kwargs,
|
||||
def __init__(
|
||||
self,
|
||||
request: Request,
|
||||
event_handler: EventCallbackHandler,
|
||||
response: StreamingAgentChatResponse,
|
||||
chat_data: ChatData,
|
||||
):
|
||||
content = VercelStreamResponse.content_generator(
|
||||
request, event_handler, response, chat_data
|
||||
)
|
||||
super().__init__(content=content)
|
||||
|
||||
@classmethod
|
||||
async def content_generator(
|
||||
cls,
|
||||
request: Request,
|
||||
event_handler: EventCallbackHandler,
|
||||
response: StreamingAgentChatResponse,
|
||||
chat_data: ChatData,
|
||||
):
|
||||
# Yield the text response
|
||||
async def _chat_response_generator():
|
||||
final_response = ""
|
||||
async for token in response.async_response_gen():
|
||||
final_response += token
|
||||
yield VercelStreamResponse.convert_text(token)
|
||||
|
||||
# Generate questions that user might interested to
|
||||
conversation = chat_data.messages + [
|
||||
Message(role="assistant", content=final_response)
|
||||
]
|
||||
questions = await NextQuestionSuggestion.suggest_next_questions(
|
||||
conversation
|
||||
)
|
||||
if len(questions) > 0:
|
||||
yield VercelStreamResponse.convert_data(
|
||||
{
|
||||
"type": "suggested_questions",
|
||||
"data": questions,
|
||||
}
|
||||
)
|
||||
|
||||
# the text_generator is the leading stream, once it's finished, also finish the event stream
|
||||
event_handler.is_done = True
|
||||
|
||||
# Yield the source nodes
|
||||
yield cls.convert_data(
|
||||
{
|
||||
"type": "sources",
|
||||
"data": {
|
||||
"nodes": [
|
||||
SourceNodes.from_source_node(node).dict()
|
||||
for node in response.source_nodes
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Yield the events from the event handler
|
||||
async def _event_generator():
|
||||
async for event in event_handler.async_event_gen():
|
||||
event_response = event.to_response()
|
||||
if event_response is not None:
|
||||
yield VercelStreamResponse.convert_data(event_response)
|
||||
|
||||
combine = stream.merge(_chat_response_generator(), _event_generator())
|
||||
is_stream_started = False
|
||||
async with combine.stream() as streamer:
|
||||
async for output in streamer:
|
||||
if not is_stream_started:
|
||||
is_stream_started = True
|
||||
# Stream a blank message to start the stream
|
||||
yield VercelStreamResponse.convert_text("")
|
||||
|
||||
yield output
|
||||
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
|
||||
@@ -0,0 +1,113 @@
|
||||
import base64
|
||||
import mimetypes
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
from uuid import uuid4
|
||||
|
||||
from app.engine.index import get_index
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.ingestion import IngestionPipeline
|
||||
from llama_index.core.readers.file.base import (
|
||||
_try_loading_included_file_formats as get_file_loaders_map,
|
||||
)
|
||||
from llama_index.core.readers.file.base import (
|
||||
default_file_metadata_func,
|
||||
)
|
||||
from llama_index.core.schema import Document
|
||||
from llama_index.indices.managed.llama_cloud.base import LlamaCloudIndex
|
||||
from llama_index.readers.file import FlatReader
|
||||
|
||||
|
||||
def get_llamaparse_parser():
|
||||
from app.engine.loaders import load_configs
|
||||
from app.engine.loaders.file import FileLoaderConfig, llama_parse_parser
|
||||
|
||||
config = load_configs()
|
||||
file_loader_config = FileLoaderConfig(**config["file"])
|
||||
if file_loader_config.use_llama_parse:
|
||||
return llama_parse_parser()
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def default_file_loaders_map():
|
||||
default_loaders = get_file_loaders_map()
|
||||
default_loaders[".txt"] = FlatReader
|
||||
return default_loaders
|
||||
|
||||
|
||||
class PrivateFileService:
|
||||
PRIVATE_STORE_PATH = "output/uploaded"
|
||||
|
||||
@staticmethod
|
||||
def preprocess_base64_file(base64_content: str) -> tuple:
|
||||
header, data = base64_content.split(",", 1)
|
||||
mime_type = header.split(";")[0].split(":", 1)[1]
|
||||
extension = mimetypes.guess_extension(mime_type)
|
||||
# File data as bytes
|
||||
return base64.b64decode(data), extension
|
||||
|
||||
@staticmethod
|
||||
def store_and_parse_file(file_data, extension) -> List[Document]:
|
||||
# Store file to the private directory
|
||||
os.makedirs(PrivateFileService.PRIVATE_STORE_PATH, exist_ok=True)
|
||||
|
||||
# random file name
|
||||
file_name = f"{uuid4().hex}{extension}"
|
||||
file_path = Path(os.path.join(PrivateFileService.PRIVATE_STORE_PATH, file_name))
|
||||
|
||||
# write file
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(file_data)
|
||||
|
||||
# Load file to documents
|
||||
# If LlamaParse is enabled, use it to parse the file
|
||||
# Otherwise, use the default file loaders
|
||||
reader = get_llamaparse_parser()
|
||||
if reader is None:
|
||||
reader_cls = default_file_loaders_map().get(extension)
|
||||
if reader_cls is None:
|
||||
raise ValueError(f"File extension {extension} is not supported")
|
||||
reader = reader_cls()
|
||||
documents = reader.load_data(file_path)
|
||||
# Add custom metadata
|
||||
for doc in documents:
|
||||
doc.metadata["file_name"] = file_name
|
||||
doc.metadata["private"] = "true"
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def process_file(base64_content: str) -> List[str]:
|
||||
file_data, extension = PrivateFileService.preprocess_base64_file(base64_content)
|
||||
documents = PrivateFileService.store_and_parse_file(file_data, extension)
|
||||
|
||||
# Only process nodes, no store the index
|
||||
pipeline = IngestionPipeline()
|
||||
nodes = pipeline.run(documents=documents)
|
||||
|
||||
# Add the nodes to the index and persist it
|
||||
current_index = get_index()
|
||||
|
||||
# Insert the documents into the index
|
||||
if isinstance(current_index, LlamaCloudIndex):
|
||||
# LlamaCloudIndex is a managed index so we don't need to process the nodes
|
||||
# just insert the documents
|
||||
for doc in documents:
|
||||
current_index.insert(doc)
|
||||
else:
|
||||
# Only process nodes, no store the index
|
||||
pipeline = IngestionPipeline()
|
||||
nodes = pipeline.run(documents=documents)
|
||||
|
||||
# Add the nodes to the index and persist it
|
||||
if current_index is None:
|
||||
current_index = VectorStoreIndex(nodes=nodes)
|
||||
else:
|
||||
current_index.insert_nodes(nodes=nodes)
|
||||
current_index.storage_context.persist(
|
||||
persist_dir=os.environ.get("STORAGE_DIR", "storage")
|
||||
)
|
||||
|
||||
# Return the document ids
|
||||
return [doc.doc_id for doc in documents]
|
||||
@@ -0,0 +1,88 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from app.api.routers.models import LlamaCloudFile
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
class LLamaCloudFileService:
|
||||
LLAMA_CLOUD_URL = "https://cloud.llamaindex.ai/api/v1"
|
||||
LOCAL_STORE_PATH = "output/llamacloud"
|
||||
|
||||
DOWNLOAD_FILE_NAME_TPL = "{pipeline_id}${filename}"
|
||||
|
||||
@classmethod
|
||||
def _get_files(cls, pipeline_id: str) -> List[Dict[str, Any]]:
|
||||
url = f"{cls.LLAMA_CLOUD_URL}/pipelines/{pipeline_id}/files"
|
||||
return cls._make_request(url)
|
||||
|
||||
@classmethod
|
||||
def _get_file_detail(cls, project_id: str, file_id: str) -> Dict[str, Any]:
|
||||
url = f"{cls.LLAMA_CLOUD_URL}/files/{file_id}/content?project_id={project_id}"
|
||||
return cls._make_request(url)
|
||||
|
||||
@classmethod
|
||||
def _download_file(cls, url: str, local_file_path: str):
|
||||
logger.info(f"Downloading file to {local_file_path}")
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(cls.LOCAL_STORE_PATH, exist_ok=True)
|
||||
# Download the file
|
||||
with requests.get(url, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
with open(local_file_path, "wb") as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
logger.info("File downloaded successfully")
|
||||
|
||||
@classmethod
|
||||
def download_llamacloud_pipeline_file(
|
||||
cls,
|
||||
file: LlamaCloudFile,
|
||||
force_download: bool = False,
|
||||
):
|
||||
file_name = file.file_name
|
||||
pipeline_id = file.pipeline_id
|
||||
|
||||
# Check is the file already exists
|
||||
downloaded_file_path = cls.get_file_path(file_name, pipeline_id)
|
||||
if os.path.exists(downloaded_file_path) and not force_download:
|
||||
logger.debug(f"File {file_name} already exists in local storage")
|
||||
return
|
||||
try:
|
||||
logger.info(f"Downloading file {file_name} for pipeline {pipeline_id}")
|
||||
files = cls._get_files(pipeline_id)
|
||||
if not files or not isinstance(files, list):
|
||||
raise Exception("No files found in LlamaCloud")
|
||||
for file_entry in files:
|
||||
if file_entry["name"] == file_name:
|
||||
file_id = file_entry["file_id"]
|
||||
project_id = file_entry["project_id"]
|
||||
file_detail = cls._get_file_detail(project_id, file_id)
|
||||
cls._download_file(file_detail["url"], downloaded_file_path)
|
||||
break
|
||||
except Exception as error:
|
||||
logger.info(f"Error fetching file from LlamaCloud: {error}")
|
||||
|
||||
@classmethod
|
||||
def get_file_name(cls, name: str, pipeline_id: str) -> str:
|
||||
return cls.DOWNLOAD_FILE_NAME_TPL.format(pipeline_id=pipeline_id, filename=name)
|
||||
|
||||
@classmethod
|
||||
def get_file_path(cls, name: str, pipeline_id: str) -> str:
|
||||
return os.path.join(cls.LOCAL_STORE_PATH, cls.get_file_name(name, pipeline_id))
|
||||
|
||||
@staticmethod
|
||||
def _make_request(
|
||||
url: str, data=None, headers: Optional[Dict] = None, method: str = "get"
|
||||
):
|
||||
if headers is None:
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Authorization": f'Bearer {os.getenv("LLAMA_CLOUD_API_KEY")}',
|
||||
}
|
||||
response = requests.request(method, url, headers=headers, data=data)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user