mirror of
https://github.com/run-llama/LlamaIndexTS.git
synced 2026-07-01 22:14:03 -04:00
Compare commits
30 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0bec460937 | |||
| 5d8d344e1f | |||
| 81e22587eb | |||
| 539ec0fe3d | |||
| ebf3bc19fd | |||
| 204b8f5316 | |||
| 71dd461a47 | |||
| 0c881c8fde | |||
| 35a6795559 | |||
| 968109455d | |||
| 815a3416f2 | |||
| 407069ca27 | |||
| 29d042175e | |||
| bbf936e9b4 | |||
| 2212793420 | |||
| 5487de8c37 | |||
| 2a038c00ec | |||
| d6c6aefd0d | |||
| 4516363097 | |||
| 69dd6d4efa | |||
| 9ea840142b | |||
| a1c45294b3 | |||
| c2ef5057b3 | |||
| b87e6d9ced | |||
| 8d618a6bc3 | |||
| 8d8bee5263 | |||
| ed924641ca | |||
| ce61f9660b | |||
| 072b13cff0 | |||
| ff274dde1d |
@@ -34,3 +34,5 @@ yarn-error.log*
|
||||
|
||||
# vercel
|
||||
.vercel
|
||||
|
||||
storage/
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
// @ts-ignore
|
||||
import * as readline from "node:readline/promises";
|
||||
// @ts-ignore
|
||||
import { stdin as input, stdout as output } from "node:process";
|
||||
import { Document } from "@llamaindex/core/src/Node";
|
||||
import { VectorStoreIndex } from "@llamaindex/core/src/BaseIndex";
|
||||
import { ContextChatEngine } from "@llamaindex/core/src/ChatEngine";
|
||||
import essay from "./essay";
|
||||
import { serviceContextFromDefaults } from "@llamaindex/core/src/ServiceContext";
|
||||
|
||||
async function main() {
|
||||
const document = new Document({ text: essay });
|
||||
const serviceContext = serviceContextFromDefaults({ chunkSize: 512 });
|
||||
const index = await VectorStoreIndex.fromDocuments(
|
||||
[document],
|
||||
undefined,
|
||||
serviceContext
|
||||
);
|
||||
const retriever = index.asRetriever();
|
||||
retriever.similarityTopK = 5;
|
||||
const chatEngine = new ContextChatEngine({ retriever });
|
||||
const rl = readline.createInterface({ input, output });
|
||||
|
||||
while (true) {
|
||||
const query = await rl.question("Query: ");
|
||||
const response = await chatEngine.achat(query);
|
||||
console.log(response);
|
||||
}
|
||||
}
|
||||
|
||||
main().catch(console.error);
|
||||
@@ -0,0 +1,17 @@
|
||||
import { Document } from "@llamaindex/core/src/Node";
|
||||
import { ListIndex } from "@llamaindex/core/src/index/list";
|
||||
import essay from "./essay";
|
||||
|
||||
async function main() {
|
||||
const document = new Document({ text: essay });
|
||||
const index = await ListIndex.fromDocuments([document]);
|
||||
const queryEngine = index.asQueryEngine();
|
||||
const response = await queryEngine.aquery(
|
||||
"What did the author do growing up?"
|
||||
);
|
||||
console.log(response.toString());
|
||||
}
|
||||
|
||||
main().catch((e: Error) => {
|
||||
console.error(e, e.stack);
|
||||
});
|
||||
@@ -1,3 +1,4 @@
|
||||
// @ts-ignore
|
||||
import process from "node:process";
|
||||
import { Configuration, OpenAIWrapper } from "@llamaindex/core/src/openai";
|
||||
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
Simple flow:
|
||||
|
||||
Get document list, in this case one document.
|
||||
Split each document into nodes, in this case sentences or lines.
|
||||
Embed each of the nodes and get vectors. Store them in memory for now.
|
||||
Embed query.
|
||||
Compare query with nodes and get the top n
|
||||
Put the top n nodes into the prompt.
|
||||
Execute prompt, get result.
|
||||
@@ -0,0 +1,60 @@
|
||||
// from llama_index import SimpleDirectoryReader, VectorStoreIndex
|
||||
// from llama_index.query_engine import SubQuestionQueryEngine
|
||||
// from llama_index.tools import QueryEngineTool, ToolMetadata
|
||||
|
||||
// # load data
|
||||
// pg_essay = SimpleDirectoryReader(
|
||||
// input_dir="docs/examples/data/paul_graham/"
|
||||
// ).load_data()
|
||||
|
||||
// # build index and query engine
|
||||
// query_engine = VectorStoreIndex.from_documents(pg_essay).as_query_engine()
|
||||
|
||||
// # setup base query engine as tool
|
||||
// query_engine_tools = [
|
||||
// QueryEngineTool(
|
||||
// query_engine=query_engine,
|
||||
// metadata=ToolMetadata(
|
||||
// name="pg_essay", description="Paul Graham essay on What I Worked On"
|
||||
// ),
|
||||
// )
|
||||
// ]
|
||||
|
||||
// query_engine = SubQuestionQueryEngine.from_defaults(
|
||||
// query_engine_tools=query_engine_tools
|
||||
// )
|
||||
|
||||
// response = query_engine.query(
|
||||
// "How was Paul Grahams life different before and after YC?"
|
||||
// )
|
||||
|
||||
// print(response)
|
||||
|
||||
import { Document } from "@llamaindex/core/src/Node";
|
||||
import { VectorStoreIndex } from "@llamaindex/core/src/BaseIndex";
|
||||
import { SubQuestionQueryEngine } from "@llamaindex/core/src/QueryEngine";
|
||||
|
||||
import essay from "./essay";
|
||||
|
||||
(async () => {
|
||||
const document = new Document({ text: essay });
|
||||
const index = await VectorStoreIndex.fromDocuments([document]);
|
||||
|
||||
const queryEngine = SubQuestionQueryEngine.fromDefaults({
|
||||
queryEngineTools: [
|
||||
{
|
||||
queryEngine: index.asQueryEngine(),
|
||||
metadata: {
|
||||
name: "pg_essay",
|
||||
description: "Paul Graham essay on What I Worked On",
|
||||
},
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
const response = await queryEngine.aquery(
|
||||
"How was Paul Grahams life different before and after YC?"
|
||||
);
|
||||
|
||||
console.log(response);
|
||||
})();
|
||||
@@ -2,7 +2,7 @@ import { Document } from "@llamaindex/core/src/Node";
|
||||
import { VectorStoreIndex } from "@llamaindex/core/src/BaseIndex";
|
||||
import essay from "./essay";
|
||||
|
||||
(async () => {
|
||||
async function main() {
|
||||
const document = new Document({ text: essay });
|
||||
const index = await VectorStoreIndex.fromDocuments([document]);
|
||||
const queryEngine = index.asQueryEngine();
|
||||
@@ -10,4 +10,6 @@ import essay from "./essay";
|
||||
"What did the author do growing up?"
|
||||
);
|
||||
console.log(response.toString());
|
||||
})();
|
||||
}
|
||||
|
||||
main().catch(console.error);
|
||||
+1
-1
@@ -18,7 +18,7 @@
|
||||
"prettier": "^2.8.8",
|
||||
"prettier-plugin-tailwindcss": "^0.3.0",
|
||||
"ts-jest": "^29.1.0",
|
||||
"turbo": "^1.10.5"
|
||||
"turbo": "^1.10.6"
|
||||
},
|
||||
"packageManager": "pnpm@7.15.0",
|
||||
"name": "llamascript"
|
||||
|
||||
@@ -10,6 +10,9 @@
|
||||
"uuid": "^9.0.0",
|
||||
"wink-nlp": "^1.14.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=18.0.0"
|
||||
},
|
||||
"main": "src/index.ts",
|
||||
"types": "src/index.ts",
|
||||
"scripts": {
|
||||
|
||||
+165
-42
@@ -1,22 +1,19 @@
|
||||
import { Document, TextNode } from "./Node";
|
||||
import { SimpleNodeParser } from "./NodeParser";
|
||||
import { Document, BaseNode, MetadataMode, NodeWithEmbedding } from "./Node";
|
||||
import { BaseQueryEngine, RetrieverQueryEngine } from "./QueryEngine";
|
||||
import { v4 as uuidv4 } from "uuid";
|
||||
import { VectorIndexRetriever } from "./Retriever";
|
||||
import { BaseEmbedding, OpenAIEmbedding } from "./Embedding";
|
||||
export class BaseIndex {
|
||||
nodes: TextNode[] = [];
|
||||
import { BaseRetriever, VectorIndexRetriever } from "./Retriever";
|
||||
import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext";
|
||||
import {
|
||||
StorageContext,
|
||||
storageContextFromDefaults,
|
||||
} from "./storage/StorageContext";
|
||||
import { BaseDocumentStore } from "./storage/docStore/types";
|
||||
import { VectorStore } from "./storage/vectorStore/types";
|
||||
import { BaseIndexStore } from "./storage/indexStore/types";
|
||||
|
||||
constructor(nodes?: TextNode[]) {
|
||||
this.nodes = nodes ?? [];
|
||||
}
|
||||
}
|
||||
|
||||
export class IndexDict {
|
||||
export abstract class IndexStruct {
|
||||
indexId: string;
|
||||
summary?: string;
|
||||
nodesDict: Record<string, TextNode> = {};
|
||||
docStore: Record<string, Document> = {}; // FIXME: this should be implemented in storageContext
|
||||
|
||||
constructor(indexId = uuidv4(), summary = undefined) {
|
||||
this.indexId = indexId;
|
||||
@@ -29,57 +26,183 @@ export class IndexDict {
|
||||
}
|
||||
return this.summary;
|
||||
}
|
||||
}
|
||||
|
||||
addNode(node: TextNode, textId?: string) {
|
||||
export class IndexDict extends IndexStruct {
|
||||
nodesDict: Record<string, BaseNode> = {};
|
||||
docStore: Record<string, Document> = {}; // FIXME: this should be implemented in storageContext
|
||||
|
||||
getSummary(): string {
|
||||
if (this.summary === undefined) {
|
||||
throw new Error("summary field of the index dict is not set");
|
||||
}
|
||||
return this.summary;
|
||||
}
|
||||
|
||||
addNode(node: BaseNode, textId?: string) {
|
||||
const vectorId = textId ?? node.id_;
|
||||
this.nodesDict[vectorId] = node;
|
||||
}
|
||||
}
|
||||
|
||||
export class VectorStoreIndex extends BaseIndex {
|
||||
indexStruct: IndexDict;
|
||||
embeddingService: BaseEmbedding; // FIXME replace with service context
|
||||
export class IndexList extends IndexStruct {
|
||||
nodes: string[] = [];
|
||||
|
||||
constructor(nodes: TextNode[]) {
|
||||
super(nodes);
|
||||
this.indexStruct = new IndexDict();
|
||||
addNode(node: BaseNode) {
|
||||
this.nodes.push(node.id_);
|
||||
}
|
||||
}
|
||||
|
||||
if (nodes !== undefined) {
|
||||
this.buildIndexFromNodes();
|
||||
}
|
||||
export interface BaseIndexInit<T> {
|
||||
serviceContext: ServiceContext;
|
||||
storageContext: StorageContext;
|
||||
docStore: BaseDocumentStore;
|
||||
vectorStore?: VectorStore;
|
||||
indexStore?: BaseIndexStore;
|
||||
indexStruct: T;
|
||||
}
|
||||
export abstract class BaseIndex<T> {
|
||||
serviceContext: ServiceContext;
|
||||
storageContext: StorageContext;
|
||||
docStore: BaseDocumentStore;
|
||||
vectorStore?: VectorStore;
|
||||
indexStore?: BaseIndexStore;
|
||||
indexStruct: T;
|
||||
|
||||
this.embeddingService = new OpenAIEmbedding();
|
||||
constructor(init: BaseIndexInit<T>) {
|
||||
this.serviceContext = init.serviceContext;
|
||||
this.storageContext = init.storageContext;
|
||||
this.docStore = init.docStore;
|
||||
this.vectorStore = init.vectorStore;
|
||||
this.indexStore = init.indexStore;
|
||||
this.indexStruct = init.indexStruct;
|
||||
}
|
||||
|
||||
async getNodeEmbeddingResults(logProgress = false) {
|
||||
for (let i = 0; i < this.nodes.length; ++i) {
|
||||
const node = this.nodes[i];
|
||||
if (logProgress) {
|
||||
console.log(`getting embedding for node ${i}/${this.nodes.length}`);
|
||||
abstract asRetriever(): BaseRetriever;
|
||||
}
|
||||
|
||||
export interface VectorIndexOptions {
|
||||
nodes?: BaseNode[];
|
||||
indexStruct?: IndexDict;
|
||||
serviceContext?: ServiceContext;
|
||||
storageContext?: StorageContext;
|
||||
}
|
||||
|
||||
interface VectorIndexConstructorProps extends BaseIndexInit<IndexDict> {
|
||||
vectorStore: VectorStore;
|
||||
}
|
||||
|
||||
export class VectorStoreIndex extends BaseIndex<IndexDict> {
|
||||
vectorStore: VectorStore;
|
||||
|
||||
private constructor(init: VectorIndexConstructorProps) {
|
||||
super(init);
|
||||
this.vectorStore = init.vectorStore;
|
||||
}
|
||||
|
||||
static async init(options: VectorIndexOptions): Promise<VectorStoreIndex> {
|
||||
const storageContext =
|
||||
options.storageContext ?? (await storageContextFromDefaults({}));
|
||||
const serviceContext =
|
||||
options.serviceContext ?? serviceContextFromDefaults({});
|
||||
const docStore = storageContext.docStore;
|
||||
const vectorStore = storageContext.vectorStore;
|
||||
|
||||
let indexStruct: IndexDict;
|
||||
if (options.indexStruct) {
|
||||
if (options.nodes) {
|
||||
throw new Error(
|
||||
"Cannot initialize VectorStoreIndex with both nodes and indexStruct"
|
||||
);
|
||||
}
|
||||
const embedding = await this.embeddingService.aGetTextEmbedding(
|
||||
node.getText()
|
||||
indexStruct = options.indexStruct;
|
||||
} else {
|
||||
if (!options.nodes) {
|
||||
throw new Error(
|
||||
"Cannot initialize VectorStoreIndex without nodes or indexStruct"
|
||||
);
|
||||
}
|
||||
indexStruct = await VectorStoreIndex.buildIndexFromNodes(
|
||||
options.nodes,
|
||||
serviceContext,
|
||||
vectorStore
|
||||
);
|
||||
node.embedding = embedding;
|
||||
}
|
||||
|
||||
return new VectorStoreIndex({
|
||||
storageContext,
|
||||
serviceContext,
|
||||
docStore,
|
||||
vectorStore,
|
||||
indexStruct,
|
||||
});
|
||||
}
|
||||
|
||||
buildIndexFromNodes() {
|
||||
for (const node of this.nodes) {
|
||||
this.indexStruct.addNode(node);
|
||||
static async agetNodeEmbeddingResults(
|
||||
nodes: BaseNode[],
|
||||
serviceContext: ServiceContext,
|
||||
logProgress = false
|
||||
) {
|
||||
const nodesWithEmbeddings: NodeWithEmbedding[] = [];
|
||||
|
||||
for (let i = 0; i < nodes.length; ++i) {
|
||||
const node = nodes[i];
|
||||
if (logProgress) {
|
||||
console.log(`getting embedding for node ${i}/${nodes.length}`);
|
||||
}
|
||||
const embedding = await serviceContext.embedModel.aGetTextEmbedding(
|
||||
node.getContent(MetadataMode.EMBED)
|
||||
);
|
||||
nodesWithEmbeddings.push({ node, embedding });
|
||||
}
|
||||
|
||||
return nodesWithEmbeddings;
|
||||
}
|
||||
|
||||
static async fromDocuments(documents: Document[]): Promise<VectorStoreIndex> {
|
||||
const nodeParser = new SimpleNodeParser(); // FIXME use service context
|
||||
const nodes = nodeParser.getNodesFromDocuments(documents);
|
||||
const index = new VectorStoreIndex(nodes);
|
||||
await index.getNodeEmbeddingResults();
|
||||
static async buildIndexFromNodes(
|
||||
nodes: BaseNode[],
|
||||
serviceContext: ServiceContext,
|
||||
vectorStore: VectorStore
|
||||
): Promise<IndexDict> {
|
||||
const embeddingResults = await this.agetNodeEmbeddingResults(
|
||||
nodes,
|
||||
serviceContext
|
||||
);
|
||||
|
||||
vectorStore.add(embeddingResults);
|
||||
|
||||
const indexDict = new IndexDict();
|
||||
for (const { node } of embeddingResults) {
|
||||
indexDict.addNode(node);
|
||||
}
|
||||
|
||||
return indexDict;
|
||||
}
|
||||
|
||||
static async fromDocuments(
|
||||
documents: Document[],
|
||||
storageContext?: StorageContext,
|
||||
serviceContext?: ServiceContext
|
||||
): Promise<VectorStoreIndex> {
|
||||
storageContext = storageContext ?? (await storageContextFromDefaults({}));
|
||||
serviceContext = serviceContext ?? serviceContextFromDefaults({});
|
||||
const docStore = storageContext.docStore;
|
||||
|
||||
for (const doc of documents) {
|
||||
docStore.setDocumentHash(doc.id_, doc.hash);
|
||||
}
|
||||
|
||||
const nodes = serviceContext.nodeParser.getNodesFromDocuments(documents);
|
||||
const index = await VectorStoreIndex.init({
|
||||
nodes,
|
||||
storageContext,
|
||||
serviceContext,
|
||||
});
|
||||
return index;
|
||||
}
|
||||
|
||||
asRetriever(): VectorIndexRetriever {
|
||||
return new VectorIndexRetriever(this, this.embeddingService);
|
||||
return new VectorIndexRetriever(this);
|
||||
}
|
||||
|
||||
asQueryEngine(): BaseQueryEngine {
|
||||
|
||||
@@ -0,0 +1,177 @@
|
||||
import { BaseChatModel, BaseMessage, ChatOpenAI } from "./LanguageModel";
|
||||
import { TextNode } from "./Node";
|
||||
import {
|
||||
SimplePrompt,
|
||||
contextSystemPrompt,
|
||||
defaultCondenseQuestionPrompt,
|
||||
messagesToHistoryStr,
|
||||
} from "./Prompt";
|
||||
import { BaseQueryEngine } from "./QueryEngine";
|
||||
import { Response } from "./Response";
|
||||
import { BaseRetriever } from "./Retriever";
|
||||
import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext";
|
||||
import { v4 as uuidv4 } from "uuid";
|
||||
import { Event } from "./callbacks/CallbackManager";
|
||||
|
||||
interface ChatEngine {
|
||||
chatRepl(): void;
|
||||
|
||||
achat(message: string, chatHistory?: BaseMessage[]): Promise<Response>;
|
||||
|
||||
reset(): void;
|
||||
}
|
||||
|
||||
export class SimpleChatEngine implements ChatEngine {
|
||||
chatHistory: BaseMessage[];
|
||||
llm: BaseChatModel;
|
||||
|
||||
constructor(init?: Partial<SimpleChatEngine>) {
|
||||
this.chatHistory = init?.chatHistory ?? [];
|
||||
this.llm = init?.llm ?? new ChatOpenAI({ model: "gpt-3.5-turbo" });
|
||||
}
|
||||
|
||||
chatRepl() {
|
||||
throw new Error("Method not implemented.");
|
||||
}
|
||||
|
||||
async achat(message: string, chatHistory?: BaseMessage[]): Promise<Response> {
|
||||
chatHistory = chatHistory ?? this.chatHistory;
|
||||
chatHistory.push({ content: message, type: "human" });
|
||||
const response = await this.llm.agenerate(chatHistory);
|
||||
chatHistory.push({ content: response.generations[0][0].text, type: "ai" });
|
||||
this.chatHistory = chatHistory;
|
||||
return new Response(response.generations[0][0].text);
|
||||
}
|
||||
|
||||
reset() {
|
||||
this.chatHistory = [];
|
||||
}
|
||||
}
|
||||
|
||||
export class CondenseQuestionChatEngine implements ChatEngine {
|
||||
queryEngine: BaseQueryEngine;
|
||||
chatHistory: BaseMessage[];
|
||||
serviceContext: ServiceContext;
|
||||
condenseMessagePrompt: SimplePrompt;
|
||||
|
||||
constructor(init: {
|
||||
queryEngine: BaseQueryEngine;
|
||||
chatHistory: BaseMessage[];
|
||||
serviceContext?: ServiceContext;
|
||||
condenseMessagePrompt?: SimplePrompt;
|
||||
}) {
|
||||
this.queryEngine = init.queryEngine;
|
||||
this.chatHistory = init?.chatHistory ?? [];
|
||||
this.serviceContext =
|
||||
init?.serviceContext ?? serviceContextFromDefaults({});
|
||||
this.condenseMessagePrompt =
|
||||
init?.condenseMessagePrompt ?? defaultCondenseQuestionPrompt;
|
||||
}
|
||||
|
||||
private async acondenseQuestion(
|
||||
chatHistory: BaseMessage[],
|
||||
question: string
|
||||
) {
|
||||
const chatHistoryStr = messagesToHistoryStr(chatHistory);
|
||||
|
||||
return this.serviceContext.llmPredictor.apredict(
|
||||
defaultCondenseQuestionPrompt,
|
||||
{
|
||||
question: question,
|
||||
chat_history: chatHistoryStr,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
async achat(
|
||||
message: string,
|
||||
chatHistory?: BaseMessage[] | undefined
|
||||
): Promise<Response> {
|
||||
chatHistory = chatHistory ?? this.chatHistory;
|
||||
|
||||
const condensedQuestion = await this.acondenseQuestion(
|
||||
chatHistory,
|
||||
message
|
||||
);
|
||||
|
||||
const response = await this.queryEngine.aquery(condensedQuestion);
|
||||
|
||||
chatHistory.push({ content: message, type: "human" });
|
||||
chatHistory.push({ content: response.response, type: "ai" });
|
||||
|
||||
return response;
|
||||
}
|
||||
|
||||
chatRepl() {
|
||||
throw new Error("Method not implemented.");
|
||||
}
|
||||
|
||||
reset() {
|
||||
this.chatHistory = [];
|
||||
}
|
||||
}
|
||||
|
||||
export class ContextChatEngine implements ChatEngine {
|
||||
retriever: BaseRetriever;
|
||||
chatModel: BaseChatModel;
|
||||
chatHistory: BaseMessage[];
|
||||
|
||||
constructor(init: {
|
||||
retriever: BaseRetriever;
|
||||
chatModel?: BaseChatModel;
|
||||
chatHistory?: BaseMessage[];
|
||||
}) {
|
||||
this.retriever = init.retriever;
|
||||
this.chatModel =
|
||||
init.chatModel ?? new ChatOpenAI({ model: "gpt-3.5-turbo-16k" });
|
||||
this.chatHistory = init?.chatHistory ?? [];
|
||||
}
|
||||
|
||||
chatRepl() {
|
||||
throw new Error("Method not implemented.");
|
||||
}
|
||||
|
||||
async achat(message: string, chatHistory?: BaseMessage[] | undefined) {
|
||||
chatHistory = chatHistory ?? this.chatHistory;
|
||||
|
||||
const parentEvent: Event = {
|
||||
id: uuidv4(),
|
||||
type: "wrapper",
|
||||
tags: ["final"],
|
||||
};
|
||||
const sourceNodesWithScore = await this.retriever.aretrieve(
|
||||
message,
|
||||
parentEvent
|
||||
);
|
||||
|
||||
const systemMessage: BaseMessage = {
|
||||
content: contextSystemPrompt({
|
||||
context: sourceNodesWithScore
|
||||
.map((r) => (r.node as TextNode).text)
|
||||
.join("\n\n"),
|
||||
}),
|
||||
type: "system",
|
||||
};
|
||||
|
||||
chatHistory.push({ content: message, type: "human" });
|
||||
|
||||
const response = await this.chatModel.agenerate(
|
||||
[systemMessage, ...chatHistory],
|
||||
parentEvent
|
||||
);
|
||||
const text = response.generations[0][0].text;
|
||||
|
||||
chatHistory.push({ content: text, type: "ai" });
|
||||
|
||||
this.chatHistory = chatHistory;
|
||||
|
||||
return new Response(
|
||||
text,
|
||||
sourceNodesWithScore.map((r) => r.node)
|
||||
);
|
||||
}
|
||||
|
||||
reset() {
|
||||
this.chatHistory = [];
|
||||
}
|
||||
}
|
||||
@@ -174,24 +174,12 @@ export function getTopKMMREmbeddings(
|
||||
}
|
||||
|
||||
export abstract class BaseEmbedding {
|
||||
static similarity(
|
||||
similarity(
|
||||
embedding1: number[],
|
||||
embedding2: number[],
|
||||
mode: SimilarityType = SimilarityType.DOT_PRODUCT
|
||||
mode: SimilarityType = SimilarityType.DEFAULT
|
||||
): number {
|
||||
if (embedding1.length !== embedding2.length) {
|
||||
throw new Error("Embedding length mismatch");
|
||||
}
|
||||
|
||||
if (mode === SimilarityType.DOT_PRODUCT) {
|
||||
let result = 0;
|
||||
for (let i = 0; i < embedding1.length; i++) {
|
||||
result += embedding1[i] * embedding2[i];
|
||||
}
|
||||
return result;
|
||||
} else {
|
||||
throw new Error("Not implemented yet");
|
||||
}
|
||||
return similarity(embedding1, embedding2, mode);
|
||||
}
|
||||
|
||||
abstract aGetTextEmbedding(text: string): Promise<number[]>;
|
||||
|
||||
@@ -1,11 +1,38 @@
|
||||
import { Event, EventTag, EventType } from "./callbacks/CallbackManager";
|
||||
import { v4 as uuidv4 } from "uuid";
|
||||
|
||||
class GlobalsHelper {
|
||||
defaultTokenizer: ((text: string) => string[]) | null = null;
|
||||
|
||||
tokenizer() {
|
||||
if (this.defaultTokenizer) {
|
||||
return this.defaultTokenizer;
|
||||
}
|
||||
|
||||
const tiktoken = require("tiktoken-node");
|
||||
let enc = new tiktoken.getEncoding("gpt2");
|
||||
const defaultTokenizer = (text: string) => {
|
||||
this.defaultTokenizer = (text: string) => {
|
||||
return enc.encode(text);
|
||||
};
|
||||
return defaultTokenizer;
|
||||
return this.defaultTokenizer;
|
||||
}
|
||||
|
||||
createEvent({
|
||||
parentEvent,
|
||||
type,
|
||||
tags,
|
||||
}: {
|
||||
parentEvent?: Event;
|
||||
type: EventType;
|
||||
tags?: EventTag[];
|
||||
}): Event {
|
||||
return {
|
||||
id: uuidv4(),
|
||||
type,
|
||||
// inherit parent tags if tags not set
|
||||
tags: tags || parentEvent?.tags,
|
||||
parentId: parentEvent?.id,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,28 +1,50 @@
|
||||
import { ChatOpenAI } from "./LanguageModel";
|
||||
import { SimplePrompt } from "./Prompt";
|
||||
import { CallbackManager, Event } from "./callbacks/CallbackManager";
|
||||
|
||||
// TODO change this to LLM class
|
||||
export interface BaseLLMPredictor {
|
||||
getLlmMetadata(): Promise<any>;
|
||||
apredict(
|
||||
prompt: string | SimplePrompt,
|
||||
input?: Record<string, string>
|
||||
input?: Record<string, string>,
|
||||
parentEvent?: Event
|
||||
): Promise<string>;
|
||||
// stream(prompt: string, options: any): Promise<any>;
|
||||
}
|
||||
|
||||
// TODO change this to LLM class
|
||||
export class ChatGPTLLMPredictor implements BaseLLMPredictor {
|
||||
llm: string;
|
||||
model: string;
|
||||
retryOnThrottling: boolean;
|
||||
languageModel: ChatOpenAI;
|
||||
callbackManager?: CallbackManager;
|
||||
|
||||
constructor(
|
||||
llm: string = "gpt-3.5-turbo",
|
||||
retryOnThrottling: boolean = true
|
||||
props:
|
||||
| {
|
||||
model?: string;
|
||||
retryOnThrottling?: boolean;
|
||||
callbackManager?: CallbackManager;
|
||||
languageModel?: ChatOpenAI;
|
||||
}
|
||||
| undefined = undefined
|
||||
) {
|
||||
this.llm = llm;
|
||||
const {
|
||||
model = "gpt-3.5-turbo",
|
||||
retryOnThrottling = true,
|
||||
callbackManager,
|
||||
languageModel,
|
||||
} = props || {};
|
||||
this.model = model;
|
||||
this.callbackManager = callbackManager;
|
||||
this.retryOnThrottling = retryOnThrottling;
|
||||
|
||||
this.languageModel = new ChatOpenAI(this.llm);
|
||||
this.languageModel =
|
||||
languageModel ??
|
||||
new ChatOpenAI({
|
||||
model: this.model,
|
||||
callbackManager: this.callbackManager,
|
||||
});
|
||||
}
|
||||
|
||||
async getLlmMetadata() {
|
||||
@@ -31,22 +53,22 @@ export class ChatGPTLLMPredictor implements BaseLLMPredictor {
|
||||
|
||||
async apredict(
|
||||
prompt: string | SimplePrompt,
|
||||
input?: Record<string, string>
|
||||
input?: Record<string, string>,
|
||||
parentEvent?: Event
|
||||
): Promise<string> {
|
||||
if (typeof prompt === "string") {
|
||||
const result = await this.languageModel.agenerate([
|
||||
{
|
||||
content: prompt,
|
||||
type: "human",
|
||||
},
|
||||
]);
|
||||
const result = await this.languageModel.agenerate(
|
||||
[
|
||||
{
|
||||
content: prompt,
|
||||
type: "human",
|
||||
},
|
||||
],
|
||||
parentEvent
|
||||
);
|
||||
return result.generations[0][0].text;
|
||||
} else {
|
||||
return this.apredict(prompt(input ?? {}));
|
||||
}
|
||||
}
|
||||
|
||||
// async stream(prompt: string, options: any) {
|
||||
// console.log("stream");
|
||||
// }
|
||||
}
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import { CallbackManager, Event } from "./callbacks/CallbackManager";
|
||||
import { aHandleOpenAIStream } from "./callbacks/utility/aHandleOpenAIStream";
|
||||
import {
|
||||
ChatCompletionRequestMessageRoleEnum,
|
||||
Configuration,
|
||||
CreateChatCompletionRequest,
|
||||
OpenAISession,
|
||||
OpenAIWrapper,
|
||||
getOpenAISession,
|
||||
} from "./openai";
|
||||
|
||||
@@ -10,7 +11,7 @@ export interface BaseLanguageModel {}
|
||||
|
||||
type MessageType = "human" | "ai" | "system" | "generic" | "function";
|
||||
|
||||
interface BaseMessage {
|
||||
export interface BaseMessage {
|
||||
content: string;
|
||||
type: MessageType;
|
||||
}
|
||||
@@ -24,9 +25,11 @@ export interface LLMResult {
|
||||
generations: Generation[][]; // Each input can have more than one generations
|
||||
}
|
||||
|
||||
export class BaseChatModel implements BaseLanguageModel {}
|
||||
export interface BaseChatModel extends BaseLanguageModel {
|
||||
agenerate(messages: BaseMessage[], parentEvent?: Event): Promise<LLMResult>;
|
||||
}
|
||||
|
||||
export class ChatOpenAI extends BaseChatModel {
|
||||
export class ChatOpenAI implements BaseChatModel {
|
||||
model: string;
|
||||
temperature: number = 0.7;
|
||||
openAIKey: string | null = null;
|
||||
@@ -34,12 +37,18 @@ export class ChatOpenAI extends BaseChatModel {
|
||||
maxRetries: number = 6;
|
||||
n: number = 1;
|
||||
maxTokens?: number;
|
||||
|
||||
session: OpenAISession;
|
||||
callbackManager?: CallbackManager;
|
||||
|
||||
constructor(model: string = "gpt-3.5-turbo") {
|
||||
super();
|
||||
constructor({
|
||||
model = "gpt-3.5-turbo",
|
||||
callbackManager,
|
||||
}: {
|
||||
model: string;
|
||||
callbackManager?: CallbackManager;
|
||||
}) {
|
||||
this.model = model;
|
||||
this.callbackManager = callbackManager;
|
||||
this.session = getOpenAISession();
|
||||
}
|
||||
|
||||
@@ -60,8 +69,11 @@ export class ChatOpenAI extends BaseChatModel {
|
||||
}
|
||||
}
|
||||
|
||||
async agenerate(messages: BaseMessage[]): Promise<LLMResult> {
|
||||
const { data } = await this.session.openai.createChatCompletion({
|
||||
async agenerate(
|
||||
messages: BaseMessage[],
|
||||
parentEvent?: Event
|
||||
): Promise<LLMResult> {
|
||||
const baseRequestParams: CreateChatCompletionRequest = {
|
||||
model: this.model,
|
||||
temperature: this.temperature,
|
||||
max_tokens: this.maxTokens,
|
||||
@@ -70,8 +82,29 @@ export class ChatOpenAI extends BaseChatModel {
|
||||
role: ChatOpenAI.mapMessageType(message.type),
|
||||
content: message.content,
|
||||
})),
|
||||
});
|
||||
};
|
||||
|
||||
if (this.callbackManager?.onLLMStream) {
|
||||
const response = await this.session.openai.createChatCompletion(
|
||||
{
|
||||
...baseRequestParams,
|
||||
stream: true,
|
||||
},
|
||||
{ responseType: "stream" }
|
||||
);
|
||||
const fullResponse = await aHandleOpenAIStream({
|
||||
response,
|
||||
onLLMStream: this.callbackManager.onLLMStream,
|
||||
parentEvent,
|
||||
});
|
||||
return { generations: [[{ text: fullResponse }]] };
|
||||
}
|
||||
|
||||
const response = await this.session.openai.createChatCompletion(
|
||||
baseRequestParams
|
||||
);
|
||||
|
||||
const { data } = response;
|
||||
const content = data.choices[0].message?.content ?? "";
|
||||
return { generations: [[{ text: content }]] };
|
||||
}
|
||||
|
||||
@@ -136,7 +136,7 @@ export class TextNode extends BaseNode {
|
||||
endCharIdx?: number;
|
||||
// textTemplate: NOTE write your own formatter if needed
|
||||
// metadataTemplate: NOTE write your own formatter if needed
|
||||
metadataSeperator: string = "\n";
|
||||
metadataSeparator: string = "\n";
|
||||
|
||||
constructor(init?: Partial<TextNode>) {
|
||||
super(init);
|
||||
@@ -174,7 +174,7 @@ export class TextNode extends BaseNode {
|
||||
|
||||
return [...usableMetadataKeys]
|
||||
.map((key) => `${key}: ${this.metadata[key]}`)
|
||||
.join(this.metadataSeperator);
|
||||
.join(this.metadataSeparator);
|
||||
}
|
||||
|
||||
setContent(value: string) {
|
||||
@@ -206,11 +206,6 @@ export class IndexNode extends TextNode {
|
||||
}
|
||||
}
|
||||
|
||||
export interface NodeWithScore {
|
||||
node: TextNode;
|
||||
score: number;
|
||||
}
|
||||
|
||||
export class Document extends TextNode {
|
||||
constructor(init?: Partial<Document>) {
|
||||
super(init);
|
||||
@@ -229,3 +224,13 @@ export class Document extends TextNode {
|
||||
export class ImageDocument extends Document {
|
||||
image?: string;
|
||||
}
|
||||
|
||||
export interface NodeWithScore {
|
||||
node: BaseNode;
|
||||
score: number;
|
||||
}
|
||||
|
||||
export interface NodeWithEmbedding {
|
||||
node: BaseNode;
|
||||
embedding: number[];
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { Document, NodeRelationship, TextNode } from "./Node";
|
||||
import { SentenceSplitter } from "./TextSplitter";
|
||||
import { DEFAULT_CHUNK_OVERLAP, DEFAULT_CHUNK_SIZE } from "./constants";
|
||||
|
||||
export function getTextSplitsFromDocument(
|
||||
document: Document,
|
||||
@@ -13,18 +14,36 @@ export function getTextSplitsFromDocument(
|
||||
|
||||
export function getNodesFromDocument(
|
||||
document: Document,
|
||||
textSplitter: SentenceSplitter
|
||||
textSplitter: SentenceSplitter,
|
||||
includeMetadata: boolean = true,
|
||||
includePrevNextRel: boolean = true
|
||||
) {
|
||||
let nodes: TextNode[] = [];
|
||||
|
||||
const textSplits = getTextSplitsFromDocument(document, textSplitter);
|
||||
|
||||
textSplits.forEach((textSplit, index) => {
|
||||
const node = new TextNode({ text: textSplit });
|
||||
textSplits.forEach((textSplit) => {
|
||||
const node = new TextNode({
|
||||
text: textSplit,
|
||||
metadata: includeMetadata ? document.metadata : {},
|
||||
});
|
||||
node.relationships[NodeRelationship.SOURCE] = document.asRelatedNodeInfo();
|
||||
nodes.push(node);
|
||||
});
|
||||
|
||||
if (includePrevNextRel) {
|
||||
nodes.forEach((node, index) => {
|
||||
if (index > 0) {
|
||||
node.relationships[NodeRelationship.PREVIOUS] =
|
||||
nodes[index - 1].asRelatedNodeInfo();
|
||||
}
|
||||
if (index < nodes.length - 1) {
|
||||
node.relationships[NodeRelationship.NEXT] =
|
||||
nodes[index + 1].asRelatedNodeInfo();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return nodes;
|
||||
}
|
||||
|
||||
@@ -33,17 +52,34 @@ export interface NodeParser {
|
||||
}
|
||||
export class SimpleNodeParser implements NodeParser {
|
||||
textSplitter: SentenceSplitter;
|
||||
includeMetadata: boolean;
|
||||
includePrevNextRel: boolean;
|
||||
|
||||
constructor(
|
||||
textSplitter: any = null,
|
||||
includeExtraInfo: boolean = true,
|
||||
includePrevNextRel: boolean = true
|
||||
) {
|
||||
this.textSplitter = textSplitter ?? new SentenceSplitter();
|
||||
constructor(init?: {
|
||||
textSplitter?: SentenceSplitter;
|
||||
includeMetadata?: boolean;
|
||||
includePrevNextRel?: boolean;
|
||||
|
||||
chunkSize?: number;
|
||||
chunkOverlap?: number;
|
||||
}) {
|
||||
this.textSplitter =
|
||||
init?.textSplitter ??
|
||||
new SentenceSplitter(
|
||||
init?.chunkSize ?? DEFAULT_CHUNK_SIZE,
|
||||
init?.chunkOverlap ?? DEFAULT_CHUNK_OVERLAP
|
||||
);
|
||||
this.includeMetadata = init?.includeMetadata ?? true;
|
||||
this.includePrevNextRel = init?.includePrevNextRel ?? true;
|
||||
}
|
||||
|
||||
static fromDefaults(): SimpleNodeParser {
|
||||
return new SimpleNodeParser();
|
||||
static fromDefaults(init?: {
|
||||
chunkSize?: number;
|
||||
chunkOverlap?: number;
|
||||
includeMetadata?: boolean;
|
||||
includePrevNextRel?: boolean;
|
||||
}): SimpleNodeParser {
|
||||
return new SimpleNodeParser(init);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
import { SubQuestion } from "./QuestionGenerator";
|
||||
|
||||
export interface BaseOutputParser<T> {
|
||||
parse(output: string): T;
|
||||
format(output: string): string;
|
||||
}
|
||||
|
||||
export interface StructuredOutput<T> {
|
||||
rawOutput: string;
|
||||
parsedOutput: T;
|
||||
}
|
||||
|
||||
class OutputParserError extends Error {
|
||||
cause: Error | undefined;
|
||||
output: string | undefined;
|
||||
|
||||
constructor(
|
||||
message: string,
|
||||
options: { cause?: Error; output?: string } = {}
|
||||
) {
|
||||
// @ts-ignore
|
||||
super(message, options); // https://github.com/tc39/proposal-error-cause
|
||||
this.name = "OutputParserError";
|
||||
|
||||
if (!this.cause) {
|
||||
// Need to check for those environments that have implemented the proposal
|
||||
this.cause = options.cause;
|
||||
}
|
||||
this.output = options.output;
|
||||
|
||||
// This line is to maintain proper stack trace in V8
|
||||
// (https://v8.dev/docs/stack-trace-api)
|
||||
if (Error.captureStackTrace) {
|
||||
Error.captureStackTrace(this, OutputParserError);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function parseJsonMarkdown(text: string) {
|
||||
text = text.trim();
|
||||
|
||||
const beginDelimiter = "```json";
|
||||
const endDelimiter = "```";
|
||||
|
||||
const beginIndex = text.indexOf(beginDelimiter);
|
||||
const endIndex = text.indexOf(
|
||||
endDelimiter,
|
||||
beginIndex + beginDelimiter.length
|
||||
);
|
||||
if (beginIndex === -1 || endIndex === -1) {
|
||||
throw new OutputParserError("Not a json markdown", { output: text });
|
||||
}
|
||||
|
||||
const jsonText = text.substring(beginIndex + beginDelimiter.length, endIndex);
|
||||
|
||||
try {
|
||||
return JSON.parse(jsonText);
|
||||
} catch (e) {
|
||||
throw new OutputParserError("Not a valid json", {
|
||||
cause: e as Error,
|
||||
output: text,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
export class SubQuestionOutputParser
|
||||
implements BaseOutputParser<StructuredOutput<SubQuestion[]>>
|
||||
{
|
||||
parse(output: string): StructuredOutput<SubQuestion[]> {
|
||||
const parsed = parseJsonMarkdown(output);
|
||||
|
||||
// TODO add zod validation
|
||||
|
||||
return { rawOutput: output, parsedOutput: parsed };
|
||||
}
|
||||
|
||||
format(output: string): string {
|
||||
return output;
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,7 @@
|
||||
import { BaseMessage } from "./LanguageModel";
|
||||
import { SubQuestion } from "./QuestionGenerator";
|
||||
import { ToolMetadata } from "./Tool";
|
||||
|
||||
/**
|
||||
* A SimplePrompt is a function that takes a dictionary of inputs and returns a string.
|
||||
* NOTE this is a different interface compared to LlamaIndex Python
|
||||
@@ -80,3 +84,236 @@ ${context}
|
||||
------------
|
||||
Given the new context, refine the original answer to better answer the question. If the context isn't useful, return the original answer.`;
|
||||
};
|
||||
|
||||
export const defaultChoiceSelectPrompt: SimplePrompt = (input) => {
|
||||
const { context = "", query = "" } = input;
|
||||
|
||||
return `A list of documents is shown below. Each document has a number next to it along
|
||||
with a summary of the document. A question is also provided.
|
||||
Respond with the numbers of the documents
|
||||
you should consult to answer the question, in order of relevance, as well
|
||||
as the relevance score. The relevance score is a number from 1-10 based on
|
||||
how relevant you think the document is to the question.
|
||||
Do not include any documents that are not relevant to the question.
|
||||
Example format:
|
||||
Document 1:
|
||||
<summary of document 1>
|
||||
|
||||
Document 2:
|
||||
<summary of document 2>
|
||||
|
||||
...
|
||||
|
||||
Document 10:\n<summary of document 10>
|
||||
|
||||
Question: <question>
|
||||
Answer:
|
||||
Doc: 9, Relevance: 7
|
||||
Doc: 3, Relevance: 4
|
||||
Doc: 7, Relevance: 3
|
||||
|
||||
Let's try this now:
|
||||
|
||||
${context}
|
||||
Question: ${query}
|
||||
Answer:`;
|
||||
};
|
||||
|
||||
/*
|
||||
PREFIX = """\
|
||||
Given a user question, and a list of tools, output a list of relevant sub-questions \
|
||||
that when composed can help answer the full user question:
|
||||
|
||||
"""
|
||||
|
||||
|
||||
example_query_str = (
|
||||
"Compare and contrast the revenue growth and EBITDA of Uber and Lyft for year 2021"
|
||||
)
|
||||
example_tools = [
|
||||
ToolMetadata(
|
||||
name="uber_10k",
|
||||
description="Provides information about Uber financials for year 2021",
|
||||
),
|
||||
ToolMetadata(
|
||||
name="lyft_10k",
|
||||
description="Provides information about Lyft financials for year 2021",
|
||||
),
|
||||
]
|
||||
example_tools_str = build_tools_text(example_tools)
|
||||
example_output = [
|
||||
SubQuestion(
|
||||
sub_question="What is the revenue growth of Uber", tool_name="uber_10k"
|
||||
),
|
||||
SubQuestion(sub_question="What is the EBITDA of Uber", tool_name="uber_10k"),
|
||||
SubQuestion(
|
||||
sub_question="What is the revenue growth of Lyft", tool_name="lyft_10k"
|
||||
),
|
||||
SubQuestion(sub_question="What is the EBITDA of Lyft", tool_name="lyft_10k"),
|
||||
]
|
||||
example_output_str = json.dumps([x.dict() for x in example_output], indent=4)
|
||||
|
||||
EXAMPLES = (
|
||||
"""\
|
||||
# Example 1
|
||||
<Tools>
|
||||
```json
|
||||
{tools_str}
|
||||
```
|
||||
|
||||
<User Question>
|
||||
{query_str}
|
||||
|
||||
|
||||
<Output>
|
||||
```json
|
||||
{output_str}
|
||||
```
|
||||
|
||||
""".format(
|
||||
query_str=example_query_str,
|
||||
tools_str=example_tools_str,
|
||||
output_str=example_output_str,
|
||||
)
|
||||
.replace("{", "{{")
|
||||
.replace("}", "}}")
|
||||
)
|
||||
|
||||
SUFFIX = """\
|
||||
# Example 2
|
||||
<Tools>
|
||||
```json
|
||||
{tools_str}
|
||||
```
|
||||
|
||||
<User Question>
|
||||
{query_str}
|
||||
|
||||
<Output>
|
||||
"""
|
||||
|
||||
DEFAULT_SUB_QUESTION_PROMPT_TMPL = PREFIX + EXAMPLES + SUFFIX
|
||||
*/
|
||||
|
||||
export function buildToolsText(tools: ToolMetadata[]) {
|
||||
const toolsObj = tools.reduce<Record<string, string>>((acc, tool) => {
|
||||
acc[tool.name] = tool.description;
|
||||
return acc;
|
||||
}, {});
|
||||
|
||||
return JSON.stringify(toolsObj, null, 4);
|
||||
}
|
||||
|
||||
const exampleTools: ToolMetadata[] = [
|
||||
{
|
||||
name: "uber_10k",
|
||||
description: "Provides information about Uber financials for year 2021",
|
||||
},
|
||||
{
|
||||
name: "lyft_10k",
|
||||
description: "Provides information about Lyft financials for year 2021",
|
||||
},
|
||||
];
|
||||
|
||||
const exampleQueryStr = `Compare and contrast the revenue growth and EBITDA of Uber and Lyft for year 2021`;
|
||||
|
||||
const exampleOutput: SubQuestion[] = [
|
||||
{
|
||||
subQuestion: "What is the revenue growth of Uber",
|
||||
toolName: "uber_10k",
|
||||
},
|
||||
{
|
||||
subQuestion: "What is the EBITDA of Uber",
|
||||
toolName: "uber_10k",
|
||||
},
|
||||
{
|
||||
subQuestion: "What is the revenue growth of Lyft",
|
||||
toolName: "lyft_10k",
|
||||
},
|
||||
{
|
||||
subQuestion: "What is the EBITDA of Lyft",
|
||||
toolName: "lyft_10k",
|
||||
},
|
||||
];
|
||||
|
||||
export const defaultSubQuestionPrompt: SimplePrompt = (input) => {
|
||||
const { toolsStr, queryStr } = input;
|
||||
|
||||
return `Given a user question, and a list of tools, output a list of relevant sub-questions that when composed can help answer the full user question:
|
||||
|
||||
# Example 1
|
||||
<Tools>
|
||||
\`\`\`json
|
||||
${buildToolsText(exampleTools)}
|
||||
\`\`\`
|
||||
|
||||
<User Question>
|
||||
${exampleQueryStr}
|
||||
|
||||
<Output>
|
||||
\`\`\`json
|
||||
${JSON.stringify(exampleOutput, null, 4)}
|
||||
\`\`\`
|
||||
|
||||
# Example 2
|
||||
<Tools>
|
||||
\`\`\`json
|
||||
${toolsStr}}
|
||||
\`\`\`
|
||||
|
||||
<User Question>
|
||||
${queryStr}
|
||||
|
||||
<Output>
|
||||
`;
|
||||
};
|
||||
|
||||
// DEFAULT_TEMPLATE = """\
|
||||
// Given a conversation (between Human and Assistant) and a follow up message from Human, \
|
||||
// rewrite the message to be a standalone question that captures all relevant context \
|
||||
// from the conversation.
|
||||
|
||||
// <Chat History>
|
||||
// {chat_history}
|
||||
|
||||
// <Follow Up Message>
|
||||
// {question}
|
||||
|
||||
// <Standalone question>
|
||||
// """
|
||||
|
||||
export const defaultCondenseQuestionPrompt: SimplePrompt = (input) => {
|
||||
const { chatHistory, question } = input;
|
||||
|
||||
return `Given a conversation (between Human and Assistant) and a follow up message from Human, rewrite the message to be a standalone question that captures all relevant context from the conversation.
|
||||
|
||||
<Chat History>
|
||||
${chatHistory}
|
||||
|
||||
<Follow Up Message>
|
||||
${question}
|
||||
|
||||
<Standalone question>
|
||||
`;
|
||||
};
|
||||
|
||||
export function messagesToHistoryStr(messages: BaseMessage[]) {
|
||||
return messages.reduce((acc, message) => {
|
||||
acc += acc ? "\n" : "";
|
||||
if (message.type === "human") {
|
||||
acc += `Human: ${message.content}`;
|
||||
} else {
|
||||
acc += `Assistant: ${message.content}`;
|
||||
}
|
||||
return acc;
|
||||
}, "");
|
||||
}
|
||||
|
||||
export const contextSystemPrompt: SimplePrompt = (input) => {
|
||||
const { context } = input;
|
||||
|
||||
return `Context information is below.
|
||||
---------------------
|
||||
${context}
|
||||
---------------------`;
|
||||
};
|
||||
|
||||
@@ -1,22 +1,136 @@
|
||||
import { NodeWithScore, TextNode } from "./Node";
|
||||
import {
|
||||
BaseQuestionGenerator,
|
||||
LLMQuestionGenerator,
|
||||
SubQuestion,
|
||||
} from "./QuestionGenerator";
|
||||
import { Response } from "./Response";
|
||||
import { ResponseSynthesizer } from "./ResponseSynthesizer";
|
||||
import { CompactAndRefine, ResponseSynthesizer } from "./ResponseSynthesizer";
|
||||
import { BaseRetriever } from "./Retriever";
|
||||
import { v4 as uuidv4 } from "uuid";
|
||||
import { Event } from "./callbacks/CallbackManager";
|
||||
import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext";
|
||||
import { QueryEngineTool, ToolMetadata } from "./Tool";
|
||||
|
||||
export interface BaseQueryEngine {
|
||||
aquery(query: string): Promise<Response>;
|
||||
aquery(query: string, parentEvent?: Event): Promise<Response>;
|
||||
}
|
||||
|
||||
export class RetrieverQueryEngine {
|
||||
export class RetrieverQueryEngine implements BaseQueryEngine {
|
||||
retriever: BaseRetriever;
|
||||
responseSynthesizer: ResponseSynthesizer;
|
||||
|
||||
constructor(retriever: BaseRetriever) {
|
||||
this.retriever = retriever;
|
||||
this.responseSynthesizer = new ResponseSynthesizer();
|
||||
const serviceContext: ServiceContext | undefined =
|
||||
this.retriever.getServiceContext();
|
||||
this.responseSynthesizer = new ResponseSynthesizer({ serviceContext });
|
||||
}
|
||||
|
||||
async aquery(query: string) {
|
||||
const nodes = await this.retriever.aretrieve(query);
|
||||
return this.responseSynthesizer.asynthesize(query, nodes);
|
||||
async aquery(query: string, parentEvent?: Event) {
|
||||
const _parentEvent: Event = parentEvent || {
|
||||
id: uuidv4(),
|
||||
type: "wrapper",
|
||||
tags: ["final"],
|
||||
};
|
||||
const nodes = await this.retriever.aretrieve(query, _parentEvent);
|
||||
return this.responseSynthesizer.asynthesize(query, nodes, _parentEvent);
|
||||
}
|
||||
}
|
||||
|
||||
export class SubQuestionQueryEngine implements BaseQueryEngine {
|
||||
responseSynthesizer: ResponseSynthesizer;
|
||||
questionGen: BaseQuestionGenerator;
|
||||
queryEngines: Record<string, BaseQueryEngine>;
|
||||
metadatas: ToolMetadata[];
|
||||
|
||||
constructor(init: {
|
||||
questionGen: BaseQuestionGenerator;
|
||||
responseSynthesizer: ResponseSynthesizer;
|
||||
queryEngineTools: QueryEngineTool[];
|
||||
}) {
|
||||
this.questionGen = init.questionGen;
|
||||
this.responseSynthesizer =
|
||||
init.responseSynthesizer ?? new ResponseSynthesizer();
|
||||
this.queryEngines = init.queryEngineTools.reduce<
|
||||
Record<string, BaseQueryEngine>
|
||||
>((acc, tool) => {
|
||||
acc[tool.metadata.name] = tool.queryEngine;
|
||||
return acc;
|
||||
}, {});
|
||||
this.metadatas = init.queryEngineTools.map((tool) => tool.metadata);
|
||||
}
|
||||
|
||||
static fromDefaults(init: {
|
||||
queryEngineTools: QueryEngineTool[];
|
||||
questionGen?: BaseQuestionGenerator;
|
||||
responseSynthesizer?: ResponseSynthesizer;
|
||||
serviceContext?: ServiceContext;
|
||||
}) {
|
||||
const serviceContext =
|
||||
init.serviceContext ?? serviceContextFromDefaults({});
|
||||
|
||||
const questionGen = init.questionGen ?? new LLMQuestionGenerator();
|
||||
const responseSynthesizer =
|
||||
init.responseSynthesizer ??
|
||||
new ResponseSynthesizer({
|
||||
responseBuilder: new CompactAndRefine(serviceContext),
|
||||
serviceContext,
|
||||
});
|
||||
|
||||
return new SubQuestionQueryEngine({
|
||||
questionGen,
|
||||
responseSynthesizer,
|
||||
queryEngineTools: init.queryEngineTools,
|
||||
});
|
||||
}
|
||||
|
||||
async aquery(query: string): Promise<Response> {
|
||||
const subQuestions = await this.questionGen.agenerate(
|
||||
this.metadatas,
|
||||
query
|
||||
);
|
||||
|
||||
// groups final retrieval+synthesis operation
|
||||
const parentEvent: Event = {
|
||||
id: uuidv4(),
|
||||
type: "wrapper",
|
||||
tags: ["final"],
|
||||
};
|
||||
|
||||
// groups all sub-queries
|
||||
const subQueryParentEvent: Event = {
|
||||
id: uuidv4(),
|
||||
parentId: parentEvent.id,
|
||||
type: "wrapper",
|
||||
tags: ["intermediate"],
|
||||
};
|
||||
|
||||
const subQNodes = await Promise.all(
|
||||
subQuestions.map((subQ) => this.aquerySubQ(subQ, subQueryParentEvent))
|
||||
);
|
||||
|
||||
const nodes = subQNodes
|
||||
.filter((node) => node !== null)
|
||||
.map((node) => node as NodeWithScore);
|
||||
return this.responseSynthesizer.asynthesize(query, nodes, parentEvent);
|
||||
}
|
||||
|
||||
private async aquerySubQ(
|
||||
subQ: SubQuestion,
|
||||
parentEvent?: Event
|
||||
): Promise<NodeWithScore | null> {
|
||||
try {
|
||||
const question = subQ.subQuestion;
|
||||
const queryEngine = this.queryEngines[subQ.toolName];
|
||||
|
||||
const response = await queryEngine.aquery(question, parentEvent);
|
||||
const responseText = response.response;
|
||||
const nodeText = `Sub question: ${question}\nResponse: ${responseText}}`;
|
||||
const node = new TextNode({ text: nodeText });
|
||||
return { node, score: 0 };
|
||||
} catch (error) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
import { BaseLLMPredictor, ChatGPTLLMPredictor } from "./LLMPredictor";
|
||||
import {
|
||||
BaseOutputParser,
|
||||
StructuredOutput,
|
||||
SubQuestionOutputParser,
|
||||
} from "./OutputParser";
|
||||
import {
|
||||
SimplePrompt,
|
||||
buildToolsText,
|
||||
defaultSubQuestionPrompt,
|
||||
} from "./Prompt";
|
||||
import { ToolMetadata } from "./Tool";
|
||||
|
||||
export interface SubQuestion {
|
||||
subQuestion: string;
|
||||
toolName: string;
|
||||
}
|
||||
|
||||
export interface BaseQuestionGenerator {
|
||||
agenerate(tools: ToolMetadata[], query: string): Promise<SubQuestion[]>;
|
||||
}
|
||||
|
||||
export class LLMQuestionGenerator implements BaseQuestionGenerator {
|
||||
llmPredictor: BaseLLMPredictor;
|
||||
prompt: SimplePrompt;
|
||||
outputParser: BaseOutputParser<StructuredOutput<SubQuestion[]>>;
|
||||
|
||||
constructor(init?: Partial<LLMQuestionGenerator>) {
|
||||
this.llmPredictor = init?.llmPredictor ?? new ChatGPTLLMPredictor();
|
||||
this.prompt = init?.prompt ?? defaultSubQuestionPrompt;
|
||||
this.outputParser = init?.outputParser ?? new SubQuestionOutputParser();
|
||||
}
|
||||
|
||||
async agenerate(
|
||||
tools: ToolMetadata[],
|
||||
query: string
|
||||
): Promise<SubQuestion[]> {
|
||||
const toolsStr = buildToolsText(tools);
|
||||
const queryStr = query;
|
||||
const prediction = await this.llmPredictor.apredict(this.prompt, {
|
||||
toolsStr,
|
||||
queryStr,
|
||||
});
|
||||
|
||||
const structuredOutput = this.outputParser.parse(prediction);
|
||||
|
||||
return structuredOutput.parsedOutput;
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,10 @@
|
||||
import { TextNode } from "./Node";
|
||||
import { BaseNode } from "./Node";
|
||||
|
||||
export class Response {
|
||||
response?: string;
|
||||
sourceNodes: TextNode[];
|
||||
response: string;
|
||||
sourceNodes?: BaseNode[];
|
||||
|
||||
constructor(response?: string, sourceNodes?: TextNode[]) {
|
||||
constructor(response: string, sourceNodes?: BaseNode[]) {
|
||||
this.response = response;
|
||||
this.sourceNodes = sourceNodes || [];
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { ChatGPTLLMPredictor } from "./LLMPredictor";
|
||||
import { NodeWithScore } from "./Node";
|
||||
import { ChatGPTLLMPredictor, BaseLLMPredictor } from "./LLMPredictor";
|
||||
import { MetadataMode, NodeWithScore } from "./Node";
|
||||
import {
|
||||
SimplePrompt,
|
||||
defaultRefinePrompt,
|
||||
@@ -8,28 +8,38 @@ import {
|
||||
import { getBiggestPrompt } from "./PromptHelper";
|
||||
import { Response } from "./Response";
|
||||
import { ServiceContext } from "./ServiceContext";
|
||||
import { Event } from "./callbacks/CallbackManager";
|
||||
|
||||
interface BaseResponseBuilder {
|
||||
agetResponse(query: string, textChunks: string[]): Promise<string>;
|
||||
agetResponse(
|
||||
query: string,
|
||||
textChunks: string[],
|
||||
parentEvent?: Event
|
||||
): Promise<string>;
|
||||
}
|
||||
|
||||
export class SimpleResponseBuilder implements BaseResponseBuilder {
|
||||
llmPredictor: ChatGPTLLMPredictor;
|
||||
llmPredictor: BaseLLMPredictor;
|
||||
textQATemplate: SimplePrompt;
|
||||
|
||||
constructor() {
|
||||
this.llmPredictor = new ChatGPTLLMPredictor();
|
||||
constructor(serviceContext?: ServiceContext) {
|
||||
this.llmPredictor =
|
||||
serviceContext?.llmPredictor ?? new ChatGPTLLMPredictor();
|
||||
this.textQATemplate = defaultTextQaPrompt;
|
||||
}
|
||||
|
||||
async agetResponse(query: string, textChunks: string[]): Promise<string> {
|
||||
async agetResponse(
|
||||
query: string,
|
||||
textChunks: string[],
|
||||
parentEvent?: Event
|
||||
): Promise<string> {
|
||||
const input = {
|
||||
query,
|
||||
context: textChunks.join("\n\n"),
|
||||
};
|
||||
|
||||
const prompt = this.textQATemplate(input);
|
||||
return this.llmPredictor.apredict(prompt, {});
|
||||
return this.llmPredictor.apredict(prompt, {}, parentEvent);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -178,20 +188,42 @@ export class TreeSummarize implements BaseResponseBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
export function getResponseBuilder(): BaseResponseBuilder {
|
||||
return new SimpleResponseBuilder();
|
||||
export function getResponseBuilder(
|
||||
serviceContext?: ServiceContext
|
||||
): SimpleResponseBuilder {
|
||||
return new SimpleResponseBuilder(serviceContext);
|
||||
}
|
||||
|
||||
// TODO replace with Logan's new response_sythesizers/factory.py
|
||||
export class ResponseSynthesizer {
|
||||
responseBuilder: BaseResponseBuilder;
|
||||
serviceContext?: ServiceContext;
|
||||
|
||||
constructor() {
|
||||
this.responseBuilder = getResponseBuilder();
|
||||
constructor({
|
||||
responseBuilder,
|
||||
serviceContext,
|
||||
}: {
|
||||
responseBuilder?: BaseResponseBuilder;
|
||||
serviceContext?: ServiceContext;
|
||||
} = {}) {
|
||||
this.serviceContext = serviceContext;
|
||||
this.responseBuilder =
|
||||
responseBuilder ?? getResponseBuilder(this.serviceContext);
|
||||
}
|
||||
|
||||
async asynthesize(query: string, nodes: NodeWithScore[]) {
|
||||
let textChunks: string[] = nodes.map((node) => node.node.text);
|
||||
const response = await this.responseBuilder.agetResponse(query, textChunks);
|
||||
async asynthesize(
|
||||
query: string,
|
||||
nodes: NodeWithScore[],
|
||||
parentEvent?: Event
|
||||
) {
|
||||
let textChunks: string[] = nodes.map((node) =>
|
||||
node.node.getContent(MetadataMode.NONE)
|
||||
);
|
||||
const response = await this.responseBuilder.agetResponse(
|
||||
query,
|
||||
textChunks,
|
||||
parentEvent
|
||||
);
|
||||
return new Response(
|
||||
response,
|
||||
nodes.map((node) => node.node)
|
||||
|
||||
@@ -1,43 +1,67 @@
|
||||
import { VectorStoreIndex } from "./BaseIndex";
|
||||
import { BaseEmbedding, getTopKEmbeddings } from "./Embedding";
|
||||
import { globalsHelper } from "./GlobalsHelper";
|
||||
import { NodeWithScore } from "./Node";
|
||||
import { ServiceContext } from "./ServiceContext";
|
||||
import { Event } from "./callbacks/CallbackManager";
|
||||
import { DEFAULT_SIMILARITY_TOP_K } from "./constants";
|
||||
import {
|
||||
VectorStoreQuery,
|
||||
VectorStoreQueryMode,
|
||||
} from "./storage/vectorStore/types";
|
||||
|
||||
export interface BaseRetriever {
|
||||
aretrieve(query: string): Promise<any>;
|
||||
aretrieve(query: string, parentEvent?: Event): Promise<NodeWithScore[]>;
|
||||
getServiceContext(): ServiceContext;
|
||||
}
|
||||
|
||||
export class VectorIndexRetriever implements BaseRetriever {
|
||||
index: VectorStoreIndex;
|
||||
similarityTopK = DEFAULT_SIMILARITY_TOP_K;
|
||||
embeddingService: BaseEmbedding;
|
||||
private serviceContext: ServiceContext;
|
||||
|
||||
constructor(index: VectorStoreIndex, embeddingService: BaseEmbedding) {
|
||||
constructor(index: VectorStoreIndex) {
|
||||
this.index = index;
|
||||
this.embeddingService = embeddingService;
|
||||
this.serviceContext = this.index.serviceContext;
|
||||
}
|
||||
|
||||
async aretrieve(query: string): Promise<NodeWithScore[]> {
|
||||
const queryEmbedding = await this.embeddingService.aGetQueryEmbedding(
|
||||
query
|
||||
);
|
||||
const [similarities, ids] = getTopKEmbeddings(
|
||||
queryEmbedding,
|
||||
this.index.nodes.map((node) => node.getEmbedding()),
|
||||
undefined,
|
||||
this.index.nodes.map((node) => node.id_)
|
||||
);
|
||||
async aretrieve(
|
||||
query: string,
|
||||
parentEvent?: Event
|
||||
): Promise<NodeWithScore[]> {
|
||||
const queryEmbedding =
|
||||
await this.serviceContext.embedModel.aGetQueryEmbedding(query);
|
||||
|
||||
const q: VectorStoreQuery = {
|
||||
queryEmbedding: queryEmbedding,
|
||||
mode: VectorStoreQueryMode.DEFAULT,
|
||||
similarityTopK: this.similarityTopK,
|
||||
};
|
||||
const result = this.index.vectorStore.query(q);
|
||||
|
||||
let nodesWithScores: NodeWithScore[] = [];
|
||||
|
||||
for (let i = 0; i < ids.length; i++) {
|
||||
const node = this.index.indexStruct.nodesDict[ids[i]];
|
||||
for (let i = 0; i < result.ids.length; i++) {
|
||||
const node = this.index.indexStruct.nodesDict[result.ids[i]];
|
||||
nodesWithScores.push({
|
||||
node: node,
|
||||
score: similarities[i],
|
||||
score: result.similarities[i],
|
||||
});
|
||||
}
|
||||
|
||||
if (this.serviceContext.callbackManager.onRetrieve) {
|
||||
this.serviceContext.callbackManager.onRetrieve({
|
||||
query,
|
||||
nodes: nodesWithScores,
|
||||
event: globalsHelper.createEvent({
|
||||
parentEvent,
|
||||
type: "retrieve",
|
||||
}),
|
||||
});
|
||||
}
|
||||
|
||||
return nodesWithScores;
|
||||
}
|
||||
|
||||
getServiceContext(): ServiceContext {
|
||||
return this.serviceContext;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,35 +1,46 @@
|
||||
import { BaseEmbedding, OpenAIEmbedding } from "./Embedding";
|
||||
import { BaseLLMPredictor, ChatGPTLLMPredictor } from "./LLMPredictor";
|
||||
import { BaseLanguageModel } from "./LanguageModel";
|
||||
import { ChatOpenAI } from "./LanguageModel";
|
||||
import { NodeParser, SimpleNodeParser } from "./NodeParser";
|
||||
import { PromptHelper } from "./PromptHelper";
|
||||
import { CallbackManager } from "./callbacks/CallbackManager";
|
||||
|
||||
export interface ServiceContext {
|
||||
llmPredictor: BaseLLMPredictor;
|
||||
promptHelper: PromptHelper;
|
||||
embedModel: BaseEmbedding;
|
||||
nodeParser: NodeParser;
|
||||
callbackManager: CallbackManager;
|
||||
// llamaLogger: any;
|
||||
// callbackManager: any;
|
||||
}
|
||||
|
||||
export interface ServiceContextOptions {
|
||||
llmPredictor?: BaseLLMPredictor;
|
||||
llm?: BaseLanguageModel;
|
||||
llm?: ChatOpenAI;
|
||||
promptHelper?: PromptHelper;
|
||||
embedModel?: BaseEmbedding;
|
||||
nodeParser?: NodeParser;
|
||||
callbackManager?: CallbackManager;
|
||||
// NodeParser arguments
|
||||
chunkSize?: number;
|
||||
chunkOverlap: number;
|
||||
chunkOverlap?: number;
|
||||
}
|
||||
|
||||
export function serviceContextFromDefaults(options: ServiceContextOptions) {
|
||||
export function serviceContextFromDefaults(options?: ServiceContextOptions) {
|
||||
const callbackManager = options?.callbackManager ?? new CallbackManager();
|
||||
const serviceContext: ServiceContext = {
|
||||
llmPredictor: options.llmPredictor ?? new ChatGPTLLMPredictor(),
|
||||
embedModel: options.embedModel ?? new OpenAIEmbedding(),
|
||||
nodeParser: options.nodeParser ?? new SimpleNodeParser(),
|
||||
promptHelper: options.promptHelper ?? new PromptHelper(),
|
||||
llmPredictor:
|
||||
options?.llmPredictor ??
|
||||
new ChatGPTLLMPredictor({ callbackManager, languageModel: options?.llm }),
|
||||
embedModel: options?.embedModel ?? new OpenAIEmbedding(),
|
||||
nodeParser:
|
||||
options?.nodeParser ??
|
||||
new SimpleNodeParser({
|
||||
chunkSize: options?.chunkSize,
|
||||
chunkOverlap: options?.chunkOverlap,
|
||||
}),
|
||||
promptHelper: options?.promptHelper ?? new PromptHelper(),
|
||||
callbackManager,
|
||||
};
|
||||
|
||||
return serviceContext;
|
||||
@@ -52,5 +63,8 @@ export function serviceContextFromServiceContext(
|
||||
if (options.nodeParser) {
|
||||
newServiceContext.nodeParser = options.nodeParser;
|
||||
}
|
||||
if (options.callbackManager) {
|
||||
newServiceContext.callbackManager = options.callbackManager;
|
||||
}
|
||||
return newServiceContext;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
import { BaseQueryEngine } from "./QueryEngine";
|
||||
|
||||
export interface ToolMetadata {
|
||||
description: string;
|
||||
name: string;
|
||||
}
|
||||
|
||||
export interface BaseTool {
|
||||
metadata: ToolMetadata;
|
||||
}
|
||||
|
||||
export interface QueryEngineTool extends BaseTool {
|
||||
queryEngine: BaseQueryEngine;
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
import { ChatCompletionResponseMessageRoleEnum } from "openai";
|
||||
import { NodeWithScore } from "../Node";
|
||||
|
||||
/*
|
||||
An event is a wrapper that groups related operations.
|
||||
For example, during retrieve and synthesize,
|
||||
a parent event wraps both operations, and each operation has it's own
|
||||
event. In this case, both sub-events will share a parentId.
|
||||
*/
|
||||
|
||||
export type EventTag = "intermediate" | "final";
|
||||
export type EventType = "retrieve" | "llmPredict" | "wrapper";
|
||||
export interface Event {
|
||||
id: string;
|
||||
type: EventType;
|
||||
tags?: EventTag[];
|
||||
parentId?: string;
|
||||
}
|
||||
|
||||
interface BaseCallbackResponse {
|
||||
event: Event;
|
||||
}
|
||||
|
||||
export interface StreamToken {
|
||||
id: string;
|
||||
object: string;
|
||||
created: number;
|
||||
model: string;
|
||||
choices: {
|
||||
index: number;
|
||||
delta: {
|
||||
content?: string;
|
||||
role?: ChatCompletionResponseMessageRoleEnum;
|
||||
};
|
||||
finish_reason: string | null;
|
||||
}[];
|
||||
}
|
||||
|
||||
export interface StreamCallbackResponse extends BaseCallbackResponse {
|
||||
index: number;
|
||||
isDone?: boolean;
|
||||
token?: StreamToken;
|
||||
}
|
||||
|
||||
export interface RetrievalCallbackResponse extends BaseCallbackResponse {
|
||||
query: string;
|
||||
nodes: NodeWithScore[];
|
||||
}
|
||||
|
||||
interface CallbackManagerMethods {
|
||||
/*
|
||||
onLLMStream is called when a token is streamed from the LLM. Defining this
|
||||
callback auto sets the stream = True flag on the openAI createChatCompletion request.
|
||||
*/
|
||||
onLLMStream?: (params: StreamCallbackResponse) => Promise<void> | void;
|
||||
/*
|
||||
onRetrieve is called as soon as the retriever finishes fetching relevant nodes.
|
||||
This callback allows you to handle the retrieved nodes even if the synthesizer
|
||||
is still running.
|
||||
*/
|
||||
onRetrieve?: (params: RetrievalCallbackResponse) => Promise<void> | void;
|
||||
}
|
||||
|
||||
export class CallbackManager implements CallbackManagerMethods {
|
||||
onLLMStream?: (params: StreamCallbackResponse) => Promise<void> | void;
|
||||
onRetrieve?: (params: RetrievalCallbackResponse) => Promise<void> | void;
|
||||
|
||||
constructor(handlers?: CallbackManagerMethods) {
|
||||
this.onLLMStream = handlers?.onLLMStream;
|
||||
this.onRetrieve = handlers?.onRetrieve;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
import { globalsHelper } from "../../GlobalsHelper";
|
||||
import { StreamCallbackResponse, Event } from "../CallbackManager";
|
||||
import { StreamToken } from "../CallbackManager";
|
||||
|
||||
export async function aHandleOpenAIStream({
|
||||
response,
|
||||
onLLMStream,
|
||||
parentEvent,
|
||||
}: {
|
||||
response: any;
|
||||
onLLMStream: (data: StreamCallbackResponse) => void;
|
||||
parentEvent?: Event;
|
||||
}): Promise<string> {
|
||||
const event = globalsHelper.createEvent({
|
||||
parentEvent,
|
||||
type: "llmPredict",
|
||||
});
|
||||
const stream = __astreamCompletion(response.data as any);
|
||||
let index = 0;
|
||||
let cumulativeText = "";
|
||||
for await (const message of stream) {
|
||||
const token: StreamToken = JSON.parse(message);
|
||||
const { content = "", role = "assistant" } = token?.choices[0]?.delta ?? {};
|
||||
// ignore the first token
|
||||
if (!content && role === "assistant" && index === 0) {
|
||||
continue;
|
||||
}
|
||||
cumulativeText += content;
|
||||
onLLMStream?.({ event, index, token });
|
||||
index++;
|
||||
}
|
||||
onLLMStream?.({ event, index, isDone: true });
|
||||
return cumulativeText;
|
||||
}
|
||||
|
||||
/*
|
||||
sources:
|
||||
- https://github.com/openai/openai-node/issues/18#issuecomment-1372047643
|
||||
- https://github.com/openai/openai-node/issues/18#issuecomment-1595805163
|
||||
*/
|
||||
async function* __astreamCompletion(data: string[]) {
|
||||
yield* __alinesToText(__achunksToLines(data));
|
||||
}
|
||||
|
||||
async function* __alinesToText(linesAsync: string | void | any) {
|
||||
for await (const line of linesAsync) {
|
||||
yield line.substring("data :".length);
|
||||
}
|
||||
}
|
||||
|
||||
async function* __achunksToLines(chunksAsync: string[]) {
|
||||
let previous = "";
|
||||
for await (const chunk of chunksAsync) {
|
||||
const bufferChunk = Buffer.isBuffer(chunk) ? chunk : Buffer.from(chunk);
|
||||
previous += bufferChunk;
|
||||
let eolIndex;
|
||||
while ((eolIndex = previous.indexOf("\n")) >= 0) {
|
||||
const line = previous.slice(0, eolIndex + 1).trimEnd();
|
||||
if (line === "data: [DONE]") break;
|
||||
if (line.startsWith("data: ")) yield line;
|
||||
previous = previous.slice(eolIndex + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,35 +0,0 @@
|
||||
import _ from "lodash";
|
||||
import { VectorStoreQueryMode } from "./storage/vectorStore/types";
|
||||
|
||||
export function getTopKEmbeddings(
|
||||
queryEmbedding: number[],
|
||||
embeddings: number[][],
|
||||
similarityFn?: (queryEmbedding: number[], emb: number[]) => number,
|
||||
similarityTopK?: number,
|
||||
embeddingIds?: number[],
|
||||
similarityCutoff?: number
|
||||
): [number[], number[]] {
|
||||
throw new Error("Not implemented");
|
||||
}
|
||||
|
||||
export function getTopKEmbeddingsLearner(
|
||||
queryEmbedding: number[],
|
||||
embeddings: number[][],
|
||||
similarityTopK?: number,
|
||||
embeddingIds?: number[],
|
||||
queryMode: VectorStoreQueryMode = VectorStoreQueryMode.SVM
|
||||
): [number[], number[]] {
|
||||
throw new Error("Not implemented");
|
||||
}
|
||||
|
||||
export function getTopKMMREmbeddings(
|
||||
queryEmbedding: number[],
|
||||
embeddings: number[][],
|
||||
similarityFn?: (queryEmbedding: number[], emb: number[]) => number,
|
||||
similarityTopK?: number,
|
||||
embeddingIds?: number[],
|
||||
similarityCutoff?: number,
|
||||
mmrThreshold?: number
|
||||
): [number[], number[]] {
|
||||
throw new Error("Not implemented");
|
||||
}
|
||||
@@ -0,0 +1,166 @@
|
||||
import { BaseNode, Document } from "../../Node";
|
||||
import { BaseIndex, BaseIndexInit, IndexList } from "../../BaseIndex";
|
||||
import { BaseQueryEngine, RetrieverQueryEngine } from "../../QueryEngine";
|
||||
import {
|
||||
StorageContext,
|
||||
storageContextFromDefaults,
|
||||
} from "../../storage/StorageContext";
|
||||
import { BaseRetriever } from "../../Retriever";
|
||||
import { ListIndexRetriever } from "./ListIndexRetriever";
|
||||
import {
|
||||
ServiceContext,
|
||||
serviceContextFromDefaults,
|
||||
} from "../../ServiceContext";
|
||||
import { BaseDocumentStore, RefDocInfo } from "../../storage/docStore/types";
|
||||
import _ from "lodash";
|
||||
|
||||
export enum ListRetrieverMode {
|
||||
DEFAULT = "default",
|
||||
// EMBEDDING = "embedding",
|
||||
LLM = "llm",
|
||||
}
|
||||
|
||||
export interface ListIndexOptions {
|
||||
nodes?: BaseNode[];
|
||||
indexStruct?: IndexList;
|
||||
serviceContext?: ServiceContext;
|
||||
storageContext?: StorageContext;
|
||||
}
|
||||
|
||||
export class ListIndex extends BaseIndex<IndexList> {
|
||||
constructor(init: BaseIndexInit<IndexList>) {
|
||||
super(init);
|
||||
}
|
||||
|
||||
static async init(options: ListIndexOptions): Promise<ListIndex> {
|
||||
const storageContext =
|
||||
options.storageContext ?? (await storageContextFromDefaults({}));
|
||||
const serviceContext =
|
||||
options.serviceContext ?? serviceContextFromDefaults({});
|
||||
const { docStore, indexStore } = storageContext;
|
||||
|
||||
let indexStruct: IndexList;
|
||||
if (options.indexStruct) {
|
||||
if (options.nodes) {
|
||||
throw new Error(
|
||||
"Cannot initialize VectorStoreIndex with both nodes and indexStruct"
|
||||
);
|
||||
}
|
||||
indexStruct = options.indexStruct;
|
||||
} else {
|
||||
if (!options.nodes) {
|
||||
throw new Error(
|
||||
"Cannot initialize VectorStoreIndex without nodes or indexStruct"
|
||||
);
|
||||
}
|
||||
indexStruct = ListIndex._buildIndexFromNodes(
|
||||
options.nodes,
|
||||
storageContext.docStore
|
||||
);
|
||||
}
|
||||
|
||||
return new ListIndex({
|
||||
storageContext,
|
||||
serviceContext,
|
||||
docStore,
|
||||
indexStore,
|
||||
indexStruct,
|
||||
});
|
||||
}
|
||||
|
||||
static async fromDocuments(
|
||||
documents: Document[],
|
||||
storageContext?: StorageContext,
|
||||
serviceContext?: ServiceContext
|
||||
): Promise<ListIndex> {
|
||||
storageContext = storageContext ?? (await storageContextFromDefaults({}));
|
||||
serviceContext = serviceContext ?? serviceContextFromDefaults({});
|
||||
const docStore = storageContext.docStore;
|
||||
|
||||
docStore.addDocuments(documents, true);
|
||||
for (const doc of documents) {
|
||||
docStore.setDocumentHash(doc.id_, doc.hash);
|
||||
}
|
||||
|
||||
const nodes = serviceContext.nodeParser.getNodesFromDocuments(documents);
|
||||
const index = await ListIndex.init({
|
||||
nodes,
|
||||
storageContext,
|
||||
serviceContext,
|
||||
});
|
||||
return index;
|
||||
}
|
||||
|
||||
asRetriever(
|
||||
mode: ListRetrieverMode = ListRetrieverMode.DEFAULT
|
||||
): BaseRetriever {
|
||||
switch (mode) {
|
||||
case ListRetrieverMode.DEFAULT:
|
||||
return new ListIndexRetriever(this);
|
||||
case ListRetrieverMode.LLM:
|
||||
throw new Error(`Support for LLM retriever mode is not implemented`);
|
||||
default:
|
||||
throw new Error(`Unknown retriever mode: ${mode}`);
|
||||
}
|
||||
}
|
||||
|
||||
asQueryEngine(
|
||||
mode: ListRetrieverMode = ListRetrieverMode.DEFAULT
|
||||
): BaseQueryEngine {
|
||||
return new RetrieverQueryEngine(this.asRetriever());
|
||||
}
|
||||
|
||||
static _buildIndexFromNodes(
|
||||
nodes: BaseNode[],
|
||||
docStore: BaseDocumentStore,
|
||||
indexStruct?: IndexList
|
||||
): IndexList {
|
||||
indexStruct = indexStruct || new IndexList();
|
||||
|
||||
docStore.addDocuments(nodes, true);
|
||||
for (const node of nodes) {
|
||||
indexStruct.addNode(node);
|
||||
}
|
||||
|
||||
return indexStruct;
|
||||
}
|
||||
|
||||
protected _insert(nodes: BaseNode[]): void {
|
||||
for (const node of nodes) {
|
||||
this.indexStruct.addNode(node);
|
||||
}
|
||||
}
|
||||
|
||||
protected _deleteNode(nodeId: string): void {
|
||||
this.indexStruct.nodes = this.indexStruct.nodes.filter(
|
||||
(existingNodeId: string) => existingNodeId !== nodeId
|
||||
);
|
||||
}
|
||||
|
||||
async getRefDocInfo(): Promise<Record<string, RefDocInfo>> {
|
||||
const nodeDocIds = this.indexStruct.nodes;
|
||||
const nodes = await this.docStore.getNodes(nodeDocIds);
|
||||
|
||||
const refDocInfoMap: Record<string, RefDocInfo> = {};
|
||||
|
||||
for (const node of nodes) {
|
||||
const refNode = node.sourceNode;
|
||||
if (_.isNil(refNode)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const refDocInfo = await this.docStore.getRefDocInfo(refNode.nodeId);
|
||||
|
||||
if (_.isNil(refDocInfo)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
refDocInfoMap[refNode.nodeId] = refDocInfo;
|
||||
}
|
||||
|
||||
return refDocInfoMap;
|
||||
}
|
||||
}
|
||||
|
||||
// Legacy
|
||||
export type GPTListIndex = ListIndex;
|
||||
@@ -0,0 +1,137 @@
|
||||
import { BaseRetriever } from "../../Retriever";
|
||||
import { NodeWithScore } from "../../Node";
|
||||
import { ListIndex } from "./ListIndex";
|
||||
import { ServiceContext } from "../../ServiceContext";
|
||||
import {
|
||||
NodeFormatterFunction,
|
||||
ChoiceSelectParserFunction,
|
||||
defaultFormatNodeBatchFn,
|
||||
defaultParseChoiceSelectAnswerFn,
|
||||
} from "./utils";
|
||||
import { SimplePrompt, defaultChoiceSelectPrompt } from "../../Prompt";
|
||||
import _ from "lodash";
|
||||
import { globalsHelper } from "../../GlobalsHelper";
|
||||
import { Event } from "../../callbacks/CallbackManager";
|
||||
|
||||
/**
|
||||
* Simple retriever for ListIndex that returns all nodes
|
||||
*/
|
||||
export class ListIndexRetriever implements BaseRetriever {
|
||||
index: ListIndex;
|
||||
|
||||
constructor(index: ListIndex) {
|
||||
this.index = index;
|
||||
}
|
||||
|
||||
async aretrieve(
|
||||
query: string,
|
||||
parentEvent?: Event
|
||||
): Promise<NodeWithScore[]> {
|
||||
const nodeIds = this.index.indexStruct.nodes;
|
||||
const nodes = await this.index.docStore.getNodes(nodeIds);
|
||||
const result = nodes.map((node) => ({
|
||||
node: node,
|
||||
score: 1,
|
||||
}));
|
||||
|
||||
if (this.index.serviceContext.callbackManager.onRetrieve) {
|
||||
this.index.serviceContext.callbackManager.onRetrieve({
|
||||
query,
|
||||
nodes: result,
|
||||
event: globalsHelper.createEvent({
|
||||
parentEvent,
|
||||
type: "retrieve",
|
||||
}),
|
||||
});
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
getServiceContext(): ServiceContext {
|
||||
return this.index.serviceContext;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* LLM retriever for ListIndex.
|
||||
*/
|
||||
export class ListIndexLLMRetriever implements BaseRetriever {
|
||||
index: ListIndex;
|
||||
choiceSelectPrompt: SimplePrompt;
|
||||
choiceBatchSize: number;
|
||||
formatNodeBatchFn: NodeFormatterFunction;
|
||||
parseChoiceSelectAnswerFn: ChoiceSelectParserFunction;
|
||||
serviceContext: ServiceContext;
|
||||
|
||||
constructor(
|
||||
index: ListIndex,
|
||||
choiceSelectPrompt?: SimplePrompt,
|
||||
choiceBatchSize: number = 10,
|
||||
formatNodeBatchFn?: NodeFormatterFunction,
|
||||
parseChoiceSelectAnswerFn?: ChoiceSelectParserFunction,
|
||||
serviceContext?: ServiceContext
|
||||
) {
|
||||
this.index = index;
|
||||
this.choiceSelectPrompt = choiceSelectPrompt || defaultChoiceSelectPrompt;
|
||||
this.choiceBatchSize = choiceBatchSize;
|
||||
this.formatNodeBatchFn = formatNodeBatchFn || defaultFormatNodeBatchFn;
|
||||
this.parseChoiceSelectAnswerFn =
|
||||
parseChoiceSelectAnswerFn || defaultParseChoiceSelectAnswerFn;
|
||||
this.serviceContext = serviceContext || index.serviceContext;
|
||||
}
|
||||
|
||||
async aretrieve(
|
||||
query: string,
|
||||
parentEvent?: Event
|
||||
): Promise<NodeWithScore[]> {
|
||||
const nodeIds = this.index.indexStruct.nodes;
|
||||
const results: NodeWithScore[] = [];
|
||||
|
||||
for (let idx = 0; idx < nodeIds.length; idx += this.choiceBatchSize) {
|
||||
const nodeIdsBatch = nodeIds.slice(idx, idx + this.choiceBatchSize);
|
||||
const nodesBatch = await this.index.docStore.getNodes(nodeIdsBatch);
|
||||
|
||||
const fmtBatchStr = this.formatNodeBatchFn(nodesBatch);
|
||||
const input = { context: fmtBatchStr, query: query };
|
||||
const rawResponse = await this.serviceContext.llmPredictor.apredict(
|
||||
this.choiceSelectPrompt,
|
||||
input
|
||||
);
|
||||
|
||||
// parseResult is a map from doc number to relevance score
|
||||
const parseResult = this.parseChoiceSelectAnswerFn(
|
||||
rawResponse,
|
||||
nodesBatch.length
|
||||
);
|
||||
const choiceNodeIds = nodeIdsBatch.filter((nodeId, idx) => {
|
||||
return `${idx}` in parseResult;
|
||||
});
|
||||
|
||||
const choiceNodes = await this.index.docStore.getNodes(choiceNodeIds);
|
||||
const nodeWithScores = choiceNodes.map((node, i) => ({
|
||||
node: node,
|
||||
score: _.get(parseResult, `${i + 1}`, 1),
|
||||
}));
|
||||
|
||||
results.push(...nodeWithScores);
|
||||
}
|
||||
|
||||
if (this.serviceContext.callbackManager.onRetrieve) {
|
||||
this.serviceContext.callbackManager.onRetrieve({
|
||||
query,
|
||||
nodes: results,
|
||||
event: globalsHelper.createEvent({
|
||||
parentEvent,
|
||||
type: "retrieve",
|
||||
}),
|
||||
});
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
getServiceContext(): ServiceContext {
|
||||
return this.serviceContext;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
export { ListIndex, ListRetrieverMode } from "./ListIndex";
|
||||
export {
|
||||
ListIndexRetriever,
|
||||
ListIndexLLMRetriever,
|
||||
} from "./ListIndexRetriever";
|
||||
@@ -0,0 +1,73 @@
|
||||
import { BaseNode, MetadataMode } from "../../Node";
|
||||
import _ from "lodash";
|
||||
|
||||
export type NodeFormatterFunction = (summaryNodes: BaseNode[]) => string;
|
||||
export const defaultFormatNodeBatchFn: NodeFormatterFunction = (
|
||||
summaryNodes: BaseNode[]
|
||||
): string => {
|
||||
return summaryNodes
|
||||
.map((node, idx) => {
|
||||
return `
|
||||
Document ${idx + 1}:
|
||||
${node.getContent(MetadataMode.LLM)}
|
||||
`.trim();
|
||||
})
|
||||
.join("\n\n");
|
||||
};
|
||||
|
||||
// map from document number to its relevance score
|
||||
export type ChoiceSelectParseResult = { [docNumber: number]: number };
|
||||
export type ChoiceSelectParserFunction = (
|
||||
answer: string,
|
||||
numChoices: number,
|
||||
raiseErr?: boolean
|
||||
) => ChoiceSelectParseResult;
|
||||
|
||||
export const defaultParseChoiceSelectAnswerFn: ChoiceSelectParserFunction = (
|
||||
answer: string,
|
||||
numChoices: number,
|
||||
raiseErr: boolean = false
|
||||
): ChoiceSelectParseResult => {
|
||||
// split the line into the answer number and relevance score portions
|
||||
const lineTokens: string[][] = answer
|
||||
.split("\n")
|
||||
.map((line: string) => {
|
||||
let lineTokens = line.split(",");
|
||||
if (lineTokens.length !== 2) {
|
||||
if (raiseErr) {
|
||||
throw new Error(
|
||||
`Invalid answer line: ${line}. Answer line must be of the form: answer_num: <int>, answer_relevance: <float>`
|
||||
);
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
return lineTokens;
|
||||
})
|
||||
.filter((lineTokens) => !_.isNil(lineTokens)) as string[][];
|
||||
|
||||
// parse the answer number and relevance score
|
||||
return lineTokens.reduce(
|
||||
(parseResult: ChoiceSelectParseResult, lineToken: string[]) => {
|
||||
try {
|
||||
let docNum = parseInt(lineToken[0].split(":")[1].trim());
|
||||
let answerRelevance = parseFloat(lineToken[1].split(":")[1].trim());
|
||||
if (docNum < 1 || docNum > numChoices) {
|
||||
if (raiseErr) {
|
||||
throw new Error(
|
||||
`Invalid answer number: ${docNum}. Answer number must be between 1 and ${numChoices}`
|
||||
);
|
||||
} else {
|
||||
parseResult[docNum] = answerRelevance;
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
if (raiseErr) {
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
return parseResult;
|
||||
},
|
||||
{}
|
||||
);
|
||||
};
|
||||
@@ -12,9 +12,9 @@ import {
|
||||
} from "./constants";
|
||||
|
||||
export interface StorageContext {
|
||||
docStore?: BaseDocumentStore;
|
||||
indexStore?: BaseIndexStore;
|
||||
vectorStore?: VectorStore;
|
||||
docStore: BaseDocumentStore;
|
||||
indexStore: BaseIndexStore;
|
||||
vectorStore: VectorStore;
|
||||
}
|
||||
|
||||
type BuilderParams = {
|
||||
@@ -32,21 +32,24 @@ export async function storageContextFromDefaults({
|
||||
persistDir,
|
||||
fs,
|
||||
}: BuilderParams): Promise<StorageContext> {
|
||||
persistDir = persistDir || DEFAULT_PERSIST_DIR;
|
||||
|
||||
fs = fs || DEFAULT_FS;
|
||||
|
||||
docStore =
|
||||
docStore ||
|
||||
(await SimpleDocumentStore.fromPersistDir(
|
||||
persistDir,
|
||||
DEFAULT_NAMESPACE,
|
||||
fs
|
||||
));
|
||||
indexStore =
|
||||
indexStore || (await SimpleIndexStore.fromPersistDir(persistDir, fs));
|
||||
vectorStore =
|
||||
vectorStore || (await SimpleVectorStore.fromPersistDir(persistDir, fs));
|
||||
if (!persistDir) {
|
||||
docStore = docStore || new SimpleDocumentStore();
|
||||
indexStore = indexStore || new SimpleIndexStore();
|
||||
vectorStore = vectorStore || new SimpleVectorStore();
|
||||
} else {
|
||||
fs = fs || DEFAULT_FS;
|
||||
docStore =
|
||||
docStore ||
|
||||
(await SimpleDocumentStore.fromPersistDir(
|
||||
persistDir,
|
||||
DEFAULT_NAMESPACE,
|
||||
fs
|
||||
));
|
||||
indexStore =
|
||||
indexStore || (await SimpleIndexStore.fromPersistDir(persistDir, fs));
|
||||
vectorStore =
|
||||
vectorStore || (await SimpleVectorStore.fromPersistDir(persistDir, fs));
|
||||
}
|
||||
|
||||
return {
|
||||
docStore,
|
||||
|
||||
@@ -77,7 +77,7 @@ export class KVDocumentStore extends BaseDocumentStore {
|
||||
let json = await this.kvstore.get(docId, this.nodeCollection);
|
||||
if (_.isNil(json)) {
|
||||
if (raiseError) {
|
||||
throw new Error(`doc_id ${docId} not found.`);
|
||||
throw new Error(`docId ${docId} not found.`);
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -23,12 +23,10 @@ export function jsonToDoc(docDict: Record<string, any>): BaseNode {
|
||||
hash: dataDict.hash,
|
||||
});
|
||||
} else if (docType === ObjectType.TEXT) {
|
||||
const relationships = dataDict.relationships;
|
||||
doc = new TextNode({
|
||||
text: relationships.text,
|
||||
id_: relationships.id_,
|
||||
embedding: relationships.embedding,
|
||||
hash: relationships.hash,
|
||||
text: dataDict.text,
|
||||
id_: dataDict.id_,
|
||||
hash: dataDict.hash,
|
||||
});
|
||||
} else {
|
||||
throw new Error(`Unknown doc type: ${docType}`);
|
||||
|
||||
@@ -12,7 +12,7 @@ import {
|
||||
getTopKMMREmbeddings,
|
||||
} from "../../Embedding";
|
||||
import { DEFAULT_PERSIST_DIR, DEFAULT_FS } from "../constants";
|
||||
import { TextNode } from "../../Node";
|
||||
import { NodeWithEmbedding } from "../../Node";
|
||||
|
||||
const LEARNER_MODES = new Set<VectorStoreQueryMode>([
|
||||
VectorStoreQueryMode.SVM,
|
||||
@@ -53,18 +53,19 @@ export class SimpleVectorStore implements VectorStore {
|
||||
return this.data.embeddingDict[textId];
|
||||
}
|
||||
|
||||
add(embeddingResults: TextNode[]): string[] {
|
||||
add(embeddingResults: NodeWithEmbedding[]): string[] {
|
||||
for (let result of embeddingResults) {
|
||||
this.data.embeddingDict[result.id_] = result.getEmbedding();
|
||||
this.data.embeddingDict[result.node.id_] = result.embedding;
|
||||
|
||||
if (!result.sourceNode) {
|
||||
if (!result.node.sourceNode) {
|
||||
console.error("Missing source node from TextNode.");
|
||||
continue;
|
||||
}
|
||||
|
||||
this.data.textIdToRefDocId[result.id_] = result.sourceNode?.nodeId;
|
||||
this.data.textIdToRefDocId[result.node.id_] =
|
||||
result.node.sourceNode?.nodeId;
|
||||
}
|
||||
return embeddingResults.map((result) => result.id_);
|
||||
return embeddingResults.map((result) => result.node.id_);
|
||||
}
|
||||
|
||||
delete(refDocId: string): void {
|
||||
|
||||
@@ -1,18 +1,11 @@
|
||||
import { TextNode } from "../../Node";
|
||||
import { BaseNode } from "../../Node";
|
||||
import { GenericFileSystem } from "../FileSystem";
|
||||
|
||||
export interface NodeWithEmbedding {
|
||||
node: TextNode;
|
||||
embedding: number[];
|
||||
|
||||
id(): string;
|
||||
refDocId(): string;
|
||||
}
|
||||
import { NodeWithEmbedding } from "../../Node";
|
||||
|
||||
export interface VectorStoreQueryResult {
|
||||
nodes?: TextNode[];
|
||||
similarities?: number[];
|
||||
ids?: string[];
|
||||
nodes?: BaseNode[];
|
||||
similarities: number[];
|
||||
ids: string[];
|
||||
}
|
||||
|
||||
export enum VectorStoreQueryMode {
|
||||
@@ -68,7 +61,7 @@ export interface VectorStore {
|
||||
storesText: boolean;
|
||||
isEmbeddingQuery?: boolean;
|
||||
client(): any;
|
||||
add(embeddingResults: TextNode[]): string[];
|
||||
add(embeddingResults: NodeWithEmbedding[]): string[];
|
||||
delete(refDocId: string, deleteKwargs?: any): void;
|
||||
query(query: VectorStoreQuery, kwargs?: any): VectorStoreQueryResult;
|
||||
persist(persistPath: string, fs?: GenericFileSystem): void;
|
||||
|
||||
@@ -0,0 +1,208 @@
|
||||
import { VectorStoreIndex } from "../BaseIndex";
|
||||
import { OpenAIEmbedding } from "../Embedding";
|
||||
import { ChatOpenAI } from "../LanguageModel";
|
||||
import { Document } from "../Node";
|
||||
import { ServiceContext, serviceContextFromDefaults } from "../ServiceContext";
|
||||
import {
|
||||
CallbackManager,
|
||||
RetrievalCallbackResponse,
|
||||
StreamCallbackResponse,
|
||||
} from "../callbacks/CallbackManager";
|
||||
import { ListIndex } from "../index/list";
|
||||
import { mockEmbeddingModel, mockLlmGeneration } from "./utility/mockOpenAI";
|
||||
|
||||
// Mock the OpenAI getOpenAISession function during testing
|
||||
jest.mock("../openai", () => {
|
||||
return {
|
||||
getOpenAISession: jest.fn().mockImplementation(() => null),
|
||||
};
|
||||
});
|
||||
|
||||
describe("CallbackManager: onLLMStream and onRetrieve", () => {
|
||||
let serviceContext: ServiceContext;
|
||||
let streamCallbackData: StreamCallbackResponse[] = [];
|
||||
let retrieveCallbackData: RetrievalCallbackResponse[] = [];
|
||||
let document: Document;
|
||||
|
||||
beforeAll(async () => {
|
||||
document = new Document({ text: "Author: My name is Paul Graham" });
|
||||
const callbackManager = new CallbackManager({
|
||||
onLLMStream: (data) => {
|
||||
streamCallbackData.push(data);
|
||||
},
|
||||
onRetrieve: (data) => {
|
||||
retrieveCallbackData.push(data);
|
||||
},
|
||||
});
|
||||
|
||||
const languageModel = new ChatOpenAI({
|
||||
model: "gpt-3.5-turbo",
|
||||
callbackManager,
|
||||
});
|
||||
mockLlmGeneration({ languageModel, callbackManager });
|
||||
|
||||
const embedModel = new OpenAIEmbedding();
|
||||
mockEmbeddingModel(embedModel);
|
||||
|
||||
serviceContext = serviceContextFromDefaults({
|
||||
callbackManager,
|
||||
llm: languageModel,
|
||||
embedModel,
|
||||
});
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
streamCallbackData = [];
|
||||
retrieveCallbackData = [];
|
||||
});
|
||||
|
||||
afterAll(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
test("For VectorStoreIndex w/ a SimpleResponseBuilder", async () => {
|
||||
const vectorStoreIndex = await VectorStoreIndex.fromDocuments(
|
||||
[document],
|
||||
undefined,
|
||||
serviceContext
|
||||
);
|
||||
const queryEngine = vectorStoreIndex.asQueryEngine();
|
||||
const query = "What is the author's name?";
|
||||
const response = await queryEngine.aquery(query);
|
||||
expect(response.toString()).toBe("MOCK_TOKEN_1-MOCK_TOKEN_2");
|
||||
expect(streamCallbackData).toEqual([
|
||||
{
|
||||
event: {
|
||||
id: expect.any(String),
|
||||
parentId: expect.any(String),
|
||||
type: "llmPredict",
|
||||
tags: ["final"],
|
||||
},
|
||||
index: 0,
|
||||
token: {
|
||||
id: "id",
|
||||
object: "object",
|
||||
created: 1,
|
||||
model: "model",
|
||||
choices: expect.any(Array),
|
||||
},
|
||||
},
|
||||
{
|
||||
event: {
|
||||
id: expect.any(String),
|
||||
parentId: expect.any(String),
|
||||
type: "llmPredict",
|
||||
tags: ["final"],
|
||||
},
|
||||
index: 1,
|
||||
token: {
|
||||
id: "id",
|
||||
object: "object",
|
||||
created: 1,
|
||||
model: "model",
|
||||
choices: expect.any(Array),
|
||||
},
|
||||
},
|
||||
{
|
||||
event: {
|
||||
id: expect.any(String),
|
||||
parentId: expect.any(String),
|
||||
type: "llmPredict",
|
||||
tags: ["final"],
|
||||
},
|
||||
index: 2,
|
||||
isDone: true,
|
||||
},
|
||||
]);
|
||||
expect(retrieveCallbackData).toEqual([
|
||||
{
|
||||
query: query,
|
||||
nodes: expect.any(Array),
|
||||
event: {
|
||||
id: expect.any(String),
|
||||
parentId: expect.any(String),
|
||||
type: "retrieve",
|
||||
tags: ["final"],
|
||||
},
|
||||
},
|
||||
]);
|
||||
// both retrieval and streaming should have
|
||||
// the same parent event
|
||||
expect(streamCallbackData[0].event.parentId).toBe(
|
||||
retrieveCallbackData[0].event.parentId
|
||||
);
|
||||
});
|
||||
|
||||
test("For ListIndex w/ a ListIndexRetriever", async () => {
|
||||
const listIndex = await ListIndex.fromDocuments(
|
||||
[document],
|
||||
undefined,
|
||||
serviceContext
|
||||
);
|
||||
const queryEngine = listIndex.asQueryEngine();
|
||||
const query = "What is the author's name?";
|
||||
const response = await queryEngine.aquery(query);
|
||||
expect(response.toString()).toBe("MOCK_TOKEN_1-MOCK_TOKEN_2");
|
||||
expect(streamCallbackData).toEqual([
|
||||
{
|
||||
event: {
|
||||
id: expect.any(String),
|
||||
parentId: expect.any(String),
|
||||
type: "llmPredict",
|
||||
tags: ["final"],
|
||||
},
|
||||
index: 0,
|
||||
token: {
|
||||
id: "id",
|
||||
object: "object",
|
||||
created: 1,
|
||||
model: "model",
|
||||
choices: expect.any(Array),
|
||||
},
|
||||
},
|
||||
{
|
||||
event: {
|
||||
id: expect.any(String),
|
||||
parentId: expect.any(String),
|
||||
type: "llmPredict",
|
||||
tags: ["final"],
|
||||
},
|
||||
index: 1,
|
||||
token: {
|
||||
id: "id",
|
||||
object: "object",
|
||||
created: 1,
|
||||
model: "model",
|
||||
choices: expect.any(Array),
|
||||
},
|
||||
},
|
||||
{
|
||||
event: {
|
||||
id: expect.any(String),
|
||||
parentId: expect.any(String),
|
||||
type: "llmPredict",
|
||||
tags: ["final"],
|
||||
},
|
||||
index: 2,
|
||||
isDone: true,
|
||||
},
|
||||
]);
|
||||
expect(retrieveCallbackData).toEqual([
|
||||
{
|
||||
query: query,
|
||||
nodes: expect.any(Array),
|
||||
event: {
|
||||
id: expect.any(String),
|
||||
parentId: expect.any(String),
|
||||
type: "retrieve",
|
||||
tags: ["final"],
|
||||
},
|
||||
},
|
||||
]);
|
||||
// both retrieval and streaming should have
|
||||
// the same parent event
|
||||
expect(streamCallbackData[0].event.parentId).toBe(
|
||||
retrieveCallbackData[0].event.parentId
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,72 @@
|
||||
import { OpenAIEmbedding } from "../../Embedding";
|
||||
import { globalsHelper } from "../../GlobalsHelper";
|
||||
import { BaseMessage, ChatOpenAI } from "../../LanguageModel";
|
||||
import { CallbackManager, Event } from "../../callbacks/CallbackManager";
|
||||
|
||||
export function mockLlmGeneration({
|
||||
languageModel,
|
||||
callbackManager,
|
||||
}: {
|
||||
languageModel: ChatOpenAI;
|
||||
callbackManager: CallbackManager;
|
||||
}) {
|
||||
jest
|
||||
.spyOn(languageModel, "agenerate")
|
||||
.mockImplementation(
|
||||
async (messages: BaseMessage[], parentEvent?: Event) => {
|
||||
const text = "MOCK_TOKEN_1-MOCK_TOKEN_2";
|
||||
const event = globalsHelper.createEvent({
|
||||
parentEvent,
|
||||
type: "llmPredict",
|
||||
});
|
||||
if (callbackManager?.onLLMStream) {
|
||||
const chunks = text.split("-");
|
||||
for (let i = 0; i < chunks.length; i++) {
|
||||
const chunk = chunks[i];
|
||||
callbackManager?.onLLMStream({
|
||||
event,
|
||||
index: i,
|
||||
token: {
|
||||
id: "id",
|
||||
object: "object",
|
||||
created: 1,
|
||||
model: "model",
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
delta: {
|
||||
content: chunk,
|
||||
},
|
||||
finish_reason: null,
|
||||
},
|
||||
],
|
||||
},
|
||||
});
|
||||
}
|
||||
callbackManager?.onLLMStream({
|
||||
event,
|
||||
index: chunks.length,
|
||||
isDone: true,
|
||||
});
|
||||
}
|
||||
return new Promise((resolve) => {
|
||||
resolve({
|
||||
generations: [[{ text }]],
|
||||
});
|
||||
});
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
export function mockEmbeddingModel(embedModel: OpenAIEmbedding) {
|
||||
jest.spyOn(embedModel, "aGetTextEmbedding").mockImplementation(async (x) => {
|
||||
return new Promise((resolve) => {
|
||||
resolve([1, 0, 0, 0, 0, 0]);
|
||||
});
|
||||
});
|
||||
jest.spyOn(embedModel, "aGetQueryEmbedding").mockImplementation(async (x) => {
|
||||
return new Promise((resolve) => {
|
||||
resolve([0, 1, 0, 0, 0, 0]);
|
||||
});
|
||||
});
|
||||
}
|
||||
Generated
+22
-22
@@ -36,8 +36,8 @@ importers:
|
||||
specifier: ^29.1.0
|
||||
version: 29.1.0(@babel/core@7.22.5)(jest@29.5.0)(typescript@4.9.5)
|
||||
turbo:
|
||||
specifier: ^1.10.5
|
||||
version: 1.10.5
|
||||
specifier: ^1.10.6
|
||||
version: 1.10.7
|
||||
|
||||
apps/docs:
|
||||
dependencies:
|
||||
@@ -5444,65 +5444,65 @@ packages:
|
||||
resolution: {integrity: sha512-C3TaO7K81YvjCgQH9Q1S3R3P3BtN3RIM8n+OvX4il1K1zgE8ZhI0op7kClgkxtutIE8hQrcrHBXvIheqKUUCxw==}
|
||||
dev: true
|
||||
|
||||
/turbo-darwin-64@1.10.5:
|
||||
resolution: {integrity: sha512-fIHu+fcW7upaZEfeneoRbZjdrcsj/NxUg7IjZZmlCjgbS9Ofl8RhRid5A1L31AUK3kkqRxzagHc4WZ5x4quBgg==}
|
||||
/turbo-darwin-64@1.10.7:
|
||||
resolution: {integrity: sha512-N2MNuhwrl6g7vGuz4y3fFG2aR1oCs0UZ5HKl8KSTn/VC2y2YIuLGedQ3OVbo0TfEvygAlF3QGAAKKtOCmGPNKA==}
|
||||
cpu: [x64]
|
||||
os: [darwin]
|
||||
requiresBuild: true
|
||||
dev: true
|
||||
optional: true
|
||||
|
||||
/turbo-darwin-arm64@1.10.5:
|
||||
resolution: {integrity: sha512-uv0sDWizuxVvdSjaKvWdPdX4aZ8IZeYJwTJRZwLNRxZV56/1LZD65gyQIqsSNVRHuXI199yahmB+7PMJNpZFdw==}
|
||||
/turbo-darwin-arm64@1.10.7:
|
||||
resolution: {integrity: sha512-WbJkvjU+6qkngp7K4EsswOriO3xrNQag7YEGRtfLoDdMTk4O4QTeU6sfg2dKfDsBpTidTvEDwgIYJhYVGzrz9Q==}
|
||||
cpu: [arm64]
|
||||
os: [darwin]
|
||||
requiresBuild: true
|
||||
dev: true
|
||||
optional: true
|
||||
|
||||
/turbo-linux-64@1.10.5:
|
||||
resolution: {integrity: sha512-hI0rErgwxNmuBCNGldhJkjSbb+mT+vjfmBVKcMI/bnBmu/KU7irCrKMe5Vas280teqBrC33GgVfXndJo2cJ1DA==}
|
||||
/turbo-linux-64@1.10.7:
|
||||
resolution: {integrity: sha512-x1CF2CDP1pDz/J8/B2T0hnmmOQI2+y11JGIzNP0KtwxDM7rmeg3DDTtDM/9PwGqfPotN9iVGgMiMvBuMFbsLhg==}
|
||||
cpu: [x64]
|
||||
os: [linux]
|
||||
requiresBuild: true
|
||||
dev: true
|
||||
optional: true
|
||||
|
||||
/turbo-linux-arm64@1.10.5:
|
||||
resolution: {integrity: sha512-JAygWZjTuD6e7w0KSGzy7UxYqeLIpGfZDne+4MGRc8I5VeWZ6i0HWTqhhIu2/A8AuklYcoj8LkOZxCnMOF3odQ==}
|
||||
/turbo-linux-arm64@1.10.7:
|
||||
resolution: {integrity: sha512-JtnBmaBSYbs7peJPkXzXxsRGSGBmBEIb6/kC8RRmyvPAMyqF8wIex0pttsI+9plghREiGPtRWv/lfQEPRlXnNQ==}
|
||||
cpu: [arm64]
|
||||
os: [linux]
|
||||
requiresBuild: true
|
||||
dev: true
|
||||
optional: true
|
||||
|
||||
/turbo-windows-64@1.10.5:
|
||||
resolution: {integrity: sha512-6w2GOKmlWEAl6QkC4c2j2ZLTwB+RK6oIDRT2KqF1m07KkY6pebEzbPZLHuP08QV+SE0t+prAn+kn7hkHYkwM+Q==}
|
||||
/turbo-windows-64@1.10.7:
|
||||
resolution: {integrity: sha512-7A/4CByoHdolWS8dg3DPm99owfu1aY/W0V0+KxFd0o2JQMTQtoBgIMSvZesXaWM57z3OLsietFivDLQPuzE75w==}
|
||||
cpu: [x64]
|
||||
os: [win32]
|
||||
requiresBuild: true
|
||||
dev: true
|
||||
optional: true
|
||||
|
||||
/turbo-windows-arm64@1.10.5:
|
||||
resolution: {integrity: sha512-3eeHRJPU+5zWa/iiikoBoPlNd74Y+L9lrG6ZsDZdzUYxNRTMrZbto1Bu1UF77t10TXeT9BsZRXjquKqrA7R7tg==}
|
||||
/turbo-windows-arm64@1.10.7:
|
||||
resolution: {integrity: sha512-D36K/3b6+hqm9IBAymnuVgyePktwQ+F0lSXr2B9JfAdFPBktSqGmp50JNC7pahxhnuCLj0Vdpe9RqfnJw5zATA==}
|
||||
cpu: [arm64]
|
||||
os: [win32]
|
||||
requiresBuild: true
|
||||
dev: true
|
||||
optional: true
|
||||
|
||||
/turbo@1.10.5:
|
||||
resolution: {integrity: sha512-4yxHTrlugJhef4eXuyrPJtrgUZWlbcwmSb8iZL/5UzNjCmx+anOm1nfW2XFrZFKy4v0+/fUlqw8LkTgGVsOKaQ==}
|
||||
/turbo@1.10.7:
|
||||
resolution: {integrity: sha512-xm0MPM28TWx1e6TNC3wokfE5eaDqlfi0G24kmeHupDUZt5Wd0OzHFENEHMPqEaNKJ0I+AMObL6nbSZonZBV2HA==}
|
||||
hasBin: true
|
||||
requiresBuild: true
|
||||
optionalDependencies:
|
||||
turbo-darwin-64: 1.10.5
|
||||
turbo-darwin-arm64: 1.10.5
|
||||
turbo-linux-64: 1.10.5
|
||||
turbo-linux-arm64: 1.10.5
|
||||
turbo-windows-64: 1.10.5
|
||||
turbo-windows-arm64: 1.10.5
|
||||
turbo-darwin-64: 1.10.7
|
||||
turbo-darwin-arm64: 1.10.7
|
||||
turbo-linux-64: 1.10.7
|
||||
turbo-linux-arm64: 1.10.7
|
||||
turbo-windows-64: 1.10.7
|
||||
turbo-windows-arm64: 1.10.7
|
||||
dev: true
|
||||
|
||||
/type-check@0.4.0:
|
||||
|
||||
Reference in New Issue
Block a user