Compare commits

...

5 Commits

Author SHA1 Message Date
github-actions[bot] 9a71382243 Release 0.5.5 (#1046)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2024-07-17 16:26:38 +07:00
Thuc Pham b974eea341 feat: add MetadataFilter for SimpleVectorStore and Milvus (#1030)
Co-authored-by: Marcus Schiesser <mail@marcusschiesser.de>
2024-07-17 16:21:21 +07:00
github-actions[bot] e82632f83d Release 0.5.4 (#1043)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2024-07-16 14:53:20 -07:00
Fabian Wimmer 1a65ead849 feat: add vendorMultiModal params to LlamaParseReader (#1042) 2024-07-16 14:20:34 -07:00
Alex Yang 50b7d1b7bb refactor: put embedding into core (#1041) 2024-07-16 10:49:03 -07:00
80 changed files with 1620 additions and 412 deletions
+14
View File
@@ -1,5 +1,19 @@
# docs
## 0.0.46
### Patch Changes
- Updated dependencies [b974eea]
- llamaindex@0.5.5
## 0.0.45
### Patch Changes
- Updated dependencies [1a65ead]
- llamaindex@0.5.4
## 0.0.44
### Patch Changes
@@ -42,10 +42,13 @@ They can be divided into two groups.
- `fastMode?` Optional. Set to true to use the fast mode. This mode will skip OCR of images, and table/heading reconstruction. Note: Non-compatible with `gpt4oMode`.
- `doNotUnrollColumns?` Optional. Set to true to keep the text according to document layout. Reduce reconstruction accuracy, and LLMs/embeddings performances in most cases.
- `pageSeparator?` Optional. The page separator to use. Defaults is `\\n---\\n`.
- `gpt4oMode` set to true to use GPT-4o to extract content. Default is `false`.
- `gpt4oApiKey?` Optional. Set the GPT-4o API key. Lowers the cost of parsing by using your own API key. Your OpenAI account will be charged. Can also be set in the environment variable `LLAMA_CLOUD_GPT4O_API_KEY`.
- `gpt4oMode` Deprecated. Use vendorMultimodal params. Set to true to use GPT-4o to extract content. Default is `false`.
- `gpt4oApiKey?` Deprecated. Use vendorMultimodal params. Optional. Set the GPT-4o API key. Lowers the cost of parsing by using your own API key. Your OpenAI account will be charged. Can also be set in the environment variable `LLAMA_CLOUD_GPT4O_API_KEY`.
- `boundingBox?` Optional. Specify an area of the document to parse. Expects the bounding box margins as a string in clockwise order, e.g. `boundingBox = "0.1,0,0,0"` to not parse the top 10% of the document.
- `targetPages?` Optional. Specify which pages to parse by specifying them as a comma-separated list. First page is `0`.
- `useVendorMultimodalModel` set to true to use a multimodal model. Default is `false`.
- `vendorMultimodalModel?` Optional. Specify which multimodal model to use. Default is GPT4o. See [here](https://docs.cloud.llamaindex.ai/llamaparse/features/multimodal) for a list of available models and cost.
- `vendorMultimodalApiKey?` Optional. Set the multimodal model API key. Can also be set in the environment variable `LLAMA_CLOUD_VENDOR_MULTIMODAL_API_KEY`.
- `numWorkers` as in the python version, is set in `SimpleDirectoryReader`. Default is 1.
### LlamaParse with SimpleDirectoryReader
@@ -75,7 +75,7 @@ const queryEngine = index.asQueryEngine({
{
key: "dogId",
value: "2",
filterType: "ExactMatch",
operator: "==",
},
],
},
@@ -135,7 +135,7 @@ async function main() {
{
key: "dogId",
value: "2",
filterType: "ExactMatch",
operator: "==",
},
],
},
+1 -1
View File
@@ -1,6 +1,6 @@
{
"name": "docs",
"version": "0.0.44",
"version": "0.0.46",
"private": true,
"scripts": {
"docusaurus": "docusaurus",
+1 -1
View File
@@ -40,7 +40,7 @@ async function main() {
{
key: "dogId",
value: "2",
filterType: "ExactMatch",
operator: "==",
},
],
},
+40
View File
@@ -0,0 +1,40 @@
import { MilvusVectorStore, VectorStoreIndex } from "llamaindex";
const collectionName = "movie_reviews";
async function main() {
try {
const milvus = new MilvusVectorStore({ collection: collectionName });
const index = await VectorStoreIndex.fromVectorStore(milvus);
const retriever = index.asRetriever({ similarityTopK: 20 });
console.log("\n=====\nQuerying the index with filters");
const queryEngineWithFilters = index.asQueryEngine({
retriever,
preFilters: {
filters: [
{
key: "document_id",
value: "./data/movie_reviews.csv_37",
operator: "==",
},
{
key: "document_id",
value: "./data/movie_reviews.csv_37",
operator: "!=",
},
],
condition: "or",
},
});
const resultAfterFilter = await queryEngineWithFilters.query({
query: "Get all movie titles.",
});
console.log(`Query from ${resultAfterFilter.sourceNodes?.length} nodes`);
console.log(resultAfterFilter.response);
} catch (e) {
console.error(e);
}
}
void main();
+143
View File
@@ -0,0 +1,143 @@
import {
Document,
Settings,
SimpleDocumentStore,
VectorStoreIndex,
storageContextFromDefaults,
} from "llamaindex";
Settings.callbackManager.on("retrieve-end", (event) => {
const { nodes } = event.detail;
console.log("Number of retrieved nodes:", nodes.length);
});
async function getDataSource() {
const docs = [
new Document({
text: "The dog is brown",
metadata: {
dogId: "1",
private: true,
},
}),
new Document({
text: "The dog is yellow",
metadata: {
dogId: "2",
private: false,
},
}),
new Document({
text: "The dog is red",
metadata: {
dogId: "3",
private: false,
},
}),
];
const storageContext = await storageContextFromDefaults({
persistDir: "./cache",
});
const numberOfDocs = Object.keys(
(storageContext.docStore as SimpleDocumentStore).toDict(),
).length;
if (numberOfDocs === 0) {
// Generate the data source if it's empty
return await VectorStoreIndex.fromDocuments(docs, {
storageContext,
});
}
return await VectorStoreIndex.init({
storageContext,
});
}
async function main() {
const index = await getDataSource();
console.log(
"=============\nQuerying index with no filters. The output should be any color.",
);
const queryEngineNoFilters = index.asQueryEngine({
similarityTopK: 3,
});
const noFilterResponse = await queryEngineNoFilters.query({
query: "What is the color of the dog?",
});
console.log("No filter response:", noFilterResponse.toString());
console.log(
"\n=============\nQuerying index with dogId 2 and private false. The output always should be red.",
);
const queryEngineEQ = index.asQueryEngine({
preFilters: {
filters: [
{
key: "private",
value: "false",
operator: "==",
},
{
key: "dogId",
value: "3",
operator: "==",
},
],
},
similarityTopK: 3,
});
const responseEQ = await queryEngineEQ.query({
query: "What is the color of the dog?",
});
console.log("Filter with dogId 2 response:", responseEQ.toString());
console.log(
"\n=============\nQuerying index with dogId IN (1, 3). The output should be brown and red.",
);
const queryEngineIN = index.asQueryEngine({
preFilters: {
filters: [
{
key: "dogId",
value: ["1", "3"],
operator: "in",
},
],
},
similarityTopK: 3,
});
const responseIN = await queryEngineIN.query({
query: "What is the color of the dog?",
});
console.log("Filter with dogId IN (1, 3) response:", responseIN.toString());
console.log(
"\n=============\nQuerying index with dogId IN (1, 3). The output should be any.",
);
const queryEngineOR = index.asQueryEngine({
preFilters: {
filters: [
{
key: "private",
value: "false",
operator: "==",
},
{
key: "dogId",
value: ["1", "3"],
operator: "in",
},
],
condition: "or",
},
similarityTopK: 3,
});
const responseOR = await queryEngineOR.query({
query: "What is the color of the dog?",
});
console.log(
"Filter with dogId with OR operator response:",
responseOR.toString(),
);
}
void main();
+1 -1
View File
@@ -64,7 +64,7 @@ async function main() {
{
key: "dogId",
value: "2",
filterType: "ExactMatch",
operator: "==",
},
],
},
@@ -1,5 +1,21 @@
# @llamaindex/autotool-02-next-example
## 0.1.30
### Patch Changes
- Updated dependencies [b974eea]
- llamaindex@0.5.5
- @llamaindex/autotool@2.0.0
## 0.1.29
### Patch Changes
- Updated dependencies [1a65ead]
- llamaindex@0.5.4
- @llamaindex/autotool@2.0.0
## 0.1.28
### Patch Changes
@@ -1,7 +1,7 @@
{
"name": "@llamaindex/autotool-02-next-example",
"private": true,
"version": "0.1.28",
"version": "0.1.30",
"scripts": {
"dev": "next dev",
"build": "next build",
+1 -1
View File
@@ -51,7 +51,7 @@
"unplugin": "^1.10.1"
},
"peerDependencies": {
"llamaindex": "^0.5.3",
"llamaindex": "^0.5.5",
"openai": "^4",
"typescript": "^4"
},
+7
View File
@@ -1,5 +1,12 @@
# @llamaindex/community
## 0.0.22
### Patch Changes
- Updated dependencies [b974eea]
- @llamaindex/core@0.1.2
## 0.0.21
### Patch Changes
+1 -1
View File
@@ -1,7 +1,7 @@
{
"name": "@llamaindex/community",
"description": "Community package for LlamaIndexTS",
"version": "0.0.21",
"version": "0.0.22",
"type": "module",
"types": "dist/type/index.d.ts",
"main": "dist/cjs/index.js",
+6
View File
@@ -1,5 +1,11 @@
# @llamaindex/core
## 0.1.2
### Patch Changes
- b974eea: Add support for Metadata filters
## 0.1.1
### Patch Changes
+15 -1
View File
@@ -1,7 +1,7 @@
{
"name": "@llamaindex/core",
"type": "module",
"version": "0.1.1",
"version": "0.1.2",
"description": "LlamaIndex Core Module",
"exports": {
"./llms": {
@@ -32,6 +32,20 @@
"default": "./dist/decorator/index.js"
}
},
"./embeddings": {
"require": {
"types": "./dist/embeddings/index.d.cts",
"default": "./dist/embeddings/index.cjs"
},
"import": {
"types": "./dist/embeddings/index.d.ts",
"default": "./dist/embeddings/index.js"
},
"default": {
"types": "./dist/embeddings/index.d.ts",
"default": "./dist/embeddings/index.js"
}
},
"./global": {
"require": {
"types": "./dist/global/index.d.cts",
@@ -1,9 +1,8 @@
import type { MessageContentDetail } from "@llamaindex/core/llms";
import type { BaseNode } from "@llamaindex/core/schema";
import { MetadataMode } from "@llamaindex/core/schema";
import { extractSingleText } from "@llamaindex/core/utils";
import { type Tokenizers } from "@llamaindex/env";
import type { TransformComponent } from "../ingestion/types.js";
import type { MessageContentDetail } from "../llms";
import type { TransformComponent } from "../schema";
import { BaseNode, MetadataMode } from "../schema";
import { extractSingleText } from "../utils";
import { truncateMaxTokens } from "./tokenizer.js";
import { SimilarityType, similarity } from "./utils.js";
@@ -17,7 +16,13 @@ export type EmbeddingInfo = {
tokenizer?: Tokenizers;
};
export abstract class BaseEmbedding implements TransformComponent {
export type BaseEmbeddingOptions = {
logProgress?: boolean;
};
export abstract class BaseEmbedding
implements TransformComponent<BaseEmbeddingOptions>
{
embedBatchSize = DEFAULT_EMBED_BATCH_SIZE;
embedInfo?: EmbeddingInfo;
@@ -45,7 +50,7 @@ export abstract class BaseEmbedding implements TransformComponent {
* Optionally override this method to retrieve multiple embeddings in a single request
* @param texts
*/
async getTextEmbeddings(texts: string[]): Promise<Array<number[]>> {
getTextEmbeddings = async (texts: string[]): Promise<Array<number[]>> => {
const embeddings: number[][] = [];
for (const text of texts) {
@@ -54,7 +59,7 @@ export abstract class BaseEmbedding implements TransformComponent {
}
return embeddings;
}
};
/**
* Get embeddings for a batch of texts
@@ -63,22 +68,23 @@ export abstract class BaseEmbedding implements TransformComponent {
*/
async getTextEmbeddingsBatch(
texts: string[],
options?: {
logProgress?: boolean;
},
options?: BaseEmbeddingOptions,
): Promise<Array<number[]>> {
return await batchEmbeddings(
texts,
this.getTextEmbeddings.bind(this),
this.getTextEmbeddings,
this.embedBatchSize,
options,
);
}
async transform(nodes: BaseNode[], _options?: any): Promise<BaseNode[]> {
async transform(
nodes: BaseNode[],
options?: BaseEmbeddingOptions,
): Promise<BaseNode[]> {
const texts = nodes.map((node) => node.getContent(MetadataMode.EMBED));
const embeddings = await this.getTextEmbeddingsBatch(texts, _options);
const embeddings = await this.getTextEmbeddingsBatch(texts, options);
for (let i = 0; i < nodes.length; i++) {
nodes[i].embedding = embeddings[i];
@@ -104,9 +110,7 @@ export async function batchEmbeddings<T>(
values: T[],
embedFunc: EmbedFunc<T>,
chunkSize: number,
options?: {
logProgress?: boolean;
},
options?: BaseEmbeddingOptions,
): Promise<Array<number[]>> {
const resultEmbeddings: Array<number[]> = [];
+4
View File
@@ -0,0 +1,4 @@
export { BaseEmbedding, batchEmbeddings } from "./base";
export type { BaseEmbeddingOptions, EmbeddingInfo } from "./base";
export { truncateMaxTokens } from "./tokenizer";
export { DEFAULT_SIMILARITY_TOP_K, SimilarityType, similarity } from "./utils";
+64
View File
@@ -0,0 +1,64 @@
export const DEFAULT_SIMILARITY_TOP_K = 2;
/**
* Similarity type
* Default is cosine similarity. Dot product and negative Euclidean distance are also supported.
*/
export enum SimilarityType {
DEFAULT = "cosine",
DOT_PRODUCT = "dot_product",
EUCLIDEAN = "euclidean",
}
/**
* The similarity between two embeddings.
* @param embedding1
* @param embedding2
* @param mode
* @returns similarity score with higher numbers meaning the two embeddings are more similar
*/
export function similarity(
embedding1: number[],
embedding2: number[],
mode: SimilarityType = SimilarityType.DEFAULT,
): number {
if (embedding1.length !== embedding2.length) {
throw new Error("Embedding length mismatch");
}
// NOTE I've taken enough Kahan to know that we should probably leave the
// numeric programming to numeric programmers. The naive approach here
// will probably cause some avoidable loss of floating point precision
// ml-distance is worth watching although they currently also use the naive
// formulas
function norm(x: number[]): number {
let result = 0;
for (let i = 0; i < x.length; i++) {
result += x[i] * x[i];
}
return Math.sqrt(result);
}
switch (mode) {
case SimilarityType.EUCLIDEAN: {
const difference = embedding1.map((x, i) => x - embedding2[i]);
return -norm(difference);
}
case SimilarityType.DOT_PRODUCT: {
let result = 0;
for (let i = 0; i < embedding1.length; i++) {
result += embedding1[i] * embedding2[i];
}
return result;
}
case SimilarityType.DEFAULT: {
return (
similarity(embedding1, embedding2, SimilarityType.DOT_PRODUCT) /
(norm(embedding1) * norm(embedding2))
);
}
default:
throw new Error("Not implemented yet");
}
}
+1
View File
@@ -1,2 +1,3 @@
export * from "./node";
export type { TransformComponent } from "./type";
export * from "./zod";
+5
View File
@@ -0,0 +1,5 @@
import type { BaseNode } from "./node";
export interface TransformComponent<Options extends Record<string, unknown>> {
transform(nodes: BaseNode[], options?: Options): Promise<BaseNode[]>;
}
@@ -1,6 +1,6 @@
import { truncateMaxTokens } from "@llamaindex/core/embeddings";
import { Tokenizers, tokenizers } from "@llamaindex/env";
import { describe, expect, test } from "vitest";
import { truncateMaxTokens } from "../../src/embeddings/tokenizer.js";
describe("truncateMaxTokens", () => {
const tokenizer = tokenizers.tokenizer(Tokenizers.CL100K_BASE);
+14
View File
@@ -1,5 +1,19 @@
# @llamaindex/experimental
## 0.0.55
### Patch Changes
- Updated dependencies [b974eea]
- llamaindex@0.5.5
## 0.0.54
### Patch Changes
- Updated dependencies [1a65ead]
- llamaindex@0.5.4
## 0.0.53
### Patch Changes
+1 -1
View File
@@ -1,7 +1,7 @@
{
"name": "@llamaindex/experimental",
"description": "Experimental package for LlamaIndexTS",
"version": "0.0.53",
"version": "0.0.55",
"type": "module",
"types": "dist/type/index.d.ts",
"main": "dist/cjs/index.js",
+14
View File
@@ -1,5 +1,19 @@
# llamaindex
## 0.5.5
### Patch Changes
- b974eea: Add support for Metadata filters
- Updated dependencies [b974eea]
- @llamaindex/core@0.1.2
## 0.5.4
### Patch Changes
- 1a65ead: feat: add vendorMultimodal params to LlamaParseReader
## 0.5.3
### Patch Changes
@@ -1,5 +1,19 @@
# @llamaindex/cloudflare-worker-agent-test
## 0.0.39
### Patch Changes
- Updated dependencies [b974eea]
- llamaindex@0.5.5
## 0.0.38
### Patch Changes
- Updated dependencies [1a65ead]
- llamaindex@0.5.4
## 0.0.37
### Patch Changes
@@ -1,6 +1,6 @@
{
"name": "@llamaindex/cloudflare-worker-agent-test",
"version": "0.0.37",
"version": "0.0.39",
"type": "module",
"private": true,
"scripts": {
@@ -1,5 +1,19 @@
# @llamaindex/next-agent-test
## 0.1.39
### Patch Changes
- Updated dependencies [b974eea]
- llamaindex@0.5.5
## 0.1.38
### Patch Changes
- Updated dependencies [1a65ead]
- llamaindex@0.5.4
## 0.1.37
### Patch Changes
@@ -1,6 +1,6 @@
{
"name": "@llamaindex/next-agent-test",
"version": "0.1.37",
"version": "0.1.39",
"private": true,
"scripts": {
"dev": "next dev",
@@ -1,5 +1,19 @@
# test-edge-runtime
## 0.1.38
### Patch Changes
- Updated dependencies [b974eea]
- llamaindex@0.5.5
## 0.1.37
### Patch Changes
- Updated dependencies [1a65ead]
- llamaindex@0.5.4
## 0.1.36
### Patch Changes
@@ -1,6 +1,6 @@
{
"name": "@llamaindex/nextjs-edge-runtime-test",
"version": "0.1.36",
"version": "0.1.38",
"private": true,
"scripts": {
"dev": "next dev",
@@ -1,5 +1,19 @@
# @llamaindex/next-node-runtime
## 0.0.20
### Patch Changes
- Updated dependencies [b974eea]
- llamaindex@0.5.5
## 0.0.19
### Patch Changes
- Updated dependencies [1a65ead]
- llamaindex@0.5.4
## 0.0.18
### Patch Changes
@@ -1,6 +1,6 @@
{
"name": "@llamaindex/next-node-runtime-test",
"version": "0.0.18",
"version": "0.0.20",
"private": true,
"scripts": {
"dev": "next dev",
@@ -1,5 +1,19 @@
# @llamaindex/waku-query-engine-test
## 0.0.39
### Patch Changes
- Updated dependencies [b974eea]
- llamaindex@0.5.5
## 0.0.38
### Patch Changes
- Updated dependencies [1a65ead]
- llamaindex@0.5.4
## 0.0.37
### Patch Changes
@@ -1,6 +1,6 @@
{
"name": "@llamaindex/waku-query-engine-test",
"version": "0.0.37",
"version": "0.0.39",
"type": "module",
"private": true,
"scripts": {
+1 -1
View File
@@ -1,6 +1,6 @@
{
"name": "llamaindex",
"version": "0.5.3",
"version": "0.5.5",
"license": "MIT",
"type": "module",
"keywords": [
+1 -1
View File
@@ -1,7 +1,7 @@
import type { BaseEmbedding } from "@llamaindex/core/embeddings";
import type { LLM } from "@llamaindex/core/llms";
import { PromptHelper } from "./PromptHelper.js";
import { OpenAIEmbedding } from "./embeddings/OpenAIEmbedding.js";
import type { BaseEmbedding } from "./embeddings/types.js";
import { OpenAI } from "./llm/openai.js";
import { SimpleNodeParser } from "./nodeParsers/SimpleNodeParser.js";
import type { NodeParser } from "./nodeParsers/types.js";
+1 -1
View File
@@ -7,10 +7,10 @@ import { OpenAI } from "./llm/openai.js";
import { PromptHelper } from "./PromptHelper.js";
import { SimpleNodeParser } from "./nodeParsers/SimpleNodeParser.js";
import type { BaseEmbedding } from "@llamaindex/core/embeddings";
import type { LLM } from "@llamaindex/core/llms";
import { AsyncLocalStorage, getEnv } from "@llamaindex/env";
import type { ServiceContext } from "./ServiceContext.js";
import type { BaseEmbedding } from "./embeddings/types.js";
import {
getEmbeddedModel,
setEmbeddedModel,
@@ -1,7 +1,6 @@
import type { Document } from "@llamaindex/core/schema";
import type { Document, TransformComponent } from "@llamaindex/core/schema";
import type { BaseRetriever } from "../Retriever.js";
import { RetrieverQueryEngine } from "../engines/query/RetrieverQueryEngine.js";
import type { TransformComponent } from "../ingestion/types.js";
import type { BaseNodePostprocessor } from "../postprocessors/types.js";
import type { BaseSynthesizer } from "../synthesizers/types.js";
import type { QueryEngine } from "../types.js";
@@ -148,11 +147,11 @@ export class LlamaCloudIndex {
static async fromDocuments(
params: {
documents: Document[];
transformations?: TransformComponent[];
transformations?: TransformComponent<any>[];
verbose?: boolean;
} & CloudConstructorParams,
): Promise<LlamaCloudIndex> {
const defaultTransformations: TransformComponent[] = [
const defaultTransformations: TransformComponent<any>[] = [
new SimpleNodeParser(),
new OpenAIEmbedding({
apiKey: getEnv("OPENAI_API_KEY"),
+3 -4
View File
@@ -3,20 +3,19 @@ import type {
PipelineCreate,
PipelineType,
} from "@llamaindex/cloud/api";
import { BaseNode } from "@llamaindex/core/schema";
import { BaseNode, type TransformComponent } from "@llamaindex/core/schema";
import { OpenAIEmbedding } from "../embeddings/OpenAIEmbedding.js";
import type { TransformComponent } from "../ingestion/types.js";
import { SimpleNodeParser } from "../nodeParsers/SimpleNodeParser.js";
export type GetPipelineCreateParams = {
pipelineName: string;
pipelineType: PipelineType;
transformations?: TransformComponent[];
transformations?: TransformComponent<any>[];
inputNodes?: BaseNode[];
};
function getTransformationConfig(
transformation: TransformComponent,
transformation: TransformComponent<any>,
): ConfiguredTransformationItem {
if (transformation instanceof SimpleNodeParser) {
return {
-1
View File
@@ -4,6 +4,5 @@ export const DEFAULT_NUM_OUTPUTS = 256;
export const DEFAULT_CHUNK_SIZE = 1024;
export const DEFAULT_CHUNK_OVERLAP = 20;
export const DEFAULT_CHUNK_OVERLAP_RATIO = 0.1;
export const DEFAULT_SIMILARITY_TOP_K = 2;
export const DEFAULT_PADDING = 5;
@@ -1,7 +1,7 @@
import { BaseEmbedding } from "@llamaindex/core/embeddings";
import type { MessageContentDetail } from "@llamaindex/core/llms";
import { extractSingleText } from "@llamaindex/core/utils";
import { getEnv } from "@llamaindex/env";
import { BaseEmbedding } from "./types.js";
const DEFAULT_MODEL = "sentence-transformers/clip-ViT-B-32";
@@ -103,10 +103,10 @@ export class DeepInfraEmbedding extends BaseEmbedding {
}
}
async getTextEmbeddings(texts: string[]): Promise<number[][]> {
getTextEmbeddings = async (texts: string[]): Promise<number[][]> => {
const textsWithPrefix = mapPrefixWithInputs(this.textPrefix, texts);
return await this.getDeepInfraEmbedding(textsWithPrefix);
}
return this.getDeepInfraEmbedding(textsWithPrefix);
};
async getQueryEmbeddings(queries: string[]): Promise<number[][]> {
const queriesWithPrefix = mapPrefixWithInputs(this.queryPrefix, queries);
@@ -1,6 +1,6 @@
import { BaseEmbedding } from "@llamaindex/core/embeddings";
import { GeminiSession, GeminiSessionStore } from "../llm/gemini/base.js";
import { GEMINI_BACKENDS } from "../llm/gemini/types.js";
import { BaseEmbedding } from "./types.js";
export enum GEMINI_EMBEDDING_MODEL {
EMBEDDING_001 = "embedding-001",
@@ -1,6 +1,6 @@
import { HfInference } from "@huggingface/inference";
import { BaseEmbedding } from "@llamaindex/core/embeddings";
import { lazyLoadTransformers } from "../internal/deps/transformers.js";
import { BaseEmbedding } from "./types.js";
export enum HuggingFaceEmbeddingModelType {
XENOVA_ALL_MINILM_L6_V2 = "Xenova/all-MiniLM-L6-v2",
@@ -91,11 +91,11 @@ export class HuggingFaceInferenceAPIEmbedding extends BaseEmbedding {
return res as number[];
}
async getTextEmbeddings(texts: string[]): Promise<Array<number[]>> {
getTextEmbeddings = async (texts: string[]): Promise<Array<number[]>> => {
const res = await this.hf.featureExtraction({
model: this.model,
inputs: texts,
});
return res as number[][];
}
};
}
@@ -1,5 +1,5 @@
import { BaseEmbedding } from "@llamaindex/core/embeddings";
import { MistralAISession } from "../llm/mistral.js";
import { BaseEmbedding } from "./types.js";
export enum MistralAIEmbeddingModelType {
MISTRAL_EMBED = "mistral-embed",
@@ -1,6 +1,6 @@
import { BaseEmbedding, type EmbeddingInfo } from "@llamaindex/core/embeddings";
import { getEnv } from "@llamaindex/env";
import { MixedbreadAI, MixedbreadAIClient } from "@mixedbread-ai/sdk";
import { BaseEmbedding, type EmbeddingInfo } from "./types.js";
type EmbeddingsRequestWithoutInput = Omit<
MixedbreadAI.EmbeddingsRequest,
@@ -153,7 +153,7 @@ export class MixedbreadAIEmbeddings extends BaseEmbedding {
* const result = await mxbai.getTextEmbeddings(texts);
* console.log(result);
*/
async getTextEmbeddings(texts: string[]): Promise<Array<number[]>> {
getTextEmbeddings = async (texts: string[]): Promise<Array<number[]>> => {
if (texts.length === 0) {
return [];
}
@@ -166,5 +166,5 @@ export class MixedbreadAIEmbeddings extends BaseEmbedding {
this.requestOptions,
);
return response.data.map((d) => d.embedding as number[]);
}
};
}
@@ -1,3 +1,4 @@
import { BaseEmbedding, batchEmbeddings } from "@llamaindex/core/embeddings";
import type { MessageContentDetail } from "@llamaindex/core/llms";
import {
ImageNode,
@@ -8,7 +9,6 @@ import {
type ImageType,
} from "@llamaindex/core/schema";
import { extractImage, extractSingleText } from "@llamaindex/core/utils";
import { BaseEmbedding, batchEmbeddings } from "./types.js";
/*
* Base class for Multi Modal embeddings.
@@ -1,5 +1,5 @@
import type { BaseEmbedding } from "@llamaindex/core/embeddings";
import { Ollama } from "../llm/ollama.js";
import type { BaseEmbedding } from "./types.js";
/**
* OllamaEmbedding is an alias for Ollama that implements the BaseEmbedding interface.
@@ -1,3 +1,4 @@
import { BaseEmbedding } from "@llamaindex/core/embeddings";
import { Tokenizers } from "@llamaindex/env";
import type { ClientOptions as OpenAIClientOptions } from "openai";
import type { AzureOpenAIConfig } from "../llm/azure.js";
@@ -8,7 +9,6 @@ import {
} from "../llm/azure.js";
import type { OpenAISession } from "../llm/openai.js";
import { getOpenAISession } from "../llm/openai.js";
import { BaseEmbedding } from "./types.js";
export const ALL_OPENAI_EMBEDDING_MODELS = {
"text-embedding-ada-002": {
@@ -132,9 +132,9 @@ export class OpenAIEmbedding extends BaseEmbedding {
* Get embeddings for a batch of texts
* @param texts
*/
async getTextEmbeddings(texts: string[]): Promise<number[][]> {
return await this.getOpenAIEmbedding(texts);
}
getTextEmbeddings = async (texts: string[]): Promise<number[][]> => {
return this.getOpenAIEmbedding(texts);
};
/**
* Get embeddings for a single text
+1 -2
View File
@@ -1,3 +1,4 @@
export * from "@llamaindex/core/embeddings";
export { DeepInfraEmbedding } from "./DeepInfraEmbedding.js";
export { FireworksEmbedding } from "./fireworks.js";
export * from "./GeminiEmbedding.js";
@@ -9,5 +10,3 @@ export * from "./MultiModalEmbedding.js";
export { OllamaEmbedding } from "./OllamaEmbedding.js";
export * from "./OpenAIEmbedding.js";
export { TogetherEmbedding } from "./together.js";
export * from "./types.js";
export * from "./utils.js";
-256
View File
@@ -1,256 +0,0 @@
import type { ImageType } from "@llamaindex/core/schema";
import { fs } from "@llamaindex/env";
import _ from "lodash";
import { filetypemime } from "magic-bytes.js";
import { DEFAULT_SIMILARITY_TOP_K } from "../constants.js";
import type { VectorStoreQueryMode } from "../storage/vectorStore/types.js";
/**
* Similarity type
* Default is cosine similarity. Dot product and negative Euclidean distance are also supported.
*/
export enum SimilarityType {
DEFAULT = "cosine",
DOT_PRODUCT = "dot_product",
EUCLIDEAN = "euclidean",
}
/**
* The similarity between two embeddings.
* @param embedding1
* @param embedding2
* @param mode
* @returns similarity score with higher numbers meaning the two embeddings are more similar
*/
export function similarity(
embedding1: number[],
embedding2: number[],
mode: SimilarityType = SimilarityType.DEFAULT,
): number {
if (embedding1.length !== embedding2.length) {
throw new Error("Embedding length mismatch");
}
// NOTE I've taken enough Kahan to know that we should probably leave the
// numeric programming to numeric programmers. The naive approach here
// will probably cause some avoidable loss of floating point precision
// ml-distance is worth watching although they currently also use the naive
// formulas
function norm(x: number[]): number {
let result = 0;
for (let i = 0; i < x.length; i++) {
result += x[i] * x[i];
}
return Math.sqrt(result);
}
switch (mode) {
case SimilarityType.EUCLIDEAN: {
const difference = embedding1.map((x, i) => x - embedding2[i]);
return -norm(difference);
}
case SimilarityType.DOT_PRODUCT: {
let result = 0;
for (let i = 0; i < embedding1.length; i++) {
result += embedding1[i] * embedding2[i];
}
return result;
}
case SimilarityType.DEFAULT: {
return (
similarity(embedding1, embedding2, SimilarityType.DOT_PRODUCT) /
(norm(embedding1) * norm(embedding2))
);
}
default:
throw new Error("Not implemented yet");
}
}
/**
* Get the top K embeddings from a list of embeddings ordered by similarity to the query.
* @param queryEmbedding
* @param embeddings list of embeddings to consider
* @param similarityTopK max number of embeddings to return, default 2
* @param embeddingIds ids of embeddings in the embeddings list
* @param similarityCutoff minimum similarity score
* @returns
*/
// eslint-disable-next-line max-params
export function getTopKEmbeddings(
queryEmbedding: number[],
embeddings: number[][],
similarityTopK: number = DEFAULT_SIMILARITY_TOP_K,
embeddingIds: any[] | null = null,
similarityCutoff: number | null = null,
): [number[], any[]] {
if (embeddingIds == null) {
embeddingIds = Array(embeddings.length).map((_, i) => i);
}
if (embeddingIds.length !== embeddings.length) {
throw new Error(
"getTopKEmbeddings: embeddings and embeddingIds length mismatch",
);
}
const similarities: { similarity: number; id: number }[] = [];
for (let i = 0; i < embeddings.length; i++) {
const sim = similarity(queryEmbedding, embeddings[i]);
if (similarityCutoff == null || sim > similarityCutoff) {
similarities.push({ similarity: sim, id: embeddingIds[i] });
}
}
similarities.sort((a, b) => b.similarity - a.similarity); // Reverse sort
const resultSimilarities: number[] = [];
const resultIds: any[] = [];
for (let i = 0; i < similarityTopK; i++) {
if (i >= similarities.length) {
break;
}
resultSimilarities.push(similarities[i].similarity);
resultIds.push(similarities[i].id);
}
return [resultSimilarities, resultIds];
}
// eslint-disable-next-line max-params
export function getTopKEmbeddingsLearner(
queryEmbedding: number[],
embeddings: number[][],
similarityTopK?: number,
embeddingsIds?: any[],
queryMode?: VectorStoreQueryMode,
): [number[], any[]] {
throw new Error("Not implemented yet");
}
// eslint-disable-next-line max-params
export function getTopKMMREmbeddings(
queryEmbedding: number[],
embeddings: number[][],
similarityFn: ((...args: any[]) => number) | null = null,
similarityTopK: number | null = null,
embeddingIds: any[] | null = null,
_similarityCutoff: number | null = null,
mmrThreshold: number | null = null,
): [number[], any[]] {
const threshold = mmrThreshold || 0.5;
similarityFn = similarityFn || similarity;
if (embeddingIds === null || embeddingIds.length === 0) {
embeddingIds = Array.from({ length: embeddings.length }, (_, i) => i);
}
const fullEmbedMap = new Map(embeddingIds.map((value, i) => [value, i]));
const embedMap = new Map(fullEmbedMap);
const embedSimilarity: Map<any, number> = new Map();
let score: number = Number.NEGATIVE_INFINITY;
let highScoreId: any | null = null;
for (let i = 0; i < embeddings.length; i++) {
const emb = embeddings[i];
const similarity = similarityFn(queryEmbedding, emb);
embedSimilarity.set(embeddingIds[i], similarity);
if (similarity * threshold > score) {
highScoreId = embeddingIds[i];
score = similarity * threshold;
}
}
const results: [number, any][] = [];
const embeddingLength = embeddings.length;
const similarityTopKCount = similarityTopK || embeddingLength;
while (results.length < Math.min(similarityTopKCount, embeddingLength)) {
results.push([score, highScoreId]);
embedMap.delete(highScoreId);
const recentEmbeddingId = highScoreId;
score = Number.NEGATIVE_INFINITY;
for (const embedId of Array.from(embedMap.keys())) {
const overlapWithRecent = similarityFn(
embeddings[embedMap.get(embedId)!],
embeddings[fullEmbedMap.get(recentEmbeddingId)!],
);
if (
threshold * embedSimilarity.get(embedId)! -
(1 - threshold) * overlapWithRecent >
score
) {
score =
threshold * embedSimilarity.get(embedId)! -
(1 - threshold) * overlapWithRecent;
highScoreId = embedId;
}
}
}
const resultSimilarities = results.map(([s, _]) => s);
const resultIds = results.map(([_, n]) => n);
return [resultSimilarities, resultIds];
}
async function blobToDataUrl(input: Blob) {
const buffer = Buffer.from(await input.arrayBuffer());
const mimes = filetypemime(buffer);
if (mimes.length < 1) {
throw new Error("Unsupported image type");
}
return "data:" + mimes[0] + ";base64," + buffer.toString("base64");
}
export async function imageToString(input: ImageType): Promise<string> {
if (input instanceof Blob) {
// if the image is a Blob, convert it to a base64 data URL
return await blobToDataUrl(input);
} else if (_.isString(input)) {
return input;
} else if (input instanceof URL) {
return input.toString();
} else {
throw new Error(`Unsupported input type: ${typeof input}`);
}
}
export function stringToImage(input: string): ImageType {
if (input.startsWith("data:")) {
// if the input is a base64 data URL, convert it back to a Blob
const base64Data = input.split(",")[1];
const byteArray = Buffer.from(base64Data, "base64");
return new Blob([byteArray]);
} else if (input.startsWith("http://") || input.startsWith("https://")) {
return new URL(input);
} else if (_.isString(input)) {
return input;
} else {
throw new Error(`Unsupported input type: ${typeof input}`);
}
}
export async function imageToDataUrl(input: ImageType): Promise<string> {
// first ensure, that the input is a Blob
if (
(input instanceof URL && input.protocol === "file:") ||
_.isString(input)
) {
// string or file URL
const dataBuffer = await fs.readFile(
input instanceof URL ? input.pathname : input,
);
input = new Blob([dataBuffer]);
} else if (!(input instanceof Blob)) {
if (input instanceof URL) {
throw new Error(`Unsupported URL with protocol: ${input.protocol}`);
} else {
throw new Error(`Unsupported input type: ${typeof input}`);
}
}
return await blobToDataUrl(input);
}
+2 -3
View File
@@ -1,12 +1,11 @@
import type { BaseNode } from "@llamaindex/core/schema";
import type { BaseNode, TransformComponent } from "@llamaindex/core/schema";
import { MetadataMode, TextNode } from "@llamaindex/core/schema";
import type { TransformComponent } from "../ingestion/types.js";
import { defaultNodeTextTemplate } from "./prompts.js";
/*
* Abstract class for all extractors.
*/
export abstract class BaseExtractor implements TransformComponent {
export abstract class BaseExtractor implements TransformComponent<any> {
isTextNodeOnly: boolean = true;
showProgress: boolean = true;
metadataMode: MetadataMode = MetadataMode.ALL;
@@ -1,3 +1,7 @@
import {
DEFAULT_SIMILARITY_TOP_K,
type BaseEmbedding,
} from "@llamaindex/core/embeddings";
import { Settings } from "@llamaindex/core/global";
import type { MessageContent } from "@llamaindex/core/llms";
import {
@@ -13,8 +17,6 @@ import { wrapEventCaller } from "@llamaindex/core/utils";
import type { BaseRetriever, RetrieveParams } from "../../Retriever.js";
import type { ServiceContext } from "../../ServiceContext.js";
import { nodeParserFromSettingsOrContext } from "../../Settings.js";
import { DEFAULT_SIMILARITY_TOP_K } from "../../constants.js";
import type { BaseEmbedding } from "../../embeddings/index.js";
import { RetrieverQueryEngine } from "../../engines/query/RetrieverQueryEngine.js";
import {
addNodesToVectorStores,
@@ -1,12 +1,11 @@
import type { BaseNode } from "@llamaindex/core/schema";
import type { BaseNode, TransformComponent } from "@llamaindex/core/schema";
import { MetadataMode } from "@llamaindex/core/schema";
import { createSHA256 } from "@llamaindex/env";
import { docToJson, jsonToDoc } from "../storage/docStore/utils.js";
import { SimpleKVStore } from "../storage/kvStore/SimpleKVStore.js";
import type { BaseKVStore } from "../storage/kvStore/types.js";
import type { TransformComponent } from "./types.js";
const transformToJSON = (obj: TransformComponent) => {
const transformToJSON = (obj: TransformComponent<any>) => {
const seen: any[] = [];
const replacer = (key: string, value: any) => {
@@ -27,7 +26,7 @@ const transformToJSON = (obj: TransformComponent) => {
export function getTransformationHash(
nodes: BaseNode[],
transform: TransformComponent,
transform: TransformComponent<any>,
) {
const nodesStr: string = nodes
.map((node) => node.getContent(MetadataMode.ALL))
@@ -1,3 +1,4 @@
import type { TransformComponent } from "@llamaindex/core/schema";
import {
ModalityType,
splitNodesByType,
@@ -16,7 +17,6 @@ import {
DocStoreStrategy,
createDocStoreStrategy,
} from "./strategies/index.js";
import type { TransformComponent } from "./types.js";
type IngestionRunArgs = {
documents?: Document[];
@@ -26,12 +26,12 @@ type IngestionRunArgs = {
type TransformRunArgs = {
inPlace?: boolean;
cache?: IngestionCache;
docStoreStrategy?: TransformComponent;
docStoreStrategy?: TransformComponent<any>;
};
export async function runTransformations(
nodesToRun: BaseNode[],
transformations: TransformComponent[],
transformations: TransformComponent<any>[],
transformOptions: any = {},
{ inPlace = true, cache, docStoreStrategy }: TransformRunArgs = {},
): Promise<BaseNode[]> {
@@ -60,7 +60,7 @@ export async function runTransformations(
}
export class IngestionPipeline {
transformations: TransformComponent[] = [];
transformations: TransformComponent<any>[] = [];
documents?: Document[];
reader?: BaseReader;
vectorStore?: VectorStore;
@@ -70,7 +70,7 @@ export class IngestionPipeline {
cache?: IngestionCache;
disableCache: boolean = false;
private _docStoreStrategy?: TransformComponent;
private _docStoreStrategy?: TransformComponent<any>;
constructor(init?: Partial<IngestionPipeline>) {
Object.assign(this, init);
@@ -112,10 +112,7 @@ export class IngestionPipeline {
return inputNodes.flat();
}
async run(
args: IngestionRunArgs & TransformRunArgs = {},
transformOptions?: any,
): Promise<BaseNode[]> {
async run(args: any = {}, transformOptions?: any): Promise<BaseNode[]> {
args.cache = args.cache ?? this.cache;
args.docStoreStrategy = args.docStoreStrategy ?? this._docStoreStrategy;
const inputNodes = await this.prepareInput(args.documents, args.nodes);
@@ -1,2 +1 @@
export * from "./IngestionPipeline.js";
export * from "./types.js";
@@ -1,11 +1,10 @@
import type { BaseNode } from "@llamaindex/core/schema";
import type { BaseNode, TransformComponent } from "@llamaindex/core/schema";
import type { BaseDocumentStore } from "../../storage/docStore/types.js";
import type { TransformComponent } from "../types.js";
/**
* Handle doc store duplicates by checking all hashes.
*/
export class DuplicatesStrategy implements TransformComponent {
export class DuplicatesStrategy implements TransformComponent<any> {
private docStore: BaseDocumentStore;
constructor(docStore: BaseDocumentStore) {
@@ -1,14 +1,13 @@
import type { BaseNode } from "@llamaindex/core/schema";
import type { BaseNode, TransformComponent } from "@llamaindex/core/schema";
import type { BaseDocumentStore } from "../../storage/docStore/types.js";
import type { VectorStore } from "../../storage/vectorStore/types.js";
import type { TransformComponent } from "../types.js";
import { classify } from "./classify.js";
/**
* Handle docstore upserts by checking hashes and ids.
* Identify missing docs and delete them from docstore and vector store
*/
export class UpsertsAndDeleteStrategy implements TransformComponent {
export class UpsertsAndDeleteStrategy implements TransformComponent<any> {
protected docStore: BaseDocumentStore;
protected vectorStores?: VectorStore[];
@@ -1,13 +1,12 @@
import type { BaseNode } from "@llamaindex/core/schema";
import type { BaseNode, TransformComponent } from "@llamaindex/core/schema";
import type { BaseDocumentStore } from "../../storage/docStore/types.js";
import type { VectorStore } from "../../storage/vectorStore/types.js";
import type { TransformComponent } from "../types.js";
import { classify } from "./classify.js";
/**
* Handles doc store upserts by checking hashes and ids.
*/
export class UpsertsStrategy implements TransformComponent {
export class UpsertsStrategy implements TransformComponent<any> {
protected docStore: BaseDocumentStore;
protected vectorStores?: VectorStore[];
@@ -1,6 +1,6 @@
import type { TransformComponent } from "@llamaindex/core/schema";
import type { BaseDocumentStore } from "../../storage/docStore/types.js";
import type { VectorStore } from "../../storage/vectorStore/types.js";
import type { TransformComponent } from "../types.js";
import { DuplicatesStrategy } from "./DuplicatesStrategy.js";
import { UpsertsAndDeleteStrategy } from "./UpsertsAndDeleteStrategy.js";
import { UpsertsStrategy } from "./UpsertsStrategy.js";
@@ -19,7 +19,7 @@ export enum DocStoreStrategy {
NONE = "none", // no-op strategy
}
class NoOpStrategy implements TransformComponent {
class NoOpStrategy implements TransformComponent<any> {
async transform(nodes: any[]): Promise<any[]> {
return nodes;
}
@@ -29,7 +29,7 @@ export function createDocStoreStrategy(
docStoreStrategy: DocStoreStrategy,
docStore?: BaseDocumentStore,
vectorStores: VectorStore[] = [],
): TransformComponent {
): TransformComponent<any> {
if (docStoreStrategy === DocStoreStrategy.NONE) {
return new NoOpStrategy();
}
@@ -1,5 +0,0 @@
import type { BaseNode } from "@llamaindex/core/schema";
export interface TransformComponent {
transform(nodes: BaseNode[], options?: any): Promise<BaseNode[]>;
}
@@ -1,6 +1,6 @@
import type { BaseEmbedding } from "@llamaindex/core/embeddings";
import { AsyncLocalStorage } from "@llamaindex/env";
import { OpenAIEmbedding } from "../../embeddings/OpenAIEmbedding.js";
import type { BaseEmbedding } from "../../embeddings/index.js";
const embeddedModelAsyncLocalStorage = new AsyncLocalStorage<BaseEmbedding>();
let globalEmbeddedModel: BaseEmbedding | null = null;
+178
View File
@@ -1,4 +1,8 @@
import { similarity } from "@llamaindex/core/embeddings";
import type { JSONValue } from "@llamaindex/core/global";
import type { ImageType } from "@llamaindex/core/schema";
import { fs } from "@llamaindex/env";
import { filetypemime } from "magic-bytes.js";
export const isAsyncIterable = (
obj: unknown,
@@ -24,3 +28,177 @@ export function prettifyError(error: unknown): string {
export function stringifyJSONToMessageContent(value: JSONValue): string {
return JSON.stringify(value, null, 2).replace(/"([^"]*)"/g, "$1");
}
/**
* Get the top K embeddings from a list of embeddings ordered by similarity to the query.
* @param queryEmbedding
* @param embeddings list of embeddings to consider
* @param similarityTopK max number of embeddings to return, default 2
* @param embeddingIds ids of embeddings in the embeddings list
* @param similarityCutoff minimum similarity score
* @returns
*/
// eslint-disable-next-line max-params
export function getTopKEmbeddings(
queryEmbedding: number[],
embeddings: number[][],
similarityTopK: number = 2,
embeddingIds: any[] | null = null,
similarityCutoff: number | null = null,
): [number[], any[]] {
if (embeddingIds == null) {
embeddingIds = Array(embeddings.length).map((_, i) => i);
}
if (embeddingIds.length !== embeddings.length) {
throw new Error(
"getTopKEmbeddings: embeddings and embeddingIds length mismatch",
);
}
const similarities: { similarity: number; id: number }[] = [];
for (let i = 0; i < embeddings.length; i++) {
const sim = similarity(queryEmbedding, embeddings[i]);
if (similarityCutoff == null || sim > similarityCutoff) {
similarities.push({ similarity: sim, id: embeddingIds[i] });
}
}
similarities.sort((a, b) => b.similarity - a.similarity); // Reverse sort
const resultSimilarities: number[] = [];
const resultIds: any[] = [];
for (let i = 0; i < similarityTopK; i++) {
if (i >= similarities.length) {
break;
}
resultSimilarities.push(similarities[i].similarity);
resultIds.push(similarities[i].id);
}
return [resultSimilarities, resultIds];
}
// eslint-disable-next-line max-params
export function getTopKMMREmbeddings(
queryEmbedding: number[],
embeddings: number[][],
similarityFn: ((...args: any[]) => number) | null = null,
similarityTopK: number | null = null,
embeddingIds: any[] | null = null,
_similarityCutoff: number | null = null,
mmrThreshold: number | null = null,
): [number[], any[]] {
const threshold = mmrThreshold || 0.5;
similarityFn = similarityFn || similarity;
if (embeddingIds === null || embeddingIds.length === 0) {
embeddingIds = Array.from({ length: embeddings.length }, (_, i) => i);
}
const fullEmbedMap = new Map(embeddingIds.map((value, i) => [value, i]));
const embedMap = new Map(fullEmbedMap);
const embedSimilarity: Map<any, number> = new Map();
let score: number = Number.NEGATIVE_INFINITY;
let highScoreId: any | null = null;
for (let i = 0; i < embeddings.length; i++) {
const emb = embeddings[i];
const similarity = similarityFn(queryEmbedding, emb);
embedSimilarity.set(embeddingIds[i], similarity);
if (similarity * threshold > score) {
highScoreId = embeddingIds[i];
score = similarity * threshold;
}
}
const results: [number, any][] = [];
const embeddingLength = embeddings.length;
const similarityTopKCount = similarityTopK || embeddingLength;
while (results.length < Math.min(similarityTopKCount, embeddingLength)) {
results.push([score, highScoreId]);
embedMap.delete(highScoreId);
const recentEmbeddingId = highScoreId;
score = Number.NEGATIVE_INFINITY;
for (const embedId of Array.from(embedMap.keys())) {
const overlapWithRecent = similarityFn(
embeddings[embedMap.get(embedId)!],
embeddings[fullEmbedMap.get(recentEmbeddingId)!],
);
if (
threshold * embedSimilarity.get(embedId)! -
(1 - threshold) * overlapWithRecent >
score
) {
score =
threshold * embedSimilarity.get(embedId)! -
(1 - threshold) * overlapWithRecent;
highScoreId = embedId;
}
}
}
const resultSimilarities = results.map(([s, _]) => s);
const resultIds = results.map(([_, n]) => n);
return [resultSimilarities, resultIds];
}
async function blobToDataUrl(input: Blob) {
const buffer = Buffer.from(await input.arrayBuffer());
const mimes = filetypemime(buffer);
if (mimes.length < 1) {
throw new Error("Unsupported image type");
}
return "data:" + mimes[0] + ";base64," + buffer.toString("base64");
}
export async function imageToString(input: ImageType): Promise<string> {
if (input instanceof Blob) {
// if the image is a Blob, convert it to a base64 data URL
return await blobToDataUrl(input);
} else if (typeof input === "string") {
return input;
} else if (input instanceof URL) {
return input.toString();
} else {
throw new Error(`Unsupported input type: ${typeof input}`);
}
}
export function stringToImage(input: string): ImageType {
if (input.startsWith("data:")) {
// if the input is a base64 data URL, convert it back to a Blob
const base64Data = input.split(",")[1];
const byteArray = Buffer.from(base64Data, "base64");
return new Blob([byteArray]);
} else if (input.startsWith("http://") || input.startsWith("https://")) {
return new URL(input);
} else {
return input;
}
}
export async function imageToDataUrl(input: ImageType): Promise<string> {
// first ensure, that the input is a Blob
if (
(input instanceof URL && input.protocol === "file:") ||
typeof input === "string"
) {
// string or file URL
const dataBuffer = await fs.readFile(
input instanceof URL ? input.pathname : input,
);
input = new Blob([dataBuffer]);
} else if (!(input instanceof Blob)) {
if (input instanceof URL) {
throw new Error(`Unsupported URL with protocol: ${input.protocol}`);
} else {
throw new Error(`Unsupported input type: ${typeof input}`);
}
}
return await blobToDataUrl(input);
}
+1 -1
View File
@@ -1,3 +1,4 @@
import { BaseEmbedding } from "@llamaindex/core/embeddings";
import type {
ChatResponse,
ChatResponseChunk,
@@ -10,7 +11,6 @@ import type {
LLMMetadata,
} from "@llamaindex/core/llms";
import { extractText, streamConverter } from "@llamaindex/core/utils";
import { BaseEmbedding } from "../embeddings/types.js";
import {
Ollama as OllamaBase,
type Config,
+2 -3
View File
@@ -1,10 +1,9 @@
import type { BaseNode } from "@llamaindex/core/schema";
import type { TransformComponent } from "../ingestion/types.js";
import type { BaseNode, TransformComponent } from "@llamaindex/core/schema";
/**
* A NodeParser generates Nodes from Documents
*/
export interface NodeParser extends TransformComponent {
export interface NodeParser extends TransformComponent<any> {
/**
* Generates an array of nodes from an array of documents.
* @param documents - The documents to generate nodes from.
@@ -129,9 +129,9 @@ export class LlamaParseReader extends FileReader {
doNotUnrollColumns?: boolean;
// The page separator to use to split the text. Default is None, which means the parser will use the default separator '\\n---\\n'.
pageSeparator?: string;
// Whether to use gpt-4o to extract text from documents.
// Deprecated. Use vendorMultimodal params. Whether to use gpt-4o to extract text from documents.
gpt4oMode: boolean = false;
// The API key for the GPT-4o API. Optional, lowers the cost of parsing. Can be set as an env variable: LLAMA_CLOUD_GPT4O_API_KEY.
// Deprecated. Use vendorMultimodal params. The API key for the GPT-4o API. Optional, lowers the cost of parsing. Can be set as an env variable: LLAMA_CLOUD_GPT4O_API_KEY.
gpt4oApiKey?: string;
// The bounding box to use to extract text from documents. Describe as a string containing the bounding box margins.
boundingBox?: string;
@@ -139,6 +139,12 @@ export class LlamaParseReader extends FileReader {
targetPages?: string;
// Whether or not to ignore and skip errors raised during parsing.
ignoreErrors: boolean = true;
// Whether to use the vendor multimodal API.
useVendorMultimodalModel: boolean = false;
// The model name for the vendor multimodal API
vendorMultimodalModelName?: string;
// The API key for the multimodal API. Can also be set as an env variable: LLAMA_CLOUD_VENDOR_MULTIMODAL_API_KEY
vendorMultimodalApiKey?: string;
// numWorkers is implemented in SimpleDirectoryReader
constructor(params: Partial<LlamaParseReader> = {}) {
@@ -158,6 +164,13 @@ export class LlamaParseReader extends FileReader {
this.gpt4oApiKey = params.gpt4oApiKey;
}
if (params.useVendorMultimodalModel) {
params.vendorMultimodalApiKey =
params.vendorMultimodalApiKey ??
getEnv("LLAMA_CLOUD_VENDOR_MULTIMODAL_API_KEY");
this.vendorMultimodalApiKey = params.vendorMultimodalApiKey;
}
}
// Create a job for the LlamaParse API
@@ -189,6 +202,9 @@ export class LlamaParseReader extends FileReader {
gpt4o_api_key: this.gpt4oApiKey,
bounding_box: this.boundingBox,
target_pages: this.targetPages,
use_vendor_multimodal_model: this.useVendorMultimodalModel?.toString(),
vendor_multimodal_model_name: this.vendorMultimodalModelName,
vendor_multimodal_api_key: this.vendorMultimodalApiKey,
};
// Appends body with any defined LlamaParseBodyParams
@@ -11,11 +11,66 @@ import {
import {
VectorStoreBase,
type IEmbedModel,
type MetadataFilters,
type VectorStoreNoEmbedModel,
type VectorStoreQuery,
type VectorStoreQueryResult,
} from "./types.js";
import { metadataDictToNode, nodeToMetadata } from "./utils.js";
import {
metadataDictToNode,
nodeToMetadata,
parseArrayValue,
parseNumberValue,
parsePrimitiveValue,
} from "./utils.js";
function parseScalarFilters(scalarFilters: MetadataFilters): string {
const condition = scalarFilters.condition ?? "and";
const filters: string[] = [];
for (const filter of scalarFilters.filters) {
switch (filter.operator) {
case "==":
case "!=": {
filters.push(
`metadata["${filter.key}"] ${filter.operator} "${parsePrimitiveValue(filter.value)}"`,
);
break;
}
case "in": {
const filterValue = parseArrayValue(filter.value)
.map((v) => `"${v}"`)
.join(", ");
filters.push(
`metadata["${filter.key}"] ${filter.operator} [${filterValue}]`,
);
break;
}
case "nin": {
// Milvus does not support `nin` operator, so we need to manually check every value
// Expected: not metadata["key"] != "value1" and not metadata["key"] != "value2"
const filterStr = parseArrayValue(filter.value)
.map((v) => `metadata["${filter.key}"] != "${v}"`)
.join(" && ");
filters.push(filterStr);
break;
}
case "<":
case "<=":
case ">":
case ">=": {
filters.push(
`metadata["${filter.key}"] ${filter.operator} ${parseNumberValue(filter.value)}`,
);
break;
}
default:
throw new Error(`Operator ${filter.operator} is not supported.`);
}
}
return filters.join(` ${condition} `);
}
export class MilvusVectorStore
extends VectorStoreBase
@@ -183,6 +238,12 @@ export class MilvusVectorStore
});
}
public toMilvusFilter(filters?: MetadataFilters): string | undefined {
if (!filters) return undefined;
// TODO: Milvus also support standard filters, we can add it later
return parseScalarFilters(filters);
}
public async query(
query: VectorStoreQuery,
_options?: any,
@@ -193,6 +254,7 @@ export class MilvusVectorStore
collection_name: this.collectionName,
limit: query.similarityTopK,
vector: query.queryEmbedding,
filter: this.toMilvusFilter(query.filters),
});
const nodes: BaseNode<Metadata>[] = [];
@@ -1,9 +1,9 @@
import type { BaseEmbedding } from "@llamaindex/core/embeddings";
import type { BaseNode } from "@llamaindex/core/schema";
import { MetadataMode } from "@llamaindex/core/schema";
import { getEnv } from "@llamaindex/env";
import type { BulkWriteOptions, Collection } from "mongodb";
import { MongoClient } from "mongodb";
import { BaseEmbedding } from "../../embeddings/types.js";
import {
VectorStoreBase,
type MetadataFilters,
@@ -272,7 +272,10 @@ export class PGVectorStore
query.filters?.filters.forEach((filter, index) => {
const paramIndex = params.length + 1;
whereClauses.push(`metadata->>'${filter.key}' = $${paramIndex}`);
params.push(filter.value);
// TODO: support filter with other operators
if (!Array.isArray(filter.value)) {
params.push(filter.value);
}
});
const where =
@@ -1,7 +1,7 @@
import {
VectorStoreBase,
type ExactMatchFilter,
type IEmbedModel,
type MetadataFilter,
type MetadataFilters,
type VectorStoreNoEmbedModel,
type VectorStoreQuery,
@@ -199,8 +199,12 @@ export class PineconeVectorStore
}
toPineconeFilter(stdFilters?: MetadataFilters) {
return stdFilters?.filters?.reduce((carry: any, item: ExactMatchFilter) => {
carry[item.key] = item.value;
return stdFilters?.filters?.reduce((carry: any, item: MetadataFilter) => {
// Use MetadataFilter with EQ operator to replace ExactMatchFilter
// TODO: support filter with other operators
if (item.operator === "==") {
carry[item.key] = item.value;
}
return carry;
}, {});
}
@@ -1,21 +1,29 @@
import type { BaseEmbedding } from "@llamaindex/core/embeddings";
import type { BaseNode } from "@llamaindex/core/schema";
import { fs, path } from "@llamaindex/env";
import { BaseEmbedding } from "../../embeddings/index.js";
import {
getTopKEmbeddings,
getTopKEmbeddingsLearner,
getTopKMMREmbeddings,
} from "../../embeddings/utils.js";
} from "../../internal/utils.js";
import { exists } from "../FileSystem.js";
import { DEFAULT_PERSIST_DIR } from "../constants.js";
import {
FilterOperator,
VectorStoreBase,
VectorStoreQueryMode,
type IEmbedModel,
type MetadataFilter,
type MetadataFilters,
type VectorStoreNoEmbedModel,
type VectorStoreQuery,
type VectorStoreQueryResult,
} from "./types.js";
import {
nodeToMetadata,
parseArrayValue,
parseNumberValue,
parsePrimitiveValue,
} from "./utils.js";
const LEARNER_MODES = new Set<VectorStoreQueryMode>([
VectorStoreQueryMode.SVM,
@@ -25,9 +33,85 @@ const LEARNER_MODES = new Set<VectorStoreQueryMode>([
const MMR_MODE = VectorStoreQueryMode.MMR;
type MetadataValue = Record<string, any>;
// Mapping of filter operators to metadata filter functions
const OPERATOR_TO_FILTER: {
[key in FilterOperator]: (
{ key, value }: MetadataFilter,
metadata: MetadataValue,
) => boolean;
} = {
[FilterOperator.EQ]: ({ key, value }, metadata) => {
return parsePrimitiveValue(metadata[key]) === parsePrimitiveValue(value);
},
[FilterOperator.NE]: ({ key, value }, metadata) => {
return parsePrimitiveValue(metadata[key]) !== parsePrimitiveValue(value);
},
[FilterOperator.IN]: ({ key, value }, metadata) => {
return parseArrayValue(value).includes(parsePrimitiveValue(metadata[key]));
},
[FilterOperator.NIN]: ({ key, value }, metadata) => {
return !parseArrayValue(value).includes(parsePrimitiveValue(metadata[key]));
},
[FilterOperator.ANY]: ({ key, value }, metadata) => {
return parseArrayValue(value).some((v) =>
parseArrayValue(metadata[key]).includes(v),
);
},
[FilterOperator.ALL]: ({ key, value }, metadata) => {
return parseArrayValue(value).every((v) =>
parseArrayValue(metadata[key]).includes(v),
);
},
[FilterOperator.TEXT_MATCH]: ({ key, value }, metadata) => {
return parsePrimitiveValue(metadata[key]).includes(
parsePrimitiveValue(value),
);
},
[FilterOperator.CONTAINS]: ({ key, value }, metadata) => {
return parseArrayValue(metadata[key]).includes(parsePrimitiveValue(value));
},
[FilterOperator.GT]: ({ key, value }, metadata) => {
return parseNumberValue(metadata[key]) > parseNumberValue(value);
},
[FilterOperator.LT]: ({ key, value }, metadata) => {
return parseNumberValue(metadata[key]) < parseNumberValue(value);
},
[FilterOperator.GTE]: ({ key, value }, metadata) => {
return parseNumberValue(metadata[key]) >= parseNumberValue(value);
},
[FilterOperator.LTE]: ({ key, value }, metadata) => {
return parseNumberValue(metadata[key]) <= parseNumberValue(value);
},
};
// Build a filter function based on the metadata and the preFilters
const buildFilterFn = (
metadata: MetadataValue | undefined,
preFilters: MetadataFilters | undefined,
) => {
if (!preFilters) return true;
if (!metadata) return false;
const { filters, condition } = preFilters;
const queryCondition = condition || "and"; // default to and
const itemFilterFn = (filter: MetadataFilter) => {
const metadataLookupFn = OPERATOR_TO_FILTER[filter.operator];
if (!metadataLookupFn)
throw new Error(`Unsupported operator: ${filter.operator}`);
return metadataLookupFn(filter, metadata);
};
if (queryCondition === "and") return filters.every(itemFilterFn);
return filters.some(itemFilterFn);
};
class SimpleVectorStoreData {
embeddingDict: Record<string, number[]> = {};
textIdToRefDocId: Record<string, string> = {};
metadataDict: Record<string, MetadataValue> = {};
}
export class SimpleVectorStore
@@ -68,6 +152,11 @@ export class SimpleVectorStore
}
this.data.textIdToRefDocId[node.id_] = node.sourceNode?.nodeId;
// Add metadata to the metadataDict
const metadata = nodeToMetadata(node, true, undefined, false);
delete metadata["_node_content"];
this.data.metadataDict[node.id_] = metadata;
}
if (this.persistPath) {
@@ -84,6 +173,7 @@ export class SimpleVectorStore
for (const textId of textIdsToDelete) {
delete this.data.embeddingDict[textId];
delete this.data.textIdToRefDocId[textId];
if (this.data.metadataDict) delete this.data.metadataDict[textId];
}
if (this.persistPath) {
await this.persist(this.persistPath);
@@ -91,36 +181,40 @@ export class SimpleVectorStore
return Promise.resolve();
}
async query(query: VectorStoreQuery): Promise<VectorStoreQueryResult> {
if (!(query.filters == null)) {
throw new Error(
"Metadata filters not implemented for SimpleVectorStore yet.",
);
}
private async filterNodes(query: VectorStoreQuery): Promise<{
nodeIds: string[];
embeddings: number[][];
}> {
const items = Object.entries(this.data.embeddingDict);
const queryFilterFn = (nodeId: string) => {
const metadata = this.data.metadataDict[nodeId];
return buildFilterFn(metadata, query.filters);
};
let nodeIds: string[], embeddings: number[][];
if (query.docIds) {
const nodeFilterFn = (nodeId: string) => {
if (!query.docIds) return true;
const availableIds = new Set(query.docIds);
const queriedItems = items.filter((item) => availableIds.has(item[0]));
nodeIds = queriedItems.map((item) => item[0]);
embeddings = queriedItems.map((item) => item[1]);
} else {
// No docIds specified, so use all available items
nodeIds = items.map((item) => item[0]);
embeddings = items.map((item) => item[1]);
}
return availableIds.has(nodeId);
};
const queriedItems = items.filter(
(item) => nodeFilterFn(item[0]) && queryFilterFn(item[0]),
);
const nodeIds = queriedItems.map((item) => item[0]);
const embeddings = queriedItems.map((item) => item[1]);
return { nodeIds, embeddings };
}
async query(query: VectorStoreQuery): Promise<VectorStoreQueryResult> {
const { nodeIds, embeddings } = await this.filterNodes(query);
const queryEmbedding = query.queryEmbedding!;
let topSimilarities: number[], topIds: string[];
if (LEARNER_MODES.has(query.mode)) {
[topSimilarities, topIds] = getTopKEmbeddingsLearner(
queryEmbedding,
embeddings,
query.similarityTopK,
nodeIds,
// fixme: unfinished
throw new Error(
"Learner modes not implemented for SimpleVectorStore yet.",
);
} else if (query.mode === MMR_MODE) {
const mmrThreshold = query.mmrThreshold;
@@ -194,6 +288,7 @@ export class SimpleVectorStore
const data = new SimpleVectorStoreData();
data.embeddingDict = dataDict.embeddingDict ?? {};
data.textIdToRefDocId = dataDict.textIdToRefDocId ?? {};
data.metadataDict = dataDict.metadataDict ?? {};
const store = new SimpleVectorStore({ data, embedModel });
store.persistPath = persistPath;
return store;
@@ -206,6 +301,7 @@ export class SimpleVectorStore
const data = new SimpleVectorStoreData();
data.embeddingDict = saveDict.embeddingDict;
data.textIdToRefDocId = saveDict.textIdToRefDocId;
data.metadataDict = saveDict.metadataDict;
return new SimpleVectorStore({ data, embedModel });
}
@@ -213,6 +309,7 @@ export class SimpleVectorStore
return {
embeddingDict: this.data.embeddingDict,
textIdToRefDocId: this.data.textIdToRefDocId,
metadataDict: this.data.metadataDict,
};
}
}
@@ -1,5 +1,5 @@
import type { BaseEmbedding } from "@llamaindex/core/embeddings";
import type { BaseNode, ModalityType } from "@llamaindex/core/schema";
import type { BaseEmbedding } from "../../embeddings/types.js";
import { getEmbeddedModel } from "../../internal/settings/EmbedModel.js";
export interface VectorStoreQueryResult {
@@ -20,20 +20,37 @@ export enum VectorStoreQueryMode {
MMR = "mmr",
}
export interface ExactMatchFilter {
filterType: "ExactMatch";
export enum FilterOperator {
EQ = "==", // default operator (string, number)
IN = "in", // In array (string or number)
GT = ">", // greater than (number)
LT = "<", // less than (number)
NE = "!=", // not equal to (string, number)
GTE = ">=", // greater than or equal to (number)
LTE = "<=", // less than or equal to (number)
NIN = "nin", // Not in array (string or number)
ANY = "any", // Contains any (array of strings)
ALL = "all", // Contains all (array of strings)
TEXT_MATCH = "text_match", // full text match (allows you to search for a specific substring, token or phrase within the text field)
CONTAINS = "contains", // metadata array contains value (string or number)
}
export enum FilterCondition {
AND = "and",
OR = "or",
}
export type MetadataFilterValue = string | number | string[] | number[];
export interface MetadataFilter {
key: string;
value: string | number;
value: MetadataFilterValue;
operator: `${FilterOperator}`; // ==, any, all,...
}
export interface MetadataFilters {
filters: ExactMatchFilter[];
}
export interface VectorStoreQuerySpec {
query: string;
filters: ExactMatchFilter[];
topK?: number;
filters: Array<MetadataFilter>;
condition?: `${FilterCondition}`; // and, or
}
export interface MetadataInfo {
@@ -1,5 +1,6 @@
import type { BaseNode, Metadata } from "@llamaindex/core/schema";
import { ObjectType, jsonToNode } from "@llamaindex/core/schema";
import type { MetadataFilterValue } from "./types.js";
const DEFAULT_TEXT_KEY = "text";
@@ -77,3 +78,25 @@ export function metadataDictToNode(
return jsonToNode(nodeObj, ObjectType.TEXT);
}
}
export const parseNumberValue = (value: MetadataFilterValue): number => {
if (typeof value !== "number") throw new Error("Value must be a number");
return value;
};
export const parsePrimitiveValue = (value: MetadataFilterValue): string => {
if (typeof value !== "number" && typeof value !== "string") {
throw new Error("Value must be a string or number");
}
return value.toString();
};
export const parseArrayValue = (value: MetadataFilterValue): string[] => {
const isPrimitiveArray =
Array.isArray(value) &&
value.every((v) => typeof v === "string" || typeof v === "number");
if (!isPrimitiveArray) {
throw new Error("Value must be an array of strings or numbers");
}
return value.map(String);
};
@@ -7,7 +7,7 @@ import {
type BaseNode,
} from "@llamaindex/core/schema";
import type { SimplePrompt } from "../Prompt.js";
import { imageToDataUrl } from "../embeddings/utils.js";
import { imageToDataUrl } from "../internal/utils.js";
export async function createMessageContent(
prompt: SimplePrompt,
+6
View File
@@ -1,5 +1,11 @@
# @llamaindex/core-test
## 0.0.5
### Patch Changes
- b974eea: Add support for Metadata filters
## 0.0.4
### Patch Changes
@@ -1,10 +1,10 @@
import type { BaseNode } from "@llamaindex/core/schema";
import { TextNode } from "@llamaindex/core/schema";
import type { TransformComponent } from "llamaindex";
import {
IngestionCache,
getTransformationHash,
} from "llamaindex/ingestion/IngestionCache";
import type { TransformComponent } from "llamaindex/ingestion/index";
import { SimpleNodeParser } from "llamaindex/nodeParsers/index";
import { beforeAll, describe, expect, test } from "vitest";
@@ -28,7 +28,7 @@ describe("IngestionCache", () => {
});
describe("getTransformationHash", () => {
let nodes: BaseNode[], transform: TransformComponent;
let nodes: BaseNode[], transform: TransformComponent<any>;
beforeAll(() => {
nodes = [new TextNode({ text: "some text", id_: "some id" })];
@@ -0,0 +1,24 @@
import type { BaseNode } from "@llamaindex/core/schema";
import type { MilvusClient } from "@zilliz/milvus2-sdk-node";
import { MilvusVectorStore } from "llamaindex";
import { type Mocked } from "vitest";
export class TestableMilvusVectorStore extends MilvusVectorStore {
public nodes: BaseNode[] = [];
private fakeTimeout = (ms: number) => {
return new Promise((resolve) => setTimeout(resolve, ms));
};
public async add(nodes: BaseNode[]): Promise<string[]> {
this.nodes.push(...nodes);
await this.fakeTimeout(100);
return nodes.map((node) => node.id_);
}
constructor() {
super({
milvusClient: {} as Mocked<MilvusClient>,
});
}
}
+1 -1
View File
@@ -1,7 +1,7 @@
{
"name": "@llamaindex/llamaindex-test",
"private": true,
"version": "0.0.4",
"version": "0.0.5",
"type": "module",
"scripts": {
"test": "vitest run"
@@ -0,0 +1,333 @@
import type { BaseNode } from "@llamaindex/core/schema";
import { TextNode } from "@llamaindex/core/schema";
import {
MilvusVectorStore,
VectorStoreQueryMode,
type MetadataFilters,
} from "llamaindex";
import { beforeEach, describe, expect, it, vi } from "vitest";
import { TestableMilvusVectorStore } from "../mocks/TestableMilvusVectorStore.js";
type FilterTestCase = {
title: string;
filters?: MetadataFilters;
expected: number;
expectedFilterStr: string | undefined;
mockResultIds: string[];
};
describe("MilvusVectorStore", () => {
let store: MilvusVectorStore;
let nodes: BaseNode[];
beforeEach(() => {
store = new TestableMilvusVectorStore();
nodes = [
new TextNode({
id_: "1",
embedding: [0.1, 0.2],
text: "The dog is brown",
metadata: {
name: "Anakin",
dogId: "1",
private: "true",
weight: 1.2,
type: ["husky", "puppy"],
},
}),
new TextNode({
id_: "2",
embedding: [0.1, 0.2],
text: "The dog is yellow",
metadata: {
name: "Luke",
dogId: "2",
private: "false",
weight: 2.3,
type: ["puppy"],
},
}),
new TextNode({
id_: "3",
embedding: [0.1, 0.2],
text: "The dog is red",
metadata: {
name: "Leia",
dogId: "3",
private: "false",
weight: 3.4,
type: ["husky"],
},
}),
];
});
describe("[MilvusVectorStore] manage nodes", () => {
it("able to add nodes to store", async () => {
const ids = await store.add(nodes);
expect(ids).length(3);
});
});
describe("[MilvusVectorStore] filter nodes with supported operators", () => {
const testcases: FilterTestCase[] = [
{
title: "No filter",
expected: 3,
mockResultIds: ["1", "2", "3"],
expectedFilterStr: undefined,
},
{
title: "Filter EQ",
filters: {
filters: [
{
key: "private",
value: "false",
operator: "==",
},
],
},
expected: 2,
mockResultIds: ["2", "3"],
expectedFilterStr: 'metadata["private"] == "false"',
},
{
title: "Filter NE",
filters: {
filters: [
{
key: "private",
value: "false",
operator: "!=",
},
],
},
expected: 1,
mockResultIds: ["1"],
expectedFilterStr: 'metadata["private"] != "false"',
},
{
title: "Filter GT",
filters: {
filters: [
{
key: "weight",
value: 2.3,
operator: ">",
},
],
},
expected: 1,
mockResultIds: ["3"],
expectedFilterStr: 'metadata["weight"] > 2.3',
},
{
title: "Filter GTE",
filters: {
filters: [
{
key: "weight",
value: 2.3,
operator: ">=",
},
],
},
expected: 2,
mockResultIds: ["2", "3"],
expectedFilterStr: 'metadata["weight"] >= 2.3',
},
{
title: "Filter LT",
filters: {
filters: [
{
key: "weight",
value: 2.3,
operator: "<",
},
],
},
expected: 1,
mockResultIds: ["1"],
expectedFilterStr: 'metadata["weight"] < 2.3',
},
{
title: "Filter LTE",
filters: {
filters: [
{
key: "weight",
value: 2.3,
operator: "<=",
},
],
},
expected: 2,
mockResultIds: ["1", "2"],
expectedFilterStr: 'metadata["weight"] <= 2.3',
},
{
title: "Filter IN",
filters: {
filters: [
{
key: "dogId",
value: ["1", "3"],
operator: "in",
},
],
},
expected: 2,
mockResultIds: ["1", "3"],
expectedFilterStr: 'metadata["dogId"] in ["1", "3"]',
},
{
title: "Filter NIN",
filters: {
filters: [
{
key: "name",
value: ["Anakin", "Leia"],
operator: "nin",
},
],
},
expected: 1,
mockResultIds: ["2"],
expectedFilterStr:
'metadata["name"] != "Anakin" && metadata["name"] != "Leia"',
},
{
title: "Filter OR",
filters: {
filters: [
{
key: "private",
value: "false",
operator: "==",
},
{
key: "dogId",
value: ["1", "3"],
operator: "in",
},
],
condition: "or",
},
expected: 3,
mockResultIds: ["1", "2", "3"],
expectedFilterStr:
'metadata["private"] == "false" or metadata["dogId"] in ["1", "3"]',
},
{
title: "Filter AND",
filters: {
filters: [
{
key: "private",
value: "false",
operator: "==",
},
{
key: "dogId",
value: "10",
operator: "==",
},
],
condition: "and",
},
expected: 0,
mockResultIds: [],
expectedFilterStr:
'metadata["private"] == "false" and metadata["dogId"] == "10"',
},
];
testcases.forEach((tc) => {
it(`[${tc.title}] should return ${tc.expected} nodes`, async () => {
expect(store.toMilvusFilter(tc.filters)).toBe(tc.expectedFilterStr);
vi.spyOn(store, "query").mockResolvedValue({
ids: tc.mockResultIds,
similarities: [0.1, 0.2, 0.3],
});
await store.add(nodes);
const result = await store.query({
queryEmbedding: [0.1, 0.2],
similarityTopK: 3,
mode: VectorStoreQueryMode.DEFAULT,
filters: tc.filters,
});
expect(result.ids).length(tc.expected);
});
});
});
describe("[MilvusVectorStore] filter nodes with unsupported operators", () => {
const testcases: Array<
Omit<FilterTestCase, "expectedFilterStr" | "mockResultIds">
> = [
{
title: "Filter ANY",
filters: {
filters: [
{
key: "type",
value: ["husky", "puppy"],
operator: "any",
},
],
},
expected: 3,
},
{
title: "Filter ALL",
filters: {
filters: [
{
key: "type",
value: ["husky", "puppy"],
operator: "all",
},
],
},
expected: 1,
},
{
title: "Filter CONTAINS",
filters: {
filters: [
{
key: "type",
value: "puppy",
operator: "contains",
},
],
},
expected: 2,
},
{
title: "Filter TEXT_MATCH",
filters: {
filters: [
{
key: "name",
value: "Luk",
operator: "text_match",
},
],
},
expected: 1,
},
];
testcases.forEach((tc) => {
it(`[Unsupported Operator] [${tc.title}] should throw error`, async () => {
const errorMsg = `Operator ${tc.filters?.filters[0].operator} is not supported.`;
expect(() => store.toMilvusFilter(tc.filters)).toThrow(errorMsg);
});
});
});
});
@@ -0,0 +1,299 @@
import {
BaseEmbedding,
BaseNode,
SimpleVectorStore,
TextNode,
VectorStoreQueryMode,
type Metadata,
type MetadataFilters,
} from "llamaindex";
import { beforeEach, describe, expect, it } from "vitest";
type FilterTestCase = {
title: string;
filters?: MetadataFilters;
expected: number;
};
describe("SimpleVectorStore", () => {
let nodes: BaseNode[];
let store: SimpleVectorStore;
beforeEach(() => {
nodes = [
new TextNode({
id_: "1",
embedding: [0.1, 0.2],
text: "The dog is brown",
metadata: {
name: "Anakin",
dogId: "1",
private: "true",
weight: 1.2,
type: ["husky", "puppy"],
},
}),
new TextNode({
id_: "2",
embedding: [0.1, 0.2],
text: "The dog is yellow",
metadata: {
name: "Luke",
dogId: "2",
private: "false",
weight: 2.3,
type: ["puppy"],
},
}),
new TextNode({
id_: "3",
embedding: [0.1, 0.2],
text: "The dog is red",
metadata: {
name: "Leia",
dogId: "3",
private: "false",
weight: 3.4,
type: ["husky"],
},
}),
];
store = new SimpleVectorStore({
embedModel: {} as BaseEmbedding, // Mocking the embedModel
data: {
embeddingDict: {},
textIdToRefDocId: {},
metadataDict: nodes.reduce(
(acc, node) => {
acc[node.id_] = node.metadata;
return acc;
},
{} as Record<string, Metadata>,
),
},
});
});
describe("[SimpleVectorStore] manage nodes", () => {
it("able to add nodes to store", async () => {
const ids = await store.add(nodes);
expect(ids).length(3);
});
});
describe("[SimpleVectorStore] query nodes", () => {
const testcases: FilterTestCase[] = [
{
title: "No filter",
expected: 3,
},
{
title: "Filter EQ",
filters: {
filters: [
{
key: "private",
value: "false",
operator: "==",
},
],
},
expected: 2,
},
{
title: "Filter NE",
filters: {
filters: [
{
key: "private",
value: "false",
operator: "!=",
},
],
},
expected: 1,
},
{
title: "Filter GT",
filters: {
filters: [
{
key: "weight",
value: 2.3,
operator: ">",
},
],
},
expected: 1,
},
{
title: "Filter GTE",
filters: {
filters: [
{
key: "weight",
value: 2.3,
operator: ">=",
},
],
},
expected: 2,
},
{
title: "Filter LT",
filters: {
filters: [
{
key: "weight",
value: 2.3,
operator: "<",
},
],
},
expected: 1,
},
{
title: "Filter LTE",
filters: {
filters: [
{
key: "weight",
value: 2.3,
operator: "<=",
},
],
},
expected: 2,
},
{
title: "Filter IN",
filters: {
filters: [
{
key: "dogId",
value: ["1", "3"],
operator: "in",
},
],
},
expected: 2,
},
{
title: "Filter NIN",
filters: {
filters: [
{
key: "name",
value: ["Anakin", "Leia"],
operator: "nin",
},
],
},
expected: 1,
},
{
title: "Filter ANY",
filters: {
filters: [
{
key: "type",
value: ["husky", "puppy"],
operator: "any",
},
],
},
expected: 3,
},
{
title: "Filter ALL",
filters: {
filters: [
{
key: "type",
value: ["husky", "puppy"],
operator: "all",
},
],
},
expected: 1,
},
{
title: "Filter CONTAINS",
filters: {
filters: [
{
key: "type",
value: "puppy",
operator: "contains",
},
],
},
expected: 2,
},
{
title: "Filter TEXT_MATCH",
filters: {
filters: [
{
key: "name",
value: "Luk",
operator: "text_match",
},
],
},
expected: 1,
},
{
title: "Filter OR",
filters: {
filters: [
{
key: "private",
value: "false",
operator: "==",
},
{
key: "dogId",
value: ["1", "3"],
operator: "in",
},
],
condition: "or",
},
expected: 3,
},
{
title: "Filter AND",
filters: {
filters: [
{
key: "private",
value: "false",
operator: "==",
},
{
key: "dogId",
value: "10",
operator: "==",
},
],
condition: "and",
},
expected: 0,
},
];
testcases.forEach((tc) => {
it(`[${tc.title}] should return ${tc.expected} nodes`, async () => {
await store.add(nodes);
const result = await store.query({
queryEmbedding: [0.1, 0.2],
similarityTopK: 3,
mode: VectorStoreQueryMode.DEFAULT,
filters: tc.filters,
});
expect(result.ids).length(tc.expected);
});
});
});
});