mirror of
https://github.com/run-llama/LlamaIndexTS.git
synced 2026-07-02 20:13:52 -04:00
Compare commits
19 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0c881c8fde | |||
| 815a3416f2 | |||
| 407069ca27 | |||
| 29d042175e | |||
| bbf936e9b4 | |||
| 2212793420 | |||
| d6c6aefd0d | |||
| 4516363097 | |||
| 69dd6d4efa | |||
| 9ea840142b | |||
| a1c45294b3 | |||
| c2ef5057b3 | |||
| b87e6d9ced | |||
| 8d618a6bc3 | |||
| 8d8bee5263 | |||
| ed924641ca | |||
| ce61f9660b | |||
| 072b13cff0 | |||
| ff274dde1d |
@@ -34,3 +34,5 @@ yarn-error.log*
|
||||
|
||||
# vercel
|
||||
.vercel
|
||||
|
||||
storage/
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
import { Document } from "@llamaindex/core/src/Node";
|
||||
import { ListIndex } from "@llamaindex/core/src/index/list";
|
||||
import essay from "./essay";
|
||||
|
||||
async function main() {
|
||||
const document = new Document({ text: essay });
|
||||
const index = await ListIndex.fromDocuments([document]);
|
||||
const queryEngine = index.asQueryEngine();
|
||||
const response = await queryEngine.aquery(
|
||||
"What did the author do growing up?"
|
||||
);
|
||||
console.log(response.toString());
|
||||
}
|
||||
|
||||
main().catch((e: Error) => {
|
||||
console.error(e, e.stack);
|
||||
});
|
||||
@@ -1,9 +0,0 @@
|
||||
Simple flow:
|
||||
|
||||
Get document list, in this case one document.
|
||||
Split each document into nodes, in this case sentences or lines.
|
||||
Embed each of the nodes and get vectors. Store them in memory for now.
|
||||
Embed query.
|
||||
Compare query with nodes and get the top n
|
||||
Put the top n nodes into the prompt.
|
||||
Execute prompt, get result.
|
||||
@@ -0,0 +1,60 @@
|
||||
// from llama_index import SimpleDirectoryReader, VectorStoreIndex
|
||||
// from llama_index.query_engine import SubQuestionQueryEngine
|
||||
// from llama_index.tools import QueryEngineTool, ToolMetadata
|
||||
|
||||
// # load data
|
||||
// pg_essay = SimpleDirectoryReader(
|
||||
// input_dir="docs/examples/data/paul_graham/"
|
||||
// ).load_data()
|
||||
|
||||
// # build index and query engine
|
||||
// query_engine = VectorStoreIndex.from_documents(pg_essay).as_query_engine()
|
||||
|
||||
// # setup base query engine as tool
|
||||
// query_engine_tools = [
|
||||
// QueryEngineTool(
|
||||
// query_engine=query_engine,
|
||||
// metadata=ToolMetadata(
|
||||
// name="pg_essay", description="Paul Graham essay on What I Worked On"
|
||||
// ),
|
||||
// )
|
||||
// ]
|
||||
|
||||
// query_engine = SubQuestionQueryEngine.from_defaults(
|
||||
// query_engine_tools=query_engine_tools
|
||||
// )
|
||||
|
||||
// response = query_engine.query(
|
||||
// "How was Paul Grahams life different before and after YC?"
|
||||
// )
|
||||
|
||||
// print(response)
|
||||
|
||||
import { Document } from "@llamaindex/core/src/Node";
|
||||
import { VectorStoreIndex } from "@llamaindex/core/src/BaseIndex";
|
||||
import { SubQuestionQueryEngine } from "@llamaindex/core/src/QueryEngine";
|
||||
|
||||
import essay from "./essay";
|
||||
|
||||
(async () => {
|
||||
const document = new Document({ text: essay });
|
||||
const index = await VectorStoreIndex.fromDocuments([document]);
|
||||
|
||||
const queryEngine = SubQuestionQueryEngine.fromDefaults({
|
||||
queryEngineTools: [
|
||||
{
|
||||
queryEngine: index.asQueryEngine(),
|
||||
metadata: {
|
||||
name: "pg_essay",
|
||||
description: "Paul Graham essay on What I Worked On",
|
||||
},
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
const response = await queryEngine.aquery(
|
||||
"How was Paul Grahams life different before and after YC?"
|
||||
);
|
||||
|
||||
console.log(response);
|
||||
})();
|
||||
@@ -2,7 +2,7 @@ import { Document } from "@llamaindex/core/src/Node";
|
||||
import { VectorStoreIndex } from "@llamaindex/core/src/BaseIndex";
|
||||
import essay from "./essay";
|
||||
|
||||
(async () => {
|
||||
async function main() {
|
||||
const document = new Document({ text: essay });
|
||||
const index = await VectorStoreIndex.fromDocuments([document]);
|
||||
const queryEngine = index.asQueryEngine();
|
||||
@@ -10,4 +10,6 @@ import essay from "./essay";
|
||||
"What did the author do growing up?"
|
||||
);
|
||||
console.log(response.toString());
|
||||
})();
|
||||
}
|
||||
|
||||
main().catch(console.error);
|
||||
+1
-1
@@ -18,7 +18,7 @@
|
||||
"prettier": "^2.8.8",
|
||||
"prettier-plugin-tailwindcss": "^0.3.0",
|
||||
"ts-jest": "^29.1.0",
|
||||
"turbo": "^1.10.5"
|
||||
"turbo": "^1.10.7"
|
||||
},
|
||||
"packageManager": "pnpm@7.15.0",
|
||||
"name": "llamascript"
|
||||
|
||||
@@ -10,6 +10,9 @@
|
||||
"uuid": "^9.0.0",
|
||||
"wink-nlp": "^1.14.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=18.0.0"
|
||||
},
|
||||
"main": "src/index.ts",
|
||||
"types": "src/index.ts",
|
||||
"scripts": {
|
||||
|
||||
+165
-42
@@ -1,22 +1,19 @@
|
||||
import { Document, TextNode } from "./Node";
|
||||
import { SimpleNodeParser } from "./NodeParser";
|
||||
import { Document, BaseNode, MetadataMode, NodeWithEmbedding } from "./Node";
|
||||
import { BaseQueryEngine, RetrieverQueryEngine } from "./QueryEngine";
|
||||
import { v4 as uuidv4 } from "uuid";
|
||||
import { VectorIndexRetriever } from "./Retriever";
|
||||
import { BaseEmbedding, OpenAIEmbedding } from "./Embedding";
|
||||
export class BaseIndex {
|
||||
nodes: TextNode[] = [];
|
||||
import { BaseRetriever, VectorIndexRetriever } from "./Retriever";
|
||||
import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext";
|
||||
import {
|
||||
StorageContext,
|
||||
storageContextFromDefaults,
|
||||
} from "./storage/StorageContext";
|
||||
import { BaseDocumentStore } from "./storage/docStore/types";
|
||||
import { VectorStore } from "./storage/vectorStore/types";
|
||||
import { BaseIndexStore } from "./storage/indexStore/types";
|
||||
|
||||
constructor(nodes?: TextNode[]) {
|
||||
this.nodes = nodes ?? [];
|
||||
}
|
||||
}
|
||||
|
||||
export class IndexDict {
|
||||
export abstract class IndexStruct {
|
||||
indexId: string;
|
||||
summary?: string;
|
||||
nodesDict: Record<string, TextNode> = {};
|
||||
docStore: Record<string, Document> = {}; // FIXME: this should be implemented in storageContext
|
||||
|
||||
constructor(indexId = uuidv4(), summary = undefined) {
|
||||
this.indexId = indexId;
|
||||
@@ -29,57 +26,183 @@ export class IndexDict {
|
||||
}
|
||||
return this.summary;
|
||||
}
|
||||
}
|
||||
|
||||
addNode(node: TextNode, textId?: string) {
|
||||
export class IndexDict extends IndexStruct {
|
||||
nodesDict: Record<string, BaseNode> = {};
|
||||
docStore: Record<string, Document> = {}; // FIXME: this should be implemented in storageContext
|
||||
|
||||
getSummary(): string {
|
||||
if (this.summary === undefined) {
|
||||
throw new Error("summary field of the index dict is not set");
|
||||
}
|
||||
return this.summary;
|
||||
}
|
||||
|
||||
addNode(node: BaseNode, textId?: string) {
|
||||
const vectorId = textId ?? node.id_;
|
||||
this.nodesDict[vectorId] = node;
|
||||
}
|
||||
}
|
||||
|
||||
export class VectorStoreIndex extends BaseIndex {
|
||||
indexStruct: IndexDict;
|
||||
embeddingService: BaseEmbedding; // FIXME replace with service context
|
||||
export class IndexList extends IndexStruct {
|
||||
nodes: string[] = [];
|
||||
|
||||
constructor(nodes: TextNode[]) {
|
||||
super(nodes);
|
||||
this.indexStruct = new IndexDict();
|
||||
addNode(node: BaseNode) {
|
||||
this.nodes.push(node.id_);
|
||||
}
|
||||
}
|
||||
|
||||
if (nodes !== undefined) {
|
||||
this.buildIndexFromNodes();
|
||||
}
|
||||
export interface BaseIndexInit<T> {
|
||||
serviceContext: ServiceContext;
|
||||
storageContext: StorageContext;
|
||||
docStore: BaseDocumentStore;
|
||||
vectorStore?: VectorStore;
|
||||
indexStore?: BaseIndexStore;
|
||||
indexStruct: T;
|
||||
}
|
||||
export abstract class BaseIndex<T> {
|
||||
serviceContext: ServiceContext;
|
||||
storageContext: StorageContext;
|
||||
docStore: BaseDocumentStore;
|
||||
vectorStore?: VectorStore;
|
||||
indexStore?: BaseIndexStore;
|
||||
indexStruct: T;
|
||||
|
||||
this.embeddingService = new OpenAIEmbedding();
|
||||
constructor(init: BaseIndexInit<T>) {
|
||||
this.serviceContext = init.serviceContext;
|
||||
this.storageContext = init.storageContext;
|
||||
this.docStore = init.docStore;
|
||||
this.vectorStore = init.vectorStore;
|
||||
this.indexStore = init.indexStore;
|
||||
this.indexStruct = init.indexStruct;
|
||||
}
|
||||
|
||||
async getNodeEmbeddingResults(logProgress = false) {
|
||||
for (let i = 0; i < this.nodes.length; ++i) {
|
||||
const node = this.nodes[i];
|
||||
if (logProgress) {
|
||||
console.log(`getting embedding for node ${i}/${this.nodes.length}`);
|
||||
abstract asRetriever(): BaseRetriever;
|
||||
}
|
||||
|
||||
export interface VectorIndexOptions {
|
||||
nodes?: BaseNode[];
|
||||
indexStruct?: IndexDict;
|
||||
serviceContext?: ServiceContext;
|
||||
storageContext?: StorageContext;
|
||||
}
|
||||
|
||||
interface VectorIndexConstructorProps extends BaseIndexInit<IndexDict> {
|
||||
vectorStore: VectorStore;
|
||||
}
|
||||
|
||||
export class VectorStoreIndex extends BaseIndex<IndexDict> {
|
||||
vectorStore: VectorStore;
|
||||
|
||||
private constructor(init: VectorIndexConstructorProps) {
|
||||
super(init);
|
||||
this.vectorStore = init.vectorStore;
|
||||
}
|
||||
|
||||
static async init(options: VectorIndexOptions): Promise<VectorStoreIndex> {
|
||||
const storageContext =
|
||||
options.storageContext ?? (await storageContextFromDefaults({}));
|
||||
const serviceContext =
|
||||
options.serviceContext ?? serviceContextFromDefaults({});
|
||||
const docStore = storageContext.docStore;
|
||||
const vectorStore = storageContext.vectorStore;
|
||||
|
||||
let indexStruct: IndexDict;
|
||||
if (options.indexStruct) {
|
||||
if (options.nodes) {
|
||||
throw new Error(
|
||||
"Cannot initialize VectorStoreIndex with both nodes and indexStruct"
|
||||
);
|
||||
}
|
||||
const embedding = await this.embeddingService.aGetTextEmbedding(
|
||||
node.getText()
|
||||
indexStruct = options.indexStruct;
|
||||
} else {
|
||||
if (!options.nodes) {
|
||||
throw new Error(
|
||||
"Cannot initialize VectorStoreIndex without nodes or indexStruct"
|
||||
);
|
||||
}
|
||||
indexStruct = await VectorStoreIndex.buildIndexFromNodes(
|
||||
options.nodes,
|
||||
serviceContext,
|
||||
vectorStore
|
||||
);
|
||||
node.embedding = embedding;
|
||||
}
|
||||
|
||||
return new VectorStoreIndex({
|
||||
storageContext,
|
||||
serviceContext,
|
||||
docStore,
|
||||
vectorStore,
|
||||
indexStruct,
|
||||
});
|
||||
}
|
||||
|
||||
buildIndexFromNodes() {
|
||||
for (const node of this.nodes) {
|
||||
this.indexStruct.addNode(node);
|
||||
static async agetNodeEmbeddingResults(
|
||||
nodes: BaseNode[],
|
||||
serviceContext: ServiceContext,
|
||||
logProgress = false
|
||||
) {
|
||||
const nodesWithEmbeddings: NodeWithEmbedding[] = [];
|
||||
|
||||
for (let i = 0; i < nodes.length; ++i) {
|
||||
const node = nodes[i];
|
||||
if (logProgress) {
|
||||
console.log(`getting embedding for node ${i}/${nodes.length}`);
|
||||
}
|
||||
const embedding = await serviceContext.embedModel.aGetTextEmbedding(
|
||||
node.getContent(MetadataMode.EMBED)
|
||||
);
|
||||
nodesWithEmbeddings.push({ node, embedding });
|
||||
}
|
||||
|
||||
return nodesWithEmbeddings;
|
||||
}
|
||||
|
||||
static async fromDocuments(documents: Document[]): Promise<VectorStoreIndex> {
|
||||
const nodeParser = new SimpleNodeParser(); // FIXME use service context
|
||||
const nodes = nodeParser.getNodesFromDocuments(documents);
|
||||
const index = new VectorStoreIndex(nodes);
|
||||
await index.getNodeEmbeddingResults();
|
||||
static async buildIndexFromNodes(
|
||||
nodes: BaseNode[],
|
||||
serviceContext: ServiceContext,
|
||||
vectorStore: VectorStore
|
||||
): Promise<IndexDict> {
|
||||
const embeddingResults = await this.agetNodeEmbeddingResults(
|
||||
nodes,
|
||||
serviceContext
|
||||
);
|
||||
|
||||
vectorStore.add(embeddingResults);
|
||||
|
||||
const indexDict = new IndexDict();
|
||||
for (const { node } of embeddingResults) {
|
||||
indexDict.addNode(node);
|
||||
}
|
||||
|
||||
return indexDict;
|
||||
}
|
||||
|
||||
static async fromDocuments(
|
||||
documents: Document[],
|
||||
storageContext?: StorageContext,
|
||||
serviceContext?: ServiceContext
|
||||
): Promise<VectorStoreIndex> {
|
||||
storageContext = storageContext ?? (await storageContextFromDefaults({}));
|
||||
serviceContext = serviceContext ?? serviceContextFromDefaults({});
|
||||
const docStore = storageContext.docStore;
|
||||
|
||||
for (const doc of documents) {
|
||||
docStore.setDocumentHash(doc.id_, doc.hash);
|
||||
}
|
||||
|
||||
const nodes = serviceContext.nodeParser.getNodesFromDocuments(documents);
|
||||
const index = await VectorStoreIndex.init({
|
||||
nodes,
|
||||
storageContext,
|
||||
serviceContext,
|
||||
});
|
||||
return index;
|
||||
}
|
||||
|
||||
asRetriever(): VectorIndexRetriever {
|
||||
return new VectorIndexRetriever(this, this.embeddingService);
|
||||
return new VectorIndexRetriever(this);
|
||||
}
|
||||
|
||||
asQueryEngine(): BaseQueryEngine {
|
||||
|
||||
@@ -174,24 +174,12 @@ export function getTopKMMREmbeddings(
|
||||
}
|
||||
|
||||
export abstract class BaseEmbedding {
|
||||
static similarity(
|
||||
similarity(
|
||||
embedding1: number[],
|
||||
embedding2: number[],
|
||||
mode: SimilarityType = SimilarityType.DOT_PRODUCT
|
||||
mode: SimilarityType = SimilarityType.DEFAULT
|
||||
): number {
|
||||
if (embedding1.length !== embedding2.length) {
|
||||
throw new Error("Embedding length mismatch");
|
||||
}
|
||||
|
||||
if (mode === SimilarityType.DOT_PRODUCT) {
|
||||
let result = 0;
|
||||
for (let i = 0; i < embedding1.length; i++) {
|
||||
result += embedding1[i] * embedding2[i];
|
||||
}
|
||||
return result;
|
||||
} else {
|
||||
throw new Error("Not implemented yet");
|
||||
}
|
||||
return similarity(embedding1, embedding2, mode);
|
||||
}
|
||||
|
||||
abstract aGetTextEmbedding(text: string): Promise<number[]>;
|
||||
|
||||
@@ -1,11 +1,17 @@
|
||||
class GlobalsHelper {
|
||||
defaultTokenizer: ((text: string) => string[]) | null = null;
|
||||
|
||||
tokenizer() {
|
||||
if (this.defaultTokenizer) {
|
||||
return this.defaultTokenizer;
|
||||
}
|
||||
|
||||
const tiktoken = require("tiktoken-node");
|
||||
let enc = new tiktoken.getEncoding("gpt2");
|
||||
const defaultTokenizer = (text: string) => {
|
||||
this.defaultTokenizer = (text: string) => {
|
||||
return enc.encode(text);
|
||||
};
|
||||
return defaultTokenizer;
|
||||
return this.defaultTokenizer;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -136,7 +136,7 @@ export class TextNode extends BaseNode {
|
||||
endCharIdx?: number;
|
||||
// textTemplate: NOTE write your own formatter if needed
|
||||
// metadataTemplate: NOTE write your own formatter if needed
|
||||
metadataSeperator: string = "\n";
|
||||
metadataSeparator: string = "\n";
|
||||
|
||||
constructor(init?: Partial<TextNode>) {
|
||||
super(init);
|
||||
@@ -174,7 +174,7 @@ export class TextNode extends BaseNode {
|
||||
|
||||
return [...usableMetadataKeys]
|
||||
.map((key) => `${key}: ${this.metadata[key]}`)
|
||||
.join(this.metadataSeperator);
|
||||
.join(this.metadataSeparator);
|
||||
}
|
||||
|
||||
setContent(value: string) {
|
||||
@@ -206,11 +206,6 @@ export class IndexNode extends TextNode {
|
||||
}
|
||||
}
|
||||
|
||||
export interface NodeWithScore {
|
||||
node: TextNode;
|
||||
score: number;
|
||||
}
|
||||
|
||||
export class Document extends TextNode {
|
||||
constructor(init?: Partial<Document>) {
|
||||
super(init);
|
||||
@@ -229,3 +224,13 @@ export class Document extends TextNode {
|
||||
export class ImageDocument extends Document {
|
||||
image?: string;
|
||||
}
|
||||
|
||||
export interface NodeWithScore {
|
||||
node: BaseNode;
|
||||
score: number;
|
||||
}
|
||||
|
||||
export interface NodeWithEmbedding {
|
||||
node: BaseNode;
|
||||
embedding: number[];
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ export function getNodesFromDocument(
|
||||
|
||||
const textSplits = getTextSplitsFromDocument(document, textSplitter);
|
||||
|
||||
textSplits.forEach((textSplit, index) => {
|
||||
textSplits.forEach((textSplit) => {
|
||||
const node = new TextNode({ text: textSplit });
|
||||
node.relationships[NodeRelationship.SOURCE] = document.asRelatedNodeInfo();
|
||||
nodes.push(node);
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
import { SubQuestion } from "./QuestionGenerator";
|
||||
|
||||
export interface BaseOutputParser<T> {
|
||||
parse(output: string): T;
|
||||
format(output: string): string;
|
||||
}
|
||||
|
||||
export interface StructuredOutput<T> {
|
||||
rawOutput: string;
|
||||
parsedOutput: T;
|
||||
}
|
||||
|
||||
class OutputParserError extends Error {
|
||||
cause: Error | undefined;
|
||||
output: string | undefined;
|
||||
|
||||
constructor(
|
||||
message: string,
|
||||
options: { cause?: Error; output?: string } = {}
|
||||
) {
|
||||
// @ts-ignore
|
||||
super(message, options); // https://github.com/tc39/proposal-error-cause
|
||||
this.name = "OutputParserError";
|
||||
|
||||
if (!this.cause) {
|
||||
// Need to check for those environments that have implemented the proposal
|
||||
this.cause = options.cause;
|
||||
}
|
||||
this.output = options.output;
|
||||
|
||||
// This line is to maintain proper stack trace in V8
|
||||
// (https://v8.dev/docs/stack-trace-api)
|
||||
if (Error.captureStackTrace) {
|
||||
Error.captureStackTrace(this, OutputParserError);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function parseJsonMarkdown(text: string) {
|
||||
text = text.trim();
|
||||
|
||||
const beginDelimiter = "```json";
|
||||
const endDelimiter = "```";
|
||||
|
||||
const beginIndex = text.indexOf(beginDelimiter);
|
||||
const endIndex = text.indexOf(
|
||||
endDelimiter,
|
||||
beginIndex + beginDelimiter.length
|
||||
);
|
||||
if (beginIndex === -1 || endIndex === -1) {
|
||||
throw new OutputParserError("Not a json markdown", { output: text });
|
||||
}
|
||||
|
||||
const jsonText = text.substring(beginIndex + beginDelimiter.length, endIndex);
|
||||
|
||||
try {
|
||||
return JSON.parse(jsonText);
|
||||
} catch (e) {
|
||||
throw new OutputParserError("Not a valid json", {
|
||||
cause: e as Error,
|
||||
output: text,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
export class SubQuestionOutputParser
|
||||
implements BaseOutputParser<StructuredOutput<SubQuestion[]>>
|
||||
{
|
||||
parse(output: string): StructuredOutput<SubQuestion[]> {
|
||||
const parsed = parseJsonMarkdown(output);
|
||||
|
||||
// TODO add zod validation
|
||||
|
||||
return { rawOutput: output, parsedOutput: parsed };
|
||||
}
|
||||
|
||||
format(output: string): string {
|
||||
return output;
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,6 @@
|
||||
import { SubQuestion } from "./QuestionGenerator";
|
||||
import { ToolMetadata } from "./Tool";
|
||||
|
||||
/**
|
||||
* A SimplePrompt is a function that takes a dictionary of inputs and returns a string.
|
||||
* NOTE this is a different interface compared to LlamaIndex Python
|
||||
@@ -80,3 +83,186 @@ ${context}
|
||||
------------
|
||||
Given the new context, refine the original answer to better answer the question. If the context isn't useful, return the original answer.`;
|
||||
};
|
||||
|
||||
export const defaultChoiceSelectPrompt: SimplePrompt = (input) => {
|
||||
const { context = "", query = "" } = input;
|
||||
|
||||
return `A list of documents is shown below. Each document has a number next to it along
|
||||
with a summary of the document. A question is also provided.
|
||||
Respond with the numbers of the documents
|
||||
you should consult to answer the question, in order of relevance, as well
|
||||
as the relevance score. The relevance score is a number from 1-10 based on
|
||||
how relevant you think the document is to the question.
|
||||
Do not include any documents that are not relevant to the question.
|
||||
Example format:
|
||||
Document 1:
|
||||
<summary of document 1>
|
||||
|
||||
Document 2:
|
||||
<summary of document 2>
|
||||
|
||||
...
|
||||
|
||||
Document 10:\n<summary of document 10>
|
||||
|
||||
Question: <question>
|
||||
Answer:
|
||||
Doc: 9, Relevance: 7
|
||||
Doc: 3, Relevance: 4
|
||||
Doc: 7, Relevance: 3
|
||||
|
||||
Let's try this now:
|
||||
|
||||
${context}
|
||||
Question: ${query}
|
||||
Answer:`;
|
||||
};
|
||||
|
||||
/*
|
||||
PREFIX = """\
|
||||
Given a user question, and a list of tools, output a list of relevant sub-questions \
|
||||
that when composed can help answer the full user question:
|
||||
|
||||
"""
|
||||
|
||||
|
||||
example_query_str = (
|
||||
"Compare and contrast the revenue growth and EBITDA of Uber and Lyft for year 2021"
|
||||
)
|
||||
example_tools = [
|
||||
ToolMetadata(
|
||||
name="uber_10k",
|
||||
description="Provides information about Uber financials for year 2021",
|
||||
),
|
||||
ToolMetadata(
|
||||
name="lyft_10k",
|
||||
description="Provides information about Lyft financials for year 2021",
|
||||
),
|
||||
]
|
||||
example_tools_str = build_tools_text(example_tools)
|
||||
example_output = [
|
||||
SubQuestion(
|
||||
sub_question="What is the revenue growth of Uber", tool_name="uber_10k"
|
||||
),
|
||||
SubQuestion(sub_question="What is the EBITDA of Uber", tool_name="uber_10k"),
|
||||
SubQuestion(
|
||||
sub_question="What is the revenue growth of Lyft", tool_name="lyft_10k"
|
||||
),
|
||||
SubQuestion(sub_question="What is the EBITDA of Lyft", tool_name="lyft_10k"),
|
||||
]
|
||||
example_output_str = json.dumps([x.dict() for x in example_output], indent=4)
|
||||
|
||||
EXAMPLES = (
|
||||
"""\
|
||||
# Example 1
|
||||
<Tools>
|
||||
```json
|
||||
{tools_str}
|
||||
```
|
||||
|
||||
<User Question>
|
||||
{query_str}
|
||||
|
||||
|
||||
<Output>
|
||||
```json
|
||||
{output_str}
|
||||
```
|
||||
|
||||
""".format(
|
||||
query_str=example_query_str,
|
||||
tools_str=example_tools_str,
|
||||
output_str=example_output_str,
|
||||
)
|
||||
.replace("{", "{{")
|
||||
.replace("}", "}}")
|
||||
)
|
||||
|
||||
SUFFIX = """\
|
||||
# Example 2
|
||||
<Tools>
|
||||
```json
|
||||
{tools_str}
|
||||
```
|
||||
|
||||
<User Question>
|
||||
{query_str}
|
||||
|
||||
<Output>
|
||||
"""
|
||||
|
||||
DEFAULT_SUB_QUESTION_PROMPT_TMPL = PREFIX + EXAMPLES + SUFFIX
|
||||
*/
|
||||
|
||||
export function buildToolsText(tools: ToolMetadata[]) {
|
||||
const toolsObj = tools.reduce<Record<string, string>>((acc, tool) => {
|
||||
acc[tool.name] = tool.description;
|
||||
return acc;
|
||||
}, {});
|
||||
|
||||
return JSON.stringify(toolsObj, null, 4);
|
||||
}
|
||||
|
||||
const exampleTools: ToolMetadata[] = [
|
||||
{
|
||||
name: "uber_10k",
|
||||
description: "Provides information about Uber financials for year 2021",
|
||||
},
|
||||
{
|
||||
name: "lyft_10k",
|
||||
description: "Provides information about Lyft financials for year 2021",
|
||||
},
|
||||
];
|
||||
|
||||
const exampleQueryStr = `Compare and contrast the revenue growth and EBITDA of Uber and Lyft for year 2021`;
|
||||
|
||||
const exampleOutput: SubQuestion[] = [
|
||||
{
|
||||
subQuestion: "What is the revenue growth of Uber",
|
||||
toolName: "uber_10k",
|
||||
},
|
||||
{
|
||||
subQuestion: "What is the EBITDA of Uber",
|
||||
toolName: "uber_10k",
|
||||
},
|
||||
{
|
||||
subQuestion: "What is the revenue growth of Lyft",
|
||||
toolName: "lyft_10k",
|
||||
},
|
||||
{
|
||||
subQuestion: "What is the EBITDA of Lyft",
|
||||
toolName: "lyft_10k",
|
||||
},
|
||||
];
|
||||
|
||||
export const defaultSubQuestionPrompt: SimplePrompt = (input) => {
|
||||
const { toolsStr, queryStr } = input;
|
||||
|
||||
return `Given a user question, and a list of tools, output a list of relevant sub-questions that when composed can help answer the full user question:
|
||||
|
||||
# Example 1
|
||||
<Tools>
|
||||
\`\`\`json
|
||||
${buildToolsText(exampleTools)}
|
||||
\`\`\`
|
||||
|
||||
<User Question>
|
||||
${exampleQueryStr}
|
||||
|
||||
<Output>
|
||||
\`\`\`json
|
||||
${JSON.stringify(exampleOutput, null, 4)}
|
||||
\`\`\`
|
||||
|
||||
# Example 2
|
||||
<Tools>
|
||||
\`\`\`json
|
||||
${toolsStr}}
|
||||
\`\`\`
|
||||
|
||||
<User Question>
|
||||
${queryStr}
|
||||
|
||||
<Output>
|
||||
`;
|
||||
};
|
||||
|
||||
@@ -1,12 +1,20 @@
|
||||
import { NodeWithScore, TextNode } from "./Node";
|
||||
import {
|
||||
BaseQuestionGenerator,
|
||||
LLMQuestionGenerator,
|
||||
SubQuestion,
|
||||
} from "./QuestionGenerator";
|
||||
import { Response } from "./Response";
|
||||
import { ResponseSynthesizer } from "./ResponseSynthesizer";
|
||||
import { CompactAndRefine, ResponseSynthesizer } from "./ResponseSynthesizer";
|
||||
import { BaseRetriever } from "./Retriever";
|
||||
import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext";
|
||||
import { QueryEngineTool, ToolMetadata } from "./Tool";
|
||||
|
||||
export interface BaseQueryEngine {
|
||||
aquery(query: string): Promise<Response>;
|
||||
}
|
||||
|
||||
export class RetrieverQueryEngine {
|
||||
export class RetrieverQueryEngine implements BaseQueryEngine {
|
||||
retriever: BaseRetriever;
|
||||
responseSynthesizer: ResponseSynthesizer;
|
||||
|
||||
@@ -20,3 +28,77 @@ export class RetrieverQueryEngine {
|
||||
return this.responseSynthesizer.asynthesize(query, nodes);
|
||||
}
|
||||
}
|
||||
|
||||
export class SubQuestionQueryEngine implements BaseQueryEngine {
|
||||
responseSynthesizer: ResponseSynthesizer;
|
||||
questionGen: BaseQuestionGenerator;
|
||||
queryEngines: Record<string, BaseQueryEngine>;
|
||||
metadatas: ToolMetadata[];
|
||||
|
||||
constructor(init: {
|
||||
questionGen: BaseQuestionGenerator;
|
||||
responseSynthesizer: ResponseSynthesizer;
|
||||
queryEngineTools: QueryEngineTool[];
|
||||
}) {
|
||||
this.questionGen = init.questionGen;
|
||||
this.responseSynthesizer =
|
||||
init.responseSynthesizer ?? new ResponseSynthesizer();
|
||||
this.queryEngines = init.queryEngineTools.reduce<
|
||||
Record<string, BaseQueryEngine>
|
||||
>((acc, tool) => {
|
||||
acc[tool.metadata.name] = tool.queryEngine;
|
||||
return acc;
|
||||
}, {});
|
||||
this.metadatas = init.queryEngineTools.map((tool) => tool.metadata);
|
||||
}
|
||||
|
||||
static fromDefaults(init: {
|
||||
queryEngineTools: QueryEngineTool[];
|
||||
questionGen?: BaseQuestionGenerator;
|
||||
responseSynthesizer?: ResponseSynthesizer;
|
||||
serviceContext?: ServiceContext;
|
||||
}) {
|
||||
const serviceContext =
|
||||
init.serviceContext ?? serviceContextFromDefaults({});
|
||||
|
||||
const questionGen = init.questionGen ?? new LLMQuestionGenerator();
|
||||
const responseSynthesizer =
|
||||
init.responseSynthesizer ??
|
||||
new ResponseSynthesizer(new CompactAndRefine(serviceContext));
|
||||
|
||||
return new SubQuestionQueryEngine({
|
||||
questionGen,
|
||||
responseSynthesizer,
|
||||
queryEngineTools: init.queryEngineTools,
|
||||
});
|
||||
}
|
||||
|
||||
async aquery(query: string): Promise<Response> {
|
||||
const subQuestions = await this.questionGen.agenerate(
|
||||
this.metadatas,
|
||||
query
|
||||
);
|
||||
const subQNodes = await Promise.all(
|
||||
subQuestions.map((subQ) => this.aquerySubQ(subQ))
|
||||
);
|
||||
const nodes = subQNodes
|
||||
.filter((node) => node !== null)
|
||||
.map((node) => node as NodeWithScore);
|
||||
return this.responseSynthesizer.asynthesize(query, nodes);
|
||||
}
|
||||
|
||||
private async aquerySubQ(subQ: SubQuestion): Promise<NodeWithScore | null> {
|
||||
try {
|
||||
const question = subQ.subQuestion;
|
||||
const queryEngine = this.queryEngines[subQ.toolName];
|
||||
|
||||
const response = await queryEngine.aquery(question);
|
||||
const responseText = response.response;
|
||||
const nodeText = `Sub question: ${question}\nResponse: ${responseText}}`;
|
||||
const node = new TextNode({ text: nodeText });
|
||||
return { node, score: 0 };
|
||||
} catch (error) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
import { BaseLLMPredictor, ChatGPTLLMPredictor } from "./LLMPredictor";
|
||||
import {
|
||||
BaseOutputParser,
|
||||
StructuredOutput,
|
||||
SubQuestionOutputParser,
|
||||
} from "./OutputParser";
|
||||
import {
|
||||
SimplePrompt,
|
||||
buildToolsText,
|
||||
defaultSubQuestionPrompt,
|
||||
} from "./Prompt";
|
||||
import { ToolMetadata } from "./Tool";
|
||||
|
||||
export interface SubQuestion {
|
||||
subQuestion: string;
|
||||
toolName: string;
|
||||
}
|
||||
|
||||
export interface BaseQuestionGenerator {
|
||||
agenerate(tools: ToolMetadata[], query: string): Promise<SubQuestion[]>;
|
||||
}
|
||||
|
||||
export class LLMQuestionGenerator implements BaseQuestionGenerator {
|
||||
llmPredictor: BaseLLMPredictor;
|
||||
prompt: SimplePrompt;
|
||||
outputParser: BaseOutputParser<StructuredOutput<SubQuestion[]>>;
|
||||
|
||||
constructor(init?: Partial<LLMQuestionGenerator>) {
|
||||
this.llmPredictor = init?.llmPredictor ?? new ChatGPTLLMPredictor();
|
||||
this.prompt = init?.prompt ?? defaultSubQuestionPrompt;
|
||||
this.outputParser = init?.outputParser ?? new SubQuestionOutputParser();
|
||||
}
|
||||
|
||||
async agenerate(
|
||||
tools: ToolMetadata[],
|
||||
query: string
|
||||
): Promise<SubQuestion[]> {
|
||||
const toolsStr = buildToolsText(tools);
|
||||
const queryStr = query;
|
||||
const prediction = await this.llmPredictor.apredict(this.prompt, {
|
||||
toolsStr,
|
||||
queryStr,
|
||||
});
|
||||
|
||||
const structuredOutput = this.outputParser.parse(prediction);
|
||||
|
||||
return structuredOutput.parsedOutput;
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,10 @@
|
||||
import { TextNode } from "./Node";
|
||||
import { BaseNode } from "./Node";
|
||||
|
||||
export class Response {
|
||||
response?: string;
|
||||
sourceNodes: TextNode[];
|
||||
sourceNodes: BaseNode[];
|
||||
|
||||
constructor(response?: string, sourceNodes?: TextNode[]) {
|
||||
constructor(response?: string, sourceNodes?: BaseNode[]) {
|
||||
this.response = response;
|
||||
this.sourceNodes = sourceNodes || [];
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { ChatGPTLLMPredictor } from "./LLMPredictor";
|
||||
import { NodeWithScore } from "./Node";
|
||||
import { MetadataMode, NodeWithScore } from "./Node";
|
||||
import {
|
||||
SimplePrompt,
|
||||
defaultRefinePrompt,
|
||||
@@ -182,15 +182,18 @@ export function getResponseBuilder(): BaseResponseBuilder {
|
||||
return new SimpleResponseBuilder();
|
||||
}
|
||||
|
||||
// TODO replace with Logan's new response_sythesizers/factory.py
|
||||
export class ResponseSynthesizer {
|
||||
responseBuilder: BaseResponseBuilder;
|
||||
|
||||
constructor() {
|
||||
this.responseBuilder = getResponseBuilder();
|
||||
constructor(responseBuilder?: BaseResponseBuilder) {
|
||||
this.responseBuilder = responseBuilder ?? getResponseBuilder();
|
||||
}
|
||||
|
||||
async asynthesize(query: string, nodes: NodeWithScore[]) {
|
||||
let textChunks: string[] = nodes.map((node) => node.node.text);
|
||||
let textChunks: string[] = nodes.map((node) =>
|
||||
node.node.getContent(MetadataMode.NONE)
|
||||
);
|
||||
const response = await this.responseBuilder.agetResponse(query, textChunks);
|
||||
return new Response(
|
||||
response,
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
import { VectorStoreIndex } from "./BaseIndex";
|
||||
import { BaseEmbedding, getTopKEmbeddings } from "./Embedding";
|
||||
import { NodeWithScore } from "./Node";
|
||||
import { ServiceContext } from "./ServiceContext";
|
||||
import { DEFAULT_SIMILARITY_TOP_K } from "./constants";
|
||||
import {
|
||||
VectorStoreQuery,
|
||||
VectorStoreQueryMode,
|
||||
} from "./storage/vectorStore/types";
|
||||
|
||||
export interface BaseRetriever {
|
||||
aretrieve(query: string): Promise<any>;
|
||||
@@ -10,31 +14,30 @@ export interface BaseRetriever {
|
||||
export class VectorIndexRetriever implements BaseRetriever {
|
||||
index: VectorStoreIndex;
|
||||
similarityTopK = DEFAULT_SIMILARITY_TOP_K;
|
||||
embeddingService: BaseEmbedding;
|
||||
private serviceContext: ServiceContext;
|
||||
|
||||
constructor(index: VectorStoreIndex, embeddingService: BaseEmbedding) {
|
||||
constructor(index: VectorStoreIndex) {
|
||||
this.index = index;
|
||||
this.embeddingService = embeddingService;
|
||||
this.serviceContext = this.index.serviceContext;
|
||||
}
|
||||
|
||||
async aretrieve(query: string): Promise<NodeWithScore[]> {
|
||||
const queryEmbedding = await this.embeddingService.aGetQueryEmbedding(
|
||||
query
|
||||
);
|
||||
const [similarities, ids] = getTopKEmbeddings(
|
||||
queryEmbedding,
|
||||
this.index.nodes.map((node) => node.getEmbedding()),
|
||||
undefined,
|
||||
this.index.nodes.map((node) => node.id_)
|
||||
);
|
||||
const queryEmbedding =
|
||||
await this.serviceContext.embedModel.aGetQueryEmbedding(query);
|
||||
|
||||
const q: VectorStoreQuery = {
|
||||
queryEmbedding: queryEmbedding,
|
||||
mode: VectorStoreQueryMode.DEFAULT,
|
||||
similarityTopK: this.similarityTopK,
|
||||
};
|
||||
const result = this.index.vectorStore.query(q);
|
||||
|
||||
let nodesWithScores: NodeWithScore[] = [];
|
||||
|
||||
for (let i = 0; i < ids.length; i++) {
|
||||
const node = this.index.indexStruct.nodesDict[ids[i]];
|
||||
for (let i = 0; i < result.ids.length; i++) {
|
||||
const node = this.index.indexStruct.nodesDict[result.ids[i]];
|
||||
nodesWithScores.push({
|
||||
node: node,
|
||||
score: similarities[i],
|
||||
score: result.similarities[i],
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -21,15 +21,15 @@ export interface ServiceContextOptions {
|
||||
nodeParser?: NodeParser;
|
||||
// NodeParser arguments
|
||||
chunkSize?: number;
|
||||
chunkOverlap: number;
|
||||
chunkOverlap?: number;
|
||||
}
|
||||
|
||||
export function serviceContextFromDefaults(options: ServiceContextOptions) {
|
||||
export function serviceContextFromDefaults(options?: ServiceContextOptions) {
|
||||
const serviceContext: ServiceContext = {
|
||||
llmPredictor: options.llmPredictor ?? new ChatGPTLLMPredictor(),
|
||||
embedModel: options.embedModel ?? new OpenAIEmbedding(),
|
||||
nodeParser: options.nodeParser ?? new SimpleNodeParser(),
|
||||
promptHelper: options.promptHelper ?? new PromptHelper(),
|
||||
llmPredictor: options?.llmPredictor ?? new ChatGPTLLMPredictor(),
|
||||
embedModel: options?.embedModel ?? new OpenAIEmbedding(),
|
||||
nodeParser: options?.nodeParser ?? new SimpleNodeParser(),
|
||||
promptHelper: options?.promptHelper ?? new PromptHelper(),
|
||||
};
|
||||
|
||||
return serviceContext;
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
import { BaseQueryEngine } from "./QueryEngine";
|
||||
|
||||
export interface ToolMetadata {
|
||||
description: string;
|
||||
name: string;
|
||||
}
|
||||
|
||||
export interface BaseTool {
|
||||
metadata: ToolMetadata;
|
||||
}
|
||||
|
||||
export interface QueryEngineTool extends BaseTool {
|
||||
queryEngine: BaseQueryEngine;
|
||||
}
|
||||
@@ -1,35 +0,0 @@
|
||||
import _ from "lodash";
|
||||
import { VectorStoreQueryMode } from "./storage/vectorStore/types";
|
||||
|
||||
export function getTopKEmbeddings(
|
||||
queryEmbedding: number[],
|
||||
embeddings: number[][],
|
||||
similarityFn?: (queryEmbedding: number[], emb: number[]) => number,
|
||||
similarityTopK?: number,
|
||||
embeddingIds?: number[],
|
||||
similarityCutoff?: number
|
||||
): [number[], number[]] {
|
||||
throw new Error("Not implemented");
|
||||
}
|
||||
|
||||
export function getTopKEmbeddingsLearner(
|
||||
queryEmbedding: number[],
|
||||
embeddings: number[][],
|
||||
similarityTopK?: number,
|
||||
embeddingIds?: number[],
|
||||
queryMode: VectorStoreQueryMode = VectorStoreQueryMode.SVM
|
||||
): [number[], number[]] {
|
||||
throw new Error("Not implemented");
|
||||
}
|
||||
|
||||
export function getTopKMMREmbeddings(
|
||||
queryEmbedding: number[],
|
||||
embeddings: number[][],
|
||||
similarityFn?: (queryEmbedding: number[], emb: number[]) => number,
|
||||
similarityTopK?: number,
|
||||
embeddingIds?: number[],
|
||||
similarityCutoff?: number,
|
||||
mmrThreshold?: number
|
||||
): [number[], number[]] {
|
||||
throw new Error("Not implemented");
|
||||
}
|
||||
@@ -0,0 +1,166 @@
|
||||
import { BaseNode, Document } from "../../Node";
|
||||
import { BaseIndex, BaseIndexInit, IndexList } from "../../BaseIndex";
|
||||
import { BaseQueryEngine, RetrieverQueryEngine } from "../../QueryEngine";
|
||||
import {
|
||||
StorageContext,
|
||||
storageContextFromDefaults,
|
||||
} from "../../storage/StorageContext";
|
||||
import { BaseRetriever } from "../../Retriever";
|
||||
import { ListIndexRetriever } from "./ListIndexRetriever";
|
||||
import {
|
||||
ServiceContext,
|
||||
serviceContextFromDefaults,
|
||||
} from "../../ServiceContext";
|
||||
import { BaseDocumentStore, RefDocInfo } from "../../storage/docStore/types";
|
||||
import _ from "lodash";
|
||||
|
||||
export enum ListRetrieverMode {
|
||||
DEFAULT = "default",
|
||||
// EMBEDDING = "embedding",
|
||||
LLM = "llm",
|
||||
}
|
||||
|
||||
export interface ListIndexOptions {
|
||||
nodes?: BaseNode[];
|
||||
indexStruct?: IndexList;
|
||||
serviceContext?: ServiceContext;
|
||||
storageContext?: StorageContext;
|
||||
}
|
||||
|
||||
export class ListIndex extends BaseIndex<IndexList> {
|
||||
constructor(init: BaseIndexInit<IndexList>) {
|
||||
super(init);
|
||||
}
|
||||
|
||||
static async init(options: ListIndexOptions): Promise<ListIndex> {
|
||||
const storageContext =
|
||||
options.storageContext ?? (await storageContextFromDefaults({}));
|
||||
const serviceContext =
|
||||
options.serviceContext ?? serviceContextFromDefaults({});
|
||||
const { docStore, indexStore } = storageContext;
|
||||
|
||||
let indexStruct: IndexList;
|
||||
if (options.indexStruct) {
|
||||
if (options.nodes) {
|
||||
throw new Error(
|
||||
"Cannot initialize VectorStoreIndex with both nodes and indexStruct"
|
||||
);
|
||||
}
|
||||
indexStruct = options.indexStruct;
|
||||
} else {
|
||||
if (!options.nodes) {
|
||||
throw new Error(
|
||||
"Cannot initialize VectorStoreIndex without nodes or indexStruct"
|
||||
);
|
||||
}
|
||||
indexStruct = ListIndex._buildIndexFromNodes(
|
||||
options.nodes,
|
||||
storageContext.docStore
|
||||
);
|
||||
}
|
||||
|
||||
return new ListIndex({
|
||||
storageContext,
|
||||
serviceContext,
|
||||
docStore,
|
||||
indexStore,
|
||||
indexStruct,
|
||||
});
|
||||
}
|
||||
|
||||
static async fromDocuments(
|
||||
documents: Document[],
|
||||
storageContext?: StorageContext,
|
||||
serviceContext?: ServiceContext
|
||||
): Promise<ListIndex> {
|
||||
storageContext = storageContext ?? (await storageContextFromDefaults({}));
|
||||
serviceContext = serviceContext ?? serviceContextFromDefaults({});
|
||||
const docStore = storageContext.docStore;
|
||||
|
||||
docStore.addDocuments(documents, true);
|
||||
for (const doc of documents) {
|
||||
docStore.setDocumentHash(doc.id_, doc.hash);
|
||||
}
|
||||
|
||||
const nodes = serviceContext.nodeParser.getNodesFromDocuments(documents);
|
||||
const index = await ListIndex.init({
|
||||
nodes,
|
||||
storageContext,
|
||||
serviceContext,
|
||||
});
|
||||
return index;
|
||||
}
|
||||
|
||||
asRetriever(
|
||||
mode: ListRetrieverMode = ListRetrieverMode.DEFAULT
|
||||
): BaseRetriever {
|
||||
switch (mode) {
|
||||
case ListRetrieverMode.DEFAULT:
|
||||
return new ListIndexRetriever(this);
|
||||
case ListRetrieverMode.LLM:
|
||||
throw new Error(`Support for LLM retriever mode is not implemented`);
|
||||
default:
|
||||
throw new Error(`Unknown retriever mode: ${mode}`);
|
||||
}
|
||||
}
|
||||
|
||||
asQueryEngine(
|
||||
mode: ListRetrieverMode = ListRetrieverMode.DEFAULT
|
||||
): BaseQueryEngine {
|
||||
return new RetrieverQueryEngine(this.asRetriever());
|
||||
}
|
||||
|
||||
static _buildIndexFromNodes(
|
||||
nodes: BaseNode[],
|
||||
docStore: BaseDocumentStore,
|
||||
indexStruct?: IndexList
|
||||
): IndexList {
|
||||
indexStruct = indexStruct || new IndexList();
|
||||
|
||||
docStore.addDocuments(nodes, true);
|
||||
for (const node of nodes) {
|
||||
indexStruct.addNode(node);
|
||||
}
|
||||
|
||||
return indexStruct;
|
||||
}
|
||||
|
||||
protected _insert(nodes: BaseNode[]): void {
|
||||
for (const node of nodes) {
|
||||
this.indexStruct.addNode(node);
|
||||
}
|
||||
}
|
||||
|
||||
protected _deleteNode(nodeId: string): void {
|
||||
this.indexStruct.nodes = this.indexStruct.nodes.filter(
|
||||
(existingNodeId: string) => existingNodeId !== nodeId
|
||||
);
|
||||
}
|
||||
|
||||
async getRefDocInfo(): Promise<Record<string, RefDocInfo>> {
|
||||
const nodeDocIds = this.indexStruct.nodes;
|
||||
const nodes = await this.docStore.getNodes(nodeDocIds);
|
||||
|
||||
const refDocInfoMap: Record<string, RefDocInfo> = {};
|
||||
|
||||
for (const node of nodes) {
|
||||
const refNode = node.sourceNode;
|
||||
if (_.isNil(refNode)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const refDocInfo = await this.docStore.getRefDocInfo(refNode.nodeId);
|
||||
|
||||
if (_.isNil(refDocInfo)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
refDocInfoMap[refNode.nodeId] = refDocInfo;
|
||||
}
|
||||
|
||||
return refDocInfoMap;
|
||||
}
|
||||
}
|
||||
|
||||
// Legacy
|
||||
export type GPTListIndex = ListIndex;
|
||||
@@ -0,0 +1,96 @@
|
||||
import { BaseRetriever } from "../../Retriever";
|
||||
import { NodeWithScore } from "../../Node";
|
||||
import { ListIndex } from "./ListIndex";
|
||||
import { ServiceContext } from "../../ServiceContext";
|
||||
import {
|
||||
NodeFormatterFunction,
|
||||
ChoiceSelectParserFunction,
|
||||
defaultFormatNodeBatchFn,
|
||||
defaultParseChoiceSelectAnswerFn,
|
||||
} from "./utils";
|
||||
import { SimplePrompt, defaultChoiceSelectPrompt } from "../../Prompt";
|
||||
import _ from "lodash";
|
||||
|
||||
/**
|
||||
* Simple retriever for ListIndex that returns all nodes
|
||||
*/
|
||||
export class ListIndexRetriever implements BaseRetriever {
|
||||
index: ListIndex;
|
||||
|
||||
constructor(index: ListIndex) {
|
||||
this.index = index;
|
||||
}
|
||||
|
||||
async aretrieve(query: string): Promise<NodeWithScore[]> {
|
||||
const nodeIds = this.index.indexStruct.nodes;
|
||||
const nodes = await this.index.docStore.getNodes(nodeIds);
|
||||
return nodes.map((node) => ({
|
||||
node: node,
|
||||
score: 1,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* LLM retriever for ListIndex.
|
||||
*/
|
||||
export class ListIndexLLMRetriever implements BaseRetriever {
|
||||
index: ListIndex;
|
||||
choiceSelectPrompt: SimplePrompt;
|
||||
choiceBatchSize: number;
|
||||
formatNodeBatchFn: NodeFormatterFunction;
|
||||
parseChoiceSelectAnswerFn: ChoiceSelectParserFunction;
|
||||
serviceContext: ServiceContext;
|
||||
|
||||
constructor(
|
||||
index: ListIndex,
|
||||
choiceSelectPrompt?: SimplePrompt,
|
||||
choiceBatchSize: number = 10,
|
||||
formatNodeBatchFn?: NodeFormatterFunction,
|
||||
parseChoiceSelectAnswerFn?: ChoiceSelectParserFunction,
|
||||
serviceContext?: ServiceContext
|
||||
) {
|
||||
this.index = index;
|
||||
this.choiceSelectPrompt = choiceSelectPrompt || defaultChoiceSelectPrompt;
|
||||
this.choiceBatchSize = choiceBatchSize;
|
||||
this.formatNodeBatchFn = formatNodeBatchFn || defaultFormatNodeBatchFn;
|
||||
this.parseChoiceSelectAnswerFn =
|
||||
parseChoiceSelectAnswerFn || defaultParseChoiceSelectAnswerFn;
|
||||
this.serviceContext = serviceContext || index.serviceContext;
|
||||
}
|
||||
|
||||
async aretrieve(query: string): Promise<NodeWithScore[]> {
|
||||
const nodeIds = this.index.indexStruct.nodes;
|
||||
const results: NodeWithScore[] = [];
|
||||
|
||||
for (let idx = 0; idx < nodeIds.length; idx += this.choiceBatchSize) {
|
||||
const nodeIdsBatch = nodeIds.slice(idx, idx + this.choiceBatchSize);
|
||||
const nodesBatch = await this.index.docStore.getNodes(nodeIdsBatch);
|
||||
|
||||
const fmtBatchStr = this.formatNodeBatchFn(nodesBatch);
|
||||
const input = { context: fmtBatchStr, query: query };
|
||||
const rawResponse = await this.serviceContext.llmPredictor.apredict(
|
||||
this.choiceSelectPrompt,
|
||||
input
|
||||
);
|
||||
|
||||
// parseResult is a map from doc number to relevance score
|
||||
const parseResult = this.parseChoiceSelectAnswerFn(
|
||||
rawResponse,
|
||||
nodesBatch.length
|
||||
);
|
||||
const choiceNodeIds = nodeIdsBatch.filter((nodeId, idx) => {
|
||||
return `${idx}` in parseResult;
|
||||
});
|
||||
|
||||
const choiceNodes = await this.index.docStore.getNodes(choiceNodeIds);
|
||||
const nodeWithScores = choiceNodes.map((node, i) => ({
|
||||
node: node,
|
||||
score: _.get(parseResult, `${i + 1}`, 1),
|
||||
}));
|
||||
|
||||
results.push(...nodeWithScores);
|
||||
}
|
||||
return results;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
export { ListIndex, ListRetrieverMode } from "./ListIndex";
|
||||
export {
|
||||
ListIndexRetriever,
|
||||
ListIndexLLMRetriever,
|
||||
} from "./ListIndexRetriever";
|
||||
@@ -0,0 +1,73 @@
|
||||
import { BaseNode, MetadataMode } from "../../Node";
|
||||
import _ from "lodash";
|
||||
|
||||
export type NodeFormatterFunction = (summaryNodes: BaseNode[]) => string;
|
||||
export const defaultFormatNodeBatchFn: NodeFormatterFunction = (
|
||||
summaryNodes: BaseNode[]
|
||||
): string => {
|
||||
return summaryNodes
|
||||
.map((node, idx) => {
|
||||
return `
|
||||
Document ${idx + 1}:
|
||||
${node.getContent(MetadataMode.LLM)}
|
||||
`.trim();
|
||||
})
|
||||
.join("\n\n");
|
||||
};
|
||||
|
||||
// map from document number to its relevance score
|
||||
export type ChoiceSelectParseResult = { [docNumber: number]: number };
|
||||
export type ChoiceSelectParserFunction = (
|
||||
answer: string,
|
||||
numChoices: number,
|
||||
raiseErr?: boolean
|
||||
) => ChoiceSelectParseResult;
|
||||
|
||||
export const defaultParseChoiceSelectAnswerFn: ChoiceSelectParserFunction = (
|
||||
answer: string,
|
||||
numChoices: number,
|
||||
raiseErr: boolean = false
|
||||
): ChoiceSelectParseResult => {
|
||||
// split the line into the answer number and relevance score portions
|
||||
const lineTokens: string[][] = answer
|
||||
.split("\n")
|
||||
.map((line: string) => {
|
||||
let lineTokens = line.split(",");
|
||||
if (lineTokens.length !== 2) {
|
||||
if (raiseErr) {
|
||||
throw new Error(
|
||||
`Invalid answer line: ${line}. Answer line must be of the form: answer_num: <int>, answer_relevance: <float>`
|
||||
);
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
return lineTokens;
|
||||
})
|
||||
.filter((lineTokens) => !_.isNil(lineTokens)) as string[][];
|
||||
|
||||
// parse the answer number and relevance score
|
||||
return lineTokens.reduce(
|
||||
(parseResult: ChoiceSelectParseResult, lineToken: string[]) => {
|
||||
try {
|
||||
let docNum = parseInt(lineToken[0].split(":")[1].trim());
|
||||
let answerRelevance = parseFloat(lineToken[1].split(":")[1].trim());
|
||||
if (docNum < 1 || docNum > numChoices) {
|
||||
if (raiseErr) {
|
||||
throw new Error(
|
||||
`Invalid answer number: ${docNum}. Answer number must be between 1 and ${numChoices}`
|
||||
);
|
||||
} else {
|
||||
parseResult[docNum] = answerRelevance;
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
if (raiseErr) {
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
return parseResult;
|
||||
},
|
||||
{}
|
||||
);
|
||||
};
|
||||
@@ -12,9 +12,9 @@ import {
|
||||
} from "./constants";
|
||||
|
||||
export interface StorageContext {
|
||||
docStore?: BaseDocumentStore;
|
||||
indexStore?: BaseIndexStore;
|
||||
vectorStore?: VectorStore;
|
||||
docStore: BaseDocumentStore;
|
||||
indexStore: BaseIndexStore;
|
||||
vectorStore: VectorStore;
|
||||
}
|
||||
|
||||
type BuilderParams = {
|
||||
@@ -32,21 +32,24 @@ export async function storageContextFromDefaults({
|
||||
persistDir,
|
||||
fs,
|
||||
}: BuilderParams): Promise<StorageContext> {
|
||||
persistDir = persistDir || DEFAULT_PERSIST_DIR;
|
||||
|
||||
fs = fs || DEFAULT_FS;
|
||||
|
||||
docStore =
|
||||
docStore ||
|
||||
(await SimpleDocumentStore.fromPersistDir(
|
||||
persistDir,
|
||||
DEFAULT_NAMESPACE,
|
||||
fs
|
||||
));
|
||||
indexStore =
|
||||
indexStore || (await SimpleIndexStore.fromPersistDir(persistDir, fs));
|
||||
vectorStore =
|
||||
vectorStore || (await SimpleVectorStore.fromPersistDir(persistDir, fs));
|
||||
if (!persistDir) {
|
||||
docStore = docStore || new SimpleDocumentStore();
|
||||
indexStore = indexStore || new SimpleIndexStore();
|
||||
vectorStore = vectorStore || new SimpleVectorStore();
|
||||
} else {
|
||||
fs = fs || DEFAULT_FS;
|
||||
docStore =
|
||||
docStore ||
|
||||
(await SimpleDocumentStore.fromPersistDir(
|
||||
persistDir,
|
||||
DEFAULT_NAMESPACE,
|
||||
fs
|
||||
));
|
||||
indexStore =
|
||||
indexStore || (await SimpleIndexStore.fromPersistDir(persistDir, fs));
|
||||
vectorStore =
|
||||
vectorStore || (await SimpleVectorStore.fromPersistDir(persistDir, fs));
|
||||
}
|
||||
|
||||
return {
|
||||
docStore,
|
||||
|
||||
@@ -77,7 +77,7 @@ export class KVDocumentStore extends BaseDocumentStore {
|
||||
let json = await this.kvstore.get(docId, this.nodeCollection);
|
||||
if (_.isNil(json)) {
|
||||
if (raiseError) {
|
||||
throw new Error(`doc_id ${docId} not found.`);
|
||||
throw new Error(`docId ${docId} not found.`);
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -23,12 +23,11 @@ export function jsonToDoc(docDict: Record<string, any>): BaseNode {
|
||||
hash: dataDict.hash,
|
||||
});
|
||||
} else if (docType === ObjectType.TEXT) {
|
||||
const relationships = dataDict.relationships;
|
||||
console.log({ dataDict });
|
||||
doc = new TextNode({
|
||||
text: relationships.text,
|
||||
id_: relationships.id_,
|
||||
embedding: relationships.embedding,
|
||||
hash: relationships.hash,
|
||||
text: dataDict.text,
|
||||
id_: dataDict.id_,
|
||||
hash: dataDict.hash,
|
||||
});
|
||||
} else {
|
||||
throw new Error(`Unknown doc type: ${docType}`);
|
||||
|
||||
@@ -12,7 +12,7 @@ import {
|
||||
getTopKMMREmbeddings,
|
||||
} from "../../Embedding";
|
||||
import { DEFAULT_PERSIST_DIR, DEFAULT_FS } from "../constants";
|
||||
import { TextNode } from "../../Node";
|
||||
import { NodeWithEmbedding } from "../../Node";
|
||||
|
||||
const LEARNER_MODES = new Set<VectorStoreQueryMode>([
|
||||
VectorStoreQueryMode.SVM,
|
||||
@@ -53,18 +53,19 @@ export class SimpleVectorStore implements VectorStore {
|
||||
return this.data.embeddingDict[textId];
|
||||
}
|
||||
|
||||
add(embeddingResults: TextNode[]): string[] {
|
||||
add(embeddingResults: NodeWithEmbedding[]): string[] {
|
||||
for (let result of embeddingResults) {
|
||||
this.data.embeddingDict[result.id_] = result.getEmbedding();
|
||||
this.data.embeddingDict[result.node.id_] = result.embedding;
|
||||
|
||||
if (!result.sourceNode) {
|
||||
if (!result.node.sourceNode) {
|
||||
console.error("Missing source node from TextNode.");
|
||||
continue;
|
||||
}
|
||||
|
||||
this.data.textIdToRefDocId[result.id_] = result.sourceNode?.nodeId;
|
||||
this.data.textIdToRefDocId[result.node.id_] =
|
||||
result.node.sourceNode?.nodeId;
|
||||
}
|
||||
return embeddingResults.map((result) => result.id_);
|
||||
return embeddingResults.map((result) => result.node.id_);
|
||||
}
|
||||
|
||||
delete(refDocId: string): void {
|
||||
|
||||
@@ -1,18 +1,11 @@
|
||||
import { TextNode } from "../../Node";
|
||||
import { BaseNode } from "../../Node";
|
||||
import { GenericFileSystem } from "../FileSystem";
|
||||
|
||||
export interface NodeWithEmbedding {
|
||||
node: TextNode;
|
||||
embedding: number[];
|
||||
|
||||
id(): string;
|
||||
refDocId(): string;
|
||||
}
|
||||
import { NodeWithEmbedding } from "../../Node";
|
||||
|
||||
export interface VectorStoreQueryResult {
|
||||
nodes?: TextNode[];
|
||||
similarities?: number[];
|
||||
ids?: string[];
|
||||
nodes?: BaseNode[];
|
||||
similarities: number[];
|
||||
ids: string[];
|
||||
}
|
||||
|
||||
export enum VectorStoreQueryMode {
|
||||
@@ -68,7 +61,7 @@ export interface VectorStore {
|
||||
storesText: boolean;
|
||||
isEmbeddingQuery?: boolean;
|
||||
client(): any;
|
||||
add(embeddingResults: TextNode[]): string[];
|
||||
add(embeddingResults: NodeWithEmbedding[]): string[];
|
||||
delete(refDocId: string, deleteKwargs?: any): void;
|
||||
query(query: VectorStoreQuery, kwargs?: any): VectorStoreQueryResult;
|
||||
persist(persistPath: string, fs?: GenericFileSystem): void;
|
||||
|
||||
Generated
+22
-22
@@ -36,8 +36,8 @@ importers:
|
||||
specifier: ^29.1.0
|
||||
version: 29.1.0(@babel/core@7.22.5)(jest@29.5.0)(typescript@4.9.5)
|
||||
turbo:
|
||||
specifier: ^1.10.5
|
||||
version: 1.10.5
|
||||
specifier: ^1.10.7
|
||||
version: 1.10.7
|
||||
|
||||
apps/docs:
|
||||
dependencies:
|
||||
@@ -5444,65 +5444,65 @@ packages:
|
||||
resolution: {integrity: sha512-C3TaO7K81YvjCgQH9Q1S3R3P3BtN3RIM8n+OvX4il1K1zgE8ZhI0op7kClgkxtutIE8hQrcrHBXvIheqKUUCxw==}
|
||||
dev: true
|
||||
|
||||
/turbo-darwin-64@1.10.5:
|
||||
resolution: {integrity: sha512-fIHu+fcW7upaZEfeneoRbZjdrcsj/NxUg7IjZZmlCjgbS9Ofl8RhRid5A1L31AUK3kkqRxzagHc4WZ5x4quBgg==}
|
||||
/turbo-darwin-64@1.10.7:
|
||||
resolution: {integrity: sha512-N2MNuhwrl6g7vGuz4y3fFG2aR1oCs0UZ5HKl8KSTn/VC2y2YIuLGedQ3OVbo0TfEvygAlF3QGAAKKtOCmGPNKA==}
|
||||
cpu: [x64]
|
||||
os: [darwin]
|
||||
requiresBuild: true
|
||||
dev: true
|
||||
optional: true
|
||||
|
||||
/turbo-darwin-arm64@1.10.5:
|
||||
resolution: {integrity: sha512-uv0sDWizuxVvdSjaKvWdPdX4aZ8IZeYJwTJRZwLNRxZV56/1LZD65gyQIqsSNVRHuXI199yahmB+7PMJNpZFdw==}
|
||||
/turbo-darwin-arm64@1.10.7:
|
||||
resolution: {integrity: sha512-WbJkvjU+6qkngp7K4EsswOriO3xrNQag7YEGRtfLoDdMTk4O4QTeU6sfg2dKfDsBpTidTvEDwgIYJhYVGzrz9Q==}
|
||||
cpu: [arm64]
|
||||
os: [darwin]
|
||||
requiresBuild: true
|
||||
dev: true
|
||||
optional: true
|
||||
|
||||
/turbo-linux-64@1.10.5:
|
||||
resolution: {integrity: sha512-hI0rErgwxNmuBCNGldhJkjSbb+mT+vjfmBVKcMI/bnBmu/KU7irCrKMe5Vas280teqBrC33GgVfXndJo2cJ1DA==}
|
||||
/turbo-linux-64@1.10.7:
|
||||
resolution: {integrity: sha512-x1CF2CDP1pDz/J8/B2T0hnmmOQI2+y11JGIzNP0KtwxDM7rmeg3DDTtDM/9PwGqfPotN9iVGgMiMvBuMFbsLhg==}
|
||||
cpu: [x64]
|
||||
os: [linux]
|
||||
requiresBuild: true
|
||||
dev: true
|
||||
optional: true
|
||||
|
||||
/turbo-linux-arm64@1.10.5:
|
||||
resolution: {integrity: sha512-JAygWZjTuD6e7w0KSGzy7UxYqeLIpGfZDne+4MGRc8I5VeWZ6i0HWTqhhIu2/A8AuklYcoj8LkOZxCnMOF3odQ==}
|
||||
/turbo-linux-arm64@1.10.7:
|
||||
resolution: {integrity: sha512-JtnBmaBSYbs7peJPkXzXxsRGSGBmBEIb6/kC8RRmyvPAMyqF8wIex0pttsI+9plghREiGPtRWv/lfQEPRlXnNQ==}
|
||||
cpu: [arm64]
|
||||
os: [linux]
|
||||
requiresBuild: true
|
||||
dev: true
|
||||
optional: true
|
||||
|
||||
/turbo-windows-64@1.10.5:
|
||||
resolution: {integrity: sha512-6w2GOKmlWEAl6QkC4c2j2ZLTwB+RK6oIDRT2KqF1m07KkY6pebEzbPZLHuP08QV+SE0t+prAn+kn7hkHYkwM+Q==}
|
||||
/turbo-windows-64@1.10.7:
|
||||
resolution: {integrity: sha512-7A/4CByoHdolWS8dg3DPm99owfu1aY/W0V0+KxFd0o2JQMTQtoBgIMSvZesXaWM57z3OLsietFivDLQPuzE75w==}
|
||||
cpu: [x64]
|
||||
os: [win32]
|
||||
requiresBuild: true
|
||||
dev: true
|
||||
optional: true
|
||||
|
||||
/turbo-windows-arm64@1.10.5:
|
||||
resolution: {integrity: sha512-3eeHRJPU+5zWa/iiikoBoPlNd74Y+L9lrG6ZsDZdzUYxNRTMrZbto1Bu1UF77t10TXeT9BsZRXjquKqrA7R7tg==}
|
||||
/turbo-windows-arm64@1.10.7:
|
||||
resolution: {integrity: sha512-D36K/3b6+hqm9IBAymnuVgyePktwQ+F0lSXr2B9JfAdFPBktSqGmp50JNC7pahxhnuCLj0Vdpe9RqfnJw5zATA==}
|
||||
cpu: [arm64]
|
||||
os: [win32]
|
||||
requiresBuild: true
|
||||
dev: true
|
||||
optional: true
|
||||
|
||||
/turbo@1.10.5:
|
||||
resolution: {integrity: sha512-4yxHTrlugJhef4eXuyrPJtrgUZWlbcwmSb8iZL/5UzNjCmx+anOm1nfW2XFrZFKy4v0+/fUlqw8LkTgGVsOKaQ==}
|
||||
/turbo@1.10.7:
|
||||
resolution: {integrity: sha512-xm0MPM28TWx1e6TNC3wokfE5eaDqlfi0G24kmeHupDUZt5Wd0OzHFENEHMPqEaNKJ0I+AMObL6nbSZonZBV2HA==}
|
||||
hasBin: true
|
||||
requiresBuild: true
|
||||
optionalDependencies:
|
||||
turbo-darwin-64: 1.10.5
|
||||
turbo-darwin-arm64: 1.10.5
|
||||
turbo-linux-64: 1.10.5
|
||||
turbo-linux-arm64: 1.10.5
|
||||
turbo-windows-64: 1.10.5
|
||||
turbo-windows-arm64: 1.10.5
|
||||
turbo-darwin-64: 1.10.7
|
||||
turbo-darwin-arm64: 1.10.7
|
||||
turbo-linux-64: 1.10.7
|
||||
turbo-linux-arm64: 1.10.7
|
||||
turbo-windows-64: 1.10.7
|
||||
turbo-windows-arm64: 1.10.7
|
||||
dev: true
|
||||
|
||||
/type-check@0.4.0:
|
||||
|
||||
Reference in New Issue
Block a user