mirror of
https://github.com/run-llama/create-llama.git
synced 2026-07-03 08:24:39 -04:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c3215ccc7b | |||
| 18ca18123f | |||
| 5ecb0c9fb7 | |||
| 7e45f604e6 | |||
| bbacf0f199 | |||
| c0c6df80c7 | |||
| 3b39a12ad6 |
@@ -0,0 +1,5 @@
|
||||
---
|
||||
"create-llama": patch
|
||||
---
|
||||
|
||||
Use ingestion pipeline for Python
|
||||
@@ -0,0 +1,5 @@
|
||||
---
|
||||
"create-llama": patch
|
||||
---
|
||||
|
||||
Display events (e.g. retrieving nodes) per chat message
|
||||
@@ -1,6 +0,0 @@
|
||||
# coderabbit.yml
|
||||
reviews:
|
||||
path_instructions:
|
||||
- path: "templates/**"
|
||||
instructions: |
|
||||
For files under the `templates` folder, do not report 'Missing Dependencies Detected' errors.
|
||||
@@ -17,9 +17,7 @@ jobs:
|
||||
matrix:
|
||||
node-version: [18, 20]
|
||||
python-version: ["3.11"]
|
||||
os: [macos-latest, windows-latest, ubuntu-22.04]
|
||||
frameworks: ["nextjs", "express", "fastapi"]
|
||||
datasources: ["--no-files", "--example-file"]
|
||||
os: [macos-latest, windows-latest]
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
@@ -28,7 +26,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
@@ -64,9 +62,6 @@ jobs:
|
||||
run: pnpm run e2e
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLAMA_CLOUD_API_KEY: ${{ secrets.LLAMA_CLOUD_API_KEY }}
|
||||
FRAMEWORK: ${{ matrix.frameworks }}
|
||||
DATASOURCE: ${{ matrix.datasources }}
|
||||
working-directory: .
|
||||
|
||||
- uses: actions/upload-artifact@v3
|
||||
|
||||
@@ -30,13 +30,3 @@ jobs:
|
||||
|
||||
- name: Run Prettier
|
||||
run: pnpm run format
|
||||
|
||||
- name: Run Python format check
|
||||
uses: chartboost/ruff-action@v1
|
||||
with:
|
||||
args: "format --check"
|
||||
|
||||
- name: Run Python lint
|
||||
uses: chartboost/ruff-action@v1
|
||||
with:
|
||||
args: "check"
|
||||
|
||||
@@ -46,8 +46,5 @@ e2e/cache
|
||||
# intellij
|
||||
**/.idea
|
||||
|
||||
# Python
|
||||
.mypy_cache/
|
||||
|
||||
# build artifacts
|
||||
create-llama-*.tgz
|
||||
|
||||
-366
@@ -1,371 +1,5 @@
|
||||
# create-llama
|
||||
|
||||
## 0.2.10
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- cb8d535: Fix only produces one agent event
|
||||
|
||||
## 0.2.9
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 0213fe0: Update dependencies for vector stores and add e2e test to ensure that they work as expected.
|
||||
|
||||
## 0.2.8
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 0031e67: Bump llama-index to 0.11.11 for the multi-agent template
|
||||
|
||||
## 0.2.7
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 505b8e9: bump: use latest ai package version
|
||||
- cf3ec97: Dynamically select model for Groq
|
||||
- 8c1087f: feat: enhance style for markdown
|
||||
|
||||
## 0.2.6
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- adc40cf: fix: vercel ai update crash sending annotations
|
||||
|
||||
## 0.2.5
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 38a8be8: fix: filter in mongo vector store
|
||||
|
||||
## 0.2.4
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 917e862: Fix errors in building the frontend
|
||||
|
||||
## 0.2.3
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- b6da3c2: Ensure the generation script always works
|
||||
|
||||
## 0.2.2
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 8105c5c: Add env config for next questions feature
|
||||
|
||||
## 0.2.1
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 6a409cb: Bump web and database reader packages
|
||||
|
||||
## 0.2.0
|
||||
|
||||
### Minor Changes
|
||||
|
||||
- 435109f: Add multi-agents template based on workflows
|
||||
|
||||
## 0.1.44
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- bedde2b: Change metadata filters to use already existing documents in LlamaCloud Index
|
||||
- 5cd12fa: Use one callback manager per request
|
||||
- 5cd12fa: Bump llama_index version to 0.11.1
|
||||
- fd4abb3: Fix to use filename for uploaded documents in NextJS
|
||||
- 2f8feab: Simplify CLI interface
|
||||
|
||||
## 0.1.43
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 4fa2b76: feat: implement citation for TS
|
||||
|
||||
## 0.1.42
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 8f670a9: Allow relative URL in documents
|
||||
|
||||
## 0.1.41
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 57e7638: Use the retrieval defaults from LlamaCloud
|
||||
|
||||
## 0.1.40
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 8ce4a85: Add UI for extractor template
|
||||
|
||||
## 0.1.39
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 3fb93c7: Use LlamaCloud pipeline for data ingestion in TS (private file uploads and generate script)
|
||||
|
||||
## 0.1.38
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- bd5e39a: Fix error that files in sub folders of 'data' are not displayed
|
||||
|
||||
## 0.1.37
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 9fd832c: Add in-text citation references
|
||||
|
||||
## 0.1.36
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 2b7a5d8: Fix: private file upload not working in Python without LlamaCloud
|
||||
|
||||
## 0.1.35
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 81ef7f0: Use LlamaCloud pipeline for data ingestion (private file uploads and generate script)
|
||||
|
||||
## 0.1.34
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- c49a5e1: Add error handling for generating the next question
|
||||
- c49a5e1: Fix wrong api key variable in Azure OpenAI provider
|
||||
|
||||
## 0.1.33
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- d746c75: Add Weaviate vector store (Typescript)
|
||||
|
||||
## 0.1.32
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 3ec5163: Add Weaviate vector database support (Python)
|
||||
|
||||
## 0.1.31
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 04a9c71: Cluster nodes by document
|
||||
|
||||
## 0.1.30
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 09e3022: Add support for LlamaTrace (Python)
|
||||
- c06ec4f: Fix imports for MongoDB
|
||||
- b6dd7a9: Always send chat data when submit message
|
||||
|
||||
## 0.1.29
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 8890e27: Let user change indexes in LlamaCloud projects
|
||||
|
||||
## 0.1.28
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 9a09e8c: Fix Vercel deployment
|
||||
|
||||
## 0.1.27
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- c5c7eee: Make components reusable for chat-llamaindex
|
||||
|
||||
## 0.1.26
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- f43399c: Add metadatafilters to context chat engine (Typescript)
|
||||
|
||||
## 0.1.25
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- c67daeb: fix: missing set private to false for default generate.py
|
||||
|
||||
## 0.1.24
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 43474a5: Configure LlamaCloud organization ID for Python
|
||||
- cf11b23: Add Azure code interpreter for Python and TS
|
||||
- fd9fb42: Add Azure OpenAI as model provider
|
||||
- 5c13646: Fix starter questions not working in python backend
|
||||
|
||||
## 0.1.23
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 6bd76fb: Add template for structured extraction
|
||||
|
||||
## 0.1.22
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- b0becaa: Add e2e testing for llamacloud datasource
|
||||
- df9cca5: Upgrade pdf viewer
|
||||
|
||||
## 0.1.21
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- bd4714c: Filter private documents for Typescript (Using MetadataFilters) and update to LlamaIndexTS 0.5.7
|
||||
- 58e6c15: Add using LlamaParse for private file uploader
|
||||
- 455ab68: Display files in sources using LlamaCloud indexes.
|
||||
- 23b7357: Use gpt-4o-mini as default model
|
||||
- 0900413: Add suggestions for next questions.
|
||||
|
||||
## 0.1.20
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 624c721: Update to LlamaIndex 0.10.55
|
||||
|
||||
## 0.1.19
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- df96159: Use Qdrant FastEmbed as local embedding provider
|
||||
- 32fb32a: Support upload document files: pdf, docx, txt
|
||||
|
||||
## 0.1.18
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- d1026ea: support Mistral as llm and embedding
|
||||
- a221cfc: Use LlamaParse for all the file types that it supports (if activated)
|
||||
|
||||
## 0.1.17
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 9ecd061: Add new template for a multi-agents app
|
||||
|
||||
## 0.1.16
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- a0aab03: Add T-System's LLMHUB as a model provider
|
||||
|
||||
## 0.1.15
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 64732f0: Fix the issue of images not showing with the sandbox URL from OpenAI's models
|
||||
- aeb6fef: use llamacloud for chat
|
||||
|
||||
## 0.1.14
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- f2c3389: chore: update to llamaindex 0.4.3
|
||||
- 5093b37: Remove non-working file selectors for Linux
|
||||
|
||||
## 0.1.13
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- b3c969d: Add image generator tool
|
||||
|
||||
## 0.1.12
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- aa69014: Fix NextJS for TS 5.2
|
||||
|
||||
## 0.1.11
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 48b96ff: Add DuckDuckGo search tool
|
||||
- 9c9decb: Reuse function tool instances and improve e2b interpreter tool for Python
|
||||
- 02ed277: Add Groq as a model provider
|
||||
- 0748f2e: Remove hard-coded Gemini supported models
|
||||
|
||||
## 0.1.10
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 9112d08: Add OpenAPI tool for Typescript
|
||||
- 8f03f8d: Add OLLAMA_REQUEST_TIMEOUT variable to config Ollama timeout (Python)
|
||||
- 8f03f8d: Apply nest_asyncio for llama parse
|
||||
|
||||
## 0.1.9
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- a42fa53: Add CSV upload
|
||||
- 563b51d: Fix Vercel streaming (python) to stream data events instantly
|
||||
- d60b3c5: Add E2B code interpreter tool for FastAPI
|
||||
- 956538e: Add OpenAPI action tool for FastAPI
|
||||
|
||||
## 0.1.8
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- cd50a33: Add interpreter tool for TS using e2b.dev
|
||||
|
||||
## 0.1.7
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 260d37a: Add system prompt env variable for TS
|
||||
- bbd5b8d: Fix postgres connection leaking issue
|
||||
- bb53425: Support HTTP proxies by setting the GLOBAL_AGENT_HTTP_PROXY env variable
|
||||
- 69c2e16: Fix streaming for Express
|
||||
- 7873bfb: Update Ollama provider to run with the base URL from the environment variable
|
||||
|
||||
## 0.1.6
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 56537a1: Display PDF files in source nodes
|
||||
|
||||
## 0.1.5
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 84db798: feat: support display latex in chat markdown
|
||||
|
||||
## 0.1.4
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 0bc8e75: Use ingestion pipeline for dedicated vector stores (Python only)
|
||||
- cb1001d: Add ChromaDB vector store
|
||||
|
||||
## 0.1.3
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 416073d: Directly import vector stores to work with NextJS
|
||||
|
||||
## 0.1.2
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 056e376: Add support for displaying tool outputs (including weather widget as example)
|
||||
|
||||
## 0.1.1
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- 7bd3ed5: Support Anthropic and Gemini as model providers
|
||||
- 7bd3ed5: Support new agents from LITS 0.3
|
||||
- cfb5257: Display events (e.g. retrieving nodes) per chat message
|
||||
|
||||
## 0.1.0
|
||||
|
||||
### Minor Changes
|
||||
|
||||
@@ -1,20 +1,14 @@
|
||||
# Create Llama
|
||||
# Create LlamaIndex App
|
||||
|
||||
The easiest way to get started with [LlamaIndex](https://www.llamaindex.ai/) is by using `create-llama`. This CLI tool enables you to quickly start building a new LlamaIndex application, with everything set up for you.
|
||||
|
||||
## Get started
|
||||
|
||||
Just run
|
||||
|
||||
```bash
|
||||
npx create-llama@latest
|
||||
```
|
||||
|
||||
to get started, or watch this video for a demo session:
|
||||
|
||||
https://github.com/user-attachments/assets/dd3edc36-4453-4416-91c2-d24326c6c167
|
||||
|
||||
Once your app is generated, run
|
||||
to get started, or see below for more options. Once your app is generated, run
|
||||
|
||||
```bash
|
||||
npm run dev
|
||||
@@ -24,20 +18,16 @@ to start the development server. You can then visit [http://localhost:3000](http
|
||||
|
||||
## What you'll get
|
||||
|
||||
- A Next.js-powered front-end using components from [shadcn/ui](https://ui.shadcn.com/). The app is set up as a chat interface that can answer questions about your data or interact with your agent
|
||||
- A Next.js-powered front-end using components from [shadcn/ui](https://ui.shadcn.com/). The app is set up as a chat interface that can answer questions about your data (see below)
|
||||
- Your choice of 3 back-ends:
|
||||
- **Next.js**: if you select this option, you’ll have a full-stack Next.js application that you can deploy to a host like [Vercel](https://vercel.com/) in just a few clicks. This uses [LlamaIndex.TS](https://www.npmjs.com/package/llamaindex), our TypeScript library.
|
||||
- **Express**: if you want a more traditional Node.js application you can generate an Express backend. This also uses LlamaIndex.TS.
|
||||
- **Python FastAPI**: if you select this option, you’ll get a backend powered by the [llama-index Python package](https://pypi.org/project/llama-index/), which you can deploy to a service like Render or fly.io.
|
||||
- **Python FastAPI**: if you select this option, you’ll get a backend powered by the [llama-index python package](https://pypi.org/project/llama-index/), which you can deploy to a service like Render or fly.io.
|
||||
- The back-end has two endpoints (one streaming, the other one non-streaming) that allow you to send the state of your chat and receive additional responses
|
||||
- You add arbitrary data sources to your chat, like local files, websites, or data retrieved from a database.
|
||||
- Turn your chat into an AI agent by adding tools (functions called by the LLM).
|
||||
- The app uses OpenAI by default, so you'll need an OpenAI API key, or you can customize it to use any of the dozens of LLMs we support.
|
||||
|
||||
Here's how it looks like:
|
||||
|
||||
https://github.com/user-attachments/assets/d57af1a1-d99b-4e9c-98d9-4cbd1327eff8
|
||||
|
||||
## Using your data
|
||||
|
||||
You can supply your own data; the app will index it and answer questions. Your generated app will have a folder called `data` (If you're using Express or Python and generate a frontend, it will be `./backend/data`).
|
||||
@@ -64,7 +54,7 @@ Optionally generate a frontend if you've selected the Python or Express back-end
|
||||
|
||||
## Customizing the AI models
|
||||
|
||||
The app will default to OpenAI's `gpt-4o-mini` LLM and `text-embedding-3-large` embedding model.
|
||||
The app will default to OpenAI's `gpt-4-turbo` LLM and `text-embedding-3-large` embedding model.
|
||||
|
||||
If you want to use different OpenAI models, add the `--ask-models` CLI parameter.
|
||||
|
||||
@@ -94,7 +84,7 @@ Need to install the following packages:
|
||||
create-llama@latest
|
||||
Ok to proceed? (y) y
|
||||
✔ What is your project named? … my-app
|
||||
✔ Which template would you like to use? › Agentic RAG (e.g. chat with docs)
|
||||
✔ Which template would you like to use? › Chat
|
||||
✔ Which framework would you like to use? › NextJS
|
||||
✔ Would you like to set up observability? › No
|
||||
✔ Please provide your OpenAI API key (leave blank to skip): …
|
||||
@@ -102,7 +92,6 @@ Ok to proceed? (y) y
|
||||
✔ Would you like to add another data source? › No
|
||||
✔ Would you like to use LlamaParse (improved parser for RAG - requires API key)? … no / yes
|
||||
✔ Would you like to use a vector database? › No, just store the data in the file system
|
||||
✔ Would you like to build an agent using tools? If so, select the tools here, otherwise just press enter › Weather
|
||||
? How would you like to proceed? › - Use arrow-keys. Return to submit.
|
||||
Just generate code (~1 sec)
|
||||
❯ Start in VSCode (~1 sec)
|
||||
|
||||
+6
-34
@@ -9,7 +9,7 @@ import { makeDir } from "./helpers/make-dir";
|
||||
|
||||
import fs from "fs";
|
||||
import terminalLink from "terminal-link";
|
||||
import type { InstallTemplateArgs, TemplateObservability } from "./helpers";
|
||||
import type { InstallTemplateArgs } from "./helpers";
|
||||
import { installTemplate } from "./helpers";
|
||||
import { writeDevcontainer } from "./helpers/devcontainer";
|
||||
import { templatesDir } from "./helpers/dir";
|
||||
@@ -142,42 +142,14 @@ export async function createApp({
|
||||
)} and learn how to get started.`,
|
||||
);
|
||||
|
||||
outputObservability(args.observability);
|
||||
|
||||
if (
|
||||
dataSources.some((dataSource) => dataSource.type === "file") &&
|
||||
process.platform === "linux"
|
||||
) {
|
||||
if (args.observability === "opentelemetry") {
|
||||
console.log(
|
||||
yellow(
|
||||
`You can add your own data files to ${terminalLink(
|
||||
"data",
|
||||
`file://${root}/data`,
|
||||
)} folder manually.`,
|
||||
),
|
||||
`\n${yellow("Observability")}: Visit the ${terminalLink(
|
||||
"documentation",
|
||||
"https://traceloop.com/docs/openllmetry/integrations",
|
||||
)} to set up the environment variables and start seeing execution traces.`,
|
||||
);
|
||||
}
|
||||
|
||||
console.log();
|
||||
}
|
||||
|
||||
function outputObservability(observability?: TemplateObservability) {
|
||||
switch (observability) {
|
||||
case "traceloop":
|
||||
console.log(
|
||||
`\n${yellow("Observability")}: Visit the ${terminalLink(
|
||||
"documentation",
|
||||
"https://traceloop.com/docs/openllmetry/integrations",
|
||||
)} to set up the environment variables and start seeing execution traces.`,
|
||||
);
|
||||
break;
|
||||
case "llamatrace":
|
||||
console.log(
|
||||
`\n${yellow("Observability")}: LlamaTrace has been configured for your project. Visit the ${terminalLink(
|
||||
"LlamaTrace dashboard",
|
||||
"https://llamatrace.com/login",
|
||||
)} to view your traces and monitor your application.`,
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,129 @@
|
||||
/* eslint-disable turbo/no-undeclared-env-vars */
|
||||
import { expect, test } from "@playwright/test";
|
||||
import { ChildProcess } from "child_process";
|
||||
import fs from "fs";
|
||||
import path from "path";
|
||||
import type {
|
||||
TemplateFramework,
|
||||
TemplatePostInstallAction,
|
||||
TemplateType,
|
||||
TemplateUI,
|
||||
} from "../helpers";
|
||||
import { createTestDir, runCreateLlama, type AppType } from "./utils";
|
||||
|
||||
const templateTypes: TemplateType[] = ["streaming"];
|
||||
const templateFrameworks: TemplateFramework[] = [
|
||||
"nextjs",
|
||||
"express",
|
||||
"fastapi",
|
||||
];
|
||||
const dataSources: string[] = ["--no-files", "--example-file"];
|
||||
const templateUIs: TemplateUI[] = ["shadcn", "html"];
|
||||
const templatePostInstallActions: TemplatePostInstallAction[] = [
|
||||
"none",
|
||||
"runApp",
|
||||
];
|
||||
|
||||
for (const templateType of templateTypes) {
|
||||
for (const templateFramework of templateFrameworks) {
|
||||
for (const dataSource of dataSources) {
|
||||
for (const templateUI of templateUIs) {
|
||||
for (const templatePostInstallAction of templatePostInstallActions) {
|
||||
const appType: AppType =
|
||||
templateFramework === "nextjs" ? "" : "--frontend";
|
||||
test.describe(`try create-llama ${templateType} ${templateFramework} ${dataSource} ${templateUI} ${appType} ${templatePostInstallAction}`, async () => {
|
||||
let port: number;
|
||||
let externalPort: number;
|
||||
let cwd: string;
|
||||
let name: string;
|
||||
let appProcess: ChildProcess;
|
||||
// Only test without using vector db for now
|
||||
const vectorDb = "none";
|
||||
|
||||
test.beforeAll(async () => {
|
||||
port = Math.floor(Math.random() * 10000) + 10000;
|
||||
externalPort = port + 1;
|
||||
cwd = await createTestDir();
|
||||
const result = await runCreateLlama(
|
||||
cwd,
|
||||
templateType,
|
||||
templateFramework,
|
||||
dataSource,
|
||||
templateUI,
|
||||
vectorDb,
|
||||
appType,
|
||||
port,
|
||||
externalPort,
|
||||
templatePostInstallAction,
|
||||
);
|
||||
name = result.projectName;
|
||||
appProcess = result.appProcess;
|
||||
});
|
||||
|
||||
test("App folder should exist", async () => {
|
||||
const dirExists = fs.existsSync(path.join(cwd, name));
|
||||
expect(dirExists).toBeTruthy();
|
||||
});
|
||||
test("Frontend should have a title", async ({ page }) => {
|
||||
test.skip(templatePostInstallAction !== "runApp");
|
||||
await page.goto(`http://localhost:${port}`);
|
||||
await expect(page.getByText("Built by LlamaIndex")).toBeVisible();
|
||||
});
|
||||
|
||||
test("Frontend should be able to submit a message and receive a response", async ({
|
||||
page,
|
||||
}) => {
|
||||
test.skip(templatePostInstallAction !== "runApp");
|
||||
await page.goto(`http://localhost:${port}`);
|
||||
await page.fill("form input", "hello");
|
||||
const [response] = await Promise.all([
|
||||
page.waitForResponse(
|
||||
(res) => {
|
||||
return (
|
||||
res.url().includes("/api/chat") && res.status() === 200
|
||||
);
|
||||
},
|
||||
{
|
||||
timeout: 1000 * 60,
|
||||
},
|
||||
),
|
||||
page.click("form button[type=submit]"),
|
||||
]);
|
||||
const text = await response.text();
|
||||
console.log("AI response when submitting message: ", text);
|
||||
expect(response.ok()).toBeTruthy();
|
||||
});
|
||||
|
||||
test("Backend frameworks should response when calling non-streaming chat API", async ({
|
||||
request,
|
||||
}) => {
|
||||
test.skip(templatePostInstallAction !== "runApp");
|
||||
test.skip(templateFramework === "nextjs");
|
||||
const response = await request.post(
|
||||
`http://localhost:${externalPort}/api/chat/request`,
|
||||
{
|
||||
data: {
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: "Hello",
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
);
|
||||
const text = await response.text();
|
||||
console.log("AI response when calling API: ", text);
|
||||
expect(response.ok()).toBeTruthy();
|
||||
});
|
||||
|
||||
// clean processes
|
||||
test.afterAll(async () => {
|
||||
appProcess?.kill();
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,64 +0,0 @@
|
||||
/* eslint-disable turbo/no-undeclared-env-vars */
|
||||
import { expect, test } from "@playwright/test";
|
||||
import { ChildProcess } from "child_process";
|
||||
import fs from "fs";
|
||||
import path from "path";
|
||||
import { TemplateFramework } from "../helpers";
|
||||
import { createTestDir, runCreateLlama } from "./utils";
|
||||
|
||||
const templateFramework: TemplateFramework = process.env.FRAMEWORK
|
||||
? (process.env.FRAMEWORK as TemplateFramework)
|
||||
: "fastapi";
|
||||
const dataSource: string = process.env.DATASOURCE
|
||||
? process.env.DATASOURCE
|
||||
: "--example-file";
|
||||
|
||||
// The extractor template currently only works with FastAPI and files (and not on Windows)
|
||||
if (
|
||||
process.platform !== "win32" &&
|
||||
templateFramework !== "nextjs" &&
|
||||
templateFramework !== "express" &&
|
||||
dataSource !== "--no-files"
|
||||
) {
|
||||
test.describe("Test extractor template", async () => {
|
||||
let frontendPort: number;
|
||||
let backendPort: number;
|
||||
let name: string;
|
||||
let appProcess: ChildProcess;
|
||||
let cwd: string;
|
||||
|
||||
// Create extractor app
|
||||
test.beforeAll(async () => {
|
||||
cwd = await createTestDir();
|
||||
frontendPort = Math.floor(Math.random() * 10000) + 10000;
|
||||
backendPort = frontendPort + 1;
|
||||
const result = await runCreateLlama(
|
||||
cwd,
|
||||
"extractor",
|
||||
"fastapi",
|
||||
"--example-file",
|
||||
"none",
|
||||
frontendPort,
|
||||
backendPort,
|
||||
"runApp",
|
||||
);
|
||||
name = result.projectName;
|
||||
appProcess = result.appProcess;
|
||||
});
|
||||
|
||||
test.afterAll(async () => {
|
||||
appProcess.kill();
|
||||
});
|
||||
|
||||
test("App folder should exist", async () => {
|
||||
const dirExists = fs.existsSync(path.join(cwd, name));
|
||||
expect(dirExists).toBeTruthy();
|
||||
});
|
||||
test("Frontend should have a title", async ({ page }) => {
|
||||
await page.goto(`http://localhost:${frontendPort}`);
|
||||
await expect(page.getByText("Built by LlamaIndex")).toBeVisible({
|
||||
timeout: 2000 * 60,
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -1,85 +0,0 @@
|
||||
/* eslint-disable turbo/no-undeclared-env-vars */
|
||||
import { expect, test } from "@playwright/test";
|
||||
import { ChildProcess } from "child_process";
|
||||
import fs from "fs";
|
||||
import path from "path";
|
||||
import type {
|
||||
TemplateFramework,
|
||||
TemplatePostInstallAction,
|
||||
TemplateUI,
|
||||
} from "../helpers";
|
||||
import { createTestDir, runCreateLlama, type AppType } from "./utils";
|
||||
|
||||
const templateFramework: TemplateFramework = "fastapi";
|
||||
const dataSource: string = "--example-file";
|
||||
const templateUI: TemplateUI = "shadcn";
|
||||
const templatePostInstallAction: TemplatePostInstallAction = "runApp";
|
||||
const appType: AppType = "--frontend";
|
||||
const userMessage = "Write a blog post about physical standards for letters";
|
||||
|
||||
test.describe(`Test multiagent template ${templateFramework} ${dataSource} ${templateUI} ${appType} ${templatePostInstallAction}`, async () => {
|
||||
test.skip(
|
||||
process.platform !== "linux" ||
|
||||
process.env.FRAMEWORK !== "fastapi" ||
|
||||
process.env.DATASOURCE === "--no-files",
|
||||
"The multiagent template currently only works with FastAPI and files. We also only run on Linux to speed up tests.",
|
||||
);
|
||||
let port: number;
|
||||
let externalPort: number;
|
||||
let cwd: string;
|
||||
let name: string;
|
||||
let appProcess: ChildProcess;
|
||||
// Only test without using vector db for now
|
||||
const vectorDb = "none";
|
||||
|
||||
test.beforeAll(async () => {
|
||||
port = Math.floor(Math.random() * 10000) + 10000;
|
||||
externalPort = port + 1;
|
||||
cwd = await createTestDir();
|
||||
const result = await runCreateLlama(
|
||||
cwd,
|
||||
"multiagent",
|
||||
templateFramework,
|
||||
dataSource,
|
||||
vectorDb,
|
||||
port,
|
||||
externalPort,
|
||||
templatePostInstallAction,
|
||||
templateUI,
|
||||
appType,
|
||||
);
|
||||
name = result.projectName;
|
||||
appProcess = result.appProcess;
|
||||
});
|
||||
|
||||
test("App folder should exist", async () => {
|
||||
const dirExists = fs.existsSync(path.join(cwd, name));
|
||||
expect(dirExists).toBeTruthy();
|
||||
});
|
||||
|
||||
test("Frontend should have a title", async ({ page }) => {
|
||||
await page.goto(`http://localhost:${port}`);
|
||||
await expect(page.getByText("Built by LlamaIndex")).toBeVisible();
|
||||
});
|
||||
|
||||
test("Frontend should be able to submit a message and receive the start of a streamed response", async ({
|
||||
page,
|
||||
}) => {
|
||||
await page.goto(`http://localhost:${port}`);
|
||||
await page.fill("form input", userMessage);
|
||||
|
||||
const responsePromise = page.waitForResponse((res) =>
|
||||
res.url().includes("/api/chat"),
|
||||
);
|
||||
|
||||
await page.click("form button[type=submit]");
|
||||
|
||||
const response = await responsePromise;
|
||||
expect(response.ok()).toBeTruthy();
|
||||
});
|
||||
|
||||
// clean processes
|
||||
test.afterAll(async () => {
|
||||
appProcess?.kill();
|
||||
});
|
||||
});
|
||||
@@ -1,139 +0,0 @@
|
||||
import { expect, test } from "@playwright/test";
|
||||
import { exec } from "child_process";
|
||||
import fs from "fs";
|
||||
import path from "path";
|
||||
import util from "util";
|
||||
import { TemplateFramework, TemplateVectorDB } from "../helpers/types";
|
||||
import { createTestDir, runCreateLlama } from "./utils";
|
||||
|
||||
const execAsync = util.promisify(exec);
|
||||
|
||||
const templateFramework: TemplateFramework = process.env.FRAMEWORK
|
||||
? (process.env.FRAMEWORK as TemplateFramework)
|
||||
: "fastapi";
|
||||
const dataSource: string = process.env.DATASOURCE
|
||||
? process.env.DATASOURCE
|
||||
: "--example-file";
|
||||
|
||||
if (
|
||||
templateFramework == "fastapi" && // test is only relevant for fastapi
|
||||
process.version.startsWith("v20.") && // XXX: Only run for Node.js version 20 (CI matrix will trigger other versions)
|
||||
dataSource === "--example-file" // XXX: this test provides its own data source - only trigger it on one data source (usually the CI matrix will trigger multiple data sources)
|
||||
) {
|
||||
// vectorDBs, tools, and data source combinations to test
|
||||
const vectorDbs: TemplateVectorDB[] = [
|
||||
"mongo",
|
||||
"pg",
|
||||
"pinecone",
|
||||
"milvus",
|
||||
"astra",
|
||||
"qdrant",
|
||||
"chroma",
|
||||
"weaviate",
|
||||
];
|
||||
|
||||
const toolOptions = [
|
||||
"wikipedia.WikipediaToolSpec",
|
||||
"google.GoogleSearchToolSpec",
|
||||
];
|
||||
|
||||
const dataSources = [
|
||||
"--example-file",
|
||||
"--web-source https://www.example.com",
|
||||
"--db-source mysql+pymysql://user:pass@localhost:3306/mydb",
|
||||
];
|
||||
|
||||
test.describe("Test resolve python dependencies", () => {
|
||||
for (const vectorDb of vectorDbs) {
|
||||
for (const tool of toolOptions) {
|
||||
for (const dataSource of dataSources) {
|
||||
const dataSourceType = dataSource.split(" ")[0];
|
||||
const optionDescription = `vectorDb: ${vectorDb}, tools: ${tool}, dataSource: ${dataSourceType}`;
|
||||
|
||||
test(`options: ${optionDescription}`, async () => {
|
||||
const cwd = await createTestDir();
|
||||
|
||||
const result = await runCreateLlama(
|
||||
cwd,
|
||||
"streaming",
|
||||
"fastapi",
|
||||
dataSource,
|
||||
vectorDb,
|
||||
3000, // port
|
||||
8000, // externalPort
|
||||
"none", // postInstallAction
|
||||
undefined, // ui
|
||||
"--no-frontend", // appType
|
||||
undefined, // llamaCloudProjectName
|
||||
undefined, // llamaCloudIndexName
|
||||
tool,
|
||||
);
|
||||
const name = result.projectName;
|
||||
|
||||
// Check if the app folder exists
|
||||
const dirExists = fs.existsSync(path.join(cwd, name));
|
||||
expect(dirExists).toBeTruthy();
|
||||
|
||||
// Check if pyproject.toml exists
|
||||
const pyprojectPath = path.join(cwd, name, "pyproject.toml");
|
||||
const pyprojectExists = fs.existsSync(pyprojectPath);
|
||||
expect(pyprojectExists).toBeTruthy();
|
||||
|
||||
// Run poetry lock
|
||||
try {
|
||||
const { stdout, stderr } = await execAsync(
|
||||
"poetry config virtualenvs.in-project true && poetry lock --no-update",
|
||||
{
|
||||
cwd: path.join(cwd, name),
|
||||
},
|
||||
);
|
||||
console.log("poetry lock stdout:", stdout);
|
||||
console.error("poetry lock stderr:", stderr);
|
||||
} catch (error) {
|
||||
console.error("Error running poetry lock:", error);
|
||||
throw error;
|
||||
}
|
||||
|
||||
// Check if poetry.lock file was created
|
||||
const poetryLockExists = fs.existsSync(
|
||||
path.join(cwd, name, "poetry.lock"),
|
||||
);
|
||||
expect(poetryLockExists).toBeTruthy();
|
||||
|
||||
// Verify that specific dependencies are in pyproject.toml
|
||||
const pyprojectContent = fs.readFileSync(pyprojectPath, "utf-8");
|
||||
if (vectorDb !== "none") {
|
||||
if (vectorDb === "pg") {
|
||||
expect(pyprojectContent).toContain(
|
||||
"llama-index-vector-stores-postgres",
|
||||
);
|
||||
} else {
|
||||
expect(pyprojectContent).toContain(
|
||||
`llama-index-vector-stores-${vectorDb}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
if (tool !== "none") {
|
||||
if (tool === "wikipedia.WikipediaToolSpec") {
|
||||
expect(pyprojectContent).toContain("wikipedia");
|
||||
}
|
||||
if (tool === "google.GoogleSearchToolSpec") {
|
||||
expect(pyprojectContent).toContain("google");
|
||||
}
|
||||
}
|
||||
|
||||
// Check for data source specific dependencies
|
||||
if (dataSource.includes("--web-source")) {
|
||||
expect(pyprojectContent).toContain("llama-index-readers-web");
|
||||
}
|
||||
if (dataSource.includes("--db-source")) {
|
||||
expect(pyprojectContent).toContain(
|
||||
"llama-index-readers-database ",
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -1,97 +0,0 @@
|
||||
import { expect, test } from "@playwright/test";
|
||||
import { exec } from "child_process";
|
||||
import fs from "fs";
|
||||
import path from "path";
|
||||
import util from "util";
|
||||
import { TemplateFramework, TemplateVectorDB } from "../helpers/types";
|
||||
import { createTestDir, runCreateLlama } from "./utils";
|
||||
|
||||
const execAsync = util.promisify(exec);
|
||||
|
||||
const templateFramework: TemplateFramework = process.env.FRAMEWORK
|
||||
? (process.env.FRAMEWORK as TemplateFramework)
|
||||
: "nextjs";
|
||||
const dataSource: string = process.env.DATASOURCE
|
||||
? process.env.DATASOURCE
|
||||
: "--example-file";
|
||||
|
||||
if (
|
||||
templateFramework == "nextjs" ||
|
||||
templateFramework == "express" // test is only relevant for TS projects
|
||||
) {
|
||||
// vectorDBs combinations to test
|
||||
const vectorDbs: TemplateVectorDB[] = [
|
||||
"mongo",
|
||||
"pg",
|
||||
"qdrant",
|
||||
"pinecone",
|
||||
"milvus",
|
||||
"astra",
|
||||
"chroma",
|
||||
"llamacloud",
|
||||
"weaviate",
|
||||
];
|
||||
|
||||
test.describe("Test resolve TS dependencies", () => {
|
||||
for (const vectorDb of vectorDbs) {
|
||||
const optionDescription = `vectorDb: ${vectorDb}, dataSource: ${dataSource}`;
|
||||
|
||||
test(`options: ${optionDescription}`, async () => {
|
||||
const cwd = await createTestDir();
|
||||
|
||||
const result = await runCreateLlama(
|
||||
cwd,
|
||||
"streaming",
|
||||
templateFramework,
|
||||
dataSource,
|
||||
vectorDb,
|
||||
3000, // port
|
||||
8000, // externalPort
|
||||
"none", // postInstallAction
|
||||
undefined, // ui
|
||||
templateFramework === "nextjs" ? "" : "--no-frontend", // appType
|
||||
undefined, // llamaCloudProjectName
|
||||
undefined, // llamaCloudIndexName
|
||||
);
|
||||
const name = result.projectName;
|
||||
|
||||
// Check if the app folder exists
|
||||
const appDir = path.join(cwd, name);
|
||||
const dirExists = fs.existsSync(appDir);
|
||||
expect(dirExists).toBeTruthy();
|
||||
|
||||
// Install dependencies using pnpm
|
||||
try {
|
||||
const { stderr: installStderr } = await execAsync(
|
||||
"pnpm install --prefer-offline",
|
||||
{
|
||||
cwd: appDir,
|
||||
},
|
||||
);
|
||||
expect(installStderr).toBeFalsy();
|
||||
} catch (error) {
|
||||
console.error("Error installing dependencies:", error);
|
||||
throw error;
|
||||
}
|
||||
|
||||
// Run tsc type check and capture the output
|
||||
try {
|
||||
const { stdout, stderr } = await execAsync(
|
||||
"pnpm exec tsc -b --diagnostics",
|
||||
{
|
||||
cwd: appDir,
|
||||
},
|
||||
);
|
||||
// Check if there's any error output
|
||||
expect(stderr).toBeFalsy();
|
||||
|
||||
// Log the stdout for debugging purposes
|
||||
console.log("TypeScript type-check output:", stdout);
|
||||
} catch (error) {
|
||||
console.error("Error running tsc:", error);
|
||||
throw error;
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -1,119 +0,0 @@
|
||||
/* eslint-disable turbo/no-undeclared-env-vars */
|
||||
import { expect, test } from "@playwright/test";
|
||||
import { ChildProcess } from "child_process";
|
||||
import fs from "fs";
|
||||
import path from "path";
|
||||
import type {
|
||||
TemplateFramework,
|
||||
TemplatePostInstallAction,
|
||||
TemplateUI,
|
||||
} from "../helpers";
|
||||
import { createTestDir, runCreateLlama, type AppType } from "./utils";
|
||||
|
||||
const templateFramework: TemplateFramework = process.env.FRAMEWORK
|
||||
? (process.env.FRAMEWORK as TemplateFramework)
|
||||
: "fastapi";
|
||||
const dataSource: string = process.env.DATASOURCE
|
||||
? process.env.DATASOURCE
|
||||
: "--example-file";
|
||||
const templateUI: TemplateUI = "shadcn";
|
||||
const templatePostInstallAction: TemplatePostInstallAction = "runApp";
|
||||
|
||||
const llamaCloudProjectName = "create-llama";
|
||||
const llamaCloudIndexName = "e2e-test";
|
||||
|
||||
const appType: AppType = templateFramework === "nextjs" ? "" : "--frontend";
|
||||
const userMessage =
|
||||
dataSource !== "--no-files" ? "Physical standard for letters" : "Hello";
|
||||
|
||||
test.describe(`Test streaming template ${templateFramework} ${dataSource} ${templateUI} ${appType} ${templatePostInstallAction}`, async () => {
|
||||
let port: number;
|
||||
let externalPort: number;
|
||||
let cwd: string;
|
||||
let name: string;
|
||||
let appProcess: ChildProcess;
|
||||
// Only test without using vector db for now
|
||||
const vectorDb = "none";
|
||||
|
||||
test.beforeAll(async () => {
|
||||
port = Math.floor(Math.random() * 10000) + 10000;
|
||||
externalPort = port + 1;
|
||||
cwd = await createTestDir();
|
||||
const result = await runCreateLlama(
|
||||
cwd,
|
||||
"streaming",
|
||||
templateFramework,
|
||||
dataSource,
|
||||
vectorDb,
|
||||
port,
|
||||
externalPort,
|
||||
templatePostInstallAction,
|
||||
templateUI,
|
||||
appType,
|
||||
llamaCloudProjectName,
|
||||
llamaCloudIndexName,
|
||||
);
|
||||
name = result.projectName;
|
||||
appProcess = result.appProcess;
|
||||
});
|
||||
|
||||
test("App folder should exist", async () => {
|
||||
const dirExists = fs.existsSync(path.join(cwd, name));
|
||||
expect(dirExists).toBeTruthy();
|
||||
});
|
||||
test("Frontend should have a title", async ({ page }) => {
|
||||
test.skip(templatePostInstallAction !== "runApp");
|
||||
await page.goto(`http://localhost:${port}`);
|
||||
await expect(page.getByText("Built by LlamaIndex")).toBeVisible();
|
||||
});
|
||||
|
||||
test("Frontend should be able to submit a message and receive a response", async ({
|
||||
page,
|
||||
}) => {
|
||||
test.skip(templatePostInstallAction !== "runApp");
|
||||
await page.goto(`http://localhost:${port}`);
|
||||
await page.fill("form input", userMessage);
|
||||
const [response] = await Promise.all([
|
||||
page.waitForResponse(
|
||||
(res) => {
|
||||
return res.url().includes("/api/chat") && res.status() === 200;
|
||||
},
|
||||
{
|
||||
timeout: 1000 * 60,
|
||||
},
|
||||
),
|
||||
page.click("form button[type=submit]"),
|
||||
]);
|
||||
const text = await response.text();
|
||||
console.log("AI response when submitting message: ", text);
|
||||
expect(response.ok()).toBeTruthy();
|
||||
});
|
||||
|
||||
test("Backend frameworks should response when calling non-streaming chat API", async ({
|
||||
request,
|
||||
}) => {
|
||||
test.skip(templatePostInstallAction !== "runApp");
|
||||
test.skip(templateFramework === "nextjs");
|
||||
const response = await request.post(
|
||||
`http://localhost:${externalPort}/api/chat/request`,
|
||||
{
|
||||
data: {
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: userMessage,
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
);
|
||||
const text = await response.text();
|
||||
console.log("AI response when calling API: ", text);
|
||||
expect(response.ok()).toBeTruthy();
|
||||
});
|
||||
|
||||
// clean processes
|
||||
test.afterAll(async () => {
|
||||
appProcess?.kill();
|
||||
});
|
||||
});
|
||||
+73
-96
@@ -18,59 +18,86 @@ export type CreateLlamaResult = {
|
||||
appProcess: ChildProcess;
|
||||
};
|
||||
|
||||
// eslint-disable-next-line max-params
|
||||
export async function checkAppHasStarted(
|
||||
frontend: boolean,
|
||||
framework: TemplateFramework,
|
||||
port: number,
|
||||
externalPort: number,
|
||||
timeout: number,
|
||||
) {
|
||||
if (frontend) {
|
||||
await Promise.all([
|
||||
waitPort({
|
||||
host: "localhost",
|
||||
port: port,
|
||||
timeout,
|
||||
}),
|
||||
waitPort({
|
||||
host: "localhost",
|
||||
port: externalPort,
|
||||
timeout,
|
||||
}),
|
||||
]).catch((err) => {
|
||||
console.error(err);
|
||||
throw err;
|
||||
});
|
||||
} else {
|
||||
let wPort: number;
|
||||
if (framework === "nextjs") {
|
||||
wPort = port;
|
||||
} else {
|
||||
wPort = externalPort;
|
||||
}
|
||||
await waitPort({
|
||||
host: "localhost",
|
||||
port: wPort,
|
||||
timeout,
|
||||
}).catch((err) => {
|
||||
console.error(err);
|
||||
throw err;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// eslint-disable-next-line max-params
|
||||
export async function runCreateLlama(
|
||||
cwd: string,
|
||||
templateType: TemplateType,
|
||||
templateFramework: TemplateFramework,
|
||||
dataSource: string,
|
||||
templateUI: TemplateUI,
|
||||
vectorDb: TemplateVectorDB,
|
||||
appType: AppType,
|
||||
port: number,
|
||||
externalPort: number,
|
||||
postInstallAction: TemplatePostInstallAction,
|
||||
templateUI?: TemplateUI,
|
||||
appType?: AppType,
|
||||
llamaCloudProjectName?: string,
|
||||
llamaCloudIndexName?: string,
|
||||
tools?: string,
|
||||
): Promise<CreateLlamaResult> {
|
||||
if (!process.env.OPENAI_API_KEY || !process.env.LLAMA_CLOUD_API_KEY) {
|
||||
throw new Error(
|
||||
"Setting the OPENAI_API_KEY and LLAMA_CLOUD_API_KEY is mandatory to run tests",
|
||||
);
|
||||
if (!process.env.OPENAI_API_KEY) {
|
||||
throw new Error("Setting OPENAI_API_KEY is mandatory to run tests");
|
||||
}
|
||||
const name = [
|
||||
templateType,
|
||||
templateFramework,
|
||||
dataSource.split(" ")[0],
|
||||
dataSource,
|
||||
templateUI,
|
||||
appType,
|
||||
].join("-");
|
||||
|
||||
// Handle different data source types
|
||||
let dataSourceArgs = [];
|
||||
if (dataSource.includes("--web-source" || "--db-source")) {
|
||||
const webSource = dataSource.split(" ")[1];
|
||||
dataSourceArgs.push("--web-source", webSource);
|
||||
} else if (dataSource.includes("--db-source")) {
|
||||
const dbSource = dataSource.split(" ")[1];
|
||||
dataSourceArgs.push("--db-source", dbSource);
|
||||
} else {
|
||||
dataSourceArgs.push(dataSource);
|
||||
}
|
||||
|
||||
const commandArgs = [
|
||||
const command = [
|
||||
"create-llama",
|
||||
name,
|
||||
"--template",
|
||||
templateType,
|
||||
"--framework",
|
||||
templateFramework,
|
||||
...dataSourceArgs,
|
||||
dataSource,
|
||||
"--ui",
|
||||
templateUI,
|
||||
"--vector-db",
|
||||
vectorDb,
|
||||
"--open-ai-key",
|
||||
process.env.OPENAI_API_KEY,
|
||||
appType,
|
||||
"--use-pnpm",
|
||||
"--port",
|
||||
port,
|
||||
@@ -79,37 +106,24 @@ export async function runCreateLlama(
|
||||
"--post-install-action",
|
||||
postInstallAction,
|
||||
"--tools",
|
||||
tools ?? "none",
|
||||
"none",
|
||||
"--no-llama-parse",
|
||||
"--observability",
|
||||
"none",
|
||||
"--llama-cloud-key",
|
||||
process.env.LLAMA_CLOUD_API_KEY,
|
||||
];
|
||||
|
||||
if (templateUI) {
|
||||
commandArgs.push("--ui", templateUI);
|
||||
}
|
||||
if (appType) {
|
||||
commandArgs.push(appType);
|
||||
}
|
||||
|
||||
const command = commandArgs.join(" ");
|
||||
].join(" ");
|
||||
console.log(`running command '${command}' in ${cwd}`);
|
||||
const appProcess = exec(command, {
|
||||
cwd,
|
||||
env: {
|
||||
...process.env,
|
||||
LLAMA_CLOUD_PROJECT_NAME: llamaCloudProjectName,
|
||||
LLAMA_CLOUD_INDEX_NAME: llamaCloudIndexName,
|
||||
},
|
||||
});
|
||||
appProcess.stderr?.on("data", (data) => {
|
||||
console.error(data.toString());
|
||||
console.log(data.toString());
|
||||
});
|
||||
appProcess.on("exit", (code) => {
|
||||
if (code !== 0 && code !== null) {
|
||||
throw new Error(`create-llama command failed with exit code ${code}`);
|
||||
throw new Error(`create-llama command was failed!`);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -120,12 +134,25 @@ export async function runCreateLlama(
|
||||
templateFramework,
|
||||
port,
|
||||
externalPort,
|
||||
1000 * 60 * 5,
|
||||
);
|
||||
} else if (postInstallAction === "dependencies") {
|
||||
await waitForProcess(appProcess, 1000 * 60); // wait 1 min for dependencies to be resolved
|
||||
} else {
|
||||
// wait 10 seconds for create-llama to exit
|
||||
await waitForProcess(appProcess, 1000 * 10);
|
||||
// wait create-llama to exit
|
||||
// we don't test install dependencies for now, so just set timeout for 10 seconds
|
||||
await new Promise((resolve, reject) => {
|
||||
const timeout = setTimeout(() => {
|
||||
reject(new Error("create-llama timeout error"));
|
||||
}, 1000 * 10);
|
||||
appProcess.on("exit", (code) => {
|
||||
if (code !== 0 && code !== null) {
|
||||
clearTimeout(timeout);
|
||||
reject(new Error("create-llama command was failed!"));
|
||||
} else {
|
||||
clearTimeout(timeout);
|
||||
resolve(undefined);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
@@ -139,53 +166,3 @@ export async function createTestDir() {
|
||||
await mkdir(cwd, { recursive: true });
|
||||
return cwd;
|
||||
}
|
||||
|
||||
// eslint-disable-next-line max-params
|
||||
async function checkAppHasStarted(
|
||||
frontend: boolean,
|
||||
framework: TemplateFramework,
|
||||
port: number,
|
||||
externalPort: number,
|
||||
) {
|
||||
const portsToWait = frontend
|
||||
? [port, externalPort]
|
||||
: [framework === "nextjs" ? port : externalPort];
|
||||
await waitPorts(portsToWait);
|
||||
}
|
||||
|
||||
async function waitPorts(ports: number[]): Promise<void> {
|
||||
const waitForPort = async (port: number): Promise<void> => {
|
||||
await waitPort({
|
||||
host: "localhost",
|
||||
port: port,
|
||||
// wait max. 5 mins for start up of app
|
||||
timeout: 1000 * 60 * 5,
|
||||
});
|
||||
};
|
||||
try {
|
||||
await Promise.all(ports.map(waitForPort));
|
||||
} catch (err) {
|
||||
console.error(err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
async function waitForProcess(
|
||||
process: ChildProcess,
|
||||
timeoutMs: number,
|
||||
): Promise<void> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const timeout = setTimeout(() => {
|
||||
reject(new Error("Process timeout error"));
|
||||
}, timeoutMs);
|
||||
|
||||
process.on("exit", (code) => {
|
||||
clearTimeout(timeout);
|
||||
if (code !== 0 && code !== null) {
|
||||
reject(new Error("Process exited with non-zero code"));
|
||||
} else {
|
||||
resolve();
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
+60
-52
@@ -36,66 +36,74 @@ export async function writeLoadersConfig(
|
||||
dataSources: TemplateDataSource[],
|
||||
useLlamaParse?: boolean,
|
||||
) {
|
||||
const loaderConfig: Record<string, any> = {};
|
||||
|
||||
// Always set file loader config
|
||||
loaderConfig.file = createFileLoaderConfig(useLlamaParse);
|
||||
|
||||
if (dataSources.length === 0) return; // no datasources, no config needed
|
||||
const loaderConfig = new Document({});
|
||||
// Web loader config
|
||||
if (dataSources.some((ds) => ds.type === "web")) {
|
||||
loaderConfig.web = createWebLoaderConfig(dataSources);
|
||||
const webLoaderConfig = new Document({});
|
||||
|
||||
// Create config for browser driver arguments
|
||||
const driverArgNodeValue = webLoaderConfig.createNode([
|
||||
"--no-sandbox",
|
||||
"--disable-dev-shm-usage",
|
||||
]);
|
||||
driverArgNodeValue.commentBefore =
|
||||
" The arguments to pass to the webdriver. E.g.: add --headless to run in headless mode";
|
||||
webLoaderConfig.set("driver_arguments", driverArgNodeValue);
|
||||
|
||||
// Create config for urls
|
||||
const urlConfigs = dataSources
|
||||
.filter((ds) => ds.type === "web")
|
||||
.map((ds) => {
|
||||
const dsConfig = ds.config as WebSourceConfig;
|
||||
return {
|
||||
base_url: dsConfig.baseUrl,
|
||||
prefix: dsConfig.prefix,
|
||||
depth: dsConfig.depth,
|
||||
};
|
||||
});
|
||||
const urlConfigNode = webLoaderConfig.createNode(urlConfigs);
|
||||
urlConfigNode.commentBefore = ` base_url: The URL to start crawling with
|
||||
prefix: Only crawl URLs matching the specified prefix
|
||||
depth: The maximum depth for BFS traversal
|
||||
You can add more websites by adding more entries (don't forget the - prefix from YAML)`;
|
||||
webLoaderConfig.set("urls", urlConfigNode);
|
||||
|
||||
// Add web config to the loaders config
|
||||
loaderConfig.set("web", webLoaderConfig);
|
||||
}
|
||||
|
||||
// File loader config
|
||||
if (dataSources.some((ds) => ds.type === "file")) {
|
||||
// Add documentation to web loader config
|
||||
const node = loaderConfig.createNode({
|
||||
use_llama_parse: useLlamaParse,
|
||||
});
|
||||
node.commentBefore = ` use_llama_parse: Use LlamaParse if \`true\`. Needs a \`LLAMA_CLOUD_API_KEY\` from https://cloud.llamaindex.ai set as environment variable`;
|
||||
loaderConfig.set("file", node);
|
||||
}
|
||||
|
||||
// DB loader config
|
||||
const dbLoaders = dataSources.filter((ds) => ds.type === "db");
|
||||
if (dbLoaders.length > 0) {
|
||||
loaderConfig.db = createDbLoaderConfig(dbLoaders);
|
||||
}
|
||||
const dbLoaderConfig = new Document({});
|
||||
const configEntries = dbLoaders.map((ds) => {
|
||||
const dsConfig = ds.config as DbSourceConfig;
|
||||
return {
|
||||
uri: dsConfig.uri,
|
||||
queries: [dsConfig.queries],
|
||||
};
|
||||
});
|
||||
|
||||
// Create a new Document with the loaderConfig
|
||||
const yamlDoc = new Document(loaderConfig);
|
||||
const node = dbLoaderConfig.createNode(configEntries);
|
||||
node.commentBefore = ` The configuration for the database loader, only supports MySQL and PostgreSQL databases for now.
|
||||
uri: The URI for the database. E.g.: mysql+pymysql://user:password@localhost:3306/db or postgresql+psycopg2://user:password@localhost:5432/db
|
||||
query: The query to fetch data from the database. E.g.: SELECT * FROM table`;
|
||||
loaderConfig.set("db", node);
|
||||
}
|
||||
|
||||
// Write loaders config
|
||||
const loaderConfigPath = path.join(root, "config", "loaders.yaml");
|
||||
await fs.mkdir(path.join(root, "config"), { recursive: true });
|
||||
await fs.writeFile(loaderConfigPath, yaml.stringify(yamlDoc));
|
||||
}
|
||||
|
||||
function createWebLoaderConfig(dataSources: TemplateDataSource[]): any {
|
||||
const webLoaderConfig: Record<string, any> = {};
|
||||
|
||||
// Create config for browser driver arguments
|
||||
webLoaderConfig.driver_arguments = [
|
||||
"--no-sandbox",
|
||||
"--disable-dev-shm-usage",
|
||||
];
|
||||
|
||||
// Create config for urls
|
||||
const urlConfigs = dataSources
|
||||
.filter((ds) => ds.type === "web")
|
||||
.map((ds) => {
|
||||
const dsConfig = ds.config as WebSourceConfig;
|
||||
return {
|
||||
base_url: dsConfig.baseUrl,
|
||||
prefix: dsConfig.prefix,
|
||||
depth: dsConfig.depth,
|
||||
};
|
||||
});
|
||||
webLoaderConfig.urls = urlConfigs;
|
||||
|
||||
return webLoaderConfig;
|
||||
}
|
||||
|
||||
function createFileLoaderConfig(useLlamaParse?: boolean): any {
|
||||
return {
|
||||
use_llama_parse: useLlamaParse,
|
||||
};
|
||||
}
|
||||
|
||||
function createDbLoaderConfig(dbLoaders: TemplateDataSource[]): any {
|
||||
return dbLoaders.map((ds) => {
|
||||
const dsConfig = ds.config as DbSourceConfig;
|
||||
return {
|
||||
uri: dsConfig.uri,
|
||||
queries: [dsConfig.queries],
|
||||
};
|
||||
});
|
||||
await fs.writeFile(loaderConfigPath, yaml.stringify(loaderConfig));
|
||||
}
|
||||
|
||||
+45
-364
@@ -1,19 +1,13 @@
|
||||
import fs from "fs/promises";
|
||||
import path from "path";
|
||||
import { TOOL_SYSTEM_PROMPT_ENV_VAR, Tool } from "./tools";
|
||||
import {
|
||||
InstallTemplateArgs,
|
||||
ModelConfig,
|
||||
TemplateDataSource,
|
||||
TemplateFramework,
|
||||
TemplateObservability,
|
||||
TemplateType,
|
||||
TemplateVectorDB,
|
||||
} from "./types";
|
||||
|
||||
import { TSYSTEMS_LLMHUB_API_URL } from "./providers/llmhub";
|
||||
|
||||
export type EnvVar = {
|
||||
type EnvVar = {
|
||||
name?: string;
|
||||
description?: string;
|
||||
value?: string;
|
||||
@@ -35,20 +29,17 @@ const renderEnvVar = (envVars: EnvVar[]): string => {
|
||||
);
|
||||
};
|
||||
|
||||
const getVectorDBEnvs = (
|
||||
vectorDb?: TemplateVectorDB,
|
||||
framework?: TemplateFramework,
|
||||
): EnvVar[] => {
|
||||
if (!vectorDb || !framework) {
|
||||
const getVectorDBEnvs = (vectorDb?: TemplateVectorDB): EnvVar[] => {
|
||||
if (!vectorDb) {
|
||||
return [];
|
||||
}
|
||||
switch (vectorDb) {
|
||||
case "mongo":
|
||||
return [
|
||||
{
|
||||
name: "MONGODB_URI",
|
||||
name: "MONGO_URI",
|
||||
description:
|
||||
"For generating a connection URI, see https://www.mongodb.com/docs/manual/reference/connection-string/ \nThe MongoDB connection URI.",
|
||||
"For generating a connection URI, see https://docs.timescale.com/use-timescale/latest/services/create-a-service\nThe MongoDB connection URI.",
|
||||
},
|
||||
{
|
||||
name: "MONGODB_DATABASE",
|
||||
@@ -138,84 +129,6 @@ const getVectorDBEnvs = (
|
||||
"Optional API key for authenticating requests to Qdrant.",
|
||||
},
|
||||
];
|
||||
case "llamacloud":
|
||||
return [
|
||||
{
|
||||
name: "LLAMA_CLOUD_INDEX_NAME",
|
||||
description:
|
||||
"The name of the LlamaCloud index to use (part of the LlamaCloud project).",
|
||||
value: "test",
|
||||
},
|
||||
{
|
||||
name: "LLAMA_CLOUD_PROJECT_NAME",
|
||||
description: "The name of the LlamaCloud project.",
|
||||
value: "Default",
|
||||
},
|
||||
{
|
||||
name: "LLAMA_CLOUD_BASE_URL",
|
||||
description:
|
||||
"The base URL for the LlamaCloud API. Only change this for non-production environments",
|
||||
value: "https://api.cloud.llamaindex.ai",
|
||||
},
|
||||
{
|
||||
name: "LLAMA_CLOUD_ORGANIZATION_ID",
|
||||
description:
|
||||
"The organization ID for the LlamaCloud project (uses default organization if not specified)",
|
||||
},
|
||||
...(framework === "nextjs"
|
||||
? // activate index selector per default (not needed for non-NextJS backends as it's handled by createFrontendEnvFile)
|
||||
[
|
||||
{
|
||||
name: "NEXT_PUBLIC_USE_LLAMACLOUD",
|
||||
description:
|
||||
"Let's the user change indexes in LlamaCloud projects",
|
||||
value: "true",
|
||||
},
|
||||
]
|
||||
: []),
|
||||
];
|
||||
case "chroma":
|
||||
const envs = [
|
||||
{
|
||||
name: "CHROMA_COLLECTION",
|
||||
description: "The name of the collection in your Chroma database",
|
||||
},
|
||||
{
|
||||
name: "CHROMA_HOST",
|
||||
description: "The API endpoint for your Chroma database",
|
||||
},
|
||||
{
|
||||
name: "CHROMA_PORT",
|
||||
description: "The port for your Chroma database",
|
||||
},
|
||||
];
|
||||
// TS Version doesn't support config local storage path
|
||||
if (framework === "fastapi") {
|
||||
envs.push({
|
||||
name: "CHROMA_PATH",
|
||||
description: `The local path to the Chroma database.
|
||||
Specify this if you are using a local Chroma database.
|
||||
Otherwise, use CHROMA_HOST and CHROMA_PORT config above`,
|
||||
});
|
||||
}
|
||||
return envs;
|
||||
case "weaviate":
|
||||
return [
|
||||
{
|
||||
name: "WEAVIATE_CLUSTER_URL",
|
||||
description:
|
||||
"The URL of the Weaviate cloud cluster, see: https://weaviate.io/developers/wcs/connect",
|
||||
},
|
||||
{
|
||||
name: "WEAVIATE_API_KEY",
|
||||
description: "The API key for the Weaviate cloud cluster",
|
||||
},
|
||||
{
|
||||
name: "WEAVIATE_INDEX_NAME",
|
||||
description:
|
||||
"(Optional) The collection name to use, default is LlamaIndex if not specified",
|
||||
},
|
||||
];
|
||||
default:
|
||||
return [];
|
||||
}
|
||||
@@ -243,10 +156,6 @@ const getModelEnvs = (modelConfig: ModelConfig): EnvVar[] => {
|
||||
description: "Dimension of the embedding model to use.",
|
||||
value: modelConfig.dimensions.toString(),
|
||||
},
|
||||
{
|
||||
name: "CONVERSATION_STARTERS",
|
||||
description: "The questions to help users get started (multi-line).",
|
||||
},
|
||||
...(modelConfig.provider === "openai"
|
||||
? [
|
||||
{
|
||||
@@ -264,130 +173,41 @@ const getModelEnvs = (modelConfig: ModelConfig): EnvVar[] => {
|
||||
},
|
||||
]
|
||||
: []),
|
||||
...(modelConfig.provider === "anthropic"
|
||||
? [
|
||||
{
|
||||
name: "ANTHROPIC_API_KEY",
|
||||
description: "The Anthropic API key to use.",
|
||||
value: modelConfig.apiKey,
|
||||
},
|
||||
]
|
||||
: []),
|
||||
...(modelConfig.provider === "groq"
|
||||
? [
|
||||
{
|
||||
name: "GROQ_API_KEY",
|
||||
description: "The Groq API key to use.",
|
||||
value: modelConfig.apiKey,
|
||||
},
|
||||
]
|
||||
: []),
|
||||
...(modelConfig.provider === "gemini"
|
||||
? [
|
||||
{
|
||||
name: "GOOGLE_API_KEY",
|
||||
description: "The Google API key to use.",
|
||||
value: modelConfig.apiKey,
|
||||
},
|
||||
]
|
||||
: []),
|
||||
...(modelConfig.provider === "ollama"
|
||||
? [
|
||||
{
|
||||
name: "OLLAMA_BASE_URL",
|
||||
description:
|
||||
"The base URL for the Ollama API. Eg: http://127.0.0.1:11434",
|
||||
},
|
||||
]
|
||||
: []),
|
||||
...(modelConfig.provider === "mistral"
|
||||
? [
|
||||
{
|
||||
name: "MISTRAL_API_KEY",
|
||||
description: "The Mistral API key to use.",
|
||||
value: modelConfig.apiKey,
|
||||
},
|
||||
]
|
||||
: []),
|
||||
...(modelConfig.provider === "azure-openai"
|
||||
? [
|
||||
{
|
||||
name: "AZURE_OPENAI_API_KEY",
|
||||
description: "The Azure OpenAI key to use.",
|
||||
value: modelConfig.apiKey,
|
||||
},
|
||||
{
|
||||
name: "AZURE_OPENAI_ENDPOINT",
|
||||
description: "The Azure OpenAI endpoint to use.",
|
||||
},
|
||||
{
|
||||
name: "AZURE_OPENAI_API_VERSION",
|
||||
description: "The Azure OpenAI API version to use.",
|
||||
},
|
||||
{
|
||||
name: "AZURE_OPENAI_LLM_DEPLOYMENT",
|
||||
description:
|
||||
"The Azure OpenAI deployment to use for LLM deployment.",
|
||||
},
|
||||
{
|
||||
name: "AZURE_OPENAI_EMBEDDING_DEPLOYMENT",
|
||||
description:
|
||||
"The Azure OpenAI deployment to use for embedding deployment.",
|
||||
},
|
||||
]
|
||||
: []),
|
||||
...(modelConfig.provider === "t-systems"
|
||||
? [
|
||||
{
|
||||
name: "T_SYSTEMS_LLMHUB_BASE_URL",
|
||||
description:
|
||||
"The base URL for the T-Systems AI Foundation Model API. Eg: http://localhost:11434",
|
||||
value: TSYSTEMS_LLMHUB_API_URL,
|
||||
},
|
||||
{
|
||||
name: "T_SYSTEMS_LLMHUB_API_KEY",
|
||||
description: "API Key for T-System's AI Foundation Model.",
|
||||
value: modelConfig.apiKey,
|
||||
},
|
||||
]
|
||||
: []),
|
||||
];
|
||||
};
|
||||
|
||||
const getFrameworkEnvs = (
|
||||
framework: TemplateFramework,
|
||||
framework?: TemplateFramework,
|
||||
port?: number,
|
||||
): EnvVar[] => {
|
||||
const sPort = port?.toString() || "8000";
|
||||
const result: EnvVar[] = [
|
||||
if (framework !== "fastapi") {
|
||||
return [];
|
||||
}
|
||||
return [
|
||||
{
|
||||
name: "FILESERVER_URL_PREFIX",
|
||||
description:
|
||||
"FILESERVER_URL_PREFIX is the URL prefix of the server storing the images generated by the interpreter.",
|
||||
value:
|
||||
framework === "nextjs"
|
||||
? // FIXME: if we are using nextjs, port should be 3000
|
||||
"http://localhost:3000/api/files"
|
||||
: `http://localhost:${sPort}/api/files`,
|
||||
name: "APP_HOST",
|
||||
description: "The address to start the backend app.",
|
||||
value: "0.0.0.0",
|
||||
},
|
||||
{
|
||||
name: "APP_PORT",
|
||||
description: "The port to start the backend app.",
|
||||
value: port?.toString() || "8000",
|
||||
},
|
||||
// TODO: Once LlamaIndexTS supports string templates, move this to `getEngineEnvs`
|
||||
{
|
||||
name: "SYSTEM_PROMPT",
|
||||
description: `Custom system prompt.
|
||||
Example:
|
||||
SYSTEM_PROMPT="
|
||||
We have provided context information below.
|
||||
---------------------
|
||||
{context_str}
|
||||
---------------------
|
||||
Given this information, please answer the question: {query_str}
|
||||
"`,
|
||||
},
|
||||
];
|
||||
if (framework === "fastapi") {
|
||||
result.push(
|
||||
...[
|
||||
{
|
||||
name: "APP_HOST",
|
||||
description: "The address to start the backend app.",
|
||||
value: "0.0.0.0",
|
||||
},
|
||||
{
|
||||
name: "APP_PORT",
|
||||
description: "The port to start the backend app.",
|
||||
value: sPort,
|
||||
},
|
||||
],
|
||||
);
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
const getEngineEnvs = (): EnvVar[] => {
|
||||
@@ -396,152 +216,21 @@ const getEngineEnvs = (): EnvVar[] => {
|
||||
name: "TOP_K",
|
||||
description:
|
||||
"The number of similar embeddings to return when retrieving documents.",
|
||||
},
|
||||
{
|
||||
name: "STREAM_TIMEOUT",
|
||||
description:
|
||||
"The time in milliseconds to wait for the stream to return a response.",
|
||||
value: "60000",
|
||||
value: "3",
|
||||
},
|
||||
];
|
||||
};
|
||||
|
||||
const getToolEnvs = (tools?: Tool[]): EnvVar[] => {
|
||||
if (!tools?.length) return [];
|
||||
const toolEnvs: EnvVar[] = [];
|
||||
tools.forEach((tool) => {
|
||||
if (tool.envVars?.length) {
|
||||
toolEnvs.push(
|
||||
// Don't include the system prompt env var here
|
||||
// It should be handled separately by merging with the default system prompt
|
||||
...tool.envVars.filter(
|
||||
(env) => env.name !== TOOL_SYSTEM_PROMPT_ENV_VAR,
|
||||
),
|
||||
);
|
||||
}
|
||||
});
|
||||
return toolEnvs;
|
||||
};
|
||||
|
||||
const getSystemPromptEnv = (
|
||||
tools?: Tool[],
|
||||
dataSources?: TemplateDataSource[],
|
||||
framework?: TemplateFramework,
|
||||
): EnvVar[] => {
|
||||
const defaultSystemPrompt =
|
||||
"You are a helpful assistant who helps users with their questions.";
|
||||
|
||||
// build tool system prompt by merging all tool system prompts
|
||||
let toolSystemPrompt = "";
|
||||
tools?.forEach((tool) => {
|
||||
const toolSystemPromptEnv = tool.envVars?.find(
|
||||
(env) => env.name === TOOL_SYSTEM_PROMPT_ENV_VAR,
|
||||
);
|
||||
if (toolSystemPromptEnv) {
|
||||
toolSystemPrompt += toolSystemPromptEnv.value + "\n";
|
||||
}
|
||||
});
|
||||
|
||||
const systemPrompt = toolSystemPrompt
|
||||
? `\"${toolSystemPrompt}\"`
|
||||
: defaultSystemPrompt;
|
||||
|
||||
const systemPromptEnv = [
|
||||
{
|
||||
name: "SYSTEM_PROMPT",
|
||||
description: "The system prompt for the AI model.",
|
||||
value: systemPrompt,
|
||||
},
|
||||
];
|
||||
|
||||
if (tools?.length == 0 && (dataSources?.length ?? 0 > 0)) {
|
||||
const citationPrompt = `'You have provided information from a knowledge base that has been passed to you in nodes of information.
|
||||
Each node has useful metadata such as node ID, file name, page, etc.
|
||||
Please add the citation to the data node for each sentence or paragraph that you reference in the provided information.
|
||||
The citation format is: . [citation:<node_id>]()
|
||||
Where the <node_id> is the unique identifier of the data node.
|
||||
|
||||
Example:
|
||||
We have two nodes:
|
||||
node_id: xyz
|
||||
file_name: llama.pdf
|
||||
|
||||
node_id: abc
|
||||
file_name: animal.pdf
|
||||
|
||||
User question: Tell me a fun fact about Llama.
|
||||
Your answer:
|
||||
A baby llama is called "Cria" [citation:xyz]().
|
||||
It often live in desert [citation:abc]().
|
||||
It\\'s cute animal.
|
||||
'`;
|
||||
systemPromptEnv.push({
|
||||
name: "SYSTEM_CITATION_PROMPT",
|
||||
description:
|
||||
"An additional system prompt to add citation when responding to user questions.",
|
||||
value: citationPrompt,
|
||||
});
|
||||
}
|
||||
|
||||
return systemPromptEnv;
|
||||
};
|
||||
|
||||
const getTemplateEnvs = (template?: TemplateType): EnvVar[] => {
|
||||
const nextQuestionEnvs: EnvVar[] = [
|
||||
{
|
||||
name: "NEXT_QUESTION_PROMPT",
|
||||
description: `Customize prompt to generate the next question suggestions based on the conversation history.
|
||||
Disable this prompt to disable the next question suggestions feature.`,
|
||||
value: `"You're a helpful assistant! Your task is to suggest the next question that user might ask.
|
||||
Here is the conversation history
|
||||
---------------------
|
||||
{conversation}
|
||||
---------------------
|
||||
Given the conversation history, please give me 3 questions that you might ask next!
|
||||
Your answer should be wrapped in three sticks which follows the following format:
|
||||
\`\`\`
|
||||
<question 1>
|
||||
<question 2>
|
||||
<question 3>
|
||||
\`\`\`"`,
|
||||
},
|
||||
];
|
||||
|
||||
if (template === "multiagent" || template === "streaming") {
|
||||
return nextQuestionEnvs;
|
||||
}
|
||||
return [];
|
||||
};
|
||||
|
||||
const getObservabilityEnvs = (
|
||||
observability?: TemplateObservability,
|
||||
): EnvVar[] => {
|
||||
if (observability === "llamatrace") {
|
||||
return [
|
||||
{
|
||||
name: "PHOENIX_API_KEY",
|
||||
description:
|
||||
"API key for LlamaTrace observability. Retrieve from https://llamatrace.com/login",
|
||||
},
|
||||
];
|
||||
}
|
||||
return [];
|
||||
};
|
||||
|
||||
export const createBackendEnvFile = async (
|
||||
root: string,
|
||||
opts: Pick<
|
||||
InstallTemplateArgs,
|
||||
| "llamaCloudKey"
|
||||
| "vectorDb"
|
||||
| "modelConfig"
|
||||
| "framework"
|
||||
| "dataSources"
|
||||
| "template"
|
||||
| "externalPort"
|
||||
| "tools"
|
||||
| "observability"
|
||||
>,
|
||||
opts: {
|
||||
llamaCloudKey?: string;
|
||||
vectorDb?: TemplateVectorDB;
|
||||
modelConfig: ModelConfig;
|
||||
framework?: TemplateFramework;
|
||||
dataSources?: TemplateDataSource[];
|
||||
port?: number;
|
||||
},
|
||||
) => {
|
||||
// Init env values
|
||||
const envFileName = ".env";
|
||||
@@ -551,15 +240,13 @@ export const createBackendEnvFile = async (
|
||||
description: `The Llama Cloud API key.`,
|
||||
value: opts.llamaCloudKey,
|
||||
},
|
||||
// Add environment variables of each component
|
||||
// Add model environment variables
|
||||
...getModelEnvs(opts.modelConfig),
|
||||
// Add engine environment variables
|
||||
...getEngineEnvs(),
|
||||
...getVectorDBEnvs(opts.vectorDb, opts.framework),
|
||||
...getFrameworkEnvs(opts.framework, opts.externalPort),
|
||||
...getToolEnvs(opts.tools),
|
||||
...getTemplateEnvs(opts.template),
|
||||
...getObservabilityEnvs(opts.observability),
|
||||
...getSystemPromptEnv(opts.tools, opts.dataSources, opts.framework),
|
||||
// Add vector database environment variables
|
||||
...getVectorDBEnvs(opts.vectorDb),
|
||||
...getFrameworkEnvs(opts.framework, opts.port),
|
||||
];
|
||||
// Render and write env file
|
||||
const content = renderEnvVar(envVars);
|
||||
@@ -571,7 +258,6 @@ export const createFrontendEnvFile = async (
|
||||
root: string,
|
||||
opts: {
|
||||
customApiPath?: string;
|
||||
vectorDb?: TemplateVectorDB;
|
||||
},
|
||||
) => {
|
||||
const defaultFrontendEnvs = [
|
||||
@@ -582,11 +268,6 @@ export const createFrontendEnvFile = async (
|
||||
? opts.customApiPath
|
||||
: "http://localhost:8000/api/chat",
|
||||
},
|
||||
{
|
||||
name: "NEXT_PUBLIC_USE_LLAMACLOUD",
|
||||
description: "Let's the user change indexes in LlamaCloud projects",
|
||||
value: opts.vectorDb === "llamacloud" ? "true" : "false",
|
||||
},
|
||||
];
|
||||
const content = renderEnvVar(defaultFrontendEnvs);
|
||||
await fs.writeFile(path.join(root, ".env"), content);
|
||||
|
||||
+48
-78
@@ -8,8 +8,8 @@ import { writeLoadersConfig } from "./datasources";
|
||||
import { createBackendEnvFile, createFrontendEnvFile } from "./env-variables";
|
||||
import { PackageManager } from "./get-pkg-manager";
|
||||
import { installLlamapackProject } from "./llama-pack";
|
||||
import { makeDir } from "./make-dir";
|
||||
import { isHavingPoetryLockFile, tryPoetryRun } from "./poetry";
|
||||
import { isModelConfigured } from "./providers";
|
||||
import { installPythonTemplate } from "./python";
|
||||
import { downloadAndExtractRepo } from "./repo";
|
||||
import { ConfigFileType, writeToolsConfig } from "./tools";
|
||||
@@ -23,31 +23,6 @@ import {
|
||||
} from "./types";
|
||||
import { installTSTemplate } from "./typescript";
|
||||
|
||||
const checkForGenerateScript = (
|
||||
modelConfig: ModelConfig,
|
||||
vectorDb?: TemplateVectorDB,
|
||||
llamaCloudKey?: string,
|
||||
useLlamaParse?: boolean,
|
||||
) => {
|
||||
const missingSettings = [];
|
||||
|
||||
if (!modelConfig.isConfigured()) {
|
||||
missingSettings.push("your model provider API key");
|
||||
}
|
||||
|
||||
const llamaCloudApiKey = llamaCloudKey ?? process.env["LLAMA_CLOUD_API_KEY"];
|
||||
const isRequiredLlamaCloudKey = useLlamaParse || vectorDb === "llamacloud";
|
||||
if (isRequiredLlamaCloudKey && !llamaCloudApiKey) {
|
||||
missingSettings.push("your LLAMA_CLOUD_API_KEY");
|
||||
}
|
||||
|
||||
if (vectorDb !== "none" && vectorDb !== "llamacloud") {
|
||||
missingSettings.push("your Vector DB environment variables");
|
||||
}
|
||||
|
||||
return missingSettings;
|
||||
};
|
||||
|
||||
// eslint-disable-next-line max-params
|
||||
async function generateContextData(
|
||||
framework: TemplateFramework,
|
||||
@@ -63,15 +38,12 @@ async function generateContextData(
|
||||
? "poetry run generate"
|
||||
: `${packageManager} run generate`,
|
||||
)}`;
|
||||
|
||||
const missingSettings = checkForGenerateScript(
|
||||
modelConfig,
|
||||
vectorDb,
|
||||
llamaCloudKey,
|
||||
useLlamaParse,
|
||||
);
|
||||
|
||||
if (!missingSettings.length) {
|
||||
const modelConfigured = isModelConfigured(modelConfig);
|
||||
const llamaCloudKeyConfigured = useLlamaParse
|
||||
? llamaCloudKey || process.env["LLAMA_CLOUD_API_KEY"]
|
||||
: true;
|
||||
const hasVectorDb = vectorDb && vectorDb !== "none";
|
||||
if (modelConfigured && llamaCloudKeyConfigured && !hasVectorDb) {
|
||||
// If all the required environment variables are set, run the generate script
|
||||
if (framework === "fastapi") {
|
||||
if (isHavingPoetryLockFile()) {
|
||||
@@ -91,16 +63,22 @@ async function generateContextData(
|
||||
}
|
||||
}
|
||||
|
||||
const settingsMessage = `After setting ${missingSettings.join(" and ")}, run ${runGenerate} to generate the context data.`;
|
||||
console.log(`\n${settingsMessage}\n\n`);
|
||||
// generate the message of what to do to run the generate script manually
|
||||
const settings = [];
|
||||
if (!modelConfigured) settings.push("your model provider API key");
|
||||
if (!llamaCloudKeyConfigured) settings.push("your Llama Cloud key");
|
||||
if (hasVectorDb) settings.push("your Vector DB environment variables");
|
||||
const settingsMessage =
|
||||
settings.length > 0 ? `After setting ${settings.join(" and ")}, ` : "";
|
||||
const generateMessage = `run ${runGenerate} to generate the context data.`;
|
||||
console.log(`\n${settingsMessage}${generateMessage}\n\n`);
|
||||
}
|
||||
}
|
||||
|
||||
const prepareContextData = async (
|
||||
const copyContextData = async (
|
||||
root: string,
|
||||
dataSources: TemplateDataSource[],
|
||||
) => {
|
||||
await makeDir(path.join(root, "data"));
|
||||
for (const dataSource of dataSources) {
|
||||
const dataSourceConfig = dataSource?.config as FileSourceConfig;
|
||||
// Copy local data
|
||||
@@ -143,15 +121,12 @@ export const installTemplate = async (
|
||||
|
||||
if (props.framework === "fastapi") {
|
||||
await installPythonTemplate(props);
|
||||
if (props.vectorDb !== "llamacloud") {
|
||||
// write loaders configuration (currently Python only)
|
||||
// not needed for LlamaCloud as it has its own loaders
|
||||
await writeLoadersConfig(
|
||||
props.root,
|
||||
props.dataSources,
|
||||
props.useLlamaParse,
|
||||
);
|
||||
}
|
||||
// write loaders configuration (currently Python only)
|
||||
await writeLoadersConfig(
|
||||
props.root,
|
||||
props.dataSources,
|
||||
props.useLlamaParse,
|
||||
);
|
||||
} else {
|
||||
await installTSTemplate(props);
|
||||
}
|
||||
@@ -167,44 +142,39 @@ export const installTemplate = async (
|
||||
// This is a backend, so we need to copy the test data and create the env file.
|
||||
|
||||
// Copy the environment file to the target directory.
|
||||
if (
|
||||
props.template === "streaming" ||
|
||||
props.template === "multiagent" ||
|
||||
props.template === "extractor"
|
||||
) {
|
||||
await createBackendEnvFile(props.root, props);
|
||||
}
|
||||
await createBackendEnvFile(props.root, {
|
||||
modelConfig: props.modelConfig,
|
||||
llamaCloudKey: props.llamaCloudKey,
|
||||
vectorDb: props.vectorDb,
|
||||
framework: props.framework,
|
||||
dataSources: props.dataSources,
|
||||
port: props.externalPort,
|
||||
});
|
||||
|
||||
await prepareContextData(
|
||||
props.root,
|
||||
props.dataSources.filter((ds) => ds.type === "file"),
|
||||
);
|
||||
|
||||
if (
|
||||
props.dataSources.length > 0 &&
|
||||
(props.postInstallAction === "runApp" ||
|
||||
props.postInstallAction === "dependencies")
|
||||
) {
|
||||
if (props.dataSources.length > 0) {
|
||||
console.log("\nGenerating context data...\n");
|
||||
await generateContextData(
|
||||
props.framework,
|
||||
props.modelConfig,
|
||||
props.packageManager,
|
||||
props.vectorDb,
|
||||
props.llamaCloudKey,
|
||||
props.useLlamaParse,
|
||||
await copyContextData(
|
||||
props.root,
|
||||
props.dataSources.filter((ds) => ds.type === "file"),
|
||||
);
|
||||
if (
|
||||
props.postInstallAction === "runApp" ||
|
||||
props.postInstallAction === "dependencies"
|
||||
) {
|
||||
await generateContextData(
|
||||
props.framework,
|
||||
props.modelConfig,
|
||||
props.packageManager,
|
||||
props.vectorDb,
|
||||
props.llamaCloudKey,
|
||||
props.useLlamaParse,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Create outputs directory
|
||||
await makeDir(path.join(props.root, "output/tools"));
|
||||
await makeDir(path.join(props.root, "output/uploaded"));
|
||||
await makeDir(path.join(props.root, "output/llamacloud"));
|
||||
} else {
|
||||
// this is a frontend for a full-stack app, create .env file with model information
|
||||
await createFrontendEnvFile(props.root, {
|
||||
customApiPath: props.customApiPath,
|
||||
vectorDb: props.vectorDb,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,106 +0,0 @@
|
||||
import ciInfo from "ci-info";
|
||||
import prompts from "prompts";
|
||||
import { ModelConfigParams } from ".";
|
||||
import { questionHandlers, toChoice } from "../../questions";
|
||||
|
||||
const MODELS = [
|
||||
"claude-3-opus",
|
||||
"claude-3-sonnet",
|
||||
"claude-3-haiku",
|
||||
"claude-2.1",
|
||||
"claude-instant-1.2",
|
||||
];
|
||||
const DEFAULT_MODEL = MODELS[0];
|
||||
|
||||
// TODO: get embedding vector dimensions from the anthropic sdk (currently not supported)
|
||||
// Use huggingface embedding models for now
|
||||
enum HuggingFaceEmbeddingModelType {
|
||||
XENOVA_ALL_MINILM_L6_V2 = "all-MiniLM-L6-v2",
|
||||
XENOVA_ALL_MPNET_BASE_V2 = "all-mpnet-base-v2",
|
||||
}
|
||||
type ModelData = {
|
||||
dimensions: number;
|
||||
};
|
||||
const EMBEDDING_MODELS: Record<HuggingFaceEmbeddingModelType, ModelData> = {
|
||||
[HuggingFaceEmbeddingModelType.XENOVA_ALL_MINILM_L6_V2]: {
|
||||
dimensions: 384,
|
||||
},
|
||||
[HuggingFaceEmbeddingModelType.XENOVA_ALL_MPNET_BASE_V2]: {
|
||||
dimensions: 768,
|
||||
},
|
||||
};
|
||||
const DEFAULT_EMBEDDING_MODEL = Object.keys(EMBEDDING_MODELS)[0];
|
||||
const DEFAULT_DIMENSIONS = Object.values(EMBEDDING_MODELS)[0].dimensions;
|
||||
|
||||
type AnthropicQuestionsParams = {
|
||||
apiKey?: string;
|
||||
askModels: boolean;
|
||||
};
|
||||
|
||||
export async function askAnthropicQuestions({
|
||||
askModels,
|
||||
apiKey,
|
||||
}: AnthropicQuestionsParams): Promise<ModelConfigParams> {
|
||||
const config: ModelConfigParams = {
|
||||
apiKey,
|
||||
model: DEFAULT_MODEL,
|
||||
embeddingModel: DEFAULT_EMBEDDING_MODEL,
|
||||
dimensions: DEFAULT_DIMENSIONS,
|
||||
isConfigured(): boolean {
|
||||
if (config.apiKey) {
|
||||
return true;
|
||||
}
|
||||
if (process.env["ANTHROPIC_API_KEY"]) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
},
|
||||
};
|
||||
|
||||
if (!config.apiKey) {
|
||||
const { key } = await prompts(
|
||||
{
|
||||
type: "text",
|
||||
name: "key",
|
||||
message:
|
||||
"Please provide your Anthropic API key (or leave blank to use ANTHROPIC_API_KEY env variable):",
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
config.apiKey = key || process.env.ANTHROPIC_API_KEY;
|
||||
}
|
||||
|
||||
// use default model values in CI or if user should not be asked
|
||||
const useDefaults = ciInfo.isCI || !askModels;
|
||||
if (!useDefaults) {
|
||||
const { model } = await prompts(
|
||||
{
|
||||
type: "select",
|
||||
name: "model",
|
||||
message: "Which LLM model would you like to use?",
|
||||
choices: MODELS.map(toChoice),
|
||||
initial: 0,
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
config.model = model;
|
||||
|
||||
const { embeddingModel } = await prompts(
|
||||
{
|
||||
type: "select",
|
||||
name: "embeddingModel",
|
||||
message: "Which embedding model would you like to use?",
|
||||
choices: Object.keys(EMBEDDING_MODELS).map(toChoice),
|
||||
initial: 0,
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
config.embeddingModel = embeddingModel;
|
||||
config.dimensions =
|
||||
EMBEDDING_MODELS[
|
||||
embeddingModel as HuggingFaceEmbeddingModelType
|
||||
].dimensions;
|
||||
}
|
||||
|
||||
return config;
|
||||
}
|
||||
@@ -1,118 +0,0 @@
|
||||
import ciInfo from "ci-info";
|
||||
import prompts from "prompts";
|
||||
import { ModelConfigParams, ModelConfigQuestionsParams } from ".";
|
||||
import { questionHandlers } from "../../questions";
|
||||
|
||||
const ALL_AZURE_OPENAI_CHAT_MODELS: Record<string, { openAIModel: string }> = {
|
||||
"gpt-35-turbo": { openAIModel: "gpt-3.5-turbo" },
|
||||
"gpt-35-turbo-16k": {
|
||||
openAIModel: "gpt-3.5-turbo-16k",
|
||||
},
|
||||
"gpt-4o": { openAIModel: "gpt-4o" },
|
||||
"gpt-4o-mini": { openAIModel: "gpt-4o-mini" },
|
||||
"gpt-4": { openAIModel: "gpt-4" },
|
||||
"gpt-4-32k": { openAIModel: "gpt-4-32k" },
|
||||
"gpt-4-turbo": {
|
||||
openAIModel: "gpt-4-turbo",
|
||||
},
|
||||
"gpt-4-turbo-2024-04-09": {
|
||||
openAIModel: "gpt-4-turbo",
|
||||
},
|
||||
"gpt-4-vision-preview": {
|
||||
openAIModel: "gpt-4-vision-preview",
|
||||
},
|
||||
"gpt-4-1106-preview": {
|
||||
openAIModel: "gpt-4-1106-preview",
|
||||
},
|
||||
"gpt-4o-2024-05-13": {
|
||||
openAIModel: "gpt-4o-2024-05-13",
|
||||
},
|
||||
"gpt-4o-mini-2024-07-18": {
|
||||
openAIModel: "gpt-4o-mini-2024-07-18",
|
||||
},
|
||||
};
|
||||
|
||||
const ALL_AZURE_OPENAI_EMBEDDING_MODELS: Record<
|
||||
string,
|
||||
{
|
||||
dimensions: number;
|
||||
openAIModel: string;
|
||||
}
|
||||
> = {
|
||||
"text-embedding-3-small": {
|
||||
dimensions: 1536,
|
||||
openAIModel: "text-embedding-3-small",
|
||||
},
|
||||
"text-embedding-3-large": {
|
||||
dimensions: 3072,
|
||||
openAIModel: "text-embedding-3-large",
|
||||
},
|
||||
};
|
||||
|
||||
const DEFAULT_MODEL = "gpt-4o";
|
||||
const DEFAULT_EMBEDDING_MODEL = "text-embedding-3-large";
|
||||
|
||||
export async function askAzureQuestions({
|
||||
openAiKey,
|
||||
askModels,
|
||||
}: ModelConfigQuestionsParams): Promise<ModelConfigParams> {
|
||||
const config: ModelConfigParams = {
|
||||
apiKey: openAiKey || process.env.AZURE_OPENAI_KEY,
|
||||
model: DEFAULT_MODEL,
|
||||
embeddingModel: DEFAULT_EMBEDDING_MODEL,
|
||||
dimensions: getDimensions(DEFAULT_EMBEDDING_MODEL),
|
||||
isConfigured(): boolean {
|
||||
// the Azure model provider can't be fully configured as endpoint and deployment names have to be configured with env variables
|
||||
return false;
|
||||
},
|
||||
};
|
||||
|
||||
// use default model values in CI or if user should not be asked
|
||||
const useDefaults = ciInfo.isCI || !askModels;
|
||||
if (!useDefaults) {
|
||||
const { model } = await prompts(
|
||||
{
|
||||
type: "select",
|
||||
name: "model",
|
||||
message: "Which LLM model would you like to use?",
|
||||
choices: getAvailableModelChoices(),
|
||||
initial: 0,
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
config.model = model;
|
||||
|
||||
const { embeddingModel } = await prompts(
|
||||
{
|
||||
type: "select",
|
||||
name: "embeddingModel",
|
||||
message: "Which embedding model would you like to use?",
|
||||
choices: getAvailableEmbeddingModelChoices(),
|
||||
initial: 0,
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
config.embeddingModel = embeddingModel;
|
||||
config.dimensions = getDimensions(embeddingModel);
|
||||
}
|
||||
|
||||
return config;
|
||||
}
|
||||
|
||||
function getAvailableModelChoices() {
|
||||
return Object.keys(ALL_AZURE_OPENAI_CHAT_MODELS).map((key) => ({
|
||||
title: key,
|
||||
value: key,
|
||||
}));
|
||||
}
|
||||
|
||||
function getAvailableEmbeddingModelChoices() {
|
||||
return Object.keys(ALL_AZURE_OPENAI_EMBEDDING_MODELS).map((key) => ({
|
||||
title: key,
|
||||
value: key,
|
||||
}));
|
||||
}
|
||||
|
||||
function getDimensions(modelName: string) {
|
||||
return ALL_AZURE_OPENAI_EMBEDDING_MODELS[modelName].dimensions;
|
||||
}
|
||||
@@ -1,87 +0,0 @@
|
||||
import ciInfo from "ci-info";
|
||||
import prompts from "prompts";
|
||||
import { ModelConfigParams } from ".";
|
||||
import { questionHandlers, toChoice } from "../../questions";
|
||||
|
||||
const MODELS = ["gemini-1.5-pro-latest", "gemini-pro", "gemini-pro-vision"];
|
||||
type ModelData = {
|
||||
dimensions: number;
|
||||
};
|
||||
const EMBEDDING_MODELS: Record<string, ModelData> = {
|
||||
"embedding-001": { dimensions: 768 },
|
||||
"text-embedding-004": { dimensions: 768 },
|
||||
};
|
||||
|
||||
const DEFAULT_MODEL = MODELS[0];
|
||||
const DEFAULT_EMBEDDING_MODEL = Object.keys(EMBEDDING_MODELS)[0];
|
||||
const DEFAULT_DIMENSIONS = Object.values(EMBEDDING_MODELS)[0].dimensions;
|
||||
|
||||
type GeminiQuestionsParams = {
|
||||
apiKey?: string;
|
||||
askModels: boolean;
|
||||
};
|
||||
|
||||
export async function askGeminiQuestions({
|
||||
askModels,
|
||||
apiKey,
|
||||
}: GeminiQuestionsParams): Promise<ModelConfigParams> {
|
||||
const config: ModelConfigParams = {
|
||||
apiKey,
|
||||
model: DEFAULT_MODEL,
|
||||
embeddingModel: DEFAULT_EMBEDDING_MODEL,
|
||||
dimensions: DEFAULT_DIMENSIONS,
|
||||
isConfigured(): boolean {
|
||||
if (config.apiKey) {
|
||||
return true;
|
||||
}
|
||||
if (process.env["GOOGLE_API_KEY"]) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
},
|
||||
};
|
||||
|
||||
if (!config.apiKey) {
|
||||
const { key } = await prompts(
|
||||
{
|
||||
type: "text",
|
||||
name: "key",
|
||||
message:
|
||||
"Please provide your Google API key (or leave blank to use GOOGLE_API_KEY env variable):",
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
config.apiKey = key || process.env.GOOGLE_API_KEY;
|
||||
}
|
||||
|
||||
// use default model values in CI or if user should not be asked
|
||||
const useDefaults = ciInfo.isCI || !askModels;
|
||||
if (!useDefaults) {
|
||||
const { model } = await prompts(
|
||||
{
|
||||
type: "select",
|
||||
name: "model",
|
||||
message: "Which LLM model would you like to use?",
|
||||
choices: MODELS.map(toChoice),
|
||||
initial: 0,
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
config.model = model;
|
||||
|
||||
const { embeddingModel } = await prompts(
|
||||
{
|
||||
type: "select",
|
||||
name: "embeddingModel",
|
||||
message: "Which embedding model would you like to use?",
|
||||
choices: Object.keys(EMBEDDING_MODELS).map(toChoice),
|
||||
initial: 0,
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
config.embeddingModel = embeddingModel;
|
||||
config.dimensions = EMBEDDING_MODELS[embeddingModel].dimensions;
|
||||
}
|
||||
|
||||
return config;
|
||||
}
|
||||
@@ -1,148 +0,0 @@
|
||||
import ciInfo from "ci-info";
|
||||
import prompts from "prompts";
|
||||
import { ModelConfigParams } from ".";
|
||||
import { questionHandlers, toChoice } from "../../questions";
|
||||
|
||||
import got from "got";
|
||||
import ora from "ora";
|
||||
import { red } from "picocolors";
|
||||
|
||||
const GROQ_API_URL = "https://api.groq.com/openai/v1";
|
||||
|
||||
async function getAvailableModelChoicesGroq(apiKey: string) {
|
||||
if (!apiKey) {
|
||||
throw new Error("Need Groq API key to retrieve model choices");
|
||||
}
|
||||
|
||||
const spinner = ora("Fetching available models from Groq").start();
|
||||
try {
|
||||
const response = await got(`${GROQ_API_URL}/models`, {
|
||||
headers: {
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
},
|
||||
timeout: 5000,
|
||||
responseType: "json",
|
||||
});
|
||||
const data: any = await response.body;
|
||||
spinner.stop();
|
||||
|
||||
// Filter out the Whisper models
|
||||
return data.data
|
||||
.filter((model: any) => !model.id.toLowerCase().includes("whisper"))
|
||||
.map((el: any) => {
|
||||
return {
|
||||
title: el.id,
|
||||
value: el.id,
|
||||
};
|
||||
});
|
||||
} catch (error: unknown) {
|
||||
spinner.stop();
|
||||
console.log(error);
|
||||
if ((error as any).response?.statusCode === 401) {
|
||||
console.log(
|
||||
red(
|
||||
"Invalid Groq API key provided! Please provide a valid key and try again!",
|
||||
),
|
||||
);
|
||||
} else {
|
||||
console.log(red("Request failed: " + error));
|
||||
}
|
||||
process.exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
const DEFAULT_MODEL = "llama3-70b-8192";
|
||||
|
||||
// Use huggingface embedding models for now as Groq doesn't support embedding models
|
||||
enum HuggingFaceEmbeddingModelType {
|
||||
XENOVA_ALL_MINILM_L6_V2 = "all-MiniLM-L6-v2",
|
||||
XENOVA_ALL_MPNET_BASE_V2 = "all-mpnet-base-v2",
|
||||
}
|
||||
type ModelData = {
|
||||
dimensions: number;
|
||||
};
|
||||
const EMBEDDING_MODELS: Record<HuggingFaceEmbeddingModelType, ModelData> = {
|
||||
[HuggingFaceEmbeddingModelType.XENOVA_ALL_MINILM_L6_V2]: {
|
||||
dimensions: 384,
|
||||
},
|
||||
[HuggingFaceEmbeddingModelType.XENOVA_ALL_MPNET_BASE_V2]: {
|
||||
dimensions: 768,
|
||||
},
|
||||
};
|
||||
const DEFAULT_EMBEDDING_MODEL = Object.keys(EMBEDDING_MODELS)[0];
|
||||
const DEFAULT_DIMENSIONS = Object.values(EMBEDDING_MODELS)[0].dimensions;
|
||||
|
||||
type GroqQuestionsParams = {
|
||||
apiKey?: string;
|
||||
askModels: boolean;
|
||||
};
|
||||
|
||||
export async function askGroqQuestions({
|
||||
askModels,
|
||||
apiKey,
|
||||
}: GroqQuestionsParams): Promise<ModelConfigParams> {
|
||||
const config: ModelConfigParams = {
|
||||
apiKey,
|
||||
model: DEFAULT_MODEL,
|
||||
embeddingModel: DEFAULT_EMBEDDING_MODEL,
|
||||
dimensions: DEFAULT_DIMENSIONS,
|
||||
isConfigured(): boolean {
|
||||
if (config.apiKey) {
|
||||
return true;
|
||||
}
|
||||
if (process.env["GROQ_API_KEY"]) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
},
|
||||
};
|
||||
|
||||
if (!config.apiKey) {
|
||||
const { key } = await prompts(
|
||||
{
|
||||
type: "text",
|
||||
name: "key",
|
||||
message:
|
||||
"Please provide your Groq API key (or leave blank to use GROQ_API_KEY env variable):",
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
config.apiKey = key || process.env.GROQ_API_KEY;
|
||||
}
|
||||
|
||||
// use default model values in CI or if user should not be asked
|
||||
const useDefaults = ciInfo.isCI || !askModels;
|
||||
if (!useDefaults) {
|
||||
const modelChoices = await getAvailableModelChoicesGroq(config.apiKey!);
|
||||
|
||||
const { model } = await prompts(
|
||||
{
|
||||
type: "select",
|
||||
name: "model",
|
||||
message: "Which LLM model would you like to use?",
|
||||
choices: modelChoices,
|
||||
initial: 0,
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
config.model = model;
|
||||
|
||||
const { embeddingModel } = await prompts(
|
||||
{
|
||||
type: "select",
|
||||
name: "embeddingModel",
|
||||
message: "Which embedding model would you like to use?",
|
||||
choices: Object.keys(EMBEDDING_MODELS).map(toChoice),
|
||||
initial: 0,
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
config.embeddingModel = embeddingModel;
|
||||
config.dimensions =
|
||||
EMBEDDING_MODELS[
|
||||
embeddingModel as HuggingFaceEmbeddingModelType
|
||||
].dimensions;
|
||||
}
|
||||
|
||||
return config;
|
||||
}
|
||||
+18
-42
@@ -1,22 +1,15 @@
|
||||
import ciInfo from "ci-info";
|
||||
import prompts from "prompts";
|
||||
import { questionHandlers } from "../../questions";
|
||||
import { ModelConfig, ModelProvider, TemplateFramework } from "../types";
|
||||
import { askAnthropicQuestions } from "./anthropic";
|
||||
import { askAzureQuestions } from "./azure";
|
||||
import { askGeminiQuestions } from "./gemini";
|
||||
import { askGroqQuestions } from "./groq";
|
||||
import { askLLMHubQuestions } from "./llmhub";
|
||||
import { askMistralQuestions } from "./mistral";
|
||||
import { ModelConfig, ModelProvider } from "../types";
|
||||
import { askOllamaQuestions } from "./ollama";
|
||||
import { askOpenAIQuestions } from "./openai";
|
||||
import { askOpenAIQuestions, isOpenAIConfigured } from "./openai";
|
||||
|
||||
const DEFAULT_MODEL_PROVIDER = "openai";
|
||||
|
||||
export type ModelConfigQuestionsParams = {
|
||||
openAiKey?: string;
|
||||
askModels: boolean;
|
||||
framework?: TemplateFramework;
|
||||
};
|
||||
|
||||
export type ModelConfigParams = Omit<ModelConfig, "provider">;
|
||||
@@ -24,29 +17,21 @@ export type ModelConfigParams = Omit<ModelConfig, "provider">;
|
||||
export async function askModelConfig({
|
||||
askModels,
|
||||
openAiKey,
|
||||
framework,
|
||||
}: ModelConfigQuestionsParams): Promise<ModelConfig> {
|
||||
let modelProvider: ModelProvider = DEFAULT_MODEL_PROVIDER;
|
||||
if (askModels && !ciInfo.isCI) {
|
||||
let choices = [
|
||||
{ title: "OpenAI", value: "openai" },
|
||||
{ title: "Groq", value: "groq" },
|
||||
{ title: "Ollama", value: "ollama" },
|
||||
{ title: "Anthropic", value: "anthropic" },
|
||||
{ title: "Gemini", value: "gemini" },
|
||||
{ title: "Mistral", value: "mistral" },
|
||||
{ title: "AzureOpenAI", value: "azure-openai" },
|
||||
];
|
||||
|
||||
if (framework === "fastapi") {
|
||||
choices.push({ title: "T-Systems", value: "t-systems" });
|
||||
}
|
||||
const { provider } = await prompts(
|
||||
{
|
||||
type: "select",
|
||||
name: "provider",
|
||||
message: "Which model provider would you like to use",
|
||||
choices: choices,
|
||||
choices: [
|
||||
{
|
||||
title: "OpenAI",
|
||||
value: "openai",
|
||||
},
|
||||
{ title: "Ollama", value: "ollama" },
|
||||
],
|
||||
initial: 0,
|
||||
},
|
||||
questionHandlers,
|
||||
@@ -59,24 +44,6 @@ export async function askModelConfig({
|
||||
case "ollama":
|
||||
modelConfig = await askOllamaQuestions({ askModels });
|
||||
break;
|
||||
case "groq":
|
||||
modelConfig = await askGroqQuestions({ askModels });
|
||||
break;
|
||||
case "anthropic":
|
||||
modelConfig = await askAnthropicQuestions({ askModels });
|
||||
break;
|
||||
case "gemini":
|
||||
modelConfig = await askGeminiQuestions({ askModels });
|
||||
break;
|
||||
case "mistral":
|
||||
modelConfig = await askMistralQuestions({ askModels });
|
||||
break;
|
||||
case "azure-openai":
|
||||
modelConfig = await askAzureQuestions({ askModels });
|
||||
break;
|
||||
case "t-systems":
|
||||
modelConfig = await askLLMHubQuestions({ askModels });
|
||||
break;
|
||||
default:
|
||||
modelConfig = await askOpenAIQuestions({
|
||||
openAiKey,
|
||||
@@ -88,3 +55,12 @@ export async function askModelConfig({
|
||||
provider: modelProvider,
|
||||
};
|
||||
}
|
||||
|
||||
export function isModelConfigured(modelConfig: ModelConfig): boolean {
|
||||
switch (modelConfig.provider) {
|
||||
case "openai":
|
||||
return isOpenAIConfigured(modelConfig);
|
||||
default:
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,169 +0,0 @@
|
||||
import ciInfo from "ci-info";
|
||||
import got from "got";
|
||||
import ora from "ora";
|
||||
import { red } from "picocolors";
|
||||
import prompts from "prompts";
|
||||
import { ModelConfigParams } from ".";
|
||||
import { questionHandlers } from "../../questions";
|
||||
|
||||
export const TSYSTEMS_LLMHUB_API_URL =
|
||||
"https://llm-server.llmhub.t-systems.net/v2";
|
||||
|
||||
const DEFAULT_MODEL = "gpt-3.5-turbo";
|
||||
const DEFAULT_EMBEDDING_MODEL = "text-embedding-3-large";
|
||||
|
||||
const LLMHUB_MODELS = [
|
||||
"gpt-35-turbo",
|
||||
"gpt-4-32k-1",
|
||||
"gpt-4-32k-canada",
|
||||
"gpt-4-32k-france",
|
||||
"gpt-4-turbo-128k-france",
|
||||
"Llama2-70b-Instruct",
|
||||
"Llama-3-70B-Instruct",
|
||||
"Mixtral-8x7B-Instruct-v0.1",
|
||||
"mistral-large-32k-france",
|
||||
"CodeLlama-2",
|
||||
];
|
||||
const LLMHUB_EMBEDDING_MODELS = [
|
||||
"text-embedding-ada-002",
|
||||
"text-embedding-ada-002-france",
|
||||
"jina-embeddings-v2-base-de",
|
||||
"jina-embeddings-v2-base-code",
|
||||
"text-embedding-bge-m3",
|
||||
];
|
||||
|
||||
type LLMHubQuestionsParams = {
|
||||
apiKey?: string;
|
||||
askModels: boolean;
|
||||
};
|
||||
|
||||
export async function askLLMHubQuestions({
|
||||
askModels,
|
||||
apiKey,
|
||||
}: LLMHubQuestionsParams): Promise<ModelConfigParams> {
|
||||
const config: ModelConfigParams = {
|
||||
apiKey,
|
||||
model: DEFAULT_MODEL,
|
||||
embeddingModel: DEFAULT_EMBEDDING_MODEL,
|
||||
dimensions: getDimensions(DEFAULT_EMBEDDING_MODEL),
|
||||
isConfigured(): boolean {
|
||||
if (config.apiKey) {
|
||||
return true;
|
||||
}
|
||||
if (process.env["T_SYSTEMS_LLMHUB_API_KEY"]) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
},
|
||||
};
|
||||
|
||||
if (!config.apiKey) {
|
||||
const { key } = await prompts(
|
||||
{
|
||||
type: "text",
|
||||
name: "key",
|
||||
message: askModels
|
||||
? "Please provide your LLMHub API key (or leave blank to use T_SYSTEMS_LLMHUB_API_KEY env variable):"
|
||||
: "Please provide your LLMHub API key (leave blank to skip):",
|
||||
validate: (value: string) => {
|
||||
if (askModels && !value) {
|
||||
if (process.env.T_SYSTEMS_LLMHUB_API_KEY) {
|
||||
return true;
|
||||
}
|
||||
return "T_SYSTEMS_LLMHUB_API_KEY env variable is not set - key is required";
|
||||
}
|
||||
return true;
|
||||
},
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
config.apiKey = key || process.env.T_SYSTEMS_LLMHUB_API_KEY;
|
||||
}
|
||||
|
||||
// use default model values in CI or if user should not be asked
|
||||
const useDefaults = ciInfo.isCI || !askModels;
|
||||
if (!useDefaults) {
|
||||
const { model } = await prompts(
|
||||
{
|
||||
type: "select",
|
||||
name: "model",
|
||||
message: "Which LLM model would you like to use?",
|
||||
choices: await getAvailableModelChoices(false, config.apiKey),
|
||||
initial: 0,
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
config.model = model;
|
||||
|
||||
const { embeddingModel } = await prompts(
|
||||
{
|
||||
type: "select",
|
||||
name: "embeddingModel",
|
||||
message: "Which embedding model would you like to use?",
|
||||
choices: await getAvailableModelChoices(true, config.apiKey),
|
||||
initial: 0,
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
config.embeddingModel = embeddingModel;
|
||||
config.dimensions = getDimensions(embeddingModel);
|
||||
}
|
||||
|
||||
return config;
|
||||
}
|
||||
|
||||
async function getAvailableModelChoices(
|
||||
selectEmbedding: boolean,
|
||||
apiKey?: string,
|
||||
) {
|
||||
if (!apiKey) {
|
||||
throw new Error("Need LLMHub key to retrieve model choices");
|
||||
}
|
||||
const isLLMModel = (modelId: string) => {
|
||||
return LLMHUB_MODELS.includes(modelId);
|
||||
};
|
||||
|
||||
const isEmbeddingModel = (modelId: string) => {
|
||||
return LLMHUB_EMBEDDING_MODELS.includes(modelId);
|
||||
};
|
||||
|
||||
const spinner = ora("Fetching available models").start();
|
||||
try {
|
||||
const response = await got(`${TSYSTEMS_LLMHUB_API_URL}/models`, {
|
||||
headers: {
|
||||
Authorization: "Bearer " + apiKey,
|
||||
},
|
||||
timeout: 5000,
|
||||
responseType: "json",
|
||||
});
|
||||
const data: any = await response.body;
|
||||
spinner.stop();
|
||||
return data.data
|
||||
.filter((model: any) =>
|
||||
selectEmbedding ? isEmbeddingModel(model.id) : isLLMModel(model.id),
|
||||
)
|
||||
.map((el: any) => {
|
||||
return {
|
||||
title: el.id,
|
||||
value: el.id,
|
||||
};
|
||||
});
|
||||
} catch (error) {
|
||||
spinner.stop();
|
||||
if ((error as any).response?.statusCode === 401) {
|
||||
console.log(
|
||||
red(
|
||||
"Invalid LLMHub API key provided! Please provide a valid key and try again!",
|
||||
),
|
||||
);
|
||||
} else {
|
||||
console.log(red("Request failed: " + error));
|
||||
}
|
||||
process.exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
function getDimensions(modelName: string) {
|
||||
// Assuming dimensions similar to OpenAI for simplicity. Update if different.
|
||||
return modelName === "text-embedding-004" ? 768 : 1536;
|
||||
}
|
||||
@@ -1,86 +0,0 @@
|
||||
import ciInfo from "ci-info";
|
||||
import prompts from "prompts";
|
||||
import { ModelConfigParams } from ".";
|
||||
import { questionHandlers, toChoice } from "../../questions";
|
||||
|
||||
const MODELS = ["mistral-tiny", "mistral-small", "mistral-medium"];
|
||||
type ModelData = {
|
||||
dimensions: number;
|
||||
};
|
||||
const EMBEDDING_MODELS: Record<string, ModelData> = {
|
||||
"mistral-embed": { dimensions: 1024 },
|
||||
};
|
||||
|
||||
const DEFAULT_MODEL = MODELS[0];
|
||||
const DEFAULT_EMBEDDING_MODEL = Object.keys(EMBEDDING_MODELS)[0];
|
||||
const DEFAULT_DIMENSIONS = Object.values(EMBEDDING_MODELS)[0].dimensions;
|
||||
|
||||
type MistralQuestionsParams = {
|
||||
apiKey?: string;
|
||||
askModels: boolean;
|
||||
};
|
||||
|
||||
export async function askMistralQuestions({
|
||||
askModels,
|
||||
apiKey,
|
||||
}: MistralQuestionsParams): Promise<ModelConfigParams> {
|
||||
const config: ModelConfigParams = {
|
||||
apiKey,
|
||||
model: DEFAULT_MODEL,
|
||||
embeddingModel: DEFAULT_EMBEDDING_MODEL,
|
||||
dimensions: DEFAULT_DIMENSIONS,
|
||||
isConfigured(): boolean {
|
||||
if (config.apiKey) {
|
||||
return true;
|
||||
}
|
||||
if (process.env["MISTRAL_API_KEY"]) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
},
|
||||
};
|
||||
|
||||
if (!config.apiKey) {
|
||||
const { key } = await prompts(
|
||||
{
|
||||
type: "text",
|
||||
name: "key",
|
||||
message:
|
||||
"Please provide your Mistral API key (or leave blank to use MISTRAL_API_KEY env variable):",
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
config.apiKey = key || process.env.MISTRAL_API_KEY;
|
||||
}
|
||||
|
||||
// use default model values in CI or if user should not be asked
|
||||
const useDefaults = ciInfo.isCI || !askModels;
|
||||
if (!useDefaults) {
|
||||
const { model } = await prompts(
|
||||
{
|
||||
type: "select",
|
||||
name: "model",
|
||||
message: "Which LLM model would you like to use?",
|
||||
choices: MODELS.map(toChoice),
|
||||
initial: 0,
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
config.model = model;
|
||||
|
||||
const { embeddingModel } = await prompts(
|
||||
{
|
||||
type: "select",
|
||||
name: "embeddingModel",
|
||||
message: "Which embedding model would you like to use?",
|
||||
choices: Object.keys(EMBEDDING_MODELS).map(toChoice),
|
||||
initial: 0,
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
config.embeddingModel = embeddingModel;
|
||||
config.dimensions = EMBEDDING_MODELS[embeddingModel].dimensions;
|
||||
}
|
||||
|
||||
return config;
|
||||
}
|
||||
@@ -29,9 +29,6 @@ export async function askOllamaQuestions({
|
||||
model: DEFAULT_MODEL,
|
||||
embeddingModel: DEFAULT_EMBEDDING_MODEL,
|
||||
dimensions: EMBEDDING_MODELS[DEFAULT_EMBEDDING_MODEL].dimensions,
|
||||
isConfigured(): boolean {
|
||||
return true;
|
||||
},
|
||||
};
|
||||
|
||||
// use default model values in CI or if user should not be asked
|
||||
|
||||
+12
-10
@@ -8,7 +8,7 @@ import { questionHandlers } from "../../questions";
|
||||
|
||||
const OPENAI_API_URL = "https://api.openai.com/v1";
|
||||
|
||||
const DEFAULT_MODEL = "gpt-4o-mini";
|
||||
const DEFAULT_MODEL = "gpt-4-turbo";
|
||||
const DEFAULT_EMBEDDING_MODEL = "text-embedding-3-large";
|
||||
|
||||
export async function askOpenAIQuestions({
|
||||
@@ -20,15 +20,6 @@ export async function askOpenAIQuestions({
|
||||
model: DEFAULT_MODEL,
|
||||
embeddingModel: DEFAULT_EMBEDDING_MODEL,
|
||||
dimensions: getDimensions(DEFAULT_EMBEDDING_MODEL),
|
||||
isConfigured(): boolean {
|
||||
if (config.apiKey) {
|
||||
return true;
|
||||
}
|
||||
if (process.env["OPENAI_API_KEY"]) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
},
|
||||
};
|
||||
|
||||
if (!config.apiKey) {
|
||||
@@ -40,6 +31,7 @@ export async function askOpenAIQuestions({
|
||||
? "Please provide your OpenAI API key (or leave blank to use OPENAI_API_KEY env variable):"
|
||||
: "Please provide your OpenAI API key (leave blank to skip):",
|
||||
validate: (value: string) => {
|
||||
console.log(value);
|
||||
if (askModels && !value) {
|
||||
if (process.env.OPENAI_API_KEY) {
|
||||
return true;
|
||||
@@ -86,6 +78,16 @@ export async function askOpenAIQuestions({
|
||||
return config;
|
||||
}
|
||||
|
||||
export function isOpenAIConfigured(params: ModelConfigParams): boolean {
|
||||
if (params.apiKey) {
|
||||
return true;
|
||||
}
|
||||
if (process.env["OPENAI_API_KEY"]) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
async function getAvailableModelChoices(
|
||||
selectEmbedding: boolean,
|
||||
apiKey?: string,
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
/* Function to conditionally load the global-agent/bootstrap module */
|
||||
export async function initializeGlobalAgent() {
|
||||
if (process.env.GLOBAL_AGENT_HTTP_PROXY) {
|
||||
/* Dynamically import global-agent/bootstrap */
|
||||
await import("global-agent/bootstrap");
|
||||
console.log("Proxy enabled via global-agent.");
|
||||
}
|
||||
}
|
||||
+67
-222
@@ -12,7 +12,6 @@ import {
|
||||
InstallTemplateArgs,
|
||||
ModelConfig,
|
||||
TemplateDataSource,
|
||||
TemplateType,
|
||||
TemplateVectorDB,
|
||||
} from "./types";
|
||||
|
||||
@@ -25,9 +24,8 @@ interface Dependency {
|
||||
const getAdditionalDependencies = (
|
||||
modelConfig: ModelConfig,
|
||||
vectorDb?: TemplateVectorDB,
|
||||
dataSources?: TemplateDataSource[],
|
||||
dataSource?: TemplateDataSource,
|
||||
tools?: Tool[],
|
||||
templateType?: TemplateType,
|
||||
) => {
|
||||
const dependencies: Dependency[] = [];
|
||||
|
||||
@@ -36,109 +34,76 @@ const getAdditionalDependencies = (
|
||||
case "mongo": {
|
||||
dependencies.push({
|
||||
name: "llama-index-vector-stores-mongodb",
|
||||
version: "^0.3.1",
|
||||
version: "^0.1.3",
|
||||
});
|
||||
break;
|
||||
}
|
||||
case "pg": {
|
||||
dependencies.push({
|
||||
name: "llama-index-vector-stores-postgres",
|
||||
version: "^0.2.5",
|
||||
version: "^0.1.1",
|
||||
});
|
||||
break;
|
||||
}
|
||||
case "pinecone": {
|
||||
dependencies.push({
|
||||
name: "llama-index-vector-stores-pinecone",
|
||||
version: "^0.2.1",
|
||||
version: "^0.1.3",
|
||||
});
|
||||
break;
|
||||
}
|
||||
case "milvus": {
|
||||
dependencies.push({
|
||||
name: "llama-index-vector-stores-milvus",
|
||||
version: "^0.2.0",
|
||||
version: "^0.1.6",
|
||||
});
|
||||
dependencies.push({
|
||||
name: "pymilvus",
|
||||
version: "2.4.4",
|
||||
version: "2.3.7",
|
||||
});
|
||||
break;
|
||||
}
|
||||
case "astra": {
|
||||
dependencies.push({
|
||||
name: "llama-index-vector-stores-astra-db",
|
||||
version: "^0.2.0",
|
||||
});
|
||||
break;
|
||||
}
|
||||
case "qdrant": {
|
||||
dependencies.push({
|
||||
name: "llama-index-vector-stores-qdrant",
|
||||
version: "^0.3.0",
|
||||
});
|
||||
break;
|
||||
}
|
||||
case "chroma": {
|
||||
dependencies.push({
|
||||
name: "llama-index-vector-stores-chroma",
|
||||
version: "^0.2.0",
|
||||
});
|
||||
break;
|
||||
}
|
||||
case "weaviate": {
|
||||
dependencies.push({
|
||||
name: "llama-index-vector-stores-weaviate",
|
||||
version: "^1.1.1",
|
||||
version: "^0.1.5",
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Add data source dependencies
|
||||
if (dataSources) {
|
||||
for (const ds of dataSources) {
|
||||
const dsType = ds?.type;
|
||||
switch (dsType) {
|
||||
case "file":
|
||||
dependencies.push({
|
||||
name: "docx2txt",
|
||||
version: "^0.8",
|
||||
});
|
||||
break;
|
||||
case "web":
|
||||
dependencies.push({
|
||||
name: "llama-index-readers-web",
|
||||
version: "^0.2.2",
|
||||
});
|
||||
break;
|
||||
case "db":
|
||||
dependencies.push({
|
||||
name: "llama-index-readers-database",
|
||||
version: "^0.2.0",
|
||||
});
|
||||
dependencies.push({
|
||||
name: "pymysql",
|
||||
version: "^1.1.0",
|
||||
extras: ["rsa"],
|
||||
});
|
||||
dependencies.push({
|
||||
name: "psycopg2",
|
||||
version: "^2.9.9",
|
||||
});
|
||||
break;
|
||||
case "llamacloud":
|
||||
dependencies.push({
|
||||
name: "llama-index-indices-managed-llama-cloud",
|
||||
version: "^0.3.1",
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
const dataSourceType = dataSource?.type;
|
||||
switch (dataSourceType) {
|
||||
case "file":
|
||||
dependencies.push({
|
||||
name: "docx2txt",
|
||||
version: "^0.8",
|
||||
});
|
||||
break;
|
||||
case "web":
|
||||
dependencies.push({
|
||||
name: "llama-index-readers-web",
|
||||
version: "^0.1.6",
|
||||
});
|
||||
break;
|
||||
case "db":
|
||||
dependencies.push({
|
||||
name: "llama-index-readers-database",
|
||||
version: "^0.1.3",
|
||||
});
|
||||
dependencies.push({
|
||||
name: "pymysql",
|
||||
version: "^1.1.0",
|
||||
extras: ["rsa"],
|
||||
});
|
||||
dependencies.push({
|
||||
name: "psycopg2",
|
||||
version: "^2.9.9",
|
||||
});
|
||||
break;
|
||||
}
|
||||
|
||||
// Add tools dependencies
|
||||
console.log("Adding tools dependencies");
|
||||
tools?.forEach((tool) => {
|
||||
tool.dependencies?.forEach((dep) => {
|
||||
dependencies.push(dep);
|
||||
@@ -149,99 +114,17 @@ const getAdditionalDependencies = (
|
||||
case "ollama":
|
||||
dependencies.push({
|
||||
name: "llama-index-llms-ollama",
|
||||
version: "0.3.0",
|
||||
version: "0.1.2",
|
||||
});
|
||||
dependencies.push({
|
||||
name: "llama-index-embeddings-ollama",
|
||||
version: "0.3.0",
|
||||
version: "0.1.2",
|
||||
});
|
||||
break;
|
||||
case "openai":
|
||||
if (templateType !== "multiagent") {
|
||||
dependencies.push({
|
||||
name: "llama-index-llms-openai",
|
||||
version: "^0.2.0",
|
||||
});
|
||||
dependencies.push({
|
||||
name: "llama-index-embeddings-openai",
|
||||
version: "^0.2.3",
|
||||
});
|
||||
dependencies.push({
|
||||
name: "llama-index-agent-openai",
|
||||
version: "^0.3.0",
|
||||
});
|
||||
}
|
||||
break;
|
||||
case "groq":
|
||||
// Fastembed==0.2.0 does not support python3.13 at the moment
|
||||
// Fixed the python version less than 3.13
|
||||
dependencies.push({
|
||||
name: "python",
|
||||
version: "^3.11,<3.13",
|
||||
});
|
||||
dependencies.push({
|
||||
name: "llama-index-llms-groq",
|
||||
version: "0.2.0",
|
||||
});
|
||||
dependencies.push({
|
||||
name: "llama-index-embeddings-fastembed",
|
||||
version: "^0.2.0",
|
||||
});
|
||||
break;
|
||||
case "anthropic":
|
||||
// Fastembed==0.2.0 does not support python3.13 at the moment
|
||||
// Fixed the python version less than 3.13
|
||||
dependencies.push({
|
||||
name: "python",
|
||||
version: "^3.11,<3.13",
|
||||
});
|
||||
dependencies.push({
|
||||
name: "llama-index-llms-anthropic",
|
||||
version: "0.3.0",
|
||||
});
|
||||
dependencies.push({
|
||||
name: "llama-index-embeddings-fastembed",
|
||||
version: "^0.2.0",
|
||||
});
|
||||
break;
|
||||
case "gemini":
|
||||
dependencies.push({
|
||||
name: "llama-index-llms-gemini",
|
||||
version: "0.3.4",
|
||||
});
|
||||
dependencies.push({
|
||||
name: "llama-index-embeddings-gemini",
|
||||
version: "^0.2.0",
|
||||
});
|
||||
break;
|
||||
case "mistral":
|
||||
dependencies.push({
|
||||
name: "llama-index-llms-mistralai",
|
||||
version: "0.2.1",
|
||||
});
|
||||
dependencies.push({
|
||||
name: "llama-index-embeddings-mistralai",
|
||||
version: "0.2.0",
|
||||
});
|
||||
break;
|
||||
case "azure-openai":
|
||||
dependencies.push({
|
||||
name: "llama-index-llms-azure-openai",
|
||||
version: "0.2.0",
|
||||
});
|
||||
dependencies.push({
|
||||
name: "llama-index-embeddings-azure-openai",
|
||||
version: "0.2.4",
|
||||
});
|
||||
break;
|
||||
case "t-systems":
|
||||
dependencies.push({
|
||||
name: "llama-index-agent-openai",
|
||||
version: "0.3.0",
|
||||
});
|
||||
dependencies.push({
|
||||
name: "llama-index-llms-openai-like",
|
||||
version: "0.2.0",
|
||||
version: "0.2.2",
|
||||
});
|
||||
break;
|
||||
}
|
||||
@@ -251,7 +134,7 @@ const getAdditionalDependencies = (
|
||||
|
||||
const mergePoetryDependencies = (
|
||||
dependencies: Dependency[],
|
||||
existingDependencies: Record<string, Omit<Dependency, "name"> | string>,
|
||||
existingDependencies: Record<string, Omit<Dependency, "name">>,
|
||||
) => {
|
||||
for (const dependency of dependencies) {
|
||||
let value = existingDependencies[dependency.name] ?? {};
|
||||
@@ -270,13 +153,7 @@ const mergePoetryDependencies = (
|
||||
);
|
||||
}
|
||||
|
||||
// Serialize separately only if extras are provided
|
||||
if (value.extras && value.extras.length > 0) {
|
||||
existingDependencies[dependency.name] = value;
|
||||
} else {
|
||||
// Otherwise, serialize just the version string
|
||||
existingDependencies[dependency.name] = value.version;
|
||||
}
|
||||
existingDependencies[dependency.name] = value;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -380,75 +257,43 @@ export const installPythonTemplate = async ({
|
||||
cwd: path.join(compPath, "vectordbs", "python", vectorDb ?? "none"),
|
||||
});
|
||||
|
||||
if (vectorDb !== "llamacloud") {
|
||||
// Copy all loaders to enginePath
|
||||
// Not needed for LlamaCloud as it has its own loaders
|
||||
const loaderPath = path.join(enginePath, "loaders");
|
||||
await copy("**", loaderPath, {
|
||||
parents: true,
|
||||
cwd: path.join(compPath, "loaders", "python"),
|
||||
});
|
||||
}
|
||||
|
||||
// Copy settings.py to app
|
||||
await copy("**", path.join(root, "app"), {
|
||||
cwd: path.join(compPath, "settings", "python"),
|
||||
// Copy all loaders to enginePath
|
||||
const loaderPath = path.join(enginePath, "loaders");
|
||||
await copy("**", loaderPath, {
|
||||
parents: true,
|
||||
cwd: path.join(compPath, "loaders", "python"),
|
||||
});
|
||||
|
||||
// Copy services
|
||||
if (template == "streaming" || template == "multiagent") {
|
||||
await copy("**", path.join(root, "app", "api", "services"), {
|
||||
cwd: path.join(compPath, "services", "python"),
|
||||
});
|
||||
// Select and copy engine code based on data sources and tools
|
||||
let engine;
|
||||
tools = tools ?? [];
|
||||
if (dataSources.length > 0 && tools.length === 0) {
|
||||
console.log("\nNo tools selected - use optimized context chat engine\n");
|
||||
engine = "chat";
|
||||
} else {
|
||||
engine = "agent";
|
||||
}
|
||||
await copy("**", enginePath, {
|
||||
parents: true,
|
||||
cwd: path.join(compPath, "engines", "python", engine),
|
||||
});
|
||||
|
||||
if (template === "streaming") {
|
||||
// For the streaming template only:
|
||||
// Select and copy engine code based on data sources and tools
|
||||
let engine;
|
||||
if (dataSources.length > 0 && (!tools || tools.length === 0)) {
|
||||
console.log("\nNo tools selected - use optimized context chat engine\n");
|
||||
engine = "chat";
|
||||
} else {
|
||||
engine = "agent";
|
||||
}
|
||||
await copy("**", enginePath, {
|
||||
parents: true,
|
||||
cwd: path.join(compPath, "engines", "python", engine),
|
||||
const addOnDependencies = dataSources
|
||||
.map((ds) => getAdditionalDependencies(modelConfig, vectorDb, ds, tools))
|
||||
.flat();
|
||||
|
||||
if (observability === "opentelemetry") {
|
||||
addOnDependencies.push({
|
||||
name: "traceloop-sdk",
|
||||
version: "^0.15.11",
|
||||
});
|
||||
}
|
||||
|
||||
console.log("Adding additional dependencies");
|
||||
|
||||
const addOnDependencies = getAdditionalDependencies(
|
||||
modelConfig,
|
||||
vectorDb,
|
||||
dataSources,
|
||||
tools,
|
||||
template,
|
||||
);
|
||||
|
||||
if (observability && observability !== "none") {
|
||||
if (observability === "traceloop") {
|
||||
addOnDependencies.push({
|
||||
name: "traceloop-sdk",
|
||||
version: "^0.15.11",
|
||||
});
|
||||
}
|
||||
|
||||
if (observability === "llamatrace") {
|
||||
addOnDependencies.push({
|
||||
name: "llama-index-callbacks-arize-phoenix",
|
||||
version: "^0.1.6",
|
||||
});
|
||||
}
|
||||
|
||||
const templateObservabilityPath = path.join(
|
||||
templatesDir,
|
||||
"components",
|
||||
"observability",
|
||||
"python",
|
||||
observability,
|
||||
"opentelemetry",
|
||||
);
|
||||
await copy("**", path.join(root, "app"), {
|
||||
cwd: templateObservabilityPath,
|
||||
|
||||
+48
-59
@@ -23,77 +23,66 @@ const createProcess = (
|
||||
});
|
||||
};
|
||||
|
||||
export function runReflexApp(
|
||||
appPath: string,
|
||||
frontendPort?: number,
|
||||
backendPort?: number,
|
||||
) {
|
||||
const commandArgs = ["run", "reflex", "run"];
|
||||
if (frontendPort) {
|
||||
commandArgs.push("--frontend-port", frontendPort.toString());
|
||||
}
|
||||
if (backendPort) {
|
||||
commandArgs.push("--backend-port", backendPort.toString());
|
||||
}
|
||||
return createProcess("poetry", commandArgs, {
|
||||
stdio: "inherit",
|
||||
cwd: appPath,
|
||||
});
|
||||
}
|
||||
|
||||
export function runFastAPIApp(appPath: string, port: number) {
|
||||
const commandArgs = ["run", "uvicorn", "main:app", "--port=" + port];
|
||||
|
||||
return createProcess("poetry", commandArgs, {
|
||||
stdio: "inherit",
|
||||
cwd: appPath,
|
||||
});
|
||||
}
|
||||
|
||||
export function runTSApp(appPath: string, port: number) {
|
||||
return createProcess("npm", ["run", "dev"], {
|
||||
stdio: "inherit",
|
||||
cwd: appPath,
|
||||
env: { ...process.env, PORT: `${port}` },
|
||||
});
|
||||
}
|
||||
|
||||
// eslint-disable-next-line max-params
|
||||
export async function runApp(
|
||||
appPath: string,
|
||||
template: string,
|
||||
frontend: boolean,
|
||||
framework: TemplateFramework,
|
||||
port?: number,
|
||||
externalPort?: number,
|
||||
): Promise<any> {
|
||||
const processes: ChildProcess[] = [];
|
||||
let backendAppProcess: ChildProcess;
|
||||
let frontendAppProcess: ChildProcess | undefined;
|
||||
const frontendPort = port || 3000;
|
||||
let backendPort = externalPort || 8000;
|
||||
|
||||
// Callback to kill all sub processes if the main process is killed
|
||||
// Callback to kill app processes
|
||||
process.on("exit", () => {
|
||||
console.log("Killing app processes...");
|
||||
processes.forEach((p) => p.kill());
|
||||
backendAppProcess.kill();
|
||||
frontendAppProcess?.kill();
|
||||
});
|
||||
|
||||
// Default sub app paths
|
||||
const backendPath = path.join(appPath, "backend");
|
||||
const frontendPath = path.join(appPath, "frontend");
|
||||
|
||||
if (template === "extractor") {
|
||||
processes.push(runReflexApp(appPath, port, externalPort));
|
||||
}
|
||||
if (template === "streaming" || template === "multiagent") {
|
||||
if (framework === "fastapi" || framework === "express") {
|
||||
const backendRunner = framework === "fastapi" ? runFastAPIApp : runTSApp;
|
||||
if (frontend) {
|
||||
processes.push(backendRunner(backendPath, externalPort || 8000));
|
||||
processes.push(runTSApp(frontendPath, port || 3000));
|
||||
} else {
|
||||
processes.push(backendRunner(appPath, externalPort || 8000));
|
||||
}
|
||||
} else if (framework === "nextjs") {
|
||||
processes.push(runTSApp(appPath, port || 3000));
|
||||
}
|
||||
let backendCommand = "";
|
||||
let backendArgs: string[];
|
||||
if (framework === "fastapi") {
|
||||
backendCommand = "poetry";
|
||||
backendArgs = [
|
||||
"run",
|
||||
"uvicorn",
|
||||
"main:app",
|
||||
"--host=0.0.0.0",
|
||||
"--port=" + backendPort,
|
||||
];
|
||||
} else if (framework === "nextjs") {
|
||||
backendCommand = "npm";
|
||||
backendArgs = ["run", "dev"];
|
||||
backendPort = frontendPort;
|
||||
} else {
|
||||
backendCommand = "npm";
|
||||
backendArgs = ["run", "dev"];
|
||||
}
|
||||
|
||||
return Promise.all(processes);
|
||||
if (frontend) {
|
||||
return new Promise((resolve, reject) => {
|
||||
backendAppProcess = createProcess(backendCommand, backendArgs, {
|
||||
stdio: "inherit",
|
||||
cwd: path.join(appPath, "backend"),
|
||||
env: { ...process.env, PORT: `${backendPort}` },
|
||||
});
|
||||
frontendAppProcess = createProcess("npm", ["run", "dev"], {
|
||||
stdio: "inherit",
|
||||
cwd: path.join(appPath, "frontend"),
|
||||
env: { ...process.env, PORT: `${frontendPort}` },
|
||||
});
|
||||
});
|
||||
} else {
|
||||
return new Promise((resolve, reject) => {
|
||||
backendAppProcess = createProcess(backendCommand, backendArgs, {
|
||||
stdio: "inherit",
|
||||
cwd: path.join(appPath),
|
||||
env: { ...process.env, PORT: `${backendPort}` },
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
+7
-192
@@ -2,25 +2,15 @@ import fs from "fs/promises";
|
||||
import path from "path";
|
||||
import { red } from "picocolors";
|
||||
import yaml from "yaml";
|
||||
import { EnvVar } from "./env-variables";
|
||||
import { makeDir } from "./make-dir";
|
||||
import { TemplateFramework } from "./types";
|
||||
|
||||
export const TOOL_SYSTEM_PROMPT_ENV_VAR = "TOOL_SYSTEM_PROMPT";
|
||||
|
||||
export enum ToolType {
|
||||
LLAMAHUB = "llamahub",
|
||||
LOCAL = "local",
|
||||
}
|
||||
|
||||
export type Tool = {
|
||||
display: string;
|
||||
name: string;
|
||||
config?: Record<string, any>;
|
||||
dependencies?: ToolDependencies[];
|
||||
supportedFrameworks?: Array<TemplateFramework>;
|
||||
type: ToolType;
|
||||
envVars?: EnvVar[];
|
||||
};
|
||||
|
||||
export type ToolDependencies = {
|
||||
@@ -30,7 +20,7 @@ export type ToolDependencies = {
|
||||
|
||||
export const supportedTools: Tool[] = [
|
||||
{
|
||||
display: "Google Search",
|
||||
display: "Google Search (configuration required after installation)",
|
||||
name: "google.GoogleSearchToolSpec",
|
||||
config: {
|
||||
engine:
|
||||
@@ -41,41 +31,10 @@ export const supportedTools: Tool[] = [
|
||||
dependencies: [
|
||||
{
|
||||
name: "llama-index-tools-google",
|
||||
version: "^0.2.0",
|
||||
version: "0.1.2",
|
||||
},
|
||||
],
|
||||
supportedFrameworks: ["fastapi"],
|
||||
type: ToolType.LLAMAHUB,
|
||||
envVars: [
|
||||
{
|
||||
name: TOOL_SYSTEM_PROMPT_ENV_VAR,
|
||||
description: "System prompt for google search tool.",
|
||||
value: `You are a Google search agent. You help users to get information from Google search.`,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
// For python app, we will use a local DuckDuckGo search tool (instead of DuckDuckGo search tool in LlamaHub)
|
||||
// to get the same results as the TS app.
|
||||
display: "DuckDuckGo Search",
|
||||
name: "duckduckgo",
|
||||
dependencies: [
|
||||
{
|
||||
name: "duckduckgo-search",
|
||||
version: "6.1.7",
|
||||
},
|
||||
],
|
||||
supportedFrameworks: ["fastapi", "nextjs", "express"],
|
||||
type: ToolType.LOCAL,
|
||||
envVars: [
|
||||
{
|
||||
name: TOOL_SYSTEM_PROMPT_ENV_VAR,
|
||||
description: "System prompt for DuckDuckGo search tool.",
|
||||
value: `You are a DuckDuckGo search agent.
|
||||
You can use the duckduckgo search tool to get information from the web to answer user questions.
|
||||
For better results, you can specify the region parameter to get results from a specific region but it's optional.`,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
display: "Wikipedia",
|
||||
@@ -83,139 +42,10 @@ For better results, you can specify the region parameter to get results from a s
|
||||
dependencies: [
|
||||
{
|
||||
name: "llama-index-tools-wikipedia",
|
||||
version: "^0.2.0",
|
||||
version: "0.1.2",
|
||||
},
|
||||
],
|
||||
supportedFrameworks: ["fastapi", "express", "nextjs"],
|
||||
type: ToolType.LLAMAHUB,
|
||||
envVars: [
|
||||
{
|
||||
name: TOOL_SYSTEM_PROMPT_ENV_VAR,
|
||||
description: "System prompt for wiki tool.",
|
||||
value: `You are a Wikipedia agent. You help users to get information from Wikipedia.`,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
display: "Weather",
|
||||
name: "weather",
|
||||
dependencies: [],
|
||||
supportedFrameworks: ["fastapi", "express", "nextjs"],
|
||||
type: ToolType.LOCAL,
|
||||
envVars: [
|
||||
{
|
||||
name: TOOL_SYSTEM_PROMPT_ENV_VAR,
|
||||
description: "System prompt for weather tool.",
|
||||
value: `You are a weather forecast agent. You help users to get the weather forecast for a given location.`,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
display: "Code Interpreter",
|
||||
name: "interpreter",
|
||||
dependencies: [
|
||||
{
|
||||
name: "e2b_code_interpreter",
|
||||
version: "0.0.7",
|
||||
},
|
||||
],
|
||||
supportedFrameworks: ["fastapi", "express", "nextjs"],
|
||||
type: ToolType.LOCAL,
|
||||
envVars: [
|
||||
{
|
||||
name: "E2B_API_KEY",
|
||||
description:
|
||||
"E2B_API_KEY key is required to run code interpreter tool. Get it here: https://e2b.dev/docs/getting-started/api-key",
|
||||
},
|
||||
{
|
||||
name: TOOL_SYSTEM_PROMPT_ENV_VAR,
|
||||
description: "System prompt for code interpreter tool.",
|
||||
value: `-You are a Python interpreter that can run any python code in a secure environment.
|
||||
- The python code runs in a Jupyter notebook. Every time you call the 'interpreter' tool, the python code is executed in a separate cell.
|
||||
- You are given tasks to complete and you run python code to solve them.
|
||||
- It's okay to make multiple calls to interpreter tool. If you get an error or the result is not what you expected, you can call the tool again. Don't give up too soon!
|
||||
- Plot visualizations using matplotlib or any other visualization library directly in the notebook.
|
||||
- You can install any pip package (if it exists) by running a cell with pip install.`,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
display: "OpenAPI action",
|
||||
name: "openapi_action.OpenAPIActionToolSpec",
|
||||
dependencies: [
|
||||
{
|
||||
name: "llama-index-tools-openapi",
|
||||
version: "0.2.0",
|
||||
},
|
||||
{
|
||||
name: "jsonschema",
|
||||
version: "^4.22.0",
|
||||
},
|
||||
{
|
||||
name: "llama-index-tools-requests",
|
||||
version: "0.2.0",
|
||||
},
|
||||
],
|
||||
config: {
|
||||
openapi_uri: "The URL or file path of the OpenAPI schema",
|
||||
},
|
||||
supportedFrameworks: ["fastapi", "express", "nextjs"],
|
||||
type: ToolType.LOCAL,
|
||||
envVars: [
|
||||
{
|
||||
name: TOOL_SYSTEM_PROMPT_ENV_VAR,
|
||||
description: "System prompt for openapi action tool.",
|
||||
value:
|
||||
"You are an OpenAPI action agent. You help users to make requests to the provided OpenAPI schema.",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
display: "Image Generator",
|
||||
name: "img_gen",
|
||||
supportedFrameworks: ["fastapi", "express", "nextjs"],
|
||||
type: ToolType.LOCAL,
|
||||
envVars: [
|
||||
{
|
||||
name: "STABILITY_API_KEY",
|
||||
description:
|
||||
"STABILITY_API_KEY key is required to run image generator. Get it here: https://platform.stability.ai/account/keys",
|
||||
},
|
||||
{
|
||||
name: TOOL_SYSTEM_PROMPT_ENV_VAR,
|
||||
description: "System prompt for image generator tool.",
|
||||
value: `You are an image generator agent. You help users to generate images using the Stability API.`,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
display: "Azure Code Interpreter",
|
||||
name: "azure_code_interpreter.AzureCodeInterpreterToolSpec",
|
||||
supportedFrameworks: ["fastapi", "nextjs", "express"],
|
||||
type: ToolType.LLAMAHUB,
|
||||
dependencies: [
|
||||
{
|
||||
name: "llama-index-tools-azure-code-interpreter",
|
||||
version: "0.2.0",
|
||||
},
|
||||
],
|
||||
envVars: [
|
||||
{
|
||||
name: "AZURE_POOL_MANAGEMENT_ENDPOINT",
|
||||
description:
|
||||
"Please follow this guideline to create and get the pool management endpoint: https://learn.microsoft.com/azure/container-apps/sessions?tabs=azure-cli",
|
||||
},
|
||||
{
|
||||
name: TOOL_SYSTEM_PROMPT_ENV_VAR,
|
||||
description: "System prompt for Azure code interpreter tool.",
|
||||
value: `-You are a Python interpreter that can run any python code in a secure environment.
|
||||
- The python code runs in a Jupyter notebook. Every time you call the 'interpreter' tool, the python code is executed in a separate cell.
|
||||
- You are given tasks to complete and you run python code to solve them.
|
||||
- It's okay to make multiple calls to interpreter tool. If you get an error or the result is not what you expected, you can call the tool again. Don't give up too soon!
|
||||
- Plot visualizations using matplotlib or any other visualization library directly in the notebook.
|
||||
- You can install any pip package (if it exists) by running a cell with pip install.`,
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
@@ -242,15 +72,9 @@ export const getTools = (toolsName: string[]): Tool[] => {
|
||||
return tools;
|
||||
};
|
||||
|
||||
export const toolRequiresConfig = (tool: Tool): boolean => {
|
||||
const hasConfig = Object.keys(tool.config || {}).length > 0;
|
||||
const hasEmptyEnvVar = tool.envVars?.some((envVar) => !envVar.value) ?? false;
|
||||
return hasConfig || hasEmptyEnvVar;
|
||||
};
|
||||
|
||||
export const toolsRequireConfig = (tools?: Tool[]): boolean => {
|
||||
if (tools) {
|
||||
return tools?.some(toolRequiresConfig);
|
||||
return tools?.some((tool) => Object.keys(tool.config || {}).length > 0);
|
||||
}
|
||||
return false;
|
||||
};
|
||||
@@ -265,19 +89,10 @@ export const writeToolsConfig = async (
|
||||
tools: Tool[] = [],
|
||||
type: ConfigFileType = ConfigFileType.YAML,
|
||||
) => {
|
||||
const configContent: {
|
||||
[key in ToolType]: Record<string, any>;
|
||||
} = {
|
||||
local: {},
|
||||
llamahub: {},
|
||||
};
|
||||
if (tools.length === 0) return; // no tools selected, no config need
|
||||
const configContent: Record<string, any> = {};
|
||||
tools.forEach((tool) => {
|
||||
if (tool.type === ToolType.LLAMAHUB) {
|
||||
configContent.llamahub[tool.name] = tool.config ?? {};
|
||||
}
|
||||
if (tool.type === ToolType.LOCAL) {
|
||||
configContent.local[tool.name] = tool.config ?? {};
|
||||
}
|
||||
configContent[tool.name] = tool.config ?? {};
|
||||
});
|
||||
const configPath = path.join(root, "config");
|
||||
await makeDir(configPath);
|
||||
|
||||
+5
-22
@@ -1,29 +1,15 @@
|
||||
import { PackageManager } from "../helpers/get-pkg-manager";
|
||||
import { Tool } from "./tools";
|
||||
|
||||
export type ModelProvider =
|
||||
| "openai"
|
||||
| "groq"
|
||||
| "ollama"
|
||||
| "anthropic"
|
||||
| "gemini"
|
||||
| "mistral"
|
||||
| "azure-openai"
|
||||
| "t-systems";
|
||||
export type ModelProvider = "openai" | "ollama";
|
||||
export type ModelConfig = {
|
||||
provider: ModelProvider;
|
||||
apiKey?: string;
|
||||
model: string;
|
||||
embeddingModel: string;
|
||||
dimensions: number;
|
||||
isConfigured(): boolean;
|
||||
};
|
||||
export type TemplateType =
|
||||
| "extractor"
|
||||
| "streaming"
|
||||
| "community"
|
||||
| "llamapack"
|
||||
| "multiagent";
|
||||
export type TemplateType = "streaming" | "community" | "llamapack";
|
||||
export type TemplateFramework = "nextjs" | "express" | "fastapi";
|
||||
export type TemplateUI = "html" | "shadcn";
|
||||
export type TemplateVectorDB =
|
||||
@@ -33,10 +19,7 @@ export type TemplateVectorDB =
|
||||
| "pinecone"
|
||||
| "milvus"
|
||||
| "astra"
|
||||
| "qdrant"
|
||||
| "chroma"
|
||||
| "llamacloud"
|
||||
| "weaviate";
|
||||
| "qdrant";
|
||||
export type TemplatePostInstallAction =
|
||||
| "none"
|
||||
| "VSCode"
|
||||
@@ -46,8 +29,8 @@ export type TemplateDataSource = {
|
||||
type: TemplateDataSourceType;
|
||||
config: TemplateDataSourceConfig;
|
||||
};
|
||||
export type TemplateDataSourceType = "file" | "web" | "db" | "llamacloud";
|
||||
export type TemplateObservability = "none" | "traceloop" | "llamatrace";
|
||||
export type TemplateDataSourceType = "file" | "web" | "db";
|
||||
export type TemplateObservability = "none" | "opentelemetry";
|
||||
// Config for both file and folder
|
||||
export type FileSourceConfig = {
|
||||
path: string;
|
||||
|
||||
+6
-54
@@ -1,7 +1,7 @@
|
||||
import fs from "fs/promises";
|
||||
import os from "os";
|
||||
import path from "path";
|
||||
import { bold, cyan, yellow } from "picocolors";
|
||||
import { bold, cyan } from "picocolors";
|
||||
import { assetRelocator, copy } from "../helpers/copy";
|
||||
import { callPackageManager } from "../helpers/install";
|
||||
import { templatesDir } from "./dir";
|
||||
@@ -33,8 +33,7 @@ export const installTSTemplate = async ({
|
||||
* Copy the template files to the target directory.
|
||||
*/
|
||||
console.log("\nInitializing project with template:", template, "\n");
|
||||
const type = template === "multiagent" ? "streaming" : template; // use nextjs streaming template for multiagent
|
||||
const templatePath = path.join(templatesDir, "types", type, framework);
|
||||
const templatePath = path.join(templatesDir, "types", template, framework);
|
||||
const copySource = ["**"];
|
||||
|
||||
await copy(copySource, root, {
|
||||
@@ -71,7 +70,7 @@ export const installTSTemplate = async ({
|
||||
);
|
||||
|
||||
const webpackConfigOtelFile = path.join(root, "webpack.config.o11y.mjs");
|
||||
if (observability === "traceloop") {
|
||||
if (observability === "opentelemetry") {
|
||||
const webpackConfigDefaultFile = path.join(root, "webpack.config.mjs");
|
||||
await fs.rm(webpackConfigDefaultFile);
|
||||
await fs.rename(webpackConfigOtelFile, webpackConfigDefaultFile);
|
||||
@@ -105,20 +104,8 @@ export const installTSTemplate = async ({
|
||||
: path.join("src", "controllers");
|
||||
const enginePath = path.join(root, relativeEngineDestPath, "engine");
|
||||
|
||||
// copy llamaindex code for TS templates
|
||||
await copy("**", path.join(root, relativeEngineDestPath, "llamaindex"), {
|
||||
parents: true,
|
||||
cwd: path.join(compPath, "llamaindex", "typescript"),
|
||||
});
|
||||
|
||||
// copy vector db component
|
||||
if (vectorDb === "llamacloud") {
|
||||
console.log(
|
||||
`\nUsing managed index from LlamaCloud. Ensure the ${yellow("LLAMA_CLOUD_* environment variables are set correctly.")}`,
|
||||
);
|
||||
} else {
|
||||
console.log("\nUsing vector DB:", vectorDb ?? "none");
|
||||
}
|
||||
console.log("\nUsing vector DB:", vectorDb, "\n");
|
||||
await copy("**", enginePath, {
|
||||
parents: true,
|
||||
cwd: path.join(compPath, "vectordbs", "typescript", vectorDb ?? "none"),
|
||||
@@ -180,7 +167,6 @@ export const installTSTemplate = async ({
|
||||
framework,
|
||||
ui,
|
||||
observability,
|
||||
vectorDb,
|
||||
});
|
||||
|
||||
if (postInstallAction === "runApp" || postInstallAction === "dependencies") {
|
||||
@@ -201,16 +187,9 @@ async function updatePackageJson({
|
||||
framework,
|
||||
ui,
|
||||
observability,
|
||||
vectorDb,
|
||||
}: Pick<
|
||||
InstallTemplateArgs,
|
||||
| "root"
|
||||
| "appName"
|
||||
| "dataSources"
|
||||
| "framework"
|
||||
| "ui"
|
||||
| "observability"
|
||||
| "vectorDb"
|
||||
"root" | "appName" | "dataSources" | "framework" | "ui" | "observability"
|
||||
> & {
|
||||
relativeEngineDestPath: string;
|
||||
}): Promise<any> {
|
||||
@@ -257,34 +236,7 @@ async function updatePackageJson({
|
||||
};
|
||||
}
|
||||
|
||||
if (vectorDb === "pg") {
|
||||
packageJson.dependencies = {
|
||||
...packageJson.dependencies,
|
||||
pg: "^8.12.0",
|
||||
};
|
||||
}
|
||||
|
||||
if (vectorDb === "qdrant") {
|
||||
packageJson.dependencies = {
|
||||
...packageJson.dependencies,
|
||||
"@qdrant/js-client-rest": "^1.11.0",
|
||||
};
|
||||
}
|
||||
if (vectorDb === "mongo") {
|
||||
packageJson.dependencies = {
|
||||
...packageJson.dependencies,
|
||||
mongodb: "^6.7.0",
|
||||
};
|
||||
}
|
||||
|
||||
if (vectorDb === "milvus") {
|
||||
packageJson.dependencies = {
|
||||
...packageJson.dependencies,
|
||||
"@zilliz/milvus2-sdk-node": "^2.4.6",
|
||||
};
|
||||
}
|
||||
|
||||
if (observability === "traceloop") {
|
||||
if (observability === "opentelemetry") {
|
||||
packageJson.dependencies = {
|
||||
...packageJson.dependencies,
|
||||
"@traceloop/node-server-sdk": "^0.5.19",
|
||||
|
||||
@@ -9,19 +9,15 @@ import prompts from "prompts";
|
||||
import terminalLink from "terminal-link";
|
||||
import checkForUpdate from "update-check";
|
||||
import { createApp } from "./create-app";
|
||||
import { EXAMPLE_FILE, getDataSources } from "./helpers/datasources";
|
||||
import { getDataSources } from "./helpers/datasources";
|
||||
import { getPkgManager } from "./helpers/get-pkg-manager";
|
||||
import { isFolderEmpty } from "./helpers/is-folder-empty";
|
||||
import { initializeGlobalAgent } from "./helpers/proxy";
|
||||
import { runApp } from "./helpers/run-app";
|
||||
import { getTools } from "./helpers/tools";
|
||||
import { validateNpmName } from "./helpers/validate-pkg";
|
||||
import packageJson from "./package.json";
|
||||
import { QuestionArgs, askQuestions, onPromptState } from "./questions";
|
||||
|
||||
// Run the initialization function
|
||||
initializeGlobalAgent();
|
||||
|
||||
let projectPath: string = "";
|
||||
|
||||
const handleSigTerm = () => process.exit(0);
|
||||
@@ -90,20 +86,6 @@ const program = new Commander.Command(packageJson.name)
|
||||
`
|
||||
|
||||
Select to use an example PDF as data source.
|
||||
`,
|
||||
)
|
||||
.option(
|
||||
"--web-source <url>",
|
||||
`
|
||||
|
||||
Specify a website URL to use as a data source.
|
||||
`,
|
||||
)
|
||||
.option(
|
||||
"--db-source <connection-string>",
|
||||
`
|
||||
|
||||
Specify a database connection string to use as a data source.
|
||||
`,
|
||||
)
|
||||
.option(
|
||||
@@ -187,14 +169,7 @@ const program = new Commander.Command(packageJson.name)
|
||||
"--ask-models",
|
||||
`
|
||||
|
||||
Allow interactive selection of LLM and embedding models of different model providers.
|
||||
`,
|
||||
)
|
||||
.option(
|
||||
"--ask-examples",
|
||||
`
|
||||
|
||||
Allow interactive selection of community templates and LlamaPacks.
|
||||
Select LLM and embedding models.
|
||||
`,
|
||||
)
|
||||
.allowUnknownOption()
|
||||
@@ -209,47 +184,14 @@ if (process.argv.includes("--tools")) {
|
||||
program.tools = getTools(program.tools.split(","));
|
||||
}
|
||||
}
|
||||
if (
|
||||
process.argv.includes("--no-llama-parse") ||
|
||||
program.template === "extractor"
|
||||
) {
|
||||
if (process.argv.includes("--no-llama-parse")) {
|
||||
program.useLlamaParse = false;
|
||||
}
|
||||
program.askModels = process.argv.includes("--ask-models");
|
||||
program.askExamples = process.argv.includes("--ask-examples");
|
||||
if (process.argv.includes("--no-files")) {
|
||||
program.dataSources = [];
|
||||
} else if (process.argv.includes("--example-file")) {
|
||||
} else {
|
||||
program.dataSources = getDataSources(program.files, program.exampleFile);
|
||||
} else if (process.argv.includes("--llamacloud")) {
|
||||
program.dataSources = [
|
||||
{
|
||||
type: "llamacloud",
|
||||
config: {},
|
||||
},
|
||||
EXAMPLE_FILE,
|
||||
];
|
||||
} else if (process.argv.includes("--web-source")) {
|
||||
program.dataSources = [
|
||||
{
|
||||
type: "web",
|
||||
config: {
|
||||
baseUrl: program.webSource,
|
||||
prefix: program.webSource,
|
||||
depth: 1,
|
||||
},
|
||||
},
|
||||
];
|
||||
} else if (process.argv.includes("--db-source")) {
|
||||
program.dataSources = [
|
||||
{
|
||||
type: "db",
|
||||
config: {
|
||||
uri: program.dbSource,
|
||||
queries: program.dbQuery || "SELECT * FROM mytable",
|
||||
},
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
const packageManager = !!program.useNpm
|
||||
@@ -387,7 +329,6 @@ Please check ${cyan(
|
||||
console.log(`Running app in ${root}...`);
|
||||
await runApp(
|
||||
root,
|
||||
program.template,
|
||||
program.frontend,
|
||||
program.framework,
|
||||
program.port,
|
||||
|
||||
+2
-3
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "create-llama",
|
||||
"version": "0.2.10",
|
||||
"version": "0.1.0",
|
||||
"description": "Create LlamaIndex-powered apps with one command",
|
||||
"keywords": [
|
||||
"rag",
|
||||
@@ -9,7 +9,7 @@
|
||||
],
|
||||
"repository": {
|
||||
"type": "git",
|
||||
"url": "https://github.com/run-llama/create-llama",
|
||||
"url": "https://github.com/run-llama/LlamaIndexTS",
|
||||
"directory": "packages/create-llama"
|
||||
},
|
||||
"license": "MIT",
|
||||
@@ -52,7 +52,6 @@
|
||||
"cross-spawn": "7.0.3",
|
||||
"fast-glob": "3.3.1",
|
||||
"fs-extra": "11.2.0",
|
||||
"global-agent": "^3.0.0",
|
||||
"got": "10.7.0",
|
||||
"ollama": "^0.5.0",
|
||||
"ora": "^8.0.1",
|
||||
|
||||
Generated
+147
-267
File diff suppressed because it is too large
Load Diff
+68
-162
@@ -9,26 +9,20 @@ import {
|
||||
TemplateDataSource,
|
||||
TemplateDataSourceType,
|
||||
TemplateFramework,
|
||||
TemplateType,
|
||||
} from "./helpers";
|
||||
import { COMMUNITY_OWNER, COMMUNITY_REPO } from "./helpers/constant";
|
||||
import { EXAMPLE_FILE } from "./helpers/datasources";
|
||||
import { templatesDir } from "./helpers/dir";
|
||||
import { getAvailableLlamapackOptions } from "./helpers/llama-pack";
|
||||
import { askModelConfig } from "./helpers/providers";
|
||||
import { askModelConfig, isModelConfigured } from "./helpers/providers";
|
||||
import { getProjectOptions } from "./helpers/repo";
|
||||
import {
|
||||
supportedTools,
|
||||
toolRequiresConfig,
|
||||
toolsRequireConfig,
|
||||
} from "./helpers/tools";
|
||||
import { supportedTools, toolsRequireConfig } from "./helpers/tools";
|
||||
|
||||
export type QuestionArgs = Omit<
|
||||
InstallAppArgs,
|
||||
"appPath" | "packageManager"
|
||||
> & {
|
||||
askModels?: boolean;
|
||||
askExamples?: boolean;
|
||||
};
|
||||
const supportedContextFileTypes = [
|
||||
".pdf",
|
||||
@@ -103,8 +97,6 @@ const getVectorDbChoices = (framework: TemplateFramework) => {
|
||||
{ title: "Milvus", value: "milvus" },
|
||||
{ title: "Astra", value: "astra" },
|
||||
{ title: "Qdrant", value: "qdrant" },
|
||||
{ title: "ChromaDB", value: "chroma" },
|
||||
{ title: "Weaviate", value: "weaviate" },
|
||||
];
|
||||
|
||||
const vectordbLang = framework === "fastapi" ? "python" : "typescript";
|
||||
@@ -125,15 +117,8 @@ const getVectorDbChoices = (framework: TemplateFramework) => {
|
||||
export const getDataSourceChoices = (
|
||||
framework: TemplateFramework,
|
||||
selectedDataSource: TemplateDataSource[],
|
||||
template?: TemplateType,
|
||||
) => {
|
||||
// If LlamaCloud is already selected, don't show any other options
|
||||
if (selectedDataSource.find((s) => s.type === "llamacloud")) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const choices = [];
|
||||
|
||||
if (selectedDataSource.length > 0) {
|
||||
choices.push({
|
||||
title: "No",
|
||||
@@ -141,39 +126,31 @@ export const getDataSourceChoices = (
|
||||
});
|
||||
}
|
||||
if (selectedDataSource === undefined || selectedDataSource.length === 0) {
|
||||
if (template !== "multiagent") {
|
||||
choices.push({
|
||||
title: "No datasource",
|
||||
value: "none",
|
||||
});
|
||||
}
|
||||
choices.push({
|
||||
title:
|
||||
process.platform !== "linux"
|
||||
? "Use an example PDF"
|
||||
: "Use an example PDF (you can add your own data files later)",
|
||||
title: "No data, just a simple chat or agent",
|
||||
value: "none",
|
||||
});
|
||||
choices.push({
|
||||
title: "Use an example PDF",
|
||||
value: "exampleFile",
|
||||
});
|
||||
}
|
||||
|
||||
// Linux has many distros so we won't support file/folder picker for now
|
||||
if (process.platform !== "linux") {
|
||||
choices.push(
|
||||
{
|
||||
title: `Use local files (${supportedContextFileTypes.join(", ")})`,
|
||||
value: "file",
|
||||
},
|
||||
{
|
||||
title:
|
||||
process.platform === "win32"
|
||||
? "Use a local folder"
|
||||
: "Use local folders",
|
||||
value: "folder",
|
||||
},
|
||||
);
|
||||
}
|
||||
choices.push(
|
||||
{
|
||||
title: `Use local files (${supportedContextFileTypes.join(", ")})`,
|
||||
value: "file",
|
||||
},
|
||||
{
|
||||
title:
|
||||
process.platform === "win32"
|
||||
? "Use a local folder"
|
||||
: "Use local folders",
|
||||
value: "folder",
|
||||
},
|
||||
);
|
||||
|
||||
if (framework === "fastapi" && template !== "extractor") {
|
||||
if (framework === "fastapi") {
|
||||
choices.push({
|
||||
title: "Use website content (requires Chrome)",
|
||||
value: "web",
|
||||
@@ -183,13 +160,6 @@ export const getDataSourceChoices = (
|
||||
value: "db",
|
||||
});
|
||||
}
|
||||
|
||||
if (!selectedDataSource.length && template !== "extractor") {
|
||||
choices.push({
|
||||
title: "Use managed index from LlamaCloud",
|
||||
value: "llamacloud",
|
||||
});
|
||||
}
|
||||
return choices;
|
||||
};
|
||||
|
||||
@@ -287,8 +257,7 @@ export const askQuestions = async (
|
||||
},
|
||||
];
|
||||
|
||||
const modelConfigured =
|
||||
!program.llamapack && program.modelConfig.isConfigured();
|
||||
const modelConfigured = isModelConfigured(program.modelConfig);
|
||||
// If using LlamaParse, require LlamaCloud API key
|
||||
const llamaCloudKeyConfigured = program.useLlamaParse
|
||||
? program.llamaCloudKey || process.env["LLAMA_CLOUD_API_KEY"]
|
||||
@@ -299,7 +268,8 @@ export const askQuestions = async (
|
||||
!hasVectorDb &&
|
||||
modelConfigured &&
|
||||
llamaCloudKeyConfigured &&
|
||||
!toolsRequireConfig(program.tools)
|
||||
!toolsRequireConfig(program.tools) &&
|
||||
!program.llamapack
|
||||
) {
|
||||
actionChoices.push({
|
||||
title:
|
||||
@@ -337,24 +307,15 @@ export const askQuestions = async (
|
||||
name: "template",
|
||||
message: "Which template would you like to use?",
|
||||
choices: [
|
||||
{ title: "Agentic RAG (e.g. chat with docs)", value: "streaming" },
|
||||
{ title: "Chat", value: "streaming" },
|
||||
{
|
||||
title: "Multi-agent app (using workflows)",
|
||||
value: "multiagent",
|
||||
title: `Community template from ${styledRepo}`,
|
||||
value: "community",
|
||||
},
|
||||
{
|
||||
title: "Example using a LlamaPack",
|
||||
value: "llamapack",
|
||||
},
|
||||
{ title: "Structured Extractor", value: "extractor" },
|
||||
...(program.askExamples
|
||||
? [
|
||||
{
|
||||
title: `Community template from ${styledRepo}`,
|
||||
value: "community",
|
||||
},
|
||||
{
|
||||
title: "Example using a LlamaPack",
|
||||
value: "llamapack",
|
||||
},
|
||||
]
|
||||
: []),
|
||||
],
|
||||
initial: 0,
|
||||
},
|
||||
@@ -410,15 +371,6 @@ export const askQuestions = async (
|
||||
return; // early return - no further questions needed for llamapack projects
|
||||
}
|
||||
|
||||
if (program.template === "multiagent") {
|
||||
// TODO: multi-agents currently only supports FastAPI
|
||||
program.framework = preferences.framework = "fastapi";
|
||||
} else if (program.template === "extractor") {
|
||||
// Extractor template only supports FastAPI, empty data sources, and llamacloud
|
||||
// So we just use example file for extractor template, this allows user to choose vector database later
|
||||
program.dataSources = [EXAMPLE_FILE];
|
||||
program.framework = preferences.framework = "fastapi";
|
||||
}
|
||||
if (!program.framework) {
|
||||
if (ciInfo.isCI) {
|
||||
program.framework = getPrefOrDefault("framework");
|
||||
@@ -444,11 +396,9 @@ export const askQuestions = async (
|
||||
}
|
||||
}
|
||||
|
||||
if (
|
||||
(program.framework === "express" || program.framework === "fastapi") &&
|
||||
(program.template === "streaming" || program.template === "multiagent")
|
||||
) {
|
||||
if (program.framework === "express" || program.framework === "fastapi") {
|
||||
// if a backend-only framework is selected, ask whether we should create a frontend
|
||||
// (only for streaming backends)
|
||||
if (program.frontend === undefined) {
|
||||
if (ciInfo.isCI) {
|
||||
program.frontend = getPrefOrDefault("frontend");
|
||||
@@ -484,7 +434,7 @@ export const askQuestions = async (
|
||||
}
|
||||
}
|
||||
|
||||
if (!program.observability && program.template === "streaming") {
|
||||
if (!program.observability) {
|
||||
if (ciInfo.isCI) {
|
||||
program.observability = getPrefOrDefault("observability");
|
||||
} else {
|
||||
@@ -495,10 +445,7 @@ export const askQuestions = async (
|
||||
message: "Would you like to set up observability?",
|
||||
choices: [
|
||||
{ title: "No", value: "none" },
|
||||
...(program.framework === "fastapi"
|
||||
? [{ title: "LlamaTrace", value: "llamatrace" }]
|
||||
: []),
|
||||
{ title: "Traceloop", value: "traceloop" },
|
||||
{ title: "OpenTelemetry", value: "opentelemetry" },
|
||||
],
|
||||
initial: 0,
|
||||
},
|
||||
@@ -514,7 +461,6 @@ export const askQuestions = async (
|
||||
const modelConfig = await askModelConfig({
|
||||
openAiKey,
|
||||
askModels: program.askModels ?? false,
|
||||
framework: program.framework,
|
||||
});
|
||||
program.modelConfig = modelConfig;
|
||||
preferences.modelConfig = modelConfig;
|
||||
@@ -528,12 +474,6 @@ export const askQuestions = async (
|
||||
// continue asking user for data sources if none are initially provided
|
||||
while (true) {
|
||||
const firstQuestion = program.dataSources.length === 0;
|
||||
const choices = getDataSourceChoices(
|
||||
program.framework,
|
||||
program.dataSources,
|
||||
program.template,
|
||||
);
|
||||
if (choices.length === 0) break;
|
||||
const { selectedSource } = await prompts(
|
||||
{
|
||||
type: "select",
|
||||
@@ -541,7 +481,10 @@ export const askQuestions = async (
|
||||
message: firstQuestion
|
||||
? "Which data source would you like to use?"
|
||||
: "Would you like to add another data source?",
|
||||
choices,
|
||||
choices: getDataSourceChoices(
|
||||
program.framework,
|
||||
program.dataSources,
|
||||
),
|
||||
initial: firstQuestion ? 1 : 0,
|
||||
},
|
||||
questionHandlers,
|
||||
@@ -637,88 +580,52 @@ export const askQuestions = async (
|
||||
type: "db",
|
||||
config: await prompts(dbPrompts, questionHandlers),
|
||||
});
|
||||
break;
|
||||
}
|
||||
case "llamacloud": {
|
||||
program.dataSources.push({
|
||||
type: "llamacloud",
|
||||
config: {},
|
||||
});
|
||||
program.dataSources.push(EXAMPLE_FILE);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const isUsingLlamaCloud = program.dataSources.some(
|
||||
(ds) => ds.type === "llamacloud",
|
||||
);
|
||||
// Asking for LlamaParse if user selected file or folder data source
|
||||
if (
|
||||
program.dataSources.some((ds) => ds.type === "file") &&
|
||||
program.useLlamaParse === undefined
|
||||
) {
|
||||
if (ciInfo.isCI) {
|
||||
program.useLlamaParse = getPrefOrDefault("useLlamaParse");
|
||||
program.llamaCloudKey = getPrefOrDefault("llamaCloudKey");
|
||||
} else {
|
||||
const { useLlamaParse } = await prompts(
|
||||
{
|
||||
type: "toggle",
|
||||
name: "useLlamaParse",
|
||||
message:
|
||||
"Would you like to use LlamaParse (improved parser for RAG - requires API key)?",
|
||||
initial: false,
|
||||
active: "yes",
|
||||
inactive: "no",
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
program.useLlamaParse = useLlamaParse;
|
||||
|
||||
// Asking for LlamaParse if user selected file data source
|
||||
if (isUsingLlamaCloud) {
|
||||
// default to use LlamaParse if using LlamaCloud
|
||||
program.useLlamaParse = preferences.useLlamaParse = true;
|
||||
} else {
|
||||
// Extractor template doesn't support LlamaParse and LlamaCloud right now (cannot use asyncio loop in Reflex)
|
||||
if (
|
||||
program.useLlamaParse === undefined &&
|
||||
program.template !== "extractor"
|
||||
) {
|
||||
// if already set useLlamaParse, don't ask again
|
||||
if (program.dataSources.some((ds) => ds.type === "file")) {
|
||||
if (ciInfo.isCI) {
|
||||
program.useLlamaParse = getPrefOrDefault("useLlamaParse");
|
||||
} else {
|
||||
const { useLlamaParse } = await prompts(
|
||||
{
|
||||
type: "toggle",
|
||||
name: "useLlamaParse",
|
||||
message:
|
||||
"Would you like to use LlamaParse (improved parser for RAG - requires API key)?",
|
||||
initial: false,
|
||||
active: "yes",
|
||||
inactive: "no",
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
program.useLlamaParse = useLlamaParse;
|
||||
preferences.useLlamaParse = useLlamaParse;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Ask for LlamaCloud API key when using a LlamaCloud index or LlamaParse
|
||||
if (isUsingLlamaCloud || program.useLlamaParse) {
|
||||
if (!program.llamaCloudKey) {
|
||||
// if already set, don't ask again
|
||||
if (ciInfo.isCI) {
|
||||
program.llamaCloudKey = getPrefOrDefault("llamaCloudKey");
|
||||
} else {
|
||||
// Ask for LlamaCloud API key
|
||||
// Ask for LlamaCloud API key
|
||||
if (useLlamaParse && program.llamaCloudKey === undefined) {
|
||||
const { llamaCloudKey } = await prompts(
|
||||
{
|
||||
type: "text",
|
||||
name: "llamaCloudKey",
|
||||
message:
|
||||
"Please provide your LlamaCloud API key (leave blank to skip):",
|
||||
"Please provide your LlamaIndex Cloud API key (leave blank to skip):",
|
||||
},
|
||||
questionHandlers,
|
||||
);
|
||||
program.llamaCloudKey = preferences.llamaCloudKey =
|
||||
llamaCloudKey || process.env.LLAMA_CLOUD_API_KEY;
|
||||
program.llamaCloudKey = llamaCloudKey;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (isUsingLlamaCloud) {
|
||||
// When using a LlamaCloud index, don't ask for vector database and use code in `llamacloud` folder for vector database
|
||||
const vectorDb = "llamacloud";
|
||||
program.vectorDb = vectorDb;
|
||||
preferences.vectorDb = vectorDb;
|
||||
} else if (program.dataSources.length > 0 && !program.vectorDb) {
|
||||
if (program.dataSources.length > 0 && !program.vectorDb) {
|
||||
if (ciInfo.isCI) {
|
||||
program.vectorDb = getPrefOrDefault("vectorDb");
|
||||
} else {
|
||||
@@ -737,8 +644,7 @@ export const askQuestions = async (
|
||||
}
|
||||
}
|
||||
|
||||
if (!program.tools && program.template === "streaming") {
|
||||
// TODO: allow to select tools also for multi-agent framework
|
||||
if (!program.tools) {
|
||||
if (ciInfo.isCI) {
|
||||
program.tools = getPrefOrDefault("tools");
|
||||
} else {
|
||||
@@ -746,7 +652,7 @@ export const askQuestions = async (
|
||||
t.supportedFrameworks?.includes(program.framework),
|
||||
);
|
||||
const toolChoices = options.map((tool) => ({
|
||||
title: `${tool.display}${toolRequiresConfig(tool) ? " (needs configuration)" : ""}`,
|
||||
title: tool.display,
|
||||
value: tool.name,
|
||||
}));
|
||||
const { toolsName } = await prompts({
|
||||
|
||||
+7
-14
@@ -1,26 +1,20 @@
|
||||
import os
|
||||
|
||||
from app.engine.index import IndexConfig, get_index
|
||||
from app.engine.tools import ToolFactory
|
||||
from llama_index.core.agent import AgentRunner
|
||||
from llama_index.core.callbacks import CallbackManager
|
||||
from llama_index.core.settings import Settings
|
||||
from llama_index.core.agent import AgentRunner
|
||||
from llama_index.core.tools.query_engine import QueryEngineTool
|
||||
from app.engine.tools import ToolFactory
|
||||
from app.engine.index import get_index
|
||||
|
||||
|
||||
def get_chat_engine(filters=None, params=None, event_handlers=None):
|
||||
def get_chat_engine():
|
||||
system_prompt = os.getenv("SYSTEM_PROMPT")
|
||||
top_k = int(os.getenv("TOP_K", 0))
|
||||
top_k = os.getenv("TOP_K", "3")
|
||||
tools = []
|
||||
callback_manager = CallbackManager(handlers=event_handlers or [])
|
||||
|
||||
# Add query tool if index exists
|
||||
index_config = IndexConfig(callback_manager=callback_manager, **(params or {}))
|
||||
index = get_index(index_config)
|
||||
index = get_index()
|
||||
if index is not None:
|
||||
query_engine = index.as_query_engine(
|
||||
filters=filters, **({"similarity_top_k": top_k} if top_k != 0 else {})
|
||||
)
|
||||
query_engine = index.as_query_engine(similarity_top_k=int(top_k))
|
||||
query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine)
|
||||
tools.append(query_engine_tool)
|
||||
|
||||
@@ -31,6 +25,5 @@ def get_chat_engine(filters=None, params=None, event_handlers=None):
|
||||
llm=Settings.llm,
|
||||
tools=tools,
|
||||
system_prompt=system_prompt,
|
||||
callback_manager=callback_manager,
|
||||
verbose=True,
|
||||
)
|
||||
@@ -0,0 +1,35 @@
|
||||
import os
|
||||
import yaml
|
||||
import importlib
|
||||
|
||||
from llama_index.core.tools.tool_spec.base import BaseToolSpec
|
||||
from llama_index.core.tools.function_tool import FunctionTool
|
||||
|
||||
|
||||
class ToolFactory:
|
||||
|
||||
@staticmethod
|
||||
def create_tool(tool_name: str, **kwargs) -> list[FunctionTool]:
|
||||
try:
|
||||
tool_package, tool_cls_name = tool_name.split(".")
|
||||
module_name = f"llama_index.tools.{tool_package}"
|
||||
module = importlib.import_module(module_name)
|
||||
tool_class = getattr(module, tool_cls_name)
|
||||
tool_spec: BaseToolSpec = tool_class(**kwargs)
|
||||
return tool_spec.to_tool_list()
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise ValueError(f"Unsupported tool: {tool_name}") from e
|
||||
except TypeError as e:
|
||||
raise ValueError(
|
||||
f"Could not create tool: {tool_name}. With config: {kwargs}"
|
||||
) from e
|
||||
|
||||
@staticmethod
|
||||
def from_env() -> list[FunctionTool]:
|
||||
tools = []
|
||||
if os.path.exists("config/tools.yaml"):
|
||||
with open("config/tools.yaml", "r") as f:
|
||||
tool_configs = yaml.safe_load(f)
|
||||
for name, config in tool_configs.items():
|
||||
tools += ToolFactory.create_tool(name, **config)
|
||||
return tools
|
||||
@@ -1,53 +0,0 @@
|
||||
import os
|
||||
import yaml
|
||||
import importlib
|
||||
from llama_index.core.tools.tool_spec.base import BaseToolSpec
|
||||
from llama_index.core.tools.function_tool import FunctionTool
|
||||
|
||||
|
||||
class ToolType:
|
||||
LLAMAHUB = "llamahub"
|
||||
LOCAL = "local"
|
||||
|
||||
|
||||
class ToolFactory:
|
||||
TOOL_SOURCE_PACKAGE_MAP = {
|
||||
ToolType.LLAMAHUB: "llama_index.tools",
|
||||
ToolType.LOCAL: "app.engine.tools",
|
||||
}
|
||||
|
||||
def load_tools(tool_type: str, tool_name: str, config: dict) -> list[FunctionTool]:
|
||||
source_package = ToolFactory.TOOL_SOURCE_PACKAGE_MAP[tool_type]
|
||||
try:
|
||||
if "ToolSpec" in tool_name:
|
||||
tool_package, tool_cls_name = tool_name.split(".")
|
||||
module_name = f"{source_package}.{tool_package}"
|
||||
module = importlib.import_module(module_name)
|
||||
tool_class = getattr(module, tool_cls_name)
|
||||
tool_spec: BaseToolSpec = tool_class(**config)
|
||||
return tool_spec.to_tool_list()
|
||||
else:
|
||||
module = importlib.import_module(f"{source_package}.{tool_name}")
|
||||
tools = module.get_tools(**config)
|
||||
if not all(isinstance(tool, FunctionTool) for tool in tools):
|
||||
raise ValueError(
|
||||
f"The module {module} does not contain valid tools"
|
||||
)
|
||||
return tools
|
||||
except ImportError as e:
|
||||
raise ValueError(f"Failed to import tool {tool_name}: {e}")
|
||||
except AttributeError as e:
|
||||
raise ValueError(f"Failed to load tool {tool_name}: {e}")
|
||||
|
||||
@staticmethod
|
||||
def from_env() -> list[FunctionTool]:
|
||||
tools = []
|
||||
if os.path.exists("config/tools.yaml"):
|
||||
with open("config/tools.yaml", "r") as f:
|
||||
tool_configs = yaml.safe_load(f)
|
||||
for tool_type, config_entries in tool_configs.items():
|
||||
for tool_name, config in config_entries.items():
|
||||
tools.extend(
|
||||
ToolFactory.load_tools(tool_type, tool_name, config)
|
||||
)
|
||||
return tools
|
||||
@@ -1,36 +0,0 @@
|
||||
from llama_index.core.tools.function_tool import FunctionTool
|
||||
|
||||
|
||||
def duckduckgo_search(
|
||||
query: str,
|
||||
region: str = "wt-wt",
|
||||
max_results: int = 10,
|
||||
):
|
||||
"""
|
||||
Use this function to search for any query in DuckDuckGo.
|
||||
Args:
|
||||
query (str): The query to search in DuckDuckGo.
|
||||
region Optional(str): The region to be used for the search in [country-language] convention, ex us-en, uk-en, ru-ru, etc...
|
||||
max_results Optional(int): The maximum number of results to be returned. Default is 10.
|
||||
"""
|
||||
try:
|
||||
from duckduckgo_search import DDGS
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"duckduckgo_search package is required to use this function."
|
||||
"Please install it by running: `poetry add duckduckgo_search` or `pip install duckduckgo_search`"
|
||||
)
|
||||
|
||||
params = {
|
||||
"keywords": query,
|
||||
"region": region,
|
||||
"max_results": max_results,
|
||||
}
|
||||
results = []
|
||||
with DDGS() as ddg:
|
||||
results = list(ddg.text(**params))
|
||||
return results
|
||||
|
||||
|
||||
def get_tools(**kwargs):
|
||||
return [FunctionTool.from_defaults(duckduckgo_search)]
|
||||
@@ -1,108 +0,0 @@
|
||||
import os
|
||||
import uuid
|
||||
import logging
|
||||
import requests
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from llama_index.core.tools import FunctionTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ImageGeneratorToolOutput(BaseModel):
|
||||
is_success: bool = Field(
|
||||
...,
|
||||
description="Whether the image generation was successful.",
|
||||
)
|
||||
image_url: Optional[str] = Field(
|
||||
None,
|
||||
description="The URL of the generated image.",
|
||||
)
|
||||
error_message: Optional[str] = Field(
|
||||
None,
|
||||
description="The error message if the image generation failed.",
|
||||
)
|
||||
|
||||
|
||||
class ImageGeneratorTool:
|
||||
_IMG_OUTPUT_FORMAT = "webp"
|
||||
_IMG_OUTPUT_DIR = "output/tool"
|
||||
_IMG_GEN_API = "https://api.stability.ai/v2beta/stable-image/generate/core"
|
||||
|
||||
def __init__(self, api_key: str = None):
|
||||
if not api_key:
|
||||
api_key = os.getenv("STABILITY_API_KEY")
|
||||
self._api_key = api_key
|
||||
self.fileserver_url_prefix = os.getenv("FILESERVER_URL_PREFIX")
|
||||
if self._api_key is None:
|
||||
raise ValueError(
|
||||
"STABILITY_API_KEY key is required to run image generator. Get it here: https://platform.stability.ai/account/keys"
|
||||
)
|
||||
if self.fileserver_url_prefix is None:
|
||||
raise ValueError("FILESERVER_URL_PREFIX is required.")
|
||||
|
||||
def _prepare_output_dir(self):
|
||||
"""
|
||||
Create the output directory if it doesn't exist
|
||||
"""
|
||||
if not os.path.exists(self._IMG_OUTPUT_DIR):
|
||||
os.makedirs(self._IMG_OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
def _save_image(self, image_data: bytes):
|
||||
self._prepare_output_dir()
|
||||
filename = f"{uuid.uuid4()}.{self._IMG_OUTPUT_FORMAT}"
|
||||
output_path = os.path.join(self._IMG_OUTPUT_DIR, filename)
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(image_data)
|
||||
url = f"{os.getenv('FILESERVER_URL_PREFIX')}/{self._IMG_OUTPUT_DIR}/{filename}"
|
||||
logger.info(f"Saved image to {output_path}.\nURL: {url}")
|
||||
return url
|
||||
|
||||
def _call_stability_api(self, prompt: str):
|
||||
headers = {
|
||||
"authorization": f"Bearer {self._api_key}",
|
||||
"accept": "image/*",
|
||||
}
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"output_format": self._IMG_OUTPUT_FORMAT,
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
self._IMG_GEN_API,
|
||||
headers=headers,
|
||||
files={"none": ""},
|
||||
data=data,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return response
|
||||
|
||||
def generate_image(self, prompt: str) -> ImageGeneratorToolOutput:
|
||||
"""
|
||||
Use this tool to generate an image based on the prompt.
|
||||
Args:
|
||||
prompt (str): The prompt to generate the image from.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Call the Stability API
|
||||
response = self._call_stability_api(prompt)
|
||||
|
||||
# Save the image and get the URL
|
||||
image_url = self._save_image(response.content)
|
||||
|
||||
return ImageGeneratorToolOutput(
|
||||
is_success=True,
|
||||
image_url=image_url,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(e, exc_info=True)
|
||||
return ImageGeneratorToolOutput(
|
||||
is_success=False,
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
|
||||
def get_tools(**kwargs):
|
||||
return [FunctionTool.from_defaults(ImageGeneratorTool(**kwargs).generate_image)]
|
||||
@@ -1,142 +0,0 @@
|
||||
import os
|
||||
import logging
|
||||
import base64
|
||||
import uuid
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Dict, Optional
|
||||
from llama_index.core.tools import FunctionTool
|
||||
from e2b_code_interpreter import CodeInterpreter
|
||||
from e2b_code_interpreter.models import Logs
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InterpreterExtraResult(BaseModel):
|
||||
type: str
|
||||
content: Optional[str] = None
|
||||
filename: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
|
||||
|
||||
class E2BToolOutput(BaseModel):
|
||||
is_error: bool
|
||||
logs: Logs
|
||||
results: List[InterpreterExtraResult] = []
|
||||
|
||||
|
||||
class E2BCodeInterpreter:
|
||||
output_dir = "output/tool"
|
||||
|
||||
def __init__(self, api_key: str = None):
|
||||
if api_key is None:
|
||||
api_key = os.getenv("E2B_API_KEY")
|
||||
filesever_url_prefix = os.getenv("FILESERVER_URL_PREFIX")
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"E2B_API_KEY key is required to run code interpreter. Get it here: https://e2b.dev/docs/getting-started/api-key"
|
||||
)
|
||||
if not filesever_url_prefix:
|
||||
raise ValueError(
|
||||
"FILESERVER_URL_PREFIX is required to display file output from sandbox"
|
||||
)
|
||||
|
||||
self.filesever_url_prefix = filesever_url_prefix
|
||||
self.interpreter = CodeInterpreter(api_key=api_key)
|
||||
|
||||
def __del__(self):
|
||||
self.interpreter.close()
|
||||
|
||||
def get_output_path(self, filename: str) -> str:
|
||||
# if output directory doesn't exist, create it
|
||||
if not os.path.exists(self.output_dir):
|
||||
os.makedirs(self.output_dir, exist_ok=True)
|
||||
return os.path.join(self.output_dir, filename)
|
||||
|
||||
def save_to_disk(self, base64_data: str, ext: str) -> Dict:
|
||||
filename = f"{uuid.uuid4()}.{ext}" # generate a unique filename
|
||||
buffer = base64.b64decode(base64_data)
|
||||
output_path = self.get_output_path(filename)
|
||||
|
||||
try:
|
||||
with open(output_path, "wb") as file:
|
||||
file.write(buffer)
|
||||
except IOError as e:
|
||||
logger.error(f"Failed to write to file {output_path}: {str(e)}")
|
||||
raise e
|
||||
|
||||
logger.info(f"Saved file to {output_path}")
|
||||
|
||||
return {
|
||||
"outputPath": output_path,
|
||||
"filename": filename,
|
||||
}
|
||||
|
||||
def get_file_url(self, filename: str) -> str:
|
||||
return f"{self.filesever_url_prefix}/{self.output_dir}/{filename}"
|
||||
|
||||
def parse_result(self, result) -> List[InterpreterExtraResult]:
|
||||
"""
|
||||
The result could include multiple formats (e.g. png, svg, etc.) but encoded in base64
|
||||
We save each result to disk and return saved file metadata (extension, filename, url)
|
||||
"""
|
||||
if not result:
|
||||
return []
|
||||
|
||||
output = []
|
||||
|
||||
try:
|
||||
formats = result.formats()
|
||||
results = [result[format] for format in formats]
|
||||
|
||||
for ext, data in zip(formats, results):
|
||||
match ext:
|
||||
case "png" | "svg" | "jpeg" | "pdf":
|
||||
result = self.save_to_disk(data, ext)
|
||||
filename = result["filename"]
|
||||
output.append(
|
||||
InterpreterExtraResult(
|
||||
type=ext,
|
||||
filename=filename,
|
||||
url=self.get_file_url(filename),
|
||||
)
|
||||
)
|
||||
case _:
|
||||
output.append(
|
||||
InterpreterExtraResult(
|
||||
type=ext,
|
||||
content=data,
|
||||
)
|
||||
)
|
||||
except Exception as error:
|
||||
logger.exception(error, exc_info=True)
|
||||
logger.error("Error when parsing output from E2b interpreter tool", error)
|
||||
|
||||
return output
|
||||
|
||||
def interpret(self, code: str) -> E2BToolOutput:
|
||||
"""
|
||||
Execute python code in a Jupyter notebook cell, the toll will return result, stdout, stderr, display_data, and error.
|
||||
|
||||
Parameters:
|
||||
code (str): The python code to be executed in a single cell.
|
||||
"""
|
||||
logger.info(
|
||||
f"\n{'='*50}\n> Running following AI-generated code:\n{code}\n{'='*50}"
|
||||
)
|
||||
exec = self.interpreter.notebook.exec_cell(code)
|
||||
|
||||
if exec.error:
|
||||
logger.error("Error when executing code", exec.error)
|
||||
output = E2BToolOutput(is_error=True, logs=exec.logs, results=[])
|
||||
else:
|
||||
if len(exec.results) == 0:
|
||||
output = E2BToolOutput(is_error=False, logs=exec.logs, results=[])
|
||||
else:
|
||||
results = self.parse_result(exec.results[0])
|
||||
output = E2BToolOutput(is_error=False, logs=exec.logs, results=results)
|
||||
return output
|
||||
|
||||
|
||||
def get_tools(**kwargs):
|
||||
return [FunctionTool.from_defaults(E2BCodeInterpreter(**kwargs).interpret)]
|
||||
@@ -1,78 +0,0 @@
|
||||
from typing import Dict, List, Tuple
|
||||
from llama_index.tools.openapi import OpenAPIToolSpec
|
||||
from llama_index.tools.requests import RequestsToolSpec
|
||||
|
||||
|
||||
class OpenAPIActionToolSpec(OpenAPIToolSpec, RequestsToolSpec):
|
||||
"""
|
||||
A combination of OpenAPI and Requests tool specs that can parse OpenAPI specs and make requests.
|
||||
|
||||
openapi_uri: str: The file path or URL to the OpenAPI spec.
|
||||
domain_headers: dict: Whitelist domains and the headers to use.
|
||||
"""
|
||||
|
||||
spec_functions = OpenAPIToolSpec.spec_functions + RequestsToolSpec.spec_functions
|
||||
# Cached parsed specs by URI
|
||||
_specs: Dict[str, Tuple[Dict, List[str]]] = {}
|
||||
|
||||
def __init__(self, openapi_uri: str, domain_headers: dict = None, **kwargs):
|
||||
if domain_headers is None:
|
||||
domain_headers = {}
|
||||
if openapi_uri not in self._specs:
|
||||
openapi_spec, servers = self._load_openapi_spec(openapi_uri)
|
||||
self._specs[openapi_uri] = (openapi_spec, servers)
|
||||
else:
|
||||
openapi_spec, servers = self._specs[openapi_uri]
|
||||
|
||||
# Add the servers to the domain headers if they are not already present
|
||||
for server in servers:
|
||||
if server not in domain_headers:
|
||||
domain_headers[server] = {}
|
||||
|
||||
OpenAPIToolSpec.__init__(self, spec=openapi_spec)
|
||||
RequestsToolSpec.__init__(self, domain_headers)
|
||||
|
||||
@staticmethod
|
||||
def _load_openapi_spec(uri: str) -> Tuple[Dict, List[str]]:
|
||||
"""
|
||||
Load an OpenAPI spec from a URI.
|
||||
|
||||
Args:
|
||||
uri (str): A file path or URL to the OpenAPI spec.
|
||||
|
||||
Returns:
|
||||
List[Document]: A list of Document objects.
|
||||
"""
|
||||
import yaml
|
||||
from urllib.parse import urlparse
|
||||
|
||||
if uri.startswith("http"):
|
||||
import requests
|
||||
|
||||
response = requests.get(uri)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(
|
||||
"Could not initialize OpenAPIActionToolSpec: "
|
||||
f"Failed to load OpenAPI spec from {uri}, status code: {response.status_code}"
|
||||
)
|
||||
spec = yaml.safe_load(response.text)
|
||||
elif uri.startswith("file"):
|
||||
filepath = urlparse(uri).path
|
||||
with open(filepath, "r") as file:
|
||||
spec = yaml.safe_load(file)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Could not initialize OpenAPIActionToolSpec: Invalid OpenAPI URI provided. "
|
||||
"Only HTTP and file path are supported."
|
||||
)
|
||||
# Add the servers to the whitelist
|
||||
try:
|
||||
servers = [
|
||||
urlparse(server["url"]).netloc for server in spec.get("servers", [])
|
||||
]
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
"Could not initialize OpenAPIActionToolSpec: Invalid OpenAPI spec provided. "
|
||||
"Could not get `servers` from the spec."
|
||||
) from e
|
||||
return spec, servers
|
||||
@@ -1,73 +0,0 @@
|
||||
"""Open Meteo weather map tool spec."""
|
||||
|
||||
import logging
|
||||
import requests
|
||||
import pytz
|
||||
from llama_index.core.tools import FunctionTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenMeteoWeather:
|
||||
geo_api = "https://geocoding-api.open-meteo.com/v1"
|
||||
weather_api = "https://api.open-meteo.com/v1"
|
||||
|
||||
@classmethod
|
||||
def _get_geo_location(cls, location: str) -> dict:
|
||||
"""Get geo location from location name."""
|
||||
params = {"name": location, "count": 10, "language": "en", "format": "json"}
|
||||
response = requests.get(f"{cls.geo_api}/search", params=params)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to fetch geo location: {response.status_code}")
|
||||
else:
|
||||
data = response.json()
|
||||
result = data["results"][0]
|
||||
geo_location = {
|
||||
"id": result["id"],
|
||||
"name": result["name"],
|
||||
"latitude": result["latitude"],
|
||||
"longitude": result["longitude"],
|
||||
}
|
||||
return geo_location
|
||||
|
||||
@classmethod
|
||||
def get_weather_information(cls, location: str) -> dict:
|
||||
"""Use this function to get the weather of any given location.
|
||||
Note that the weather code should follow WMO Weather interpretation codes (WW):
|
||||
0: Clear sky
|
||||
1, 2, 3: Mainly clear, partly cloudy, and overcast
|
||||
45, 48: Fog and depositing rime fog
|
||||
51, 53, 55: Drizzle: Light, moderate, and dense intensity
|
||||
56, 57: Freezing Drizzle: Light and dense intensity
|
||||
61, 63, 65: Rain: Slight, moderate and heavy intensity
|
||||
66, 67: Freezing Rain: Light and heavy intensity
|
||||
71, 73, 75: Snow fall: Slight, moderate, and heavy intensity
|
||||
77: Snow grains
|
||||
80, 81, 82: Rain showers: Slight, moderate, and violent
|
||||
85, 86: Snow showers slight and heavy
|
||||
95: Thunderstorm: Slight or moderate
|
||||
96, 99: Thunderstorm with slight and heavy hail
|
||||
"""
|
||||
logger.info(
|
||||
f"Calling open-meteo api to get weather information of location: {location}"
|
||||
)
|
||||
geo_location = cls._get_geo_location(location)
|
||||
timezone = pytz.timezone("UTC").zone
|
||||
params = {
|
||||
"latitude": geo_location["latitude"],
|
||||
"longitude": geo_location["longitude"],
|
||||
"current": "temperature_2m,weather_code",
|
||||
"hourly": "temperature_2m,weather_code",
|
||||
"daily": "weather_code",
|
||||
"timezone": timezone,
|
||||
}
|
||||
response = requests.get(f"{cls.weather_api}/forecast", params=params)
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Failed to fetch weather information: {response.status_code}"
|
||||
)
|
||||
return response.json()
|
||||
|
||||
|
||||
def get_tools(**kwargs):
|
||||
return [FunctionTool.from_defaults(OpenMeteoWeather.get_weather_information)]
|
||||
+8
-12
@@ -1,13 +1,11 @@
|
||||
import os
|
||||
|
||||
from fastapi import HTTPException
|
||||
from llama_index.core.settings import Settings
|
||||
|
||||
from app.engine.index import get_index
|
||||
from fastapi import HTTPException
|
||||
|
||||
|
||||
def get_query_engine(output_cls):
|
||||
top_k = int(os.getenv("TOP_K", 0))
|
||||
def get_chat_engine():
|
||||
system_prompt = os.getenv("SYSTEM_PROMPT")
|
||||
top_k = os.getenv("TOP_K", 3)
|
||||
|
||||
index = get_index()
|
||||
if index is None:
|
||||
@@ -18,10 +16,8 @@ def get_query_engine(output_cls):
|
||||
),
|
||||
)
|
||||
|
||||
sllm = Settings.llm.as_structured_llm(output_cls)
|
||||
|
||||
return index.as_query_engine(
|
||||
llm=sllm,
|
||||
response_mode="tree_summarize",
|
||||
**({"similarity_top_k": top_k} if top_k != 0 else {}),
|
||||
return index.as_chat_engine(
|
||||
similarity_top_k=int(top_k),
|
||||
system_prompt=system_prompt,
|
||||
chat_mode="condense_plus_context",
|
||||
)
|
||||
@@ -1,48 +0,0 @@
|
||||
import os
|
||||
|
||||
from app.engine.index import IndexConfig, get_index
|
||||
from app.engine.node_postprocessors import NodeCitationProcessor
|
||||
from fastapi import HTTPException
|
||||
from llama_index.core.callbacks import CallbackManager
|
||||
from llama_index.core.chat_engine import CondensePlusContextChatEngine
|
||||
from llama_index.core.memory import ChatMemoryBuffer
|
||||
from llama_index.core.settings import Settings
|
||||
|
||||
|
||||
def get_chat_engine(filters=None, params=None, event_handlers=None):
|
||||
system_prompt = os.getenv("SYSTEM_PROMPT")
|
||||
citation_prompt = os.getenv("SYSTEM_CITATION_PROMPT", None)
|
||||
top_k = int(os.getenv("TOP_K", 0))
|
||||
llm = Settings.llm
|
||||
memory = ChatMemoryBuffer.from_defaults(
|
||||
token_limit=llm.metadata.context_window - 256
|
||||
)
|
||||
callback_manager = CallbackManager(handlers=event_handlers or [])
|
||||
|
||||
node_postprocessors = []
|
||||
if citation_prompt:
|
||||
node_postprocessors = [NodeCitationProcessor()]
|
||||
system_prompt = f"{system_prompt}\n{citation_prompt}"
|
||||
|
||||
index_config = IndexConfig(callback_manager=callback_manager, **(params or {}))
|
||||
index = get_index(index_config)
|
||||
if index is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=str(
|
||||
"StorageContext is empty - call 'poetry run generate' to generate the storage first"
|
||||
),
|
||||
)
|
||||
|
||||
retriever = index.as_retriever(
|
||||
filters=filters, **({"similarity_top_k": top_k} if top_k != 0 else {})
|
||||
)
|
||||
|
||||
return CondensePlusContextChatEngine(
|
||||
llm=llm,
|
||||
memory=memory,
|
||||
system_prompt=system_prompt,
|
||||
retriever=retriever,
|
||||
node_postprocessors=node_postprocessors,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
@@ -1,21 +0,0 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from llama_index.core import QueryBundle
|
||||
from llama_index.core.postprocessor.types import BaseNodePostprocessor
|
||||
from llama_index.core.schema import NodeWithScore
|
||||
|
||||
|
||||
class NodeCitationProcessor(BaseNodePostprocessor):
|
||||
"""
|
||||
Append node_id into metadata for citation purpose.
|
||||
Config SYSTEM_CITATION_PROMPT in your runtime environment variable to enable this feature.
|
||||
"""
|
||||
|
||||
def _postprocess_nodes(
|
||||
self,
|
||||
nodes: List[NodeWithScore],
|
||||
query_bundle: Optional[QueryBundle] = None,
|
||||
) -> List[NodeWithScore]:
|
||||
for node_score in nodes:
|
||||
node_score.node.metadata["node_id"] = node_score.node.node_id
|
||||
return nodes
|
||||
@@ -1,51 +1,37 @@
|
||||
import {
|
||||
BaseToolWithCall,
|
||||
ChatEngine,
|
||||
OpenAIAgent,
|
||||
QueryEngineTool,
|
||||
} from "llamaindex";
|
||||
import { BaseTool, OpenAIAgent, QueryEngineTool } from "llamaindex";
|
||||
import { ToolsFactory } from "llamaindex/tools/ToolsFactory";
|
||||
import fs from "node:fs/promises";
|
||||
import path from "node:path";
|
||||
import { getDataSource } from "./index";
|
||||
import { generateFilters } from "./queryFilter";
|
||||
import { createTools } from "./tools";
|
||||
import { STORAGE_CACHE_DIR } from "./shared";
|
||||
|
||||
export async function createChatEngine(documentIds?: string[], params?: any) {
|
||||
const tools: BaseToolWithCall[] = [];
|
||||
export async function createChatEngine() {
|
||||
let tools: BaseTool[] = [];
|
||||
|
||||
// Add a query engine tool if we have a data source
|
||||
// Delete this code if you don't have a data source
|
||||
const index = await getDataSource(params);
|
||||
const index = await getDataSource();
|
||||
if (index) {
|
||||
tools.push(
|
||||
new QueryEngineTool({
|
||||
queryEngine: index.asQueryEngine({
|
||||
preFilters: generateFilters(documentIds || []),
|
||||
}),
|
||||
queryEngine: index.asQueryEngine(),
|
||||
metadata: {
|
||||
name: "data_query_engine",
|
||||
description: `A query engine for documents from your data source.`,
|
||||
description: `A query engine for documents in storage folder: ${STORAGE_CACHE_DIR}`,
|
||||
},
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
const configFile = path.join("config", "tools.json");
|
||||
let toolConfig: any;
|
||||
try {
|
||||
// add tools from config file if it exists
|
||||
toolConfig = JSON.parse(await fs.readFile(configFile, "utf8"));
|
||||
} catch (e) {
|
||||
console.info(`Could not read ${configFile} file. Using no tools.`);
|
||||
}
|
||||
if (toolConfig) {
|
||||
tools.push(...(await createTools(toolConfig)));
|
||||
}
|
||||
const config = JSON.parse(
|
||||
await fs.readFile(path.join("config", "tools.json"), "utf8"),
|
||||
);
|
||||
tools = tools.concat(await ToolsFactory.createTools(config));
|
||||
} catch {}
|
||||
|
||||
const agent = new OpenAIAgent({
|
||||
return new OpenAIAgent({
|
||||
tools,
|
||||
systemPrompt: process.env.SYSTEM_PROMPT,
|
||||
}) as unknown as ChatEngine;
|
||||
|
||||
return agent;
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
import { JSONSchemaType } from "ajv";
|
||||
import { search } from "duck-duck-scrape";
|
||||
import { BaseTool, ToolMetadata } from "llamaindex";
|
||||
|
||||
export type DuckDuckGoParameter = {
|
||||
query: string;
|
||||
region?: string;
|
||||
};
|
||||
|
||||
export type DuckDuckGoToolParams = {
|
||||
metadata?: ToolMetadata<JSONSchemaType<DuckDuckGoParameter>>;
|
||||
};
|
||||
|
||||
const DEFAULT_META_DATA: ToolMetadata<JSONSchemaType<DuckDuckGoParameter>> = {
|
||||
name: "duckduckgo",
|
||||
description: "Use this function to search for any query in DuckDuckGo.",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
query: {
|
||||
type: "string",
|
||||
description: "The query to search in DuckDuckGo.",
|
||||
},
|
||||
region: {
|
||||
type: "string",
|
||||
description:
|
||||
"Optional, The region to be used for the search in [country-language] convention, ex us-en, uk-en, ru-ru, etc...",
|
||||
nullable: true,
|
||||
},
|
||||
},
|
||||
required: ["query"],
|
||||
},
|
||||
};
|
||||
|
||||
type DuckDuckGoSearchResult = {
|
||||
title: string;
|
||||
description: string;
|
||||
url: string;
|
||||
};
|
||||
|
||||
export class DuckDuckGoSearchTool implements BaseTool<DuckDuckGoParameter> {
|
||||
metadata: ToolMetadata<JSONSchemaType<DuckDuckGoParameter>>;
|
||||
|
||||
constructor(params: DuckDuckGoToolParams) {
|
||||
this.metadata = params.metadata ?? DEFAULT_META_DATA;
|
||||
}
|
||||
|
||||
async call(input: DuckDuckGoParameter) {
|
||||
const { query, region } = input;
|
||||
const options = region ? { region } : {};
|
||||
const searchResults = await search(query, options);
|
||||
|
||||
return searchResults.results.map((result) => {
|
||||
return {
|
||||
title: result.title,
|
||||
description: result.description,
|
||||
url: result.url,
|
||||
} as DuckDuckGoSearchResult;
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1,112 +0,0 @@
|
||||
import type { JSONSchemaType } from "ajv";
|
||||
import { FormData } from "formdata-node";
|
||||
import fs from "fs";
|
||||
import got from "got";
|
||||
import { BaseTool, ToolMetadata } from "llamaindex";
|
||||
import path from "node:path";
|
||||
import { Readable } from "stream";
|
||||
|
||||
export type ImgGeneratorParameter = {
|
||||
prompt: string;
|
||||
};
|
||||
|
||||
export type ImgGeneratorToolParams = {
|
||||
metadata?: ToolMetadata<JSONSchemaType<ImgGeneratorParameter>>;
|
||||
};
|
||||
|
||||
export type ImgGeneratorToolOutput = {
|
||||
isSuccess: boolean;
|
||||
imageUrl?: string;
|
||||
errorMessage?: string;
|
||||
};
|
||||
|
||||
const DEFAULT_META_DATA: ToolMetadata<JSONSchemaType<ImgGeneratorParameter>> = {
|
||||
name: "image_generator",
|
||||
description: `Use this function to generate an image based on the prompt.`,
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
prompt: {
|
||||
type: "string",
|
||||
description: "The prompt to generate the image",
|
||||
},
|
||||
},
|
||||
required: ["prompt"],
|
||||
},
|
||||
};
|
||||
|
||||
export class ImgGeneratorTool implements BaseTool<ImgGeneratorParameter> {
|
||||
readonly IMG_OUTPUT_FORMAT = "webp";
|
||||
readonly IMG_OUTPUT_DIR = "output/tool";
|
||||
readonly IMG_GEN_API =
|
||||
"https://api.stability.ai/v2beta/stable-image/generate/core";
|
||||
|
||||
metadata: ToolMetadata<JSONSchemaType<ImgGeneratorParameter>>;
|
||||
|
||||
constructor(params?: ImgGeneratorToolParams) {
|
||||
this.checkRequiredEnvVars();
|
||||
this.metadata = params?.metadata || DEFAULT_META_DATA;
|
||||
}
|
||||
|
||||
async call(input: ImgGeneratorParameter): Promise<ImgGeneratorToolOutput> {
|
||||
return await this.generateImage(input.prompt);
|
||||
}
|
||||
|
||||
private generateImage = async (
|
||||
prompt: string,
|
||||
): Promise<ImgGeneratorToolOutput> => {
|
||||
try {
|
||||
const buffer = await this.promptToImgBuffer(prompt);
|
||||
const imageUrl = this.saveImage(buffer);
|
||||
return { isSuccess: true, imageUrl };
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
return {
|
||||
isSuccess: false,
|
||||
errorMessage: "Failed to generate image. Please try again.",
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
private promptToImgBuffer = async (prompt: string) => {
|
||||
const form = new FormData();
|
||||
form.append("prompt", prompt);
|
||||
form.append("output_format", this.IMG_OUTPUT_FORMAT);
|
||||
const buffer = await got
|
||||
.post(this.IMG_GEN_API, {
|
||||
// Not sure why it shows an type error when passing form to body
|
||||
// Although I follow document: https://github.com/sindresorhus/got/blob/main/documentation/2-options.md#body
|
||||
// Tt still works fine, so I make casting to unknown to avoid the typescript warning
|
||||
// Found a similar issue: https://github.com/sindresorhus/got/discussions/1877
|
||||
body: form as unknown as Buffer | Readable | string,
|
||||
headers: {
|
||||
Authorization: `Bearer ${process.env.STABILITY_API_KEY}`,
|
||||
Accept: "image/*",
|
||||
},
|
||||
})
|
||||
.buffer();
|
||||
return buffer;
|
||||
};
|
||||
|
||||
private saveImage = (buffer: Buffer) => {
|
||||
const filename = `${crypto.randomUUID()}.${this.IMG_OUTPUT_FORMAT}`;
|
||||
const outputPath = path.join(this.IMG_OUTPUT_DIR, filename);
|
||||
fs.writeFileSync(outputPath, buffer);
|
||||
const url = `${process.env.FILESERVER_URL_PREFIX}/${this.IMG_OUTPUT_DIR}/${filename}`;
|
||||
console.log(`Saved image to ${outputPath}.\nURL: ${url}`);
|
||||
return url;
|
||||
};
|
||||
|
||||
private checkRequiredEnvVars = () => {
|
||||
if (!process.env.STABILITY_API_KEY) {
|
||||
throw new Error(
|
||||
"STABILITY_API_KEY key is required to run image generator. Get it here: https://platform.stability.ai/account/keys",
|
||||
);
|
||||
}
|
||||
if (!process.env.FILESERVER_URL_PREFIX) {
|
||||
throw new Error(
|
||||
"FILESERVER_URL_PREFIX is required to display file output after generation",
|
||||
);
|
||||
}
|
||||
};
|
||||
}
|
||||
@@ -1,61 +0,0 @@
|
||||
import { BaseToolWithCall } from "llamaindex";
|
||||
import { ToolsFactory } from "llamaindex/tools/ToolsFactory";
|
||||
import { DuckDuckGoSearchTool, DuckDuckGoToolParams } from "./duckduckgo";
|
||||
import { ImgGeneratorTool, ImgGeneratorToolParams } from "./img-gen";
|
||||
import { InterpreterTool, InterpreterToolParams } from "./interpreter";
|
||||
import { OpenAPIActionTool } from "./openapi-action";
|
||||
import { WeatherTool, WeatherToolParams } from "./weather";
|
||||
|
||||
type ToolCreator = (config: unknown) => Promise<BaseToolWithCall[]>;
|
||||
|
||||
export async function createTools(toolConfig: {
|
||||
local: Record<string, unknown>;
|
||||
llamahub: any;
|
||||
}): Promise<BaseToolWithCall[]> {
|
||||
// add local tools from the 'tools' folder (if configured)
|
||||
const tools = await createLocalTools(toolConfig.local);
|
||||
// add tools from LlamaIndexTS (if configured)
|
||||
tools.push(...(await ToolsFactory.createTools(toolConfig.llamahub)));
|
||||
return tools;
|
||||
}
|
||||
|
||||
const toolFactory: Record<string, ToolCreator> = {
|
||||
weather: async (config: unknown) => {
|
||||
return [new WeatherTool(config as WeatherToolParams)];
|
||||
},
|
||||
interpreter: async (config: unknown) => {
|
||||
return [new InterpreterTool(config as InterpreterToolParams)];
|
||||
},
|
||||
"openapi_action.OpenAPIActionToolSpec": async (config: unknown) => {
|
||||
const { openapi_uri, domain_headers } = config as {
|
||||
openapi_uri: string;
|
||||
domain_headers: Record<string, Record<string, string>>;
|
||||
};
|
||||
const openAPIActionTool = new OpenAPIActionTool(
|
||||
openapi_uri,
|
||||
domain_headers,
|
||||
);
|
||||
return await openAPIActionTool.toToolFunctions();
|
||||
},
|
||||
duckduckgo: async (config: unknown) => {
|
||||
return [new DuckDuckGoSearchTool(config as DuckDuckGoToolParams)];
|
||||
},
|
||||
img_gen: async (config: unknown) => {
|
||||
return [new ImgGeneratorTool(config as ImgGeneratorToolParams)];
|
||||
},
|
||||
};
|
||||
|
||||
async function createLocalTools(
|
||||
localConfig: Record<string, unknown>,
|
||||
): Promise<BaseToolWithCall[]> {
|
||||
const tools: BaseToolWithCall[] = [];
|
||||
|
||||
for (const [key, toolConfig] of Object.entries(localConfig)) {
|
||||
if (key in toolFactory) {
|
||||
const newTools = await toolFactory[key](toolConfig);
|
||||
tools.push(...newTools);
|
||||
}
|
||||
}
|
||||
|
||||
return tools;
|
||||
}
|
||||
@@ -1,189 +0,0 @@
|
||||
import { CodeInterpreter, Logs, Result } from "@e2b/code-interpreter";
|
||||
import type { JSONSchemaType } from "ajv";
|
||||
import fs from "fs";
|
||||
import { BaseTool, ToolMetadata } from "llamaindex";
|
||||
import crypto from "node:crypto";
|
||||
import path from "node:path";
|
||||
|
||||
export type InterpreterParameter = {
|
||||
code: string;
|
||||
};
|
||||
|
||||
export type InterpreterToolParams = {
|
||||
metadata?: ToolMetadata<JSONSchemaType<InterpreterParameter>>;
|
||||
apiKey?: string;
|
||||
fileServerURLPrefix?: string;
|
||||
};
|
||||
|
||||
export type InterpreterToolOutput = {
|
||||
isError: boolean;
|
||||
logs: Logs;
|
||||
extraResult: InterpreterExtraResult[];
|
||||
};
|
||||
|
||||
type InterpreterExtraType =
|
||||
| "html"
|
||||
| "markdown"
|
||||
| "svg"
|
||||
| "png"
|
||||
| "jpeg"
|
||||
| "pdf"
|
||||
| "latex"
|
||||
| "json"
|
||||
| "javascript";
|
||||
|
||||
export type InterpreterExtraResult = {
|
||||
type: InterpreterExtraType;
|
||||
content?: string;
|
||||
filename?: string;
|
||||
url?: string;
|
||||
};
|
||||
|
||||
const DEFAULT_META_DATA: ToolMetadata<JSONSchemaType<InterpreterParameter>> = {
|
||||
name: "interpreter",
|
||||
description:
|
||||
"Execute python code in a Jupyter notebook cell and return any result, stdout, stderr, display_data, and error.",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
code: {
|
||||
type: "string",
|
||||
description: "The python code to execute in a single cell.",
|
||||
},
|
||||
},
|
||||
required: ["code"],
|
||||
},
|
||||
};
|
||||
|
||||
export class InterpreterTool implements BaseTool<InterpreterParameter> {
|
||||
private readonly outputDir = "output/tool";
|
||||
private apiKey?: string;
|
||||
private fileServerURLPrefix?: string;
|
||||
metadata: ToolMetadata<JSONSchemaType<InterpreterParameter>>;
|
||||
codeInterpreter?: CodeInterpreter;
|
||||
|
||||
constructor(params?: InterpreterToolParams) {
|
||||
this.metadata = params?.metadata || DEFAULT_META_DATA;
|
||||
this.apiKey = params?.apiKey || process.env.E2B_API_KEY;
|
||||
this.fileServerURLPrefix =
|
||||
params?.fileServerURLPrefix || process.env.FILESERVER_URL_PREFIX;
|
||||
|
||||
if (!this.apiKey) {
|
||||
throw new Error(
|
||||
"E2B_API_KEY key is required to run code interpreter. Get it here: https://e2b.dev/docs/getting-started/api-key",
|
||||
);
|
||||
}
|
||||
if (!this.fileServerURLPrefix) {
|
||||
throw new Error(
|
||||
"FILESERVER_URL_PREFIX is required to display file output from sandbox",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
public async initInterpreter() {
|
||||
if (!this.codeInterpreter) {
|
||||
this.codeInterpreter = await CodeInterpreter.create({
|
||||
apiKey: this.apiKey,
|
||||
});
|
||||
}
|
||||
return this.codeInterpreter;
|
||||
}
|
||||
|
||||
public async codeInterpret(code: string): Promise<InterpreterToolOutput> {
|
||||
console.log(
|
||||
`\n${"=".repeat(50)}\n> Running following AI-generated code:\n${code}\n${"=".repeat(50)}`,
|
||||
);
|
||||
const interpreter = await this.initInterpreter();
|
||||
const exec = await interpreter.notebook.execCell(code);
|
||||
if (exec.error) console.error("[Code Interpreter error]", exec.error);
|
||||
const extraResult = await this.getExtraResult(exec.results[0]);
|
||||
const result: InterpreterToolOutput = {
|
||||
isError: !!exec.error,
|
||||
logs: exec.logs,
|
||||
extraResult,
|
||||
};
|
||||
return result;
|
||||
}
|
||||
|
||||
async call(input: InterpreterParameter): Promise<InterpreterToolOutput> {
|
||||
const result = await this.codeInterpret(input.code);
|
||||
return result;
|
||||
}
|
||||
|
||||
async close() {
|
||||
await this.codeInterpreter?.close();
|
||||
}
|
||||
|
||||
private async getExtraResult(
|
||||
res?: Result,
|
||||
): Promise<InterpreterExtraResult[]> {
|
||||
if (!res) return [];
|
||||
const output: InterpreterExtraResult[] = [];
|
||||
|
||||
try {
|
||||
const formats = res.formats(); // formats available for the result. Eg: ['png', ...]
|
||||
const results = formats.map((f) => res[f as keyof Result]); // get base64 data for each format
|
||||
|
||||
// save base64 data to file and return the url
|
||||
for (let i = 0; i < formats.length; i++) {
|
||||
const ext = formats[i];
|
||||
const data = results[i];
|
||||
switch (ext) {
|
||||
case "png":
|
||||
case "jpeg":
|
||||
case "svg":
|
||||
case "pdf":
|
||||
const { filename } = this.saveToDisk(data, ext);
|
||||
output.push({
|
||||
type: ext as InterpreterExtraType,
|
||||
filename,
|
||||
url: this.getFileUrl(filename),
|
||||
});
|
||||
break;
|
||||
default:
|
||||
output.push({
|
||||
type: ext as InterpreterExtraType,
|
||||
content: data,
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error when parsing e2b response", error);
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
// Consider saving to cloud storage instead but it may cost more for you
|
||||
// See: https://e2b.dev/docs/sandbox/api/filesystem#write-to-file
|
||||
private saveToDisk(
|
||||
base64Data: string,
|
||||
ext: string,
|
||||
): {
|
||||
outputPath: string;
|
||||
filename: string;
|
||||
} {
|
||||
const filename = `${crypto.randomUUID()}.${ext}`; // generate a unique filename
|
||||
const buffer = Buffer.from(base64Data, "base64");
|
||||
const outputPath = this.getOutputPath(filename);
|
||||
fs.writeFileSync(outputPath, buffer);
|
||||
console.log(`Saved file to ${outputPath}`);
|
||||
return {
|
||||
outputPath,
|
||||
filename,
|
||||
};
|
||||
}
|
||||
|
||||
private getOutputPath(filename: string): string {
|
||||
// if outputDir doesn't exist, create it
|
||||
if (!fs.existsSync(this.outputDir)) {
|
||||
fs.mkdirSync(this.outputDir, { recursive: true });
|
||||
}
|
||||
return path.join(this.outputDir, filename);
|
||||
}
|
||||
|
||||
private getFileUrl(filename: string): string {
|
||||
return `${this.fileServerURLPrefix}/${this.outputDir}/${filename}`;
|
||||
}
|
||||
}
|
||||
@@ -1,164 +0,0 @@
|
||||
import SwaggerParser from "@apidevtools/swagger-parser";
|
||||
import { JSONSchemaType } from "ajv";
|
||||
import got from "got";
|
||||
import { FunctionTool, JSONValue, ToolMetadata } from "llamaindex";
|
||||
|
||||
interface DomainHeaders {
|
||||
[key: string]: { [header: string]: string };
|
||||
}
|
||||
|
||||
type Input = {
|
||||
url: string;
|
||||
params: object;
|
||||
};
|
||||
|
||||
type APIInfo = {
|
||||
description: string;
|
||||
title: string;
|
||||
};
|
||||
|
||||
export class OpenAPIActionTool {
|
||||
// cache the loaded specs by URL
|
||||
private static specs: Record<string, any> = {};
|
||||
|
||||
private readonly INVALID_URL_PROMPT =
|
||||
"This url did not include a hostname or scheme. Please determine the complete URL and try again.";
|
||||
|
||||
private createLoadSpecMetaData = (info: APIInfo) => {
|
||||
return {
|
||||
name: "load_openapi_spec",
|
||||
description: `Use this to retrieve the OpenAPI spec for the API named ${info.title} with the following description: ${info.description}. Call it before making any requests to the API.`,
|
||||
};
|
||||
};
|
||||
|
||||
private readonly createMethodCallMetaData = (
|
||||
method: "POST" | "PATCH" | "GET",
|
||||
info: APIInfo,
|
||||
) => {
|
||||
return {
|
||||
name: `${method.toLowerCase()}_request`,
|
||||
description: `Use this to call the ${method} method on the API named ${info.title}`,
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
url: {
|
||||
type: "string",
|
||||
description: `The url to make the ${method} request against`,
|
||||
},
|
||||
params: {
|
||||
type: "object",
|
||||
description:
|
||||
method === "GET"
|
||||
? "the URL parameters to provide with the get request"
|
||||
: `the key-value pairs to provide with the ${method} request`,
|
||||
},
|
||||
},
|
||||
required: ["url"],
|
||||
},
|
||||
} as ToolMetadata<JSONSchemaType<Input>>;
|
||||
};
|
||||
|
||||
constructor(
|
||||
public openapi_uri: string,
|
||||
public domainHeaders: DomainHeaders = {},
|
||||
) {}
|
||||
|
||||
async loadOpenapiSpec(url: string): Promise<any> {
|
||||
const api = await SwaggerParser.validate(url);
|
||||
return {
|
||||
servers: "servers" in api ? api.servers : "",
|
||||
info: { description: api.info.description, title: api.info.title },
|
||||
endpoints: api.paths,
|
||||
};
|
||||
}
|
||||
|
||||
async getRequest(input: Input): Promise<JSONValue> {
|
||||
if (!this.validUrl(input.url)) {
|
||||
return this.INVALID_URL_PROMPT;
|
||||
}
|
||||
try {
|
||||
const data = await got
|
||||
.get(input.url, {
|
||||
headers: this.getHeadersForUrl(input.url),
|
||||
searchParams: input.params as URLSearchParams,
|
||||
})
|
||||
.json();
|
||||
return data as JSONValue;
|
||||
} catch (error) {
|
||||
return error as JSONValue;
|
||||
}
|
||||
}
|
||||
|
||||
async postRequest(input: Input): Promise<JSONValue> {
|
||||
if (!this.validUrl(input.url)) {
|
||||
return this.INVALID_URL_PROMPT;
|
||||
}
|
||||
try {
|
||||
const res = await got.post(input.url, {
|
||||
headers: this.getHeadersForUrl(input.url),
|
||||
json: input.params,
|
||||
});
|
||||
return res.body as JSONValue;
|
||||
} catch (error) {
|
||||
return error as JSONValue;
|
||||
}
|
||||
}
|
||||
|
||||
async patchRequest(input: Input): Promise<JSONValue> {
|
||||
if (!this.validUrl(input.url)) {
|
||||
return this.INVALID_URL_PROMPT;
|
||||
}
|
||||
try {
|
||||
const res = await got.patch(input.url, {
|
||||
headers: this.getHeadersForUrl(input.url),
|
||||
json: input.params,
|
||||
});
|
||||
return res.body as JSONValue;
|
||||
} catch (error) {
|
||||
return error as JSONValue;
|
||||
}
|
||||
}
|
||||
|
||||
public async toToolFunctions() {
|
||||
if (!OpenAPIActionTool.specs[this.openapi_uri]) {
|
||||
console.log(`Loading spec for URL: ${this.openapi_uri}`);
|
||||
const spec = await this.loadOpenapiSpec(this.openapi_uri);
|
||||
OpenAPIActionTool.specs[this.openapi_uri] = spec;
|
||||
}
|
||||
const spec = OpenAPIActionTool.specs[this.openapi_uri];
|
||||
// TODO: read endpoints with parameters from spec and create one tool for each endpoint
|
||||
// For now, we just create a tool for each HTTP method which does not work well for passing parameters
|
||||
return [
|
||||
FunctionTool.from(() => {
|
||||
return spec;
|
||||
}, this.createLoadSpecMetaData(spec.info)),
|
||||
FunctionTool.from(
|
||||
this.getRequest.bind(this),
|
||||
this.createMethodCallMetaData("GET", spec.info),
|
||||
),
|
||||
FunctionTool.from(
|
||||
this.postRequest.bind(this),
|
||||
this.createMethodCallMetaData("POST", spec.info),
|
||||
),
|
||||
FunctionTool.from(
|
||||
this.patchRequest.bind(this),
|
||||
this.createMethodCallMetaData("PATCH", spec.info),
|
||||
),
|
||||
];
|
||||
}
|
||||
|
||||
private validUrl(url: string): boolean {
|
||||
const parsed = new URL(url);
|
||||
return !!parsed.protocol && !!parsed.hostname;
|
||||
}
|
||||
|
||||
private getDomain(url: string): string {
|
||||
const parsed = new URL(url);
|
||||
return parsed.hostname;
|
||||
}
|
||||
|
||||
private getHeadersForUrl(url: string): { [header: string]: string } {
|
||||
const domain = this.getDomain(url);
|
||||
return this.domainHeaders[domain] || {};
|
||||
}
|
||||
}
|
||||
@@ -1,81 +0,0 @@
|
||||
import type { JSONSchemaType } from "ajv";
|
||||
import { BaseTool, ToolMetadata } from "llamaindex";
|
||||
|
||||
interface GeoLocation {
|
||||
id: string;
|
||||
name: string;
|
||||
latitude: number;
|
||||
longitude: number;
|
||||
}
|
||||
|
||||
export type WeatherParameter = {
|
||||
location: string;
|
||||
};
|
||||
|
||||
export type WeatherToolParams = {
|
||||
metadata?: ToolMetadata<JSONSchemaType<WeatherParameter>>;
|
||||
};
|
||||
|
||||
const DEFAULT_META_DATA: ToolMetadata<JSONSchemaType<WeatherParameter>> = {
|
||||
name: "get_weather_information",
|
||||
description: `
|
||||
Use this function to get the weather of any given location.
|
||||
Note that the weather code should follow WMO Weather interpretation codes (WW):
|
||||
0: Clear sky
|
||||
1, 2, 3: Mainly clear, partly cloudy, and overcast
|
||||
45, 48: Fog and depositing rime fog
|
||||
51, 53, 55: Drizzle: Light, moderate, and dense intensity
|
||||
56, 57: Freezing Drizzle: Light and dense intensity
|
||||
61, 63, 65: Rain: Slight, moderate and heavy intensity
|
||||
66, 67: Freezing Rain: Light and heavy intensity
|
||||
71, 73, 75: Snow fall: Slight, moderate, and heavy intensity
|
||||
77: Snow grains
|
||||
80, 81, 82: Rain showers: Slight, moderate, and violent
|
||||
85, 86: Snow showers slight and heavy
|
||||
95: Thunderstorm: Slight or moderate
|
||||
96, 99: Thunderstorm with slight and heavy hail
|
||||
`,
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
location: {
|
||||
type: "string",
|
||||
description: "The location to get the weather information",
|
||||
},
|
||||
},
|
||||
required: ["location"],
|
||||
},
|
||||
};
|
||||
|
||||
export class WeatherTool implements BaseTool<WeatherParameter> {
|
||||
metadata: ToolMetadata<JSONSchemaType<WeatherParameter>>;
|
||||
|
||||
private getGeoLocation = async (location: string): Promise<GeoLocation> => {
|
||||
const apiUrl = `https://geocoding-api.open-meteo.com/v1/search?name=${location}&count=10&language=en&format=json`;
|
||||
const response = await fetch(apiUrl);
|
||||
const data = await response.json();
|
||||
const { id, name, latitude, longitude } = data.results[0];
|
||||
return { id, name, latitude, longitude };
|
||||
};
|
||||
|
||||
private getWeatherByLocation = async (location: string) => {
|
||||
console.log(
|
||||
"Calling open-meteo api to get weather information of location:",
|
||||
location,
|
||||
);
|
||||
const { latitude, longitude } = await this.getGeoLocation(location);
|
||||
const timezone = Intl.DateTimeFormat().resolvedOptions().timeZone;
|
||||
const apiUrl = `https://api.open-meteo.com/v1/forecast?latitude=${latitude}&longitude=${longitude}¤t=temperature_2m,weather_code&hourly=temperature_2m,weather_code&daily=weather_code&timezone=${timezone}`;
|
||||
const response = await fetch(apiUrl);
|
||||
const data = await response.json();
|
||||
return data;
|
||||
};
|
||||
|
||||
constructor(params?: WeatherToolParams) {
|
||||
this.metadata = params?.metadata || DEFAULT_META_DATA;
|
||||
}
|
||||
|
||||
async call(input: WeatherParameter) {
|
||||
return await this.getWeatherByLocation(input.location);
|
||||
}
|
||||
}
|
||||
@@ -1,32 +1,20 @@
|
||||
import { ContextChatEngine, Settings } from "llamaindex";
|
||||
import { getDataSource } from "./index";
|
||||
import { nodeCitationProcessor } from "./nodePostprocessors";
|
||||
import { generateFilters } from "./queryFilter";
|
||||
|
||||
export async function createChatEngine(documentIds?: string[], params?: any) {
|
||||
const index = await getDataSource(params);
|
||||
export async function createChatEngine() {
|
||||
const index = await getDataSource();
|
||||
if (!index) {
|
||||
throw new Error(
|
||||
`StorageContext is empty - call 'npm run generate' to generate the storage first`,
|
||||
);
|
||||
}
|
||||
const retriever = index.asRetriever({
|
||||
similarityTopK: process.env.TOP_K ? parseInt(process.env.TOP_K) : undefined,
|
||||
filters: generateFilters(documentIds || []),
|
||||
});
|
||||
|
||||
const systemPrompt = process.env.SYSTEM_PROMPT;
|
||||
const citationPrompt = process.env.SYSTEM_CITATION_PROMPT;
|
||||
const prompt =
|
||||
[systemPrompt, citationPrompt].filter((p) => p).join("\n") || undefined;
|
||||
const nodePostprocessors = citationPrompt
|
||||
? [nodeCitationProcessor]
|
||||
: undefined;
|
||||
const retriever = index.asRetriever();
|
||||
retriever.similarityTopK = process.env.TOP_K
|
||||
? parseInt(process.env.TOP_K)
|
||||
: 3;
|
||||
|
||||
return new ContextChatEngine({
|
||||
chatModel: Settings.llm,
|
||||
retriever,
|
||||
systemPrompt: prompt,
|
||||
nodePostprocessors,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
import {
|
||||
BaseNodePostprocessor,
|
||||
MessageContent,
|
||||
NodeWithScore,
|
||||
} from "llamaindex";
|
||||
|
||||
class NodeCitationProcessor implements BaseNodePostprocessor {
|
||||
/**
|
||||
* Append node_id into metadata for citation purpose.
|
||||
* Config SYSTEM_CITATION_PROMPT in your runtime environment variable to enable this feature.
|
||||
*/
|
||||
async postprocessNodes(
|
||||
nodes: NodeWithScore[],
|
||||
query?: MessageContent,
|
||||
): Promise<NodeWithScore[]> {
|
||||
for (const nodeScore of nodes) {
|
||||
if (!nodeScore.node || !nodeScore.node.metadata) {
|
||||
continue; // Skip nodes with missing properties
|
||||
}
|
||||
nodeScore.node.metadata["node_id"] = nodeScore.node.id_;
|
||||
}
|
||||
return nodes;
|
||||
}
|
||||
}
|
||||
|
||||
export const nodeCitationProcessor = new NodeCitationProcessor();
|
||||
@@ -1,63 +0,0 @@
|
||||
import fs from "fs";
|
||||
import { getExtractors } from "../../engine/loader";
|
||||
|
||||
const MIME_TYPE_TO_EXT: Record<string, string> = {
|
||||
"application/pdf": "pdf",
|
||||
"text/plain": "txt",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document":
|
||||
"docx",
|
||||
};
|
||||
|
||||
const UPLOADED_FOLDER = "output/uploaded";
|
||||
|
||||
export async function storeAndParseFile(
|
||||
filename: string,
|
||||
fileBuffer: Buffer,
|
||||
mimeType: string,
|
||||
) {
|
||||
const documents = await loadDocuments(fileBuffer, mimeType);
|
||||
await saveDocument(filename, fileBuffer, mimeType);
|
||||
for (const document of documents) {
|
||||
document.metadata = {
|
||||
...document.metadata,
|
||||
file_name: filename,
|
||||
private: "true", // to separate private uploads from public documents
|
||||
};
|
||||
}
|
||||
return documents;
|
||||
}
|
||||
|
||||
async function loadDocuments(fileBuffer: Buffer, mimeType: string) {
|
||||
const extractors = getExtractors();
|
||||
const reader = extractors[MIME_TYPE_TO_EXT[mimeType]];
|
||||
|
||||
if (!reader) {
|
||||
throw new Error(`Unsupported document type: ${mimeType}`);
|
||||
}
|
||||
console.log(`Processing uploaded document of type: ${mimeType}`);
|
||||
return await reader.loadDataAsContent(fileBuffer);
|
||||
}
|
||||
|
||||
async function saveDocument(
|
||||
filename: string,
|
||||
fileBuffer: Buffer,
|
||||
mimeType: string,
|
||||
) {
|
||||
const fileExt = MIME_TYPE_TO_EXT[mimeType];
|
||||
if (!fileExt) throw new Error(`Unsupported document type: ${mimeType}`);
|
||||
|
||||
const filepath = `${UPLOADED_FOLDER}/${filename}`;
|
||||
const fileurl = `${process.env.FILESERVER_URL_PREFIX}/${filepath}`;
|
||||
|
||||
if (!fs.existsSync(UPLOADED_FOLDER)) {
|
||||
fs.mkdirSync(UPLOADED_FOLDER, { recursive: true });
|
||||
}
|
||||
await fs.promises.writeFile(filepath, fileBuffer);
|
||||
|
||||
console.log(`Saved document file to ${filepath}.\nURL: ${fileurl}`);
|
||||
return {
|
||||
filename,
|
||||
filepath,
|
||||
fileurl,
|
||||
};
|
||||
}
|
||||
@@ -1,28 +0,0 @@
|
||||
import {
|
||||
Document,
|
||||
IngestionPipeline,
|
||||
Settings,
|
||||
SimpleNodeParser,
|
||||
VectorStoreIndex,
|
||||
} from "llamaindex";
|
||||
|
||||
export async function runPipeline(
|
||||
currentIndex: VectorStoreIndex,
|
||||
documents: Document[],
|
||||
) {
|
||||
// Use ingestion pipeline to process the documents into nodes and add them to the vector store
|
||||
const pipeline = new IngestionPipeline({
|
||||
transformations: [
|
||||
new SimpleNodeParser({
|
||||
chunkSize: Settings.chunkSize,
|
||||
chunkOverlap: Settings.chunkOverlap,
|
||||
}),
|
||||
Settings.embedModel,
|
||||
],
|
||||
});
|
||||
const nodes = await pipeline.run({ documents });
|
||||
await currentIndex.insertNodes(nodes);
|
||||
currentIndex.storageContext.docStore.persist();
|
||||
console.log("Added nodes to the vector store.");
|
||||
return documents.map((document) => document.id_);
|
||||
}
|
||||
@@ -1,32 +0,0 @@
|
||||
import { LLamaCloudFileService, VectorStoreIndex } from "llamaindex";
|
||||
import { LlamaCloudIndex } from "llamaindex/cloud/LlamaCloudIndex";
|
||||
import { storeAndParseFile } from "./helper";
|
||||
import { runPipeline } from "./pipeline";
|
||||
|
||||
export async function uploadDocument(
|
||||
index: VectorStoreIndex | LlamaCloudIndex,
|
||||
filename: string,
|
||||
raw: string,
|
||||
): Promise<string[]> {
|
||||
const [header, content] = raw.split(",");
|
||||
const mimeType = header.replace("data:", "").replace(";base64", "");
|
||||
const fileBuffer = Buffer.from(content, "base64");
|
||||
|
||||
if (index instanceof LlamaCloudIndex) {
|
||||
// trigger LlamaCloudIndex API to upload the file and run the pipeline
|
||||
const projectId = await index.getProjectId();
|
||||
const pipelineId = await index.getPipelineId();
|
||||
return [
|
||||
await LLamaCloudFileService.addFileToPipeline(
|
||||
projectId,
|
||||
pipelineId,
|
||||
new File([fileBuffer], filename, { type: mimeType }),
|
||||
{ private: "true" },
|
||||
),
|
||||
];
|
||||
}
|
||||
|
||||
// run the pipeline for other vector store indexes
|
||||
const documents = await storeAndParseFile(filename, fileBuffer, mimeType);
|
||||
return runPipeline(index, documents);
|
||||
}
|
||||
@@ -1,124 +0,0 @@
|
||||
import { JSONValue } from "ai";
|
||||
import { MessageContent, MessageContentDetail } from "llamaindex";
|
||||
|
||||
export type DocumentFileType = "csv" | "pdf" | "txt" | "docx";
|
||||
|
||||
export type DocumentFileContent = {
|
||||
type: "ref" | "text";
|
||||
value: string[] | string;
|
||||
};
|
||||
|
||||
export type DocumentFile = {
|
||||
id: string;
|
||||
filename: string;
|
||||
filesize: number;
|
||||
filetype: DocumentFileType;
|
||||
content: DocumentFileContent;
|
||||
};
|
||||
|
||||
type Annotation = {
|
||||
type: string;
|
||||
data: object;
|
||||
};
|
||||
|
||||
export function retrieveDocumentIds(annotations?: JSONValue[]): string[] {
|
||||
if (!annotations) return [];
|
||||
|
||||
const ids: string[] = [];
|
||||
|
||||
for (const annotation of annotations) {
|
||||
const { type, data } = getValidAnnotation(annotation);
|
||||
if (
|
||||
type === "document_file" &&
|
||||
"files" in data &&
|
||||
Array.isArray(data.files)
|
||||
) {
|
||||
const files = data.files as DocumentFile[];
|
||||
for (const file of files) {
|
||||
if (Array.isArray(file.content.value)) {
|
||||
// it's an array, so it's an array of doc IDs
|
||||
for (const id of file.content.value) {
|
||||
ids.push(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ids;
|
||||
}
|
||||
|
||||
export function convertMessageContent(
|
||||
content: string,
|
||||
annotations?: JSONValue[],
|
||||
): MessageContent {
|
||||
if (!annotations) return content;
|
||||
return [
|
||||
{
|
||||
type: "text",
|
||||
text: content,
|
||||
},
|
||||
...convertAnnotations(annotations),
|
||||
];
|
||||
}
|
||||
|
||||
function convertAnnotations(annotations: JSONValue[]): MessageContentDetail[] {
|
||||
const content: MessageContentDetail[] = [];
|
||||
annotations.forEach((annotation: JSONValue) => {
|
||||
const { type, data } = getValidAnnotation(annotation);
|
||||
// convert image
|
||||
if (type === "image" && "url" in data && typeof data.url === "string") {
|
||||
content.push({
|
||||
type: "image_url",
|
||||
image_url: {
|
||||
url: data.url,
|
||||
},
|
||||
});
|
||||
}
|
||||
// convert the content of files to a text message
|
||||
if (
|
||||
type === "document_file" &&
|
||||
"files" in data &&
|
||||
Array.isArray(data.files)
|
||||
) {
|
||||
// get all CSV files and convert their whole content to one text message
|
||||
// currently CSV files are the only files where we send the whole content - we don't use an index
|
||||
const csvFiles: DocumentFile[] = data.files.filter(
|
||||
(file: DocumentFile) => file.filetype === "csv",
|
||||
);
|
||||
if (csvFiles && csvFiles.length > 0) {
|
||||
const csvContents = csvFiles.map((file: DocumentFile) => {
|
||||
const fileContent = Array.isArray(file.content.value)
|
||||
? file.content.value.join("\n")
|
||||
: file.content.value;
|
||||
return "```csv\n" + fileContent + "\n```";
|
||||
});
|
||||
const text =
|
||||
"Use the following CSV content:\n" + csvContents.join("\n\n");
|
||||
content.push({
|
||||
type: "text",
|
||||
text,
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
return content;
|
||||
}
|
||||
|
||||
function getValidAnnotation(annotation: JSONValue): Annotation {
|
||||
if (
|
||||
!(
|
||||
annotation &&
|
||||
typeof annotation === "object" &&
|
||||
"type" in annotation &&
|
||||
typeof annotation.type === "string" &&
|
||||
"data" in annotation &&
|
||||
annotation.data &&
|
||||
typeof annotation.data === "object"
|
||||
)
|
||||
) {
|
||||
throw new Error("Client sent invalid annotation. Missing data and type");
|
||||
}
|
||||
return { type: annotation.type, data: annotation.data };
|
||||
}
|
||||
@@ -1,191 +0,0 @@
|
||||
import { StreamData } from "ai";
|
||||
import {
|
||||
CallbackManager,
|
||||
LLamaCloudFileService,
|
||||
Metadata,
|
||||
MetadataMode,
|
||||
NodeWithScore,
|
||||
ToolCall,
|
||||
ToolOutput,
|
||||
} from "llamaindex";
|
||||
import path from "node:path";
|
||||
import { DATA_DIR } from "../../engine/loader";
|
||||
import { downloadFile } from "./file";
|
||||
|
||||
const LLAMA_CLOUD_DOWNLOAD_FOLDER = "output/llamacloud";
|
||||
|
||||
export function appendSourceData(
|
||||
data: StreamData,
|
||||
sourceNodes?: NodeWithScore<Metadata>[],
|
||||
) {
|
||||
if (!sourceNodes?.length) return;
|
||||
try {
|
||||
const nodes = sourceNodes.map((node) => ({
|
||||
metadata: node.node.metadata,
|
||||
id: node.node.id_,
|
||||
score: node.score ?? null,
|
||||
url: getNodeUrl(node.node.metadata),
|
||||
text: node.node.getContent(MetadataMode.NONE),
|
||||
}));
|
||||
data.appendMessageAnnotation({
|
||||
type: "sources",
|
||||
data: {
|
||||
nodes,
|
||||
},
|
||||
});
|
||||
} catch (error) {
|
||||
console.error("Error appending source data:", error);
|
||||
}
|
||||
}
|
||||
|
||||
export function appendEventData(data: StreamData, title?: string) {
|
||||
if (!title) return;
|
||||
data.appendMessageAnnotation({
|
||||
type: "events",
|
||||
data: {
|
||||
title,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
export function appendToolData(
|
||||
data: StreamData,
|
||||
toolCall: ToolCall,
|
||||
toolOutput: ToolOutput,
|
||||
) {
|
||||
data.appendMessageAnnotation({
|
||||
type: "tools",
|
||||
data: {
|
||||
toolCall: {
|
||||
id: toolCall.id,
|
||||
name: toolCall.name,
|
||||
input: toolCall.input,
|
||||
},
|
||||
toolOutput: {
|
||||
output: toolOutput.output,
|
||||
isError: toolOutput.isError,
|
||||
},
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
export function createStreamTimeout(stream: StreamData) {
|
||||
const timeout = Number(process.env.STREAM_TIMEOUT ?? 1000 * 60 * 5); // default to 5 minutes
|
||||
const t = setTimeout(() => {
|
||||
appendEventData(stream, `Stream timed out after ${timeout / 1000} seconds`);
|
||||
stream.close();
|
||||
}, timeout);
|
||||
return t;
|
||||
}
|
||||
|
||||
export function createCallbackManager(stream: StreamData) {
|
||||
const callbackManager = new CallbackManager();
|
||||
|
||||
callbackManager.on("retrieve-end", (data) => {
|
||||
const { nodes, query } = data.detail;
|
||||
appendSourceData(stream, nodes);
|
||||
appendEventData(stream, `Retrieving context for query: '${query}'`);
|
||||
appendEventData(
|
||||
stream,
|
||||
`Retrieved ${nodes.length} sources to use as context for the query`,
|
||||
);
|
||||
downloadFilesFromNodes(nodes); // don't await to avoid blocking chat streaming
|
||||
});
|
||||
|
||||
callbackManager.on("llm-tool-call", (event) => {
|
||||
const { name, input } = event.detail.toolCall;
|
||||
const inputString = Object.entries(input)
|
||||
.map(([key, value]) => `${key}: ${value}`)
|
||||
.join(", ");
|
||||
appendEventData(
|
||||
stream,
|
||||
`Using tool: '${name}' with inputs: '${inputString}'`,
|
||||
);
|
||||
});
|
||||
|
||||
callbackManager.on("llm-tool-result", (event) => {
|
||||
const { toolCall, toolResult } = event.detail;
|
||||
appendToolData(stream, toolCall, toolResult);
|
||||
});
|
||||
|
||||
return callbackManager;
|
||||
}
|
||||
|
||||
function getNodeUrl(metadata: Metadata) {
|
||||
if (!process.env.FILESERVER_URL_PREFIX) {
|
||||
console.warn(
|
||||
"FILESERVER_URL_PREFIX is not set. File URLs will not be generated.",
|
||||
);
|
||||
}
|
||||
const fileName = metadata["file_name"];
|
||||
if (fileName && process.env.FILESERVER_URL_PREFIX) {
|
||||
// file_name exists and file server is configured
|
||||
const pipelineId = metadata["pipeline_id"];
|
||||
if (pipelineId) {
|
||||
const name = toDownloadedName(pipelineId, fileName);
|
||||
return `${process.env.FILESERVER_URL_PREFIX}/${LLAMA_CLOUD_DOWNLOAD_FOLDER}/${name}`;
|
||||
}
|
||||
const isPrivate = metadata["private"] === "true";
|
||||
if (isPrivate) {
|
||||
return `${process.env.FILESERVER_URL_PREFIX}/output/uploaded/${fileName}`;
|
||||
}
|
||||
const filePath = metadata["file_path"];
|
||||
const dataDir = path.resolve(DATA_DIR);
|
||||
|
||||
if (filePath && dataDir) {
|
||||
const relativePath = path.relative(dataDir, filePath);
|
||||
return `${process.env.FILESERVER_URL_PREFIX}/data/${relativePath}`;
|
||||
}
|
||||
}
|
||||
// fallback to URL in metadata (e.g. for websites)
|
||||
return metadata["URL"];
|
||||
}
|
||||
|
||||
async function downloadFilesFromNodes(nodes: NodeWithScore<Metadata>[]) {
|
||||
try {
|
||||
const files = nodesToLlamaCloudFiles(nodes);
|
||||
for (const { pipelineId, fileName, downloadedName } of files) {
|
||||
const downloadUrl = await LLamaCloudFileService.getFileUrl(
|
||||
pipelineId,
|
||||
fileName,
|
||||
);
|
||||
if (downloadUrl) {
|
||||
await downloadFile(
|
||||
downloadUrl,
|
||||
downloadedName,
|
||||
LLAMA_CLOUD_DOWNLOAD_FOLDER,
|
||||
);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error downloading files from nodes:", error);
|
||||
}
|
||||
}
|
||||
|
||||
function nodesToLlamaCloudFiles(nodes: NodeWithScore<Metadata>[]) {
|
||||
const files: Array<{
|
||||
pipelineId: string;
|
||||
fileName: string;
|
||||
downloadedName: string;
|
||||
}> = [];
|
||||
for (const node of nodes) {
|
||||
const pipelineId = node.node.metadata["pipeline_id"];
|
||||
const fileName = node.node.metadata["file_name"];
|
||||
if (!pipelineId || !fileName) continue;
|
||||
const isDuplicate = files.some(
|
||||
(f) => f.pipelineId === pipelineId && f.fileName === fileName,
|
||||
);
|
||||
if (!isDuplicate) {
|
||||
files.push({
|
||||
pipelineId,
|
||||
fileName,
|
||||
downloadedName: toDownloadedName(pipelineId, fileName),
|
||||
});
|
||||
}
|
||||
}
|
||||
return files;
|
||||
}
|
||||
|
||||
function toDownloadedName(pipelineId: string, fileName: string) {
|
||||
return `${pipelineId}$${fileName}`;
|
||||
}
|
||||
@@ -1,35 +0,0 @@
|
||||
import fs from "node:fs";
|
||||
import https from "node:https";
|
||||
import path from "node:path";
|
||||
|
||||
export async function downloadFile(
|
||||
urlToDownload: string,
|
||||
filename: string,
|
||||
folder = "output/uploaded",
|
||||
) {
|
||||
try {
|
||||
const downloadedPath = path.join(folder, filename);
|
||||
|
||||
// Check if file already exists
|
||||
if (fs.existsSync(downloadedPath)) return;
|
||||
|
||||
const file = fs.createWriteStream(downloadedPath);
|
||||
https
|
||||
.get(urlToDownload, (response) => {
|
||||
response.pipe(file);
|
||||
file.on("finish", () => {
|
||||
file.close(() => {
|
||||
console.log("File downloaded successfully");
|
||||
});
|
||||
});
|
||||
})
|
||||
.on("error", (err) => {
|
||||
fs.unlink(downloadedPath, () => {
|
||||
console.error("Error downloading file:", err);
|
||||
throw err;
|
||||
});
|
||||
});
|
||||
} catch (error) {
|
||||
throw new Error(`Error downloading file: ${error}`);
|
||||
}
|
||||
}
|
||||
@@ -1,43 +0,0 @@
|
||||
import { ChatMessage, Settings } from "llamaindex";
|
||||
|
||||
export async function generateNextQuestions(conversation: ChatMessage[]) {
|
||||
const llm = Settings.llm;
|
||||
const NEXT_QUESTION_PROMPT = process.env.NEXT_QUESTION_PROMPT;
|
||||
if (!NEXT_QUESTION_PROMPT) {
|
||||
return [];
|
||||
}
|
||||
|
||||
// Format conversation
|
||||
const conversationText = conversation
|
||||
.map((message) => `${message.role}: ${message.content}`)
|
||||
.join("\n");
|
||||
const message = NEXT_QUESTION_PROMPT.replace(
|
||||
"{conversation}",
|
||||
conversationText,
|
||||
);
|
||||
|
||||
try {
|
||||
const response = await llm.complete({ prompt: message });
|
||||
const questions = extractQuestions(response.text);
|
||||
return questions;
|
||||
} catch (error) {
|
||||
console.error("Error when generating the next questions: ", error);
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: instead of parsing the LLM's result we can use structured predict, once LITS supports it
|
||||
function extractQuestions(text: string): string[] {
|
||||
// Extract the text inside the triple backticks
|
||||
// @ts-ignore
|
||||
const contentMatch = text.match(/```(.*?)```/s);
|
||||
const content = contentMatch ? contentMatch[1] : "";
|
||||
|
||||
// Split the content by newlines to get each question
|
||||
const questions = content
|
||||
.split("\n")
|
||||
.map((question) => question.trim())
|
||||
.filter((question) => question !== "");
|
||||
|
||||
return questions;
|
||||
}
|
||||
@@ -1,9 +1,11 @@
|
||||
import logging
|
||||
|
||||
import os
|
||||
import yaml
|
||||
from app.engine.loaders.db import DBLoaderConfig, get_db_documents
|
||||
import importlib
|
||||
import logging
|
||||
from typing import Dict
|
||||
from app.engine.loaders.file import FileLoaderConfig, get_file_documents
|
||||
from app.engine.loaders.web import WebLoaderConfig, get_web_documents
|
||||
from app.engine.loaders.db import DBLoaderConfig, get_db_documents
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import List
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, validator
|
||||
from llama_index.core.indices.vector_store import VectorStoreIndex
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict
|
||||
from llama_parse import LlamaParse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.config import DATA_DIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
|
||||
class FileLoaderConfig(BaseModel):
|
||||
data_dir: str = "data"
|
||||
use_llama_parse: bool = False
|
||||
|
||||
@validator("data_dir")
|
||||
def data_dir_must_exist(cls, v):
|
||||
if not os.path.isdir(v):
|
||||
raise ValueError(f"Directory '{v}' does not exist")
|
||||
return v
|
||||
|
||||
|
||||
def llama_parse_parser():
|
||||
if os.getenv("LLAMA_CLOUD_API_KEY") is None:
|
||||
@@ -19,56 +20,15 @@ def llama_parse_parser():
|
||||
"LLAMA_CLOUD_API_KEY environment variable is not set. "
|
||||
"Please set it in .env file or in your shell environment then run again!"
|
||||
)
|
||||
parser = LlamaParse(
|
||||
result_type="markdown",
|
||||
verbose=True,
|
||||
language="en",
|
||||
ignore_errors=False,
|
||||
)
|
||||
parser = LlamaParse(result_type="markdown", verbose=True, language="en")
|
||||
return parser
|
||||
|
||||
|
||||
def llama_parse_extractor() -> Dict[str, LlamaParse]:
|
||||
from llama_parse.utils import SUPPORTED_FILE_TYPES
|
||||
|
||||
parser = llama_parse_parser()
|
||||
return {file_type: parser for file_type in SUPPORTED_FILE_TYPES}
|
||||
|
||||
|
||||
def get_file_documents(config: FileLoaderConfig):
|
||||
from llama_index.core.readers import SimpleDirectoryReader
|
||||
|
||||
try:
|
||||
file_extractor = None
|
||||
if config.use_llama_parse:
|
||||
# LlamaParse is async first,
|
||||
# so we need to use nest_asyncio to run it in sync mode
|
||||
import nest_asyncio
|
||||
|
||||
nest_asyncio.apply()
|
||||
|
||||
file_extractor = llama_parse_extractor()
|
||||
reader = SimpleDirectoryReader(
|
||||
DATA_DIR,
|
||||
recursive=True,
|
||||
filename_as_id=True,
|
||||
raise_on_error=True,
|
||||
file_extractor=file_extractor,
|
||||
)
|
||||
return reader.load_data()
|
||||
except Exception as e:
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
# Catch the error if the data dir is empty
|
||||
# and return as empty document list
|
||||
_, _, exc_traceback = sys.exc_info()
|
||||
function_name = traceback.extract_tb(exc_traceback)[-1].name
|
||||
if function_name == "_add_files":
|
||||
logger.warning(
|
||||
f"Failed to load file documents, error message: {e} . Return as empty document list."
|
||||
)
|
||||
return []
|
||||
else:
|
||||
# Raise the error if it is not the case of empty data dir
|
||||
raise e
|
||||
reader = SimpleDirectoryReader(config.data_dir, recursive=True, filename_as_id=True)
|
||||
if config.use_llama_parse:
|
||||
parser = llama_parse_parser()
|
||||
reader.file_extractor = {".pdf": parser}
|
||||
return reader.load_data()
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import os
|
||||
import json
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
|
||||
@@ -1,24 +1,9 @@
|
||||
import {
|
||||
FILE_EXT_TO_READER,
|
||||
SimpleDirectoryReader,
|
||||
} from "llamaindex/readers/SimpleDirectoryReader";
|
||||
import { SimpleDirectoryReader } from "llamaindex";
|
||||
|
||||
export const DATA_DIR = "./data";
|
||||
|
||||
export function getExtractors() {
|
||||
return FILE_EXT_TO_READER;
|
||||
}
|
||||
|
||||
export async function getDocuments() {
|
||||
const documents = await new SimpleDirectoryReader().loadData({
|
||||
return await new SimpleDirectoryReader().loadData({
|
||||
directoryPath: DATA_DIR,
|
||||
});
|
||||
// Set private=false to mark the document as public (required for filtering)
|
||||
for (const document of documents) {
|
||||
document.metadata = {
|
||||
...document.metadata,
|
||||
private: "false",
|
||||
};
|
||||
}
|
||||
return documents;
|
||||
}
|
||||
|
||||
@@ -1,38 +1,19 @@
|
||||
import { LlamaParseReader } from "llamaindex/readers/LlamaParseReader";
|
||||
import {
|
||||
FILE_EXT_TO_READER,
|
||||
LlamaParseReader,
|
||||
SimpleDirectoryReader,
|
||||
} from "llamaindex/readers/SimpleDirectoryReader";
|
||||
} from "llamaindex";
|
||||
|
||||
export const DATA_DIR = "./data";
|
||||
|
||||
export function getExtractors() {
|
||||
const llamaParseParser = new LlamaParseReader({ resultType: "markdown" });
|
||||
const extractors = FILE_EXT_TO_READER;
|
||||
// Change all the supported extractors to LlamaParse
|
||||
// except for .txt, it doesn't need to be parsed
|
||||
for (const key in extractors) {
|
||||
if (key === "txt") {
|
||||
continue;
|
||||
}
|
||||
extractors[key] = llamaParseParser;
|
||||
}
|
||||
return extractors;
|
||||
}
|
||||
|
||||
export async function getDocuments() {
|
||||
const reader = new SimpleDirectoryReader();
|
||||
const extractors = getExtractors();
|
||||
const documents = await reader.loadData({
|
||||
// Load PDFs using LlamaParseReader
|
||||
return await reader.loadData({
|
||||
directoryPath: DATA_DIR,
|
||||
fileExtToReader: extractors,
|
||||
fileExtToReader: {
|
||||
...FILE_EXT_TO_READER,
|
||||
pdf: new LlamaParseReader({ resultType: "markdown" }),
|
||||
},
|
||||
});
|
||||
// Set private=false to mark the document as public (required for filtering)
|
||||
for (const document of documents) {
|
||||
document.metadata = {
|
||||
...document.metadata,
|
||||
private: "false",
|
||||
};
|
||||
}
|
||||
return documents;
|
||||
}
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
import llama_index.core
|
||||
import os
|
||||
|
||||
|
||||
def init_observability():
|
||||
PHOENIX_API_KEY = os.getenv("PHOENIX_API_KEY")
|
||||
if not PHOENIX_API_KEY:
|
||||
raise ValueError("PHOENIX_API_KEY environment variable is not set")
|
||||
os.environ["OTEL_EXPORTER_OTLP_HEADERS"] = f"api_key={PHOENIX_API_KEY}"
|
||||
llama_index.core.set_global_handler(
|
||||
"arize_phoenix", endpoint="https://llamatrace.com/v1/traces"
|
||||
)
|
||||
@@ -6,7 +6,7 @@ authors = ["Marcus Schiesser <mail@marcusschiesser.de>"]
|
||||
readme = "README.md"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.11,<4.0"
|
||||
python = "^3.11,<3.12"
|
||||
llama-index = "^0.10.6"
|
||||
llama-index-readers-file = "^0.1.3"
|
||||
python-dotenv = "^1.0.0"
|
||||
|
||||
@@ -1,124 +0,0 @@
|
||||
import base64
|
||||
import mimetypes
|
||||
import os
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from app.engine.index import IndexConfig, get_index
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.ingestion import IngestionPipeline
|
||||
from llama_index.core.readers.file.base import (
|
||||
_try_loading_included_file_formats as get_file_loaders_map,
|
||||
)
|
||||
from llama_index.core.schema import Document
|
||||
from llama_index.indices.managed.llama_cloud.base import LlamaCloudIndex
|
||||
from llama_index.readers.file import FlatReader
|
||||
|
||||
|
||||
def get_llamaparse_parser():
|
||||
from app.engine.loaders import load_configs
|
||||
from app.engine.loaders.file import FileLoaderConfig, llama_parse_parser
|
||||
|
||||
config = load_configs()
|
||||
file_loader_config = FileLoaderConfig(**config["file"])
|
||||
if file_loader_config.use_llama_parse:
|
||||
return llama_parse_parser()
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def default_file_loaders_map():
|
||||
default_loaders = get_file_loaders_map()
|
||||
default_loaders[".txt"] = FlatReader
|
||||
return default_loaders
|
||||
|
||||
|
||||
class PrivateFileService:
|
||||
PRIVATE_STORE_PATH = "output/uploaded"
|
||||
|
||||
@staticmethod
|
||||
def preprocess_base64_file(base64_content: str) -> Tuple[bytes, str | None]:
|
||||
header, data = base64_content.split(",", 1)
|
||||
mime_type = header.split(";")[0].split(":", 1)[1]
|
||||
extension = mimetypes.guess_extension(mime_type)
|
||||
# File data as bytes
|
||||
return base64.b64decode(data), extension
|
||||
|
||||
@staticmethod
|
||||
def store_and_parse_file(file_name, file_data, extension) -> List[Document]:
|
||||
# Store file to the private directory
|
||||
os.makedirs(PrivateFileService.PRIVATE_STORE_PATH, exist_ok=True)
|
||||
file_path = Path(os.path.join(PrivateFileService.PRIVATE_STORE_PATH, file_name))
|
||||
|
||||
# write file
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(file_data)
|
||||
|
||||
# Load file to documents
|
||||
# If LlamaParse is enabled, use it to parse the file
|
||||
# Otherwise, use the default file loaders
|
||||
reader = get_llamaparse_parser()
|
||||
if reader is None:
|
||||
reader_cls = default_file_loaders_map().get(extension)
|
||||
if reader_cls is None:
|
||||
raise ValueError(f"File extension {extension} is not supported")
|
||||
reader = reader_cls()
|
||||
documents = reader.load_data(file_path)
|
||||
# Add custom metadata
|
||||
for doc in documents:
|
||||
doc.metadata["file_name"] = file_name
|
||||
doc.metadata["private"] = "true"
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def process_file(
|
||||
file_name: str, base64_content: str, params: Optional[dict] = None
|
||||
) -> List[str]:
|
||||
if params is None:
|
||||
params = {}
|
||||
|
||||
file_data, extension = PrivateFileService.preprocess_base64_file(base64_content)
|
||||
|
||||
# Add the nodes to the index and persist it
|
||||
index_config = IndexConfig(**params)
|
||||
current_index = get_index(index_config)
|
||||
|
||||
# Insert the documents into the index
|
||||
if isinstance(current_index, LlamaCloudIndex):
|
||||
from app.engine.service import LLamaCloudFileService
|
||||
|
||||
project_id = current_index._get_project_id()
|
||||
pipeline_id = current_index._get_pipeline_id()
|
||||
# LlamaCloudIndex is a managed index so we can directly use the files
|
||||
upload_file = (file_name, BytesIO(file_data))
|
||||
return [
|
||||
LLamaCloudFileService.add_file_to_pipeline(
|
||||
project_id,
|
||||
pipeline_id,
|
||||
upload_file,
|
||||
custom_metadata={
|
||||
# Set private=true to mark the document as private user docs (required for filtering)
|
||||
"private": "true",
|
||||
},
|
||||
)
|
||||
]
|
||||
else:
|
||||
# First process documents into nodes
|
||||
documents = PrivateFileService.store_and_parse_file(
|
||||
file_name, file_data, extension
|
||||
)
|
||||
pipeline = IngestionPipeline()
|
||||
nodes = pipeline.run(documents=documents)
|
||||
|
||||
# Add the nodes to the index and persist it
|
||||
if current_index is None:
|
||||
current_index = VectorStoreIndex(nodes=nodes)
|
||||
else:
|
||||
current_index.insert_nodes(nodes=nodes)
|
||||
current_index.storage_context.persist(
|
||||
persist_dir=os.environ.get("STORAGE_DIR", "storage")
|
||||
)
|
||||
|
||||
# Return the document ids
|
||||
return [doc.doc_id for doc in documents]
|
||||
@@ -1,78 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
from app.api.routers.models import Message
|
||||
from llama_index.core.prompts import PromptTemplate
|
||||
from llama_index.core.settings import Settings
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
class NextQuestionSuggestion:
|
||||
"""
|
||||
Suggest the next questions that user might ask based on the conversation history
|
||||
Disable this feature by removing the NEXT_QUESTION_PROMPT environment variable
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_configured_prompt(cls) -> Optional[str]:
|
||||
prompt = os.getenv("NEXT_QUESTION_PROMPT", None)
|
||||
if not prompt:
|
||||
return None
|
||||
return PromptTemplate(prompt)
|
||||
|
||||
@classmethod
|
||||
async def suggest_next_questions_all_messages(
|
||||
cls,
|
||||
messages: List[Message],
|
||||
) -> Optional[List[str]]:
|
||||
"""
|
||||
Suggest the next questions that user might ask based on the conversation history
|
||||
Return None if suggestion is disabled or there is an error
|
||||
"""
|
||||
prompt_template = cls.get_configured_prompt()
|
||||
if not prompt_template:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Reduce the cost by only using the last two messages
|
||||
last_user_message = None
|
||||
last_assistant_message = None
|
||||
for message in reversed(messages):
|
||||
if message.role == "user":
|
||||
last_user_message = f"User: {message.content}"
|
||||
elif message.role == "assistant":
|
||||
last_assistant_message = f"Assistant: {message.content}"
|
||||
if last_user_message and last_assistant_message:
|
||||
break
|
||||
conversation: str = f"{last_user_message}\n{last_assistant_message}"
|
||||
|
||||
# Call the LLM and parse questions from the output
|
||||
prompt = prompt_template.format(conversation=conversation)
|
||||
output = await Settings.llm.acomplete(prompt)
|
||||
questions = cls._extract_questions(output.text)
|
||||
|
||||
return questions
|
||||
except Exception as e:
|
||||
logger.error(f"Error when generating next question: {e}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _extract_questions(cls, text: str) -> List[str]:
|
||||
content_match = re.search(r"```(.*?)```", text, re.DOTALL)
|
||||
content = content_match.group(1) if content_match else ""
|
||||
return content.strip().split("\n")
|
||||
|
||||
@classmethod
|
||||
async def suggest_next_questions(
|
||||
cls,
|
||||
chat_history: List[Message],
|
||||
response: str,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Suggest the next questions that user might ask based on the chat history and the last response
|
||||
"""
|
||||
messages = chat_history + [Message(role="assistant", content=response)]
|
||||
return await cls.suggest_next_questions_all_messages(messages)
|
||||
@@ -1,64 +0,0 @@
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
from llama_index.core.settings import Settings
|
||||
from typing import Dict
|
||||
import os
|
||||
|
||||
DEFAULT_MODEL = "gpt-3.5-turbo"
|
||||
DEFAULT_EMBEDDING_MODEL = "text-embedding-3-large"
|
||||
|
||||
|
||||
class TSIEmbedding(OpenAIEmbedding):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._query_engine = self._text_engine = self.model_name
|
||||
|
||||
|
||||
def llm_config_from_env() -> Dict:
|
||||
from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||
|
||||
model = os.getenv("MODEL", DEFAULT_MODEL)
|
||||
temperature = os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)
|
||||
max_tokens = os.getenv("LLM_MAX_TOKENS")
|
||||
api_key = os.getenv("T_SYSTEMS_LLMHUB_API_KEY")
|
||||
api_base = os.getenv("T_SYSTEMS_LLMHUB_BASE_URL")
|
||||
|
||||
config = {
|
||||
"model": model,
|
||||
"api_key": api_key,
|
||||
"api_base": api_base,
|
||||
"temperature": float(temperature),
|
||||
"max_tokens": int(max_tokens) if max_tokens is not None else None,
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def embedding_config_from_env() -> Dict:
|
||||
from llama_index.core.constants import DEFAULT_EMBEDDING_DIM
|
||||
|
||||
model = os.getenv("EMBEDDING_MODEL", DEFAULT_EMBEDDING_MODEL)
|
||||
dimension = os.getenv("EMBEDDING_DIM", DEFAULT_EMBEDDING_DIM)
|
||||
api_key = os.getenv("T_SYSTEMS_LLMHUB_API_KEY")
|
||||
api_base = os.getenv("T_SYSTEMS_LLMHUB_BASE_URL")
|
||||
|
||||
config = {
|
||||
"model_name": model,
|
||||
"dimension": int(dimension) if dimension is not None else None,
|
||||
"api_key": api_key,
|
||||
"api_base": api_base,
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def init_llmhub():
|
||||
from llama_index.llms.openai_like import OpenAILike
|
||||
|
||||
llm_configs = llm_config_from_env()
|
||||
embedding_configs = embedding_config_from_env()
|
||||
|
||||
Settings.embed_model = TSIEmbedding(**embedding_configs)
|
||||
Settings.llm = OpenAILike(
|
||||
**llm_configs,
|
||||
is_chat_model=True,
|
||||
is_function_calling_model=False,
|
||||
context_window=4096,
|
||||
)
|
||||
@@ -1,166 +0,0 @@
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from llama_index.core.settings import Settings
|
||||
|
||||
|
||||
def init_settings():
|
||||
model_provider = os.getenv("MODEL_PROVIDER")
|
||||
match model_provider:
|
||||
case "openai":
|
||||
init_openai()
|
||||
case "groq":
|
||||
init_groq()
|
||||
case "ollama":
|
||||
init_ollama()
|
||||
case "anthropic":
|
||||
init_anthropic()
|
||||
case "gemini":
|
||||
init_gemini()
|
||||
case "mistral":
|
||||
init_mistral()
|
||||
case "azure-openai":
|
||||
init_azure_openai()
|
||||
case "t-systems":
|
||||
from .llmhub import init_llmhub
|
||||
|
||||
init_llmhub()
|
||||
case _:
|
||||
raise ValueError(f"Invalid model provider: {model_provider}")
|
||||
|
||||
Settings.chunk_size = int(os.getenv("CHUNK_SIZE", "1024"))
|
||||
Settings.chunk_overlap = int(os.getenv("CHUNK_OVERLAP", "20"))
|
||||
|
||||
|
||||
def init_ollama():
|
||||
from llama_index.embeddings.ollama import OllamaEmbedding
|
||||
from llama_index.llms.ollama.base import DEFAULT_REQUEST_TIMEOUT, Ollama
|
||||
|
||||
base_url = os.getenv("OLLAMA_BASE_URL") or "http://127.0.0.1:11434"
|
||||
request_timeout = float(
|
||||
os.getenv("OLLAMA_REQUEST_TIMEOUT", DEFAULT_REQUEST_TIMEOUT)
|
||||
)
|
||||
Settings.embed_model = OllamaEmbedding(
|
||||
base_url=base_url,
|
||||
model_name=os.getenv("EMBEDDING_MODEL"),
|
||||
)
|
||||
Settings.llm = Ollama(
|
||||
base_url=base_url, model=os.getenv("MODEL"), request_timeout=request_timeout
|
||||
)
|
||||
|
||||
|
||||
def init_openai():
|
||||
from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
from llama_index.llms.openai import OpenAI
|
||||
|
||||
max_tokens = os.getenv("LLM_MAX_TOKENS")
|
||||
config = {
|
||||
"model": os.getenv("MODEL"),
|
||||
"temperature": float(os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)),
|
||||
"max_tokens": int(max_tokens) if max_tokens is not None else None,
|
||||
}
|
||||
Settings.llm = OpenAI(**config)
|
||||
|
||||
dimensions = os.getenv("EMBEDDING_DIM")
|
||||
config = {
|
||||
"model": os.getenv("EMBEDDING_MODEL"),
|
||||
"dimensions": int(dimensions) if dimensions is not None else None,
|
||||
}
|
||||
Settings.embed_model = OpenAIEmbedding(**config)
|
||||
|
||||
|
||||
def init_azure_openai():
|
||||
from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
|
||||
from llama_index.llms.azure_openai import AzureOpenAI
|
||||
|
||||
llm_deployment = os.environ["AZURE_OPENAI_LLM_DEPLOYMENT"]
|
||||
embedding_deployment = os.environ["AZURE_OPENAI_EMBEDDING_DEPLOYMENT"]
|
||||
max_tokens = os.getenv("LLM_MAX_TOKENS")
|
||||
temperature = os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)
|
||||
dimensions = os.getenv("EMBEDDING_DIM")
|
||||
|
||||
azure_config = {
|
||||
"api_key": os.environ["AZURE_OPENAI_API_KEY"],
|
||||
"azure_endpoint": os.environ["AZURE_OPENAI_ENDPOINT"],
|
||||
"api_version": os.getenv("AZURE_OPENAI_API_VERSION")
|
||||
or os.getenv("OPENAI_API_VERSION"),
|
||||
}
|
||||
|
||||
Settings.llm = AzureOpenAI(
|
||||
model=os.getenv("MODEL"),
|
||||
max_tokens=int(max_tokens) if max_tokens is not None else None,
|
||||
temperature=float(temperature),
|
||||
deployment_name=llm_deployment,
|
||||
**azure_config,
|
||||
)
|
||||
|
||||
Settings.embed_model = AzureOpenAIEmbedding(
|
||||
model=os.getenv("EMBEDDING_MODEL"),
|
||||
dimensions=int(dimensions) if dimensions is not None else None,
|
||||
deployment_name=embedding_deployment,
|
||||
**azure_config,
|
||||
)
|
||||
|
||||
|
||||
def init_fastembed():
|
||||
"""
|
||||
Use Qdrant Fastembed as the local embedding provider.
|
||||
"""
|
||||
from llama_index.embeddings.fastembed import FastEmbedEmbedding
|
||||
|
||||
embed_model_map: Dict[str, str] = {
|
||||
# Small and multilingual
|
||||
"all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
# Large and multilingual
|
||||
"paraphrase-multilingual-mpnet-base-v2": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2", # noqa: E501
|
||||
}
|
||||
|
||||
# This will download the model automatically if it is not already downloaded
|
||||
Settings.embed_model = FastEmbedEmbedding(
|
||||
model_name=embed_model_map[os.getenv("EMBEDDING_MODEL")]
|
||||
)
|
||||
|
||||
|
||||
def init_groq():
|
||||
from llama_index.llms.groq import Groq
|
||||
|
||||
Settings.llm = Groq(model=os.getenv("MODEL"))
|
||||
# Groq does not provide embeddings, so we use FastEmbed instead
|
||||
init_fastembed()
|
||||
|
||||
|
||||
def init_anthropic():
|
||||
from llama_index.llms.anthropic import Anthropic
|
||||
|
||||
model_map: Dict[str, str] = {
|
||||
"claude-3-opus": "claude-3-opus-20240229",
|
||||
"claude-3-sonnet": "claude-3-sonnet-20240229",
|
||||
"claude-3-haiku": "claude-3-haiku-20240307",
|
||||
"claude-2.1": "claude-2.1",
|
||||
"claude-instant-1.2": "claude-instant-1.2",
|
||||
}
|
||||
|
||||
Settings.llm = Anthropic(model=model_map[os.getenv("MODEL")])
|
||||
# Anthropic does not provide embeddings, so we use FastEmbed instead
|
||||
init_fastembed()
|
||||
|
||||
|
||||
def init_gemini():
|
||||
from llama_index.embeddings.gemini import GeminiEmbedding
|
||||
from llama_index.llms.gemini import Gemini
|
||||
|
||||
model_name = f"models/{os.getenv('MODEL')}"
|
||||
embed_model_name = f"models/{os.getenv('EMBEDDING_MODEL')}"
|
||||
|
||||
Settings.llm = Gemini(model=model_name)
|
||||
Settings.embed_model = GeminiEmbedding(model_name=embed_model_name)
|
||||
|
||||
|
||||
def init_mistral():
|
||||
from llama_index.embeddings.mistralai import MistralAIEmbedding
|
||||
from llama_index.llms.mistralai import MistralAI
|
||||
|
||||
Settings.llm = MistralAI(model=os.getenv("MODEL"))
|
||||
Settings.embed_model = MistralAIEmbedding(model_name=os.getenv("EMBEDDING_MODEL"))
|
||||
@@ -1,7 +1,5 @@
|
||||
"use client";
|
||||
|
||||
import { Message } from "./chat-messages";
|
||||
|
||||
export interface ChatInputProps {
|
||||
/** The current value of the input */
|
||||
input?: string;
|
||||
@@ -14,8 +12,7 @@ export interface ChatInputProps {
|
||||
/** Form submission handler to automatically reset input and append a user message */
|
||||
handleSubmit: (e: React.FormEvent<HTMLFormElement>) => void;
|
||||
isLoading: boolean;
|
||||
messages: Message[];
|
||||
setInput?: (input: string) => void;
|
||||
multiModal?: boolean;
|
||||
}
|
||||
|
||||
export default function ChatInput(props: ChatInputProps) {
|
||||
|
||||
@@ -19,12 +19,8 @@ export default function ChatMessages({
|
||||
isLoading?: boolean;
|
||||
stop?: () => void;
|
||||
reload?: () => void;
|
||||
append?: (
|
||||
message: Message | Omit<Message, "id">,
|
||||
) => Promise<string | null | undefined>;
|
||||
}) {
|
||||
const scrollableChatContainerRef = useRef<HTMLDivElement>(null);
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
|
||||
const scrollToBottom = () => {
|
||||
if (scrollableChatContainerRef.current) {
|
||||
@@ -35,14 +31,14 @@ export default function ChatMessages({
|
||||
|
||||
useEffect(() => {
|
||||
scrollToBottom();
|
||||
}, [messages.length, lastMessage]);
|
||||
}, [messages.length]);
|
||||
|
||||
return (
|
||||
<div
|
||||
className="flex-1 w-full max-w-5xl p-4 bg-white rounded-xl shadow-xl overflow-auto"
|
||||
ref={scrollableChatContainerRef}
|
||||
>
|
||||
<div className="flex flex-col gap-5 divide-y">
|
||||
<div className="w-full max-w-5xl p-4 bg-white rounded-xl shadow-xl">
|
||||
<div
|
||||
className="flex flex-col gap-5 divide-y h-[50vh] overflow-auto"
|
||||
ref={scrollableChatContainerRef}
|
||||
>
|
||||
{messages.map((m: Message) => (
|
||||
<ChatItem key={m.id} {...m} />
|
||||
))}
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
"use client";
|
||||
|
||||
export interface ChatConfig {
|
||||
backend?: string;
|
||||
}
|
||||
|
||||
function getBackendOrigin(): string {
|
||||
const chatAPI = process.env.NEXT_PUBLIC_CHAT_API;
|
||||
if (chatAPI) {
|
||||
return new URL(chatAPI).origin;
|
||||
} else {
|
||||
if (typeof window !== "undefined") {
|
||||
// Use BASE_URL from window.ENV
|
||||
return (window as any).ENV?.BASE_URL || "";
|
||||
}
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
export function useClientConfig(): ChatConfig {
|
||||
return {
|
||||
backend: getBackendOrigin(),
|
||||
};
|
||||
}
|
||||
@@ -3,18 +3,10 @@ from llama_index.vector_stores.astra_db import AstraDBVectorStore
|
||||
|
||||
|
||||
def get_vector_store():
|
||||
endpoint = os.getenv("ASTRA_DB_ENDPOINT")
|
||||
token = os.getenv("ASTRA_DB_APPLICATION_TOKEN")
|
||||
collection = os.getenv("ASTRA_DB_COLLECTION")
|
||||
if not endpoint or not token or not collection:
|
||||
raise ValueError(
|
||||
"Please config ASTRA_DB_ENDPOINT, ASTRA_DB_APPLICATION_TOKEN and ASTRA_DB_COLLECTION"
|
||||
" to your environment variables or config them in the .env file"
|
||||
)
|
||||
store = AstraDBVectorStore(
|
||||
token=token,
|
||||
api_endpoint=endpoint,
|
||||
collection_name=collection,
|
||||
embedding_dimension=int(os.getenv("EMBEDDING_DIM")),
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_ENDPOINT"],
|
||||
collection_name=os.environ["ASTRA_DB_COLLECTION"],
|
||||
embedding_dimension=int(os.environ["EMBEDDING_DIM"]),
|
||||
)
|
||||
return store
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
import os
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
|
||||
|
||||
def get_vector_store():
|
||||
collection_name = os.getenv("CHROMA_COLLECTION", "default")
|
||||
chroma_path = os.getenv("CHROMA_PATH")
|
||||
# if CHROMA_PATH is set, use a local ChromaVectorStore from the path
|
||||
# otherwise, use a remote ChromaVectorStore (ChromaDB Cloud is not supported yet)
|
||||
if chroma_path:
|
||||
store = ChromaVectorStore.from_params(
|
||||
persist_dir=chroma_path, collection_name=collection_name
|
||||
)
|
||||
else:
|
||||
if not os.getenv("CHROMA_HOST") or not os.getenv("CHROMA_PORT"):
|
||||
raise ValueError(
|
||||
"Please provide either CHROMA_PATH or CHROMA_HOST and CHROMA_PORT"
|
||||
)
|
||||
store = ChromaVectorStore.from_params(
|
||||
host=os.getenv("CHROMA_HOST"),
|
||||
port=int(os.getenv("CHROMA_PORT")),
|
||||
collection_name=collection_name,
|
||||
)
|
||||
return store
|
||||
@@ -1,50 +0,0 @@
|
||||
# flake8: noqa: E402
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from app.engine.index import get_index
|
||||
|
||||
load_dotenv()
|
||||
|
||||
import logging
|
||||
from llama_index.core.readers import SimpleDirectoryReader
|
||||
from app.engine.service import LLamaCloudFileService
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def generate_datasource():
|
||||
logger.info("Generate index for the provided data")
|
||||
|
||||
index = get_index()
|
||||
project_id = index._get_project_id()
|
||||
pipeline_id = index._get_pipeline_id()
|
||||
|
||||
# use SimpleDirectoryReader to retrieve the files to process
|
||||
reader = SimpleDirectoryReader(
|
||||
"data",
|
||||
recursive=True,
|
||||
)
|
||||
files_to_process = reader.input_files
|
||||
|
||||
# add each file to the LlamaCloud pipeline
|
||||
for input_file in files_to_process:
|
||||
with open(input_file, "rb") as f:
|
||||
logger.info(
|
||||
f"Adding file {input_file} to pipeline {index.name} in project {index.project_name}"
|
||||
)
|
||||
LLamaCloudFileService.add_file_to_pipeline(
|
||||
project_id,
|
||||
pipeline_id,
|
||||
f,
|
||||
custom_metadata={
|
||||
# Set private=false to mark the document as public (required for filtering)
|
||||
"private": "false",
|
||||
},
|
||||
)
|
||||
|
||||
logger.info("Finished generating the index")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_datasource()
|
||||
@@ -1,87 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from llama_index.core.callbacks import CallbackManager
|
||||
from llama_index.core.ingestion.api_utils import (
|
||||
get_client as llama_cloud_get_client,
|
||||
)
|
||||
from llama_index.indices.managed.llama_cloud import LlamaCloudIndex
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
class LlamaCloudConfig(BaseModel):
|
||||
# Private attributes
|
||||
api_key: str = Field(
|
||||
default=os.getenv("LLAMA_CLOUD_API_KEY"),
|
||||
exclude=True, # Exclude from the model representation
|
||||
)
|
||||
base_url: Optional[str] = Field(
|
||||
default=os.getenv("LLAMA_CLOUD_BASE_URL"),
|
||||
exclude=True,
|
||||
)
|
||||
organization_id: Optional[str] = Field(
|
||||
default=os.getenv("LLAMA_CLOUD_ORGANIZATION_ID"),
|
||||
exclude=True,
|
||||
)
|
||||
# Configuration attributes, can be set by the user
|
||||
pipeline: str = Field(
|
||||
description="The name of the pipeline to use",
|
||||
default=os.getenv("LLAMA_CLOUD_INDEX_NAME"),
|
||||
)
|
||||
project: str = Field(
|
||||
description="The name of the LlamaCloud project",
|
||||
default=os.getenv("LLAMA_CLOUD_PROJECT_NAME"),
|
||||
)
|
||||
|
||||
# Validate and throw error if the env variables are not set before starting the app
|
||||
@validator("pipeline", "project", "api_key", pre=True, always=True)
|
||||
@classmethod
|
||||
def validate_env_vars(cls, value):
|
||||
if value is None:
|
||||
raise ValueError(
|
||||
"Please set LLAMA_CLOUD_INDEX_NAME, LLAMA_CLOUD_PROJECT_NAME and LLAMA_CLOUD_API_KEY"
|
||||
" to your environment variables or config them in .env file"
|
||||
)
|
||||
return value
|
||||
|
||||
def to_client_kwargs(self) -> dict:
|
||||
return {
|
||||
"api_key": self.api_key,
|
||||
"base_url": self.base_url,
|
||||
}
|
||||
|
||||
|
||||
class IndexConfig(BaseModel):
|
||||
llama_cloud_pipeline_config: LlamaCloudConfig = Field(
|
||||
default=LlamaCloudConfig(),
|
||||
alias="llamaCloudPipeline",
|
||||
)
|
||||
callback_manager: Optional[CallbackManager] = Field(
|
||||
default=None,
|
||||
)
|
||||
|
||||
def to_index_kwargs(self) -> dict:
|
||||
return {
|
||||
"name": self.llama_cloud_pipeline_config.pipeline,
|
||||
"project_name": self.llama_cloud_pipeline_config.project,
|
||||
"api_key": self.llama_cloud_pipeline_config.api_key,
|
||||
"base_url": self.llama_cloud_pipeline_config.base_url,
|
||||
"organization_id": self.llama_cloud_pipeline_config.organization_id,
|
||||
"callback_manager": self.callback_manager,
|
||||
}
|
||||
|
||||
|
||||
def get_index(config: IndexConfig = None):
|
||||
if config is None:
|
||||
config = IndexConfig()
|
||||
index = LlamaCloudIndex(**config.to_index_kwargs())
|
||||
|
||||
return index
|
||||
|
||||
|
||||
def get_client():
|
||||
config = LlamaCloudConfig()
|
||||
return llama_cloud_get_client(**config.to_client_kwargs())
|
||||
@@ -1,35 +0,0 @@
|
||||
from llama_index.core.vector_stores.types import MetadataFilter, MetadataFilters
|
||||
|
||||
|
||||
def generate_filters(doc_ids):
|
||||
"""
|
||||
Generate public/private document filters based on the doc_ids and the vector store.
|
||||
"""
|
||||
# Using "is_empty" filter to include the documents don't have the "private" key because they're uploaded in LlamaCloud UI
|
||||
public_doc_filter = MetadataFilter(
|
||||
key="private",
|
||||
value=None,
|
||||
operator="is_empty", # type: ignore
|
||||
)
|
||||
selected_doc_filter = MetadataFilter(
|
||||
key="file_id", # Note: LLamaCloud uses "file_id" to reference private document ids as "doc_id" is a restricted field in LlamaCloud
|
||||
value=doc_ids,
|
||||
operator="in", # type: ignore
|
||||
)
|
||||
if len(doc_ids) > 0:
|
||||
# If doc_ids are provided, we will select both public and selected documents
|
||||
filters = MetadataFilters(
|
||||
filters=[
|
||||
public_doc_filter,
|
||||
selected_doc_filter,
|
||||
],
|
||||
condition="or", # type: ignore
|
||||
)
|
||||
else:
|
||||
filters = MetadataFilters(
|
||||
filters=[
|
||||
public_doc_filter,
|
||||
]
|
||||
)
|
||||
|
||||
return filters
|
||||
@@ -1,173 +0,0 @@
|
||||
from io import BytesIO
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
import typing
|
||||
|
||||
from fastapi import BackgroundTasks
|
||||
from llama_cloud import ManagedIngestionStatus, PipelineFileCreateCustomMetadataValue
|
||||
from pydantic import BaseModel
|
||||
import requests
|
||||
from app.api.routers.models import SourceNodes
|
||||
from app.engine.index import get_client
|
||||
from llama_index.core.schema import NodeWithScore
|
||||
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
class LlamaCloudFile(BaseModel):
|
||||
file_name: str
|
||||
pipeline_id: str
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, LlamaCloudFile):
|
||||
return NotImplemented
|
||||
return (
|
||||
self.file_name == other.file_name and self.pipeline_id == other.pipeline_id
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.file_name, self.pipeline_id))
|
||||
|
||||
|
||||
class LLamaCloudFileService:
|
||||
LOCAL_STORE_PATH = "output/llamacloud"
|
||||
DOWNLOAD_FILE_NAME_TPL = "{pipeline_id}${filename}"
|
||||
|
||||
@classmethod
|
||||
def get_all_projects_with_pipelines(cls) -> List[Dict[str, Any]]:
|
||||
try:
|
||||
client = get_client()
|
||||
projects = client.projects.list_projects()
|
||||
pipelines = client.pipelines.search_pipelines()
|
||||
return [
|
||||
{
|
||||
**(project.dict()),
|
||||
"pipelines": [
|
||||
{"id": p.id, "name": p.name}
|
||||
for p in pipelines
|
||||
if p.project_id == project.id
|
||||
],
|
||||
}
|
||||
for project in projects
|
||||
]
|
||||
except Exception as error:
|
||||
logger.error(f"Error listing projects and pipelines: {error}")
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def add_file_to_pipeline(
|
||||
cls,
|
||||
project_id: str,
|
||||
pipeline_id: str,
|
||||
upload_file: Union[typing.IO, Tuple[str, BytesIO]],
|
||||
custom_metadata: Optional[Dict[str, PipelineFileCreateCustomMetadataValue]],
|
||||
) -> str:
|
||||
client = get_client()
|
||||
file = client.files.upload_file(project_id=project_id, upload_file=upload_file)
|
||||
files = [
|
||||
{
|
||||
"file_id": file.id,
|
||||
"custom_metadata": {"file_id": file.id, **(custom_metadata or {})},
|
||||
}
|
||||
]
|
||||
files = client.pipelines.add_files_to_pipeline(pipeline_id, request=files)
|
||||
|
||||
# Wait 2s for the file to be processed
|
||||
max_attempts = 20
|
||||
attempt = 0
|
||||
while attempt < max_attempts:
|
||||
result = client.pipelines.get_pipeline_file_status(pipeline_id, file.id)
|
||||
if result.status == ManagedIngestionStatus.ERROR:
|
||||
raise Exception(f"File processing failed: {str(result)}")
|
||||
if result.status == ManagedIngestionStatus.SUCCESS:
|
||||
# File is ingested - return the file id
|
||||
return file.id
|
||||
attempt += 1
|
||||
time.sleep(0.1) # Sleep for 100ms
|
||||
raise Exception(
|
||||
f"File processing did not complete after {max_attempts} attempts."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def download_pipeline_file(
|
||||
cls,
|
||||
file: LlamaCloudFile,
|
||||
force_download: bool = False,
|
||||
):
|
||||
client = get_client()
|
||||
file_name = file.file_name
|
||||
pipeline_id = file.pipeline_id
|
||||
|
||||
# Check is the file already exists
|
||||
downloaded_file_path = cls._get_file_path(file_name, pipeline_id)
|
||||
if os.path.exists(downloaded_file_path) and not force_download:
|
||||
logger.debug(f"File {file_name} already exists in local storage")
|
||||
return
|
||||
try:
|
||||
logger.info(f"Downloading file {file_name} for pipeline {pipeline_id}")
|
||||
files = client.pipelines.list_pipeline_files(pipeline_id)
|
||||
if not files or not isinstance(files, list):
|
||||
raise Exception("No files found in LlamaCloud")
|
||||
for file_entry in files:
|
||||
if file_entry.name == file_name:
|
||||
file_id = file_entry.file_id
|
||||
project_id = file_entry.project_id
|
||||
file_detail = client.files.read_file_content(
|
||||
file_id, project_id=project_id
|
||||
)
|
||||
cls._download_file(file_detail.url, downloaded_file_path)
|
||||
break
|
||||
except Exception as error:
|
||||
logger.info(f"Error fetching file from LlamaCloud: {error}")
|
||||
|
||||
@classmethod
|
||||
def download_files_from_nodes(
|
||||
cls, nodes: List[NodeWithScore], background_tasks: BackgroundTasks
|
||||
):
|
||||
files = cls._get_files_to_download(nodes)
|
||||
for file in files:
|
||||
logger.info(f"Adding download of {file.file_name} to background tasks")
|
||||
background_tasks.add_task(
|
||||
LLamaCloudFileService.download_pipeline_file, file
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_files_to_download(cls, nodes: List[NodeWithScore]) -> Set[LlamaCloudFile]:
|
||||
source_nodes = SourceNodes.from_source_nodes(nodes)
|
||||
llama_cloud_files = [
|
||||
LlamaCloudFile(
|
||||
file_name=node.metadata.get("file_name"),
|
||||
pipeline_id=node.metadata.get("pipeline_id"),
|
||||
)
|
||||
for node in source_nodes
|
||||
if (
|
||||
node.metadata.get("pipeline_id") is not None
|
||||
and node.metadata.get("file_name") is not None
|
||||
)
|
||||
]
|
||||
# Remove duplicates and return
|
||||
return set(llama_cloud_files)
|
||||
|
||||
@classmethod
|
||||
def _get_file_name(cls, name: str, pipeline_id: str) -> str:
|
||||
return cls.DOWNLOAD_FILE_NAME_TPL.format(pipeline_id=pipeline_id, filename=name)
|
||||
|
||||
@classmethod
|
||||
def _get_file_path(cls, name: str, pipeline_id: str) -> str:
|
||||
return os.path.join(cls.LOCAL_STORE_PATH, cls._get_file_name(name, pipeline_id))
|
||||
|
||||
@classmethod
|
||||
def _download_file(cls, url: str, local_file_path: str):
|
||||
logger.info(f"Saving file to {local_file_path}")
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(cls.LOCAL_STORE_PATH, exist_ok=True)
|
||||
# Download the file
|
||||
with requests.get(url, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
with open(local_file_path, "wb") as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
logger.info("File downloaded successfully")
|
||||
@@ -3,18 +3,11 @@ from llama_index.vector_stores.milvus import MilvusVectorStore
|
||||
|
||||
|
||||
def get_vector_store():
|
||||
address = os.getenv("MILVUS_ADDRESS")
|
||||
collection = os.getenv("MILVUS_COLLECTION")
|
||||
if not address or not collection:
|
||||
raise ValueError(
|
||||
"Please set MILVUS_ADDRESS and MILVUS_COLLECTION to your environment variables"
|
||||
" or config them in the .env file"
|
||||
)
|
||||
store = MilvusVectorStore(
|
||||
uri=address,
|
||||
uri=os.environ["MILVUS_ADDRESS"],
|
||||
user=os.getenv("MILVUS_USERNAME"),
|
||||
password=os.getenv("MILVUS_PASSWORD"),
|
||||
collection_name=collection,
|
||||
collection_name=os.getenv("MILVUS_COLLECTION"),
|
||||
dim=int(os.getenv("EMBEDDING_DIM")),
|
||||
)
|
||||
return store
|
||||
|
||||
@@ -3,18 +3,9 @@ from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch
|
||||
|
||||
|
||||
def get_vector_store():
|
||||
db_uri = os.getenv("MONGODB_URI")
|
||||
db_name = os.getenv("MONGODB_DATABASE")
|
||||
collection_name = os.getenv("MONGODB_VECTORS")
|
||||
index_name = os.getenv("MONGODB_VECTOR_INDEX")
|
||||
if not db_uri or not db_name or not collection_name or not index_name:
|
||||
raise ValueError(
|
||||
"Please set MONGODB_URI, MONGODB_DATABASE, MONGODB_VECTORS, and MONGODB_VECTOR_INDEX"
|
||||
" to your environment variables or config them in .env file"
|
||||
)
|
||||
store = MongoDBAtlasVectorSearch(
|
||||
db_name=db_name,
|
||||
collection_name=collection_name,
|
||||
index_name=index_name,
|
||||
db_name=os.environ["MONGODB_DATABASE"],
|
||||
collection_name=os.environ["MONGODB_VECTORS"],
|
||||
index_name=os.environ["MONGODB_VECTOR_INDEX"],
|
||||
)
|
||||
return store
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
# flake8: noqa: E402
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from app.engine.loaders import get_documents
|
||||
from app.settings import init_settings
|
||||
from llama_index.core.indices import (
|
||||
VectorStoreIndex,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def generate_datasource():
|
||||
init_settings()
|
||||
logger.info("Creating new index")
|
||||
storage_dir = os.environ.get("STORAGE_DIR", "storage")
|
||||
# load the documents and create the index
|
||||
documents = get_documents()
|
||||
# Set private=false to mark the document as public (required for filtering)
|
||||
for doc in documents:
|
||||
doc.metadata["private"] = "false"
|
||||
index = VectorStoreIndex.from_documents(
|
||||
documents,
|
||||
show_progress=True,
|
||||
)
|
||||
# store it for later
|
||||
index.storage_context.persist(storage_dir)
|
||||
logger.info(f"Finished creating new index. Stored in {storage_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_datasource()
|
||||
@@ -1,43 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
from datetime import timedelta
|
||||
from typing import Optional
|
||||
|
||||
from cachetools import TTLCache, cached
|
||||
from llama_index.core.callbacks import CallbackManager
|
||||
from llama_index.core.indices import load_index_from_storage
|
||||
from llama_index.core.storage import StorageContext
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
class IndexConfig(BaseModel):
|
||||
callback_manager: Optional[CallbackManager] = Field(
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
def get_index(config: IndexConfig = None):
|
||||
if config is None:
|
||||
config = IndexConfig()
|
||||
storage_dir = os.getenv("STORAGE_DIR", "storage")
|
||||
# check if storage already exists
|
||||
if not os.path.exists(storage_dir):
|
||||
return None
|
||||
# load the existing index
|
||||
logger.info(f"Loading index from {storage_dir}...")
|
||||
storage_context = get_storage_context(storage_dir)
|
||||
index = load_index_from_storage(
|
||||
storage_context, callback_manager=config.callback_manager
|
||||
)
|
||||
logger.info(f"Finished loading index from {storage_dir}")
|
||||
return index
|
||||
|
||||
|
||||
@cached(
|
||||
TTLCache(maxsize=10, ttl=timedelta(minutes=5).total_seconds()),
|
||||
key=lambda *args, **kwargs: "global_storage_context",
|
||||
)
|
||||
def get_storage_context(persist_dir: str) -> StorageContext:
|
||||
return StorageContext.from_defaults(persist_dir=persist_dir)
|
||||
@@ -0,0 +1,16 @@
|
||||
import os
|
||||
|
||||
from llama_index.core.vector_stores import SimpleVectorStore
|
||||
from app.constants import STORAGE_DIR
|
||||
|
||||
|
||||
def get_vector_store():
|
||||
if not os.path.exists(STORAGE_DIR):
|
||||
# Vector store hasn't been persisted before, create a new one
|
||||
vector_store = SimpleVectorStore()
|
||||
else:
|
||||
# Vector store has already been persisted before at STORAGE_DIR - load it
|
||||
vector_store = SimpleVectorStore.from_persist_dir(
|
||||
STORAGE_DIR, namespace="default"
|
||||
)
|
||||
return vector_store
|
||||
@@ -2,36 +2,30 @@ import os
|
||||
from llama_index.vector_stores.postgres import PGVectorStore
|
||||
from urllib.parse import urlparse
|
||||
|
||||
STORAGE_DIR = "storage"
|
||||
PGVECTOR_SCHEMA = "public"
|
||||
PGVECTOR_TABLE = "llamaindex_embedding"
|
||||
|
||||
vector_store: PGVectorStore = None
|
||||
|
||||
|
||||
def get_vector_store():
|
||||
global vector_store
|
||||
original_conn_string = os.environ.get("PG_CONNECTION_STRING")
|
||||
if original_conn_string is None or original_conn_string == "":
|
||||
raise ValueError("PG_CONNECTION_STRING environment variable is not set.")
|
||||
|
||||
if vector_store is None:
|
||||
original_conn_string = os.environ.get("PG_CONNECTION_STRING")
|
||||
if original_conn_string is None or original_conn_string == "":
|
||||
raise ValueError("PG_CONNECTION_STRING environment variable is not set.")
|
||||
# The PGVectorStore requires both two connection strings, one for psycopg2 and one for asyncpg
|
||||
# Update the configured scheme with the psycopg2 and asyncpg schemes
|
||||
original_scheme = urlparse(original_conn_string).scheme + "://"
|
||||
conn_string = original_conn_string.replace(
|
||||
original_scheme, "postgresql+psycopg2://"
|
||||
)
|
||||
async_conn_string = original_conn_string.replace(
|
||||
original_scheme, "postgresql+asyncpg://"
|
||||
)
|
||||
|
||||
# The PGVectorStore requires both two connection strings, one for psycopg2 and one for asyncpg
|
||||
# Update the configured scheme with the psycopg2 and asyncpg schemes
|
||||
original_scheme = urlparse(original_conn_string).scheme + "://"
|
||||
conn_string = original_conn_string.replace(
|
||||
original_scheme, "postgresql+psycopg2://"
|
||||
)
|
||||
async_conn_string = original_conn_string.replace(
|
||||
original_scheme, "postgresql+asyncpg://"
|
||||
)
|
||||
|
||||
vector_store = PGVectorStore(
|
||||
connection_string=conn_string,
|
||||
async_connection_string=async_conn_string,
|
||||
schema_name=PGVECTOR_SCHEMA,
|
||||
table_name=PGVECTOR_TABLE,
|
||||
embed_dim=int(os.environ.get("EMBEDDING_DIM", 1024)),
|
||||
)
|
||||
|
||||
return vector_store
|
||||
return PGVectorStore(
|
||||
connection_string=conn_string,
|
||||
async_connection_string=async_conn_string,
|
||||
schema_name=PGVECTOR_SCHEMA,
|
||||
table_name=PGVECTOR_TABLE,
|
||||
embed_dim=int(os.environ.get("EMBEDDING_DIM", 768)),
|
||||
)
|
||||
|
||||
@@ -3,17 +3,9 @@ from llama_index.vector_stores.pinecone import PineconeVectorStore
|
||||
|
||||
|
||||
def get_vector_store():
|
||||
api_key = os.getenv("PINECONE_API_KEY")
|
||||
index_name = os.getenv("PINECONE_INDEX_NAME")
|
||||
environment = os.getenv("PINECONE_ENVIRONMENT")
|
||||
if not api_key or not index_name or not environment:
|
||||
raise ValueError(
|
||||
"Please set PINECONE_API_KEY, PINECONE_INDEX_NAME, and PINECONE_ENVIRONMENT"
|
||||
" to your environment variables or config them in the .env file"
|
||||
)
|
||||
store = PineconeVectorStore(
|
||||
api_key=api_key,
|
||||
index_name=index_name,
|
||||
environment=environment,
|
||||
api_key=os.environ["PINECONE_API_KEY"],
|
||||
index_name=os.environ["PINECONE_INDEX_NAME"],
|
||||
environment=os.environ["PINECONE_ENVIRONMENT"],
|
||||
)
|
||||
return store
|
||||
|
||||
@@ -3,17 +3,9 @@ from llama_index.vector_stores.qdrant import QdrantVectorStore
|
||||
|
||||
|
||||
def get_vector_store():
|
||||
collection_name = os.getenv("QDRANT_COLLECTION")
|
||||
url = os.getenv("QDRANT_URL")
|
||||
api_key = os.getenv("QDRANT_API_KEY")
|
||||
if not collection_name or not url:
|
||||
raise ValueError(
|
||||
"Please set QDRANT_COLLECTION, QDRANT_URL"
|
||||
" to your environment variables or config them in the .env file"
|
||||
)
|
||||
store = QdrantVectorStore(
|
||||
collection_name=collection_name,
|
||||
url=url,
|
||||
api_key=api_key,
|
||||
collection_name=os.getenv("QDRANT_COLLECTION"),
|
||||
url=os.getenv("QDRANT_URL"),
|
||||
api_key=os.getenv("QDRANT_API_KEY"),
|
||||
)
|
||||
return store
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
from llama_index.core.vector_stores.types import MetadataFilter, MetadataFilters
|
||||
|
||||
|
||||
def generate_filters(doc_ids):
|
||||
"""
|
||||
Generate public/private document filters based on the doc_ids and the vector store.
|
||||
"""
|
||||
public_doc_filter = MetadataFilter(
|
||||
key="private",
|
||||
value="true",
|
||||
operator="!=", # type: ignore
|
||||
)
|
||||
# Weaviate doesn't support "in" filter right now, so use "any" instead - it has the same behavior.
|
||||
# TODO: Use "in" operator, once Weaviate supports it
|
||||
selected_doc_filter = MetadataFilter(
|
||||
key="doc_id",
|
||||
value=doc_ids,
|
||||
operator="any", # type: ignore
|
||||
)
|
||||
if len(doc_ids) > 0:
|
||||
# If doc_ids are provided, we will select both public and selected documents
|
||||
filters = MetadataFilters(
|
||||
filters=[
|
||||
public_doc_filter,
|
||||
selected_doc_filter,
|
||||
],
|
||||
condition="or", # type: ignore
|
||||
)
|
||||
else:
|
||||
filters = MetadataFilters(
|
||||
filters=[
|
||||
public_doc_filter,
|
||||
]
|
||||
)
|
||||
|
||||
return filters
|
||||
@@ -1,35 +0,0 @@
|
||||
import os
|
||||
|
||||
import weaviate
|
||||
from llama_index.vector_stores.weaviate import WeaviateVectorStore
|
||||
|
||||
DEFAULT_INDEX_NAME = "LlamaIndex"
|
||||
|
||||
|
||||
def _create_weaviate_client():
|
||||
cluster_url = os.getenv("WEAVIATE_CLUSTER_URL")
|
||||
api_key = os.getenv("WEAVIATE_API_KEY")
|
||||
if not cluster_url or not api_key:
|
||||
raise ValueError(
|
||||
"Environment variables: WEAVIATE_CLUSTER_URL and WEAVIATE_API_KEY are required."
|
||||
)
|
||||
auth_credentials = weaviate.auth.AuthApiKey(api_key)
|
||||
client = weaviate.connect_to_weaviate_cloud(cluster_url, auth_credentials)
|
||||
return client
|
||||
|
||||
|
||||
# Global variable to store the Weaviate client
|
||||
client = None
|
||||
|
||||
|
||||
def get_vector_store():
|
||||
global client
|
||||
if client is None:
|
||||
client = _create_weaviate_client()
|
||||
|
||||
index_name = os.getenv("WEAVIATE_INDEX_NAME", DEFAULT_INDEX_NAME)
|
||||
vector_store = WeaviateVectorStore(
|
||||
weaviate_client=client,
|
||||
index_name=index_name,
|
||||
)
|
||||
return vector_store
|
||||
@@ -1,7 +1,10 @@
|
||||
/* eslint-disable turbo/no-undeclared-env-vars */
|
||||
import * as dotenv from "dotenv";
|
||||
import { VectorStoreIndex, storageContextFromDefaults } from "llamaindex";
|
||||
import { AstraDBVectorStore } from "llamaindex/vector-store/AstraDBVectorStore";
|
||||
import {
|
||||
AstraDBVectorStore,
|
||||
VectorStoreIndex,
|
||||
storageContextFromDefaults,
|
||||
} from "llamaindex";
|
||||
import { getDocuments } from "./loader";
|
||||
import { initSettings } from "./settings";
|
||||
import { checkRequiredEnvVars } from "./shared";
|
||||
@@ -15,12 +18,13 @@ async function loadAndIndex() {
|
||||
// create vector store and a collection
|
||||
const collectionName = process.env.ASTRA_DB_COLLECTION!;
|
||||
const vectorStore = new AstraDBVectorStore();
|
||||
await vectorStore.createAndConnect(collectionName, {
|
||||
await vectorStore.create(collectionName, {
|
||||
vector: {
|
||||
dimension: parseInt(process.env.EMBEDDING_DIM!),
|
||||
metric: "cosine",
|
||||
},
|
||||
});
|
||||
await vectorStore.connect(collectionName);
|
||||
|
||||
// create index from documents and store them in Astra
|
||||
console.log("Start creating embeddings...");
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
/* eslint-disable turbo/no-undeclared-env-vars */
|
||||
import { VectorStoreIndex } from "llamaindex";
|
||||
import { AstraDBVectorStore } from "llamaindex/vector-store/AstraDBVectorStore";
|
||||
import { AstraDBVectorStore, VectorStoreIndex } from "llamaindex";
|
||||
import { checkRequiredEnvVars } from "./shared";
|
||||
|
||||
export async function getDataSource(params?: any) {
|
||||
export async function getDataSource() {
|
||||
checkRequiredEnvVars();
|
||||
const store = new AstraDBVectorStore();
|
||||
await store.connect(process.env.ASTRA_DB_COLLECTION!);
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user