Compare commits

...

30 Commits

Author SHA1 Message Date
Abdul Jamjoom 0bec460937 clean-up 2023-07-09 16:46:34 -07:00
Abdul Jamjoom 5d8d344e1f updated tests 2023-07-09 16:42:58 -07:00
Abdul Jamjoom 81e22587eb add callback and event tracking 2023-07-09 16:31:04 -07:00
Yi Ding 539ec0fe3d try smaller chunks 2023-07-06 21:41:13 -07:00
Yi Ding ebf3bc19fd ContextChatEngine v1 2023-07-06 20:59:10 -07:00
Yi Ding 204b8f5316 chatengine 2023-07-06 08:53:09 -07:00
yisding 71dd461a47 Merge pull request #9 from run-llama/subquestion
SubQuestionQueryEngine
2023-07-05 19:39:11 -07:00
Yi Ding 0c881c8fde finished subquestion demo 2023-07-05 08:37:41 -07:00
Abdul Jamjoom 35a6795559 polish 2023-07-04 22:34:49 -07:00
Abdul Jamjoom 968109455d Create the CallbackManager with onLLMStream and onRetrieve 2023-07-04 21:53:45 -07:00
Yi Ding 815a3416f2 more work 2023-07-04 20:53:36 -07:00
Yi Ding 407069ca27 prompt work for question generator 2023-07-04 14:08:04 -07:00
Yi Ding 29d042175e initial work 2023-07-04 14:06:49 -07:00
yisding bbf936e9b4 Merge pull request #7 from run-llama/listindex
List Index
2023-07-04 10:18:35 -07:00
Yi Ding 2212793420 more housekeeping 2023-07-04 09:47:44 -07:00
Abdul Jamjoom 5487de8c37 Merge branch 'main' of github.com:run-llama/llamascript into stream_responses 2023-07-04 09:34:46 -07:00
Abdul Jamjoom 2a038c00ec unstable checkpoint 2023-07-04 09:27:54 -07:00
Yi Ding d6c6aefd0d some housekeeping 2023-07-04 09:22:16 -07:00
Yi Ding 4516363097 make persistence optional 2023-07-04 08:56:31 -07:00
Yi Ding 69dd6d4efa make persistence optional 2023-07-04 08:54:22 -07:00
Sourabh Desai 9ea840142b changes to get test script running 2023-07-04 07:40:41 +00:00
Sourabh Desai a1c45294b3 updates for compilation errors 2023-07-04 05:55:51 +00:00
Sourabh Desai c2ef5057b3 add init functions for list index. Still needs some refactoring + testing 2023-07-03 22:59:28 +00:00
Sourabh Desai b87e6d9ced finish implementation for llm list index retriever 2023-07-03 18:31:32 +00:00
Sourabh Desai 8d618a6bc3 better structuring + adding missing functionality 2023-07-03 16:14:53 +00:00
Sourabh Desai 8d8bee5263 start implementing list index retrievers 2023-07-03 06:42:00 +00:00
Sourabh Desai ed924641ca start implemetation of list index 2023-07-03 05:40:05 +00:00
Yi Ding ce61f9660b turbo update 2023-06-29 21:01:49 -07:00
Yi Ding 072b13cff0 use servicecontext/storagecontext 2023-06-29 21:00:36 -07:00
yisding ff274dde1d Merge pull request #5 from run-llama/nodev3
initial nodev3 work
2023-06-29 17:47:10 -07:00
41 changed files with 2111 additions and 272 deletions
+2
View File
@@ -34,3 +34,5 @@ yarn-error.log*
# vercel
.vercel
storage/
+31
View File
@@ -0,0 +1,31 @@
// @ts-ignore
import * as readline from "node:readline/promises";
// @ts-ignore
import { stdin as input, stdout as output } from "node:process";
import { Document } from "@llamaindex/core/src/Node";
import { VectorStoreIndex } from "@llamaindex/core/src/BaseIndex";
import { ContextChatEngine } from "@llamaindex/core/src/ChatEngine";
import essay from "./essay";
import { serviceContextFromDefaults } from "@llamaindex/core/src/ServiceContext";
async function main() {
const document = new Document({ text: essay });
const serviceContext = serviceContextFromDefaults({ chunkSize: 512 });
const index = await VectorStoreIndex.fromDocuments(
[document],
undefined,
serviceContext
);
const retriever = index.asRetriever();
retriever.similarityTopK = 5;
const chatEngine = new ContextChatEngine({ retriever });
const rl = readline.createInterface({ input, output });
while (true) {
const query = await rl.question("Query: ");
const response = await chatEngine.achat(query);
console.log(response);
}
}
main().catch(console.error);
+17
View File
@@ -0,0 +1,17 @@
import { Document } from "@llamaindex/core/src/Node";
import { ListIndex } from "@llamaindex/core/src/index/list";
import essay from "./essay";
async function main() {
const document = new Document({ text: essay });
const index = await ListIndex.fromDocuments([document]);
const queryEngine = index.asQueryEngine();
const response = await queryEngine.aquery(
"What did the author do growing up?"
);
console.log(response.toString());
}
main().catch((e: Error) => {
console.error(e, e.stack);
});
+1
View File
@@ -1,3 +1,4 @@
// @ts-ignore
import process from "node:process";
import { Configuration, OpenAIWrapper } from "@llamaindex/core/src/openai";
-9
View File
@@ -1,9 +0,0 @@
Simple flow:
Get document list, in this case one document.
Split each document into nodes, in this case sentences or lines.
Embed each of the nodes and get vectors. Store them in memory for now.
Embed query.
Compare query with nodes and get the top n
Put the top n nodes into the prompt.
Execute prompt, get result.
+60
View File
@@ -0,0 +1,60 @@
// from llama_index import SimpleDirectoryReader, VectorStoreIndex
// from llama_index.query_engine import SubQuestionQueryEngine
// from llama_index.tools import QueryEngineTool, ToolMetadata
// # load data
// pg_essay = SimpleDirectoryReader(
// input_dir="docs/examples/data/paul_graham/"
// ).load_data()
// # build index and query engine
// query_engine = VectorStoreIndex.from_documents(pg_essay).as_query_engine()
// # setup base query engine as tool
// query_engine_tools = [
// QueryEngineTool(
// query_engine=query_engine,
// metadata=ToolMetadata(
// name="pg_essay", description="Paul Graham essay on What I Worked On"
// ),
// )
// ]
// query_engine = SubQuestionQueryEngine.from_defaults(
// query_engine_tools=query_engine_tools
// )
// response = query_engine.query(
// "How was Paul Grahams life different before and after YC?"
// )
// print(response)
import { Document } from "@llamaindex/core/src/Node";
import { VectorStoreIndex } from "@llamaindex/core/src/BaseIndex";
import { SubQuestionQueryEngine } from "@llamaindex/core/src/QueryEngine";
import essay from "./essay";
(async () => {
const document = new Document({ text: essay });
const index = await VectorStoreIndex.fromDocuments([document]);
const queryEngine = SubQuestionQueryEngine.fromDefaults({
queryEngineTools: [
{
queryEngine: index.asQueryEngine(),
metadata: {
name: "pg_essay",
description: "Paul Graham essay on What I Worked On",
},
},
],
});
const response = await queryEngine.aquery(
"How was Paul Grahams life different before and after YC?"
);
console.log(response);
})();
@@ -2,7 +2,7 @@ import { Document } from "@llamaindex/core/src/Node";
import { VectorStoreIndex } from "@llamaindex/core/src/BaseIndex";
import essay from "./essay";
(async () => {
async function main() {
const document = new Document({ text: essay });
const index = await VectorStoreIndex.fromDocuments([document]);
const queryEngine = index.asQueryEngine();
@@ -10,4 +10,6 @@ import essay from "./essay";
"What did the author do growing up?"
);
console.log(response.toString());
})();
}
main().catch(console.error);
+1 -1
View File
@@ -18,7 +18,7 @@
"prettier": "^2.8.8",
"prettier-plugin-tailwindcss": "^0.3.0",
"ts-jest": "^29.1.0",
"turbo": "^1.10.5"
"turbo": "^1.10.6"
},
"packageManager": "pnpm@7.15.0",
"name": "llamascript"
+3
View File
@@ -10,6 +10,9 @@
"uuid": "^9.0.0",
"wink-nlp": "^1.14.1"
},
"engines": {
"node": ">=18.0.0"
},
"main": "src/index.ts",
"types": "src/index.ts",
"scripts": {
+165 -42
View File
@@ -1,22 +1,19 @@
import { Document, TextNode } from "./Node";
import { SimpleNodeParser } from "./NodeParser";
import { Document, BaseNode, MetadataMode, NodeWithEmbedding } from "./Node";
import { BaseQueryEngine, RetrieverQueryEngine } from "./QueryEngine";
import { v4 as uuidv4 } from "uuid";
import { VectorIndexRetriever } from "./Retriever";
import { BaseEmbedding, OpenAIEmbedding } from "./Embedding";
export class BaseIndex {
nodes: TextNode[] = [];
import { BaseRetriever, VectorIndexRetriever } from "./Retriever";
import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext";
import {
StorageContext,
storageContextFromDefaults,
} from "./storage/StorageContext";
import { BaseDocumentStore } from "./storage/docStore/types";
import { VectorStore } from "./storage/vectorStore/types";
import { BaseIndexStore } from "./storage/indexStore/types";
constructor(nodes?: TextNode[]) {
this.nodes = nodes ?? [];
}
}
export class IndexDict {
export abstract class IndexStruct {
indexId: string;
summary?: string;
nodesDict: Record<string, TextNode> = {};
docStore: Record<string, Document> = {}; // FIXME: this should be implemented in storageContext
constructor(indexId = uuidv4(), summary = undefined) {
this.indexId = indexId;
@@ -29,57 +26,183 @@ export class IndexDict {
}
return this.summary;
}
}
addNode(node: TextNode, textId?: string) {
export class IndexDict extends IndexStruct {
nodesDict: Record<string, BaseNode> = {};
docStore: Record<string, Document> = {}; // FIXME: this should be implemented in storageContext
getSummary(): string {
if (this.summary === undefined) {
throw new Error("summary field of the index dict is not set");
}
return this.summary;
}
addNode(node: BaseNode, textId?: string) {
const vectorId = textId ?? node.id_;
this.nodesDict[vectorId] = node;
}
}
export class VectorStoreIndex extends BaseIndex {
indexStruct: IndexDict;
embeddingService: BaseEmbedding; // FIXME replace with service context
export class IndexList extends IndexStruct {
nodes: string[] = [];
constructor(nodes: TextNode[]) {
super(nodes);
this.indexStruct = new IndexDict();
addNode(node: BaseNode) {
this.nodes.push(node.id_);
}
}
if (nodes !== undefined) {
this.buildIndexFromNodes();
}
export interface BaseIndexInit<T> {
serviceContext: ServiceContext;
storageContext: StorageContext;
docStore: BaseDocumentStore;
vectorStore?: VectorStore;
indexStore?: BaseIndexStore;
indexStruct: T;
}
export abstract class BaseIndex<T> {
serviceContext: ServiceContext;
storageContext: StorageContext;
docStore: BaseDocumentStore;
vectorStore?: VectorStore;
indexStore?: BaseIndexStore;
indexStruct: T;
this.embeddingService = new OpenAIEmbedding();
constructor(init: BaseIndexInit<T>) {
this.serviceContext = init.serviceContext;
this.storageContext = init.storageContext;
this.docStore = init.docStore;
this.vectorStore = init.vectorStore;
this.indexStore = init.indexStore;
this.indexStruct = init.indexStruct;
}
async getNodeEmbeddingResults(logProgress = false) {
for (let i = 0; i < this.nodes.length; ++i) {
const node = this.nodes[i];
if (logProgress) {
console.log(`getting embedding for node ${i}/${this.nodes.length}`);
abstract asRetriever(): BaseRetriever;
}
export interface VectorIndexOptions {
nodes?: BaseNode[];
indexStruct?: IndexDict;
serviceContext?: ServiceContext;
storageContext?: StorageContext;
}
interface VectorIndexConstructorProps extends BaseIndexInit<IndexDict> {
vectorStore: VectorStore;
}
export class VectorStoreIndex extends BaseIndex<IndexDict> {
vectorStore: VectorStore;
private constructor(init: VectorIndexConstructorProps) {
super(init);
this.vectorStore = init.vectorStore;
}
static async init(options: VectorIndexOptions): Promise<VectorStoreIndex> {
const storageContext =
options.storageContext ?? (await storageContextFromDefaults({}));
const serviceContext =
options.serviceContext ?? serviceContextFromDefaults({});
const docStore = storageContext.docStore;
const vectorStore = storageContext.vectorStore;
let indexStruct: IndexDict;
if (options.indexStruct) {
if (options.nodes) {
throw new Error(
"Cannot initialize VectorStoreIndex with both nodes and indexStruct"
);
}
const embedding = await this.embeddingService.aGetTextEmbedding(
node.getText()
indexStruct = options.indexStruct;
} else {
if (!options.nodes) {
throw new Error(
"Cannot initialize VectorStoreIndex without nodes or indexStruct"
);
}
indexStruct = await VectorStoreIndex.buildIndexFromNodes(
options.nodes,
serviceContext,
vectorStore
);
node.embedding = embedding;
}
return new VectorStoreIndex({
storageContext,
serviceContext,
docStore,
vectorStore,
indexStruct,
});
}
buildIndexFromNodes() {
for (const node of this.nodes) {
this.indexStruct.addNode(node);
static async agetNodeEmbeddingResults(
nodes: BaseNode[],
serviceContext: ServiceContext,
logProgress = false
) {
const nodesWithEmbeddings: NodeWithEmbedding[] = [];
for (let i = 0; i < nodes.length; ++i) {
const node = nodes[i];
if (logProgress) {
console.log(`getting embedding for node ${i}/${nodes.length}`);
}
const embedding = await serviceContext.embedModel.aGetTextEmbedding(
node.getContent(MetadataMode.EMBED)
);
nodesWithEmbeddings.push({ node, embedding });
}
return nodesWithEmbeddings;
}
static async fromDocuments(documents: Document[]): Promise<VectorStoreIndex> {
const nodeParser = new SimpleNodeParser(); // FIXME use service context
const nodes = nodeParser.getNodesFromDocuments(documents);
const index = new VectorStoreIndex(nodes);
await index.getNodeEmbeddingResults();
static async buildIndexFromNodes(
nodes: BaseNode[],
serviceContext: ServiceContext,
vectorStore: VectorStore
): Promise<IndexDict> {
const embeddingResults = await this.agetNodeEmbeddingResults(
nodes,
serviceContext
);
vectorStore.add(embeddingResults);
const indexDict = new IndexDict();
for (const { node } of embeddingResults) {
indexDict.addNode(node);
}
return indexDict;
}
static async fromDocuments(
documents: Document[],
storageContext?: StorageContext,
serviceContext?: ServiceContext
): Promise<VectorStoreIndex> {
storageContext = storageContext ?? (await storageContextFromDefaults({}));
serviceContext = serviceContext ?? serviceContextFromDefaults({});
const docStore = storageContext.docStore;
for (const doc of documents) {
docStore.setDocumentHash(doc.id_, doc.hash);
}
const nodes = serviceContext.nodeParser.getNodesFromDocuments(documents);
const index = await VectorStoreIndex.init({
nodes,
storageContext,
serviceContext,
});
return index;
}
asRetriever(): VectorIndexRetriever {
return new VectorIndexRetriever(this, this.embeddingService);
return new VectorIndexRetriever(this);
}
asQueryEngine(): BaseQueryEngine {
+177
View File
@@ -0,0 +1,177 @@
import { BaseChatModel, BaseMessage, ChatOpenAI } from "./LanguageModel";
import { TextNode } from "./Node";
import {
SimplePrompt,
contextSystemPrompt,
defaultCondenseQuestionPrompt,
messagesToHistoryStr,
} from "./Prompt";
import { BaseQueryEngine } from "./QueryEngine";
import { Response } from "./Response";
import { BaseRetriever } from "./Retriever";
import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext";
import { v4 as uuidv4 } from "uuid";
import { Event } from "./callbacks/CallbackManager";
interface ChatEngine {
chatRepl(): void;
achat(message: string, chatHistory?: BaseMessage[]): Promise<Response>;
reset(): void;
}
export class SimpleChatEngine implements ChatEngine {
chatHistory: BaseMessage[];
llm: BaseChatModel;
constructor(init?: Partial<SimpleChatEngine>) {
this.chatHistory = init?.chatHistory ?? [];
this.llm = init?.llm ?? new ChatOpenAI({ model: "gpt-3.5-turbo" });
}
chatRepl() {
throw new Error("Method not implemented.");
}
async achat(message: string, chatHistory?: BaseMessage[]): Promise<Response> {
chatHistory = chatHistory ?? this.chatHistory;
chatHistory.push({ content: message, type: "human" });
const response = await this.llm.agenerate(chatHistory);
chatHistory.push({ content: response.generations[0][0].text, type: "ai" });
this.chatHistory = chatHistory;
return new Response(response.generations[0][0].text);
}
reset() {
this.chatHistory = [];
}
}
export class CondenseQuestionChatEngine implements ChatEngine {
queryEngine: BaseQueryEngine;
chatHistory: BaseMessage[];
serviceContext: ServiceContext;
condenseMessagePrompt: SimplePrompt;
constructor(init: {
queryEngine: BaseQueryEngine;
chatHistory: BaseMessage[];
serviceContext?: ServiceContext;
condenseMessagePrompt?: SimplePrompt;
}) {
this.queryEngine = init.queryEngine;
this.chatHistory = init?.chatHistory ?? [];
this.serviceContext =
init?.serviceContext ?? serviceContextFromDefaults({});
this.condenseMessagePrompt =
init?.condenseMessagePrompt ?? defaultCondenseQuestionPrompt;
}
private async acondenseQuestion(
chatHistory: BaseMessage[],
question: string
) {
const chatHistoryStr = messagesToHistoryStr(chatHistory);
return this.serviceContext.llmPredictor.apredict(
defaultCondenseQuestionPrompt,
{
question: question,
chat_history: chatHistoryStr,
}
);
}
async achat(
message: string,
chatHistory?: BaseMessage[] | undefined
): Promise<Response> {
chatHistory = chatHistory ?? this.chatHistory;
const condensedQuestion = await this.acondenseQuestion(
chatHistory,
message
);
const response = await this.queryEngine.aquery(condensedQuestion);
chatHistory.push({ content: message, type: "human" });
chatHistory.push({ content: response.response, type: "ai" });
return response;
}
chatRepl() {
throw new Error("Method not implemented.");
}
reset() {
this.chatHistory = [];
}
}
export class ContextChatEngine implements ChatEngine {
retriever: BaseRetriever;
chatModel: BaseChatModel;
chatHistory: BaseMessage[];
constructor(init: {
retriever: BaseRetriever;
chatModel?: BaseChatModel;
chatHistory?: BaseMessage[];
}) {
this.retriever = init.retriever;
this.chatModel =
init.chatModel ?? new ChatOpenAI({ model: "gpt-3.5-turbo-16k" });
this.chatHistory = init?.chatHistory ?? [];
}
chatRepl() {
throw new Error("Method not implemented.");
}
async achat(message: string, chatHistory?: BaseMessage[] | undefined) {
chatHistory = chatHistory ?? this.chatHistory;
const parentEvent: Event = {
id: uuidv4(),
type: "wrapper",
tags: ["final"],
};
const sourceNodesWithScore = await this.retriever.aretrieve(
message,
parentEvent
);
const systemMessage: BaseMessage = {
content: contextSystemPrompt({
context: sourceNodesWithScore
.map((r) => (r.node as TextNode).text)
.join("\n\n"),
}),
type: "system",
};
chatHistory.push({ content: message, type: "human" });
const response = await this.chatModel.agenerate(
[systemMessage, ...chatHistory],
parentEvent
);
const text = response.generations[0][0].text;
chatHistory.push({ content: text, type: "ai" });
this.chatHistory = chatHistory;
return new Response(
text,
sourceNodesWithScore.map((r) => r.node)
);
}
reset() {
this.chatHistory = [];
}
}
+3 -15
View File
@@ -174,24 +174,12 @@ export function getTopKMMREmbeddings(
}
export abstract class BaseEmbedding {
static similarity(
similarity(
embedding1: number[],
embedding2: number[],
mode: SimilarityType = SimilarityType.DOT_PRODUCT
mode: SimilarityType = SimilarityType.DEFAULT
): number {
if (embedding1.length !== embedding2.length) {
throw new Error("Embedding length mismatch");
}
if (mode === SimilarityType.DOT_PRODUCT) {
let result = 0;
for (let i = 0; i < embedding1.length; i++) {
result += embedding1[i] * embedding2[i];
}
return result;
} else {
throw new Error("Not implemented yet");
}
return similarity(embedding1, embedding2, mode);
}
abstract aGetTextEmbedding(text: string): Promise<number[]>;
+29 -2
View File
@@ -1,11 +1,38 @@
import { Event, EventTag, EventType } from "./callbacks/CallbackManager";
import { v4 as uuidv4 } from "uuid";
class GlobalsHelper {
defaultTokenizer: ((text: string) => string[]) | null = null;
tokenizer() {
if (this.defaultTokenizer) {
return this.defaultTokenizer;
}
const tiktoken = require("tiktoken-node");
let enc = new tiktoken.getEncoding("gpt2");
const defaultTokenizer = (text: string) => {
this.defaultTokenizer = (text: string) => {
return enc.encode(text);
};
return defaultTokenizer;
return this.defaultTokenizer;
}
createEvent({
parentEvent,
type,
tags,
}: {
parentEvent?: Event;
type: EventType;
tags?: EventTag[];
}): Event {
return {
id: uuidv4(),
type,
// inherit parent tags if tags not set
tags: tags || parentEvent?.tags,
parentId: parentEvent?.id,
};
}
}
+40 -18
View File
@@ -1,28 +1,50 @@
import { ChatOpenAI } from "./LanguageModel";
import { SimplePrompt } from "./Prompt";
import { CallbackManager, Event } from "./callbacks/CallbackManager";
// TODO change this to LLM class
export interface BaseLLMPredictor {
getLlmMetadata(): Promise<any>;
apredict(
prompt: string | SimplePrompt,
input?: Record<string, string>
input?: Record<string, string>,
parentEvent?: Event
): Promise<string>;
// stream(prompt: string, options: any): Promise<any>;
}
// TODO change this to LLM class
export class ChatGPTLLMPredictor implements BaseLLMPredictor {
llm: string;
model: string;
retryOnThrottling: boolean;
languageModel: ChatOpenAI;
callbackManager?: CallbackManager;
constructor(
llm: string = "gpt-3.5-turbo",
retryOnThrottling: boolean = true
props:
| {
model?: string;
retryOnThrottling?: boolean;
callbackManager?: CallbackManager;
languageModel?: ChatOpenAI;
}
| undefined = undefined
) {
this.llm = llm;
const {
model = "gpt-3.5-turbo",
retryOnThrottling = true,
callbackManager,
languageModel,
} = props || {};
this.model = model;
this.callbackManager = callbackManager;
this.retryOnThrottling = retryOnThrottling;
this.languageModel = new ChatOpenAI(this.llm);
this.languageModel =
languageModel ??
new ChatOpenAI({
model: this.model,
callbackManager: this.callbackManager,
});
}
async getLlmMetadata() {
@@ -31,22 +53,22 @@ export class ChatGPTLLMPredictor implements BaseLLMPredictor {
async apredict(
prompt: string | SimplePrompt,
input?: Record<string, string>
input?: Record<string, string>,
parentEvent?: Event
): Promise<string> {
if (typeof prompt === "string") {
const result = await this.languageModel.agenerate([
{
content: prompt,
type: "human",
},
]);
const result = await this.languageModel.agenerate(
[
{
content: prompt,
type: "human",
},
],
parentEvent
);
return result.generations[0][0].text;
} else {
return this.apredict(prompt(input ?? {}));
}
}
// async stream(prompt: string, options: any) {
// console.log("stream");
// }
}
+44 -11
View File
@@ -1,8 +1,9 @@
import { CallbackManager, Event } from "./callbacks/CallbackManager";
import { aHandleOpenAIStream } from "./callbacks/utility/aHandleOpenAIStream";
import {
ChatCompletionRequestMessageRoleEnum,
Configuration,
CreateChatCompletionRequest,
OpenAISession,
OpenAIWrapper,
getOpenAISession,
} from "./openai";
@@ -10,7 +11,7 @@ export interface BaseLanguageModel {}
type MessageType = "human" | "ai" | "system" | "generic" | "function";
interface BaseMessage {
export interface BaseMessage {
content: string;
type: MessageType;
}
@@ -24,9 +25,11 @@ export interface LLMResult {
generations: Generation[][]; // Each input can have more than one generations
}
export class BaseChatModel implements BaseLanguageModel {}
export interface BaseChatModel extends BaseLanguageModel {
agenerate(messages: BaseMessage[], parentEvent?: Event): Promise<LLMResult>;
}
export class ChatOpenAI extends BaseChatModel {
export class ChatOpenAI implements BaseChatModel {
model: string;
temperature: number = 0.7;
openAIKey: string | null = null;
@@ -34,12 +37,18 @@ export class ChatOpenAI extends BaseChatModel {
maxRetries: number = 6;
n: number = 1;
maxTokens?: number;
session: OpenAISession;
callbackManager?: CallbackManager;
constructor(model: string = "gpt-3.5-turbo") {
super();
constructor({
model = "gpt-3.5-turbo",
callbackManager,
}: {
model: string;
callbackManager?: CallbackManager;
}) {
this.model = model;
this.callbackManager = callbackManager;
this.session = getOpenAISession();
}
@@ -60,8 +69,11 @@ export class ChatOpenAI extends BaseChatModel {
}
}
async agenerate(messages: BaseMessage[]): Promise<LLMResult> {
const { data } = await this.session.openai.createChatCompletion({
async agenerate(
messages: BaseMessage[],
parentEvent?: Event
): Promise<LLMResult> {
const baseRequestParams: CreateChatCompletionRequest = {
model: this.model,
temperature: this.temperature,
max_tokens: this.maxTokens,
@@ -70,8 +82,29 @@ export class ChatOpenAI extends BaseChatModel {
role: ChatOpenAI.mapMessageType(message.type),
content: message.content,
})),
});
};
if (this.callbackManager?.onLLMStream) {
const response = await this.session.openai.createChatCompletion(
{
...baseRequestParams,
stream: true,
},
{ responseType: "stream" }
);
const fullResponse = await aHandleOpenAIStream({
response,
onLLMStream: this.callbackManager.onLLMStream,
parentEvent,
});
return { generations: [[{ text: fullResponse }]] };
}
const response = await this.session.openai.createChatCompletion(
baseRequestParams
);
const { data } = response;
const content = data.choices[0].message?.content ?? "";
return { generations: [[{ text: content }]] };
}
+12 -7
View File
@@ -136,7 +136,7 @@ export class TextNode extends BaseNode {
endCharIdx?: number;
// textTemplate: NOTE write your own formatter if needed
// metadataTemplate: NOTE write your own formatter if needed
metadataSeperator: string = "\n";
metadataSeparator: string = "\n";
constructor(init?: Partial<TextNode>) {
super(init);
@@ -174,7 +174,7 @@ export class TextNode extends BaseNode {
return [...usableMetadataKeys]
.map((key) => `${key}: ${this.metadata[key]}`)
.join(this.metadataSeperator);
.join(this.metadataSeparator);
}
setContent(value: string) {
@@ -206,11 +206,6 @@ export class IndexNode extends TextNode {
}
}
export interface NodeWithScore {
node: TextNode;
score: number;
}
export class Document extends TextNode {
constructor(init?: Partial<Document>) {
super(init);
@@ -229,3 +224,13 @@ export class Document extends TextNode {
export class ImageDocument extends Document {
image?: string;
}
export interface NodeWithScore {
node: BaseNode;
score: number;
}
export interface NodeWithEmbedding {
node: BaseNode;
embedding: number[];
}
+47 -11
View File
@@ -1,5 +1,6 @@
import { Document, NodeRelationship, TextNode } from "./Node";
import { SentenceSplitter } from "./TextSplitter";
import { DEFAULT_CHUNK_OVERLAP, DEFAULT_CHUNK_SIZE } from "./constants";
export function getTextSplitsFromDocument(
document: Document,
@@ -13,18 +14,36 @@ export function getTextSplitsFromDocument(
export function getNodesFromDocument(
document: Document,
textSplitter: SentenceSplitter
textSplitter: SentenceSplitter,
includeMetadata: boolean = true,
includePrevNextRel: boolean = true
) {
let nodes: TextNode[] = [];
const textSplits = getTextSplitsFromDocument(document, textSplitter);
textSplits.forEach((textSplit, index) => {
const node = new TextNode({ text: textSplit });
textSplits.forEach((textSplit) => {
const node = new TextNode({
text: textSplit,
metadata: includeMetadata ? document.metadata : {},
});
node.relationships[NodeRelationship.SOURCE] = document.asRelatedNodeInfo();
nodes.push(node);
});
if (includePrevNextRel) {
nodes.forEach((node, index) => {
if (index > 0) {
node.relationships[NodeRelationship.PREVIOUS] =
nodes[index - 1].asRelatedNodeInfo();
}
if (index < nodes.length - 1) {
node.relationships[NodeRelationship.NEXT] =
nodes[index + 1].asRelatedNodeInfo();
}
});
}
return nodes;
}
@@ -33,17 +52,34 @@ export interface NodeParser {
}
export class SimpleNodeParser implements NodeParser {
textSplitter: SentenceSplitter;
includeMetadata: boolean;
includePrevNextRel: boolean;
constructor(
textSplitter: any = null,
includeExtraInfo: boolean = true,
includePrevNextRel: boolean = true
) {
this.textSplitter = textSplitter ?? new SentenceSplitter();
constructor(init?: {
textSplitter?: SentenceSplitter;
includeMetadata?: boolean;
includePrevNextRel?: boolean;
chunkSize?: number;
chunkOverlap?: number;
}) {
this.textSplitter =
init?.textSplitter ??
new SentenceSplitter(
init?.chunkSize ?? DEFAULT_CHUNK_SIZE,
init?.chunkOverlap ?? DEFAULT_CHUNK_OVERLAP
);
this.includeMetadata = init?.includeMetadata ?? true;
this.includePrevNextRel = init?.includePrevNextRel ?? true;
}
static fromDefaults(): SimpleNodeParser {
return new SimpleNodeParser();
static fromDefaults(init?: {
chunkSize?: number;
chunkOverlap?: number;
includeMetadata?: boolean;
includePrevNextRel?: boolean;
}): SimpleNodeParser {
return new SimpleNodeParser(init);
}
/**
+80
View File
@@ -0,0 +1,80 @@
import { SubQuestion } from "./QuestionGenerator";
export interface BaseOutputParser<T> {
parse(output: string): T;
format(output: string): string;
}
export interface StructuredOutput<T> {
rawOutput: string;
parsedOutput: T;
}
class OutputParserError extends Error {
cause: Error | undefined;
output: string | undefined;
constructor(
message: string,
options: { cause?: Error; output?: string } = {}
) {
// @ts-ignore
super(message, options); // https://github.com/tc39/proposal-error-cause
this.name = "OutputParserError";
if (!this.cause) {
// Need to check for those environments that have implemented the proposal
this.cause = options.cause;
}
this.output = options.output;
// This line is to maintain proper stack trace in V8
// (https://v8.dev/docs/stack-trace-api)
if (Error.captureStackTrace) {
Error.captureStackTrace(this, OutputParserError);
}
}
}
function parseJsonMarkdown(text: string) {
text = text.trim();
const beginDelimiter = "```json";
const endDelimiter = "```";
const beginIndex = text.indexOf(beginDelimiter);
const endIndex = text.indexOf(
endDelimiter,
beginIndex + beginDelimiter.length
);
if (beginIndex === -1 || endIndex === -1) {
throw new OutputParserError("Not a json markdown", { output: text });
}
const jsonText = text.substring(beginIndex + beginDelimiter.length, endIndex);
try {
return JSON.parse(jsonText);
} catch (e) {
throw new OutputParserError("Not a valid json", {
cause: e as Error,
output: text,
});
}
}
export class SubQuestionOutputParser
implements BaseOutputParser<StructuredOutput<SubQuestion[]>>
{
parse(output: string): StructuredOutput<SubQuestion[]> {
const parsed = parseJsonMarkdown(output);
// TODO add zod validation
return { rawOutput: output, parsedOutput: parsed };
}
format(output: string): string {
return output;
}
}
+237
View File
@@ -1,3 +1,7 @@
import { BaseMessage } from "./LanguageModel";
import { SubQuestion } from "./QuestionGenerator";
import { ToolMetadata } from "./Tool";
/**
* A SimplePrompt is a function that takes a dictionary of inputs and returns a string.
* NOTE this is a different interface compared to LlamaIndex Python
@@ -80,3 +84,236 @@ ${context}
------------
Given the new context, refine the original answer to better answer the question. If the context isn't useful, return the original answer.`;
};
export const defaultChoiceSelectPrompt: SimplePrompt = (input) => {
const { context = "", query = "" } = input;
return `A list of documents is shown below. Each document has a number next to it along
with a summary of the document. A question is also provided.
Respond with the numbers of the documents
you should consult to answer the question, in order of relevance, as well
as the relevance score. The relevance score is a number from 1-10 based on
how relevant you think the document is to the question.
Do not include any documents that are not relevant to the question.
Example format:
Document 1:
<summary of document 1>
Document 2:
<summary of document 2>
...
Document 10:\n<summary of document 10>
Question: <question>
Answer:
Doc: 9, Relevance: 7
Doc: 3, Relevance: 4
Doc: 7, Relevance: 3
Let's try this now:
${context}
Question: ${query}
Answer:`;
};
/*
PREFIX = """\
Given a user question, and a list of tools, output a list of relevant sub-questions \
that when composed can help answer the full user question:
"""
example_query_str = (
"Compare and contrast the revenue growth and EBITDA of Uber and Lyft for year 2021"
)
example_tools = [
ToolMetadata(
name="uber_10k",
description="Provides information about Uber financials for year 2021",
),
ToolMetadata(
name="lyft_10k",
description="Provides information about Lyft financials for year 2021",
),
]
example_tools_str = build_tools_text(example_tools)
example_output = [
SubQuestion(
sub_question="What is the revenue growth of Uber", tool_name="uber_10k"
),
SubQuestion(sub_question="What is the EBITDA of Uber", tool_name="uber_10k"),
SubQuestion(
sub_question="What is the revenue growth of Lyft", tool_name="lyft_10k"
),
SubQuestion(sub_question="What is the EBITDA of Lyft", tool_name="lyft_10k"),
]
example_output_str = json.dumps([x.dict() for x in example_output], indent=4)
EXAMPLES = (
"""\
# Example 1
<Tools>
```json
{tools_str}
```
<User Question>
{query_str}
<Output>
```json
{output_str}
```
""".format(
query_str=example_query_str,
tools_str=example_tools_str,
output_str=example_output_str,
)
.replace("{", "{{")
.replace("}", "}}")
)
SUFFIX = """\
# Example 2
<Tools>
```json
{tools_str}
```
<User Question>
{query_str}
<Output>
"""
DEFAULT_SUB_QUESTION_PROMPT_TMPL = PREFIX + EXAMPLES + SUFFIX
*/
export function buildToolsText(tools: ToolMetadata[]) {
const toolsObj = tools.reduce<Record<string, string>>((acc, tool) => {
acc[tool.name] = tool.description;
return acc;
}, {});
return JSON.stringify(toolsObj, null, 4);
}
const exampleTools: ToolMetadata[] = [
{
name: "uber_10k",
description: "Provides information about Uber financials for year 2021",
},
{
name: "lyft_10k",
description: "Provides information about Lyft financials for year 2021",
},
];
const exampleQueryStr = `Compare and contrast the revenue growth and EBITDA of Uber and Lyft for year 2021`;
const exampleOutput: SubQuestion[] = [
{
subQuestion: "What is the revenue growth of Uber",
toolName: "uber_10k",
},
{
subQuestion: "What is the EBITDA of Uber",
toolName: "uber_10k",
},
{
subQuestion: "What is the revenue growth of Lyft",
toolName: "lyft_10k",
},
{
subQuestion: "What is the EBITDA of Lyft",
toolName: "lyft_10k",
},
];
export const defaultSubQuestionPrompt: SimplePrompt = (input) => {
const { toolsStr, queryStr } = input;
return `Given a user question, and a list of tools, output a list of relevant sub-questions that when composed can help answer the full user question:
# Example 1
<Tools>
\`\`\`json
${buildToolsText(exampleTools)}
\`\`\`
<User Question>
${exampleQueryStr}
<Output>
\`\`\`json
${JSON.stringify(exampleOutput, null, 4)}
\`\`\`
# Example 2
<Tools>
\`\`\`json
${toolsStr}}
\`\`\`
<User Question>
${queryStr}
<Output>
`;
};
// DEFAULT_TEMPLATE = """\
// Given a conversation (between Human and Assistant) and a follow up message from Human, \
// rewrite the message to be a standalone question that captures all relevant context \
// from the conversation.
// <Chat History>
// {chat_history}
// <Follow Up Message>
// {question}
// <Standalone question>
// """
export const defaultCondenseQuestionPrompt: SimplePrompt = (input) => {
const { chatHistory, question } = input;
return `Given a conversation (between Human and Assistant) and a follow up message from Human, rewrite the message to be a standalone question that captures all relevant context from the conversation.
<Chat History>
${chatHistory}
<Follow Up Message>
${question}
<Standalone question>
`;
};
export function messagesToHistoryStr(messages: BaseMessage[]) {
return messages.reduce((acc, message) => {
acc += acc ? "\n" : "";
if (message.type === "human") {
acc += `Human: ${message.content}`;
} else {
acc += `Assistant: ${message.content}`;
}
return acc;
}, "");
}
export const contextSystemPrompt: SimplePrompt = (input) => {
const { context } = input;
return `Context information is below.
---------------------
${context}
---------------------`;
};
+121 -7
View File
@@ -1,22 +1,136 @@
import { NodeWithScore, TextNode } from "./Node";
import {
BaseQuestionGenerator,
LLMQuestionGenerator,
SubQuestion,
} from "./QuestionGenerator";
import { Response } from "./Response";
import { ResponseSynthesizer } from "./ResponseSynthesizer";
import { CompactAndRefine, ResponseSynthesizer } from "./ResponseSynthesizer";
import { BaseRetriever } from "./Retriever";
import { v4 as uuidv4 } from "uuid";
import { Event } from "./callbacks/CallbackManager";
import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext";
import { QueryEngineTool, ToolMetadata } from "./Tool";
export interface BaseQueryEngine {
aquery(query: string): Promise<Response>;
aquery(query: string, parentEvent?: Event): Promise<Response>;
}
export class RetrieverQueryEngine {
export class RetrieverQueryEngine implements BaseQueryEngine {
retriever: BaseRetriever;
responseSynthesizer: ResponseSynthesizer;
constructor(retriever: BaseRetriever) {
this.retriever = retriever;
this.responseSynthesizer = new ResponseSynthesizer();
const serviceContext: ServiceContext | undefined =
this.retriever.getServiceContext();
this.responseSynthesizer = new ResponseSynthesizer({ serviceContext });
}
async aquery(query: string) {
const nodes = await this.retriever.aretrieve(query);
return this.responseSynthesizer.asynthesize(query, nodes);
async aquery(query: string, parentEvent?: Event) {
const _parentEvent: Event = parentEvent || {
id: uuidv4(),
type: "wrapper",
tags: ["final"],
};
const nodes = await this.retriever.aretrieve(query, _parentEvent);
return this.responseSynthesizer.asynthesize(query, nodes, _parentEvent);
}
}
export class SubQuestionQueryEngine implements BaseQueryEngine {
responseSynthesizer: ResponseSynthesizer;
questionGen: BaseQuestionGenerator;
queryEngines: Record<string, BaseQueryEngine>;
metadatas: ToolMetadata[];
constructor(init: {
questionGen: BaseQuestionGenerator;
responseSynthesizer: ResponseSynthesizer;
queryEngineTools: QueryEngineTool[];
}) {
this.questionGen = init.questionGen;
this.responseSynthesizer =
init.responseSynthesizer ?? new ResponseSynthesizer();
this.queryEngines = init.queryEngineTools.reduce<
Record<string, BaseQueryEngine>
>((acc, tool) => {
acc[tool.metadata.name] = tool.queryEngine;
return acc;
}, {});
this.metadatas = init.queryEngineTools.map((tool) => tool.metadata);
}
static fromDefaults(init: {
queryEngineTools: QueryEngineTool[];
questionGen?: BaseQuestionGenerator;
responseSynthesizer?: ResponseSynthesizer;
serviceContext?: ServiceContext;
}) {
const serviceContext =
init.serviceContext ?? serviceContextFromDefaults({});
const questionGen = init.questionGen ?? new LLMQuestionGenerator();
const responseSynthesizer =
init.responseSynthesizer ??
new ResponseSynthesizer({
responseBuilder: new CompactAndRefine(serviceContext),
serviceContext,
});
return new SubQuestionQueryEngine({
questionGen,
responseSynthesizer,
queryEngineTools: init.queryEngineTools,
});
}
async aquery(query: string): Promise<Response> {
const subQuestions = await this.questionGen.agenerate(
this.metadatas,
query
);
// groups final retrieval+synthesis operation
const parentEvent: Event = {
id: uuidv4(),
type: "wrapper",
tags: ["final"],
};
// groups all sub-queries
const subQueryParentEvent: Event = {
id: uuidv4(),
parentId: parentEvent.id,
type: "wrapper",
tags: ["intermediate"],
};
const subQNodes = await Promise.all(
subQuestions.map((subQ) => this.aquerySubQ(subQ, subQueryParentEvent))
);
const nodes = subQNodes
.filter((node) => node !== null)
.map((node) => node as NodeWithScore);
return this.responseSynthesizer.asynthesize(query, nodes, parentEvent);
}
private async aquerySubQ(
subQ: SubQuestion,
parentEvent?: Event
): Promise<NodeWithScore | null> {
try {
const question = subQ.subQuestion;
const queryEngine = this.queryEngines[subQ.toolName];
const response = await queryEngine.aquery(question, parentEvent);
const responseText = response.response;
const nodeText = `Sub question: ${question}\nResponse: ${responseText}}`;
const node = new TextNode({ text: nodeText });
return { node, score: 0 };
} catch (error) {
return null;
}
}
}
+49
View File
@@ -0,0 +1,49 @@
import { BaseLLMPredictor, ChatGPTLLMPredictor } from "./LLMPredictor";
import {
BaseOutputParser,
StructuredOutput,
SubQuestionOutputParser,
} from "./OutputParser";
import {
SimplePrompt,
buildToolsText,
defaultSubQuestionPrompt,
} from "./Prompt";
import { ToolMetadata } from "./Tool";
export interface SubQuestion {
subQuestion: string;
toolName: string;
}
export interface BaseQuestionGenerator {
agenerate(tools: ToolMetadata[], query: string): Promise<SubQuestion[]>;
}
export class LLMQuestionGenerator implements BaseQuestionGenerator {
llmPredictor: BaseLLMPredictor;
prompt: SimplePrompt;
outputParser: BaseOutputParser<StructuredOutput<SubQuestion[]>>;
constructor(init?: Partial<LLMQuestionGenerator>) {
this.llmPredictor = init?.llmPredictor ?? new ChatGPTLLMPredictor();
this.prompt = init?.prompt ?? defaultSubQuestionPrompt;
this.outputParser = init?.outputParser ?? new SubQuestionOutputParser();
}
async agenerate(
tools: ToolMetadata[],
query: string
): Promise<SubQuestion[]> {
const toolsStr = buildToolsText(tools);
const queryStr = query;
const prediction = await this.llmPredictor.apredict(this.prompt, {
toolsStr,
queryStr,
});
const structuredOutput = this.outputParser.parse(prediction);
return structuredOutput.parsedOutput;
}
}
+4 -4
View File
@@ -1,10 +1,10 @@
import { TextNode } from "./Node";
import { BaseNode } from "./Node";
export class Response {
response?: string;
sourceNodes: TextNode[];
response: string;
sourceNodes?: BaseNode[];
constructor(response?: string, sourceNodes?: TextNode[]) {
constructor(response: string, sourceNodes?: BaseNode[]) {
this.response = response;
this.sourceNodes = sourceNodes || [];
}
+47 -15
View File
@@ -1,5 +1,5 @@
import { ChatGPTLLMPredictor } from "./LLMPredictor";
import { NodeWithScore } from "./Node";
import { ChatGPTLLMPredictor, BaseLLMPredictor } from "./LLMPredictor";
import { MetadataMode, NodeWithScore } from "./Node";
import {
SimplePrompt,
defaultRefinePrompt,
@@ -8,28 +8,38 @@ import {
import { getBiggestPrompt } from "./PromptHelper";
import { Response } from "./Response";
import { ServiceContext } from "./ServiceContext";
import { Event } from "./callbacks/CallbackManager";
interface BaseResponseBuilder {
agetResponse(query: string, textChunks: string[]): Promise<string>;
agetResponse(
query: string,
textChunks: string[],
parentEvent?: Event
): Promise<string>;
}
export class SimpleResponseBuilder implements BaseResponseBuilder {
llmPredictor: ChatGPTLLMPredictor;
llmPredictor: BaseLLMPredictor;
textQATemplate: SimplePrompt;
constructor() {
this.llmPredictor = new ChatGPTLLMPredictor();
constructor(serviceContext?: ServiceContext) {
this.llmPredictor =
serviceContext?.llmPredictor ?? new ChatGPTLLMPredictor();
this.textQATemplate = defaultTextQaPrompt;
}
async agetResponse(query: string, textChunks: string[]): Promise<string> {
async agetResponse(
query: string,
textChunks: string[],
parentEvent?: Event
): Promise<string> {
const input = {
query,
context: textChunks.join("\n\n"),
};
const prompt = this.textQATemplate(input);
return this.llmPredictor.apredict(prompt, {});
return this.llmPredictor.apredict(prompt, {}, parentEvent);
}
}
@@ -178,20 +188,42 @@ export class TreeSummarize implements BaseResponseBuilder {
}
}
export function getResponseBuilder(): BaseResponseBuilder {
return new SimpleResponseBuilder();
export function getResponseBuilder(
serviceContext?: ServiceContext
): SimpleResponseBuilder {
return new SimpleResponseBuilder(serviceContext);
}
// TODO replace with Logan's new response_sythesizers/factory.py
export class ResponseSynthesizer {
responseBuilder: BaseResponseBuilder;
serviceContext?: ServiceContext;
constructor() {
this.responseBuilder = getResponseBuilder();
constructor({
responseBuilder,
serviceContext,
}: {
responseBuilder?: BaseResponseBuilder;
serviceContext?: ServiceContext;
} = {}) {
this.serviceContext = serviceContext;
this.responseBuilder =
responseBuilder ?? getResponseBuilder(this.serviceContext);
}
async asynthesize(query: string, nodes: NodeWithScore[]) {
let textChunks: string[] = nodes.map((node) => node.node.text);
const response = await this.responseBuilder.agetResponse(query, textChunks);
async asynthesize(
query: string,
nodes: NodeWithScore[],
parentEvent?: Event
) {
let textChunks: string[] = nodes.map((node) =>
node.node.getContent(MetadataMode.NONE)
);
const response = await this.responseBuilder.agetResponse(
query,
textChunks,
parentEvent
);
return new Response(
response,
nodes.map((node) => node.node)
+43 -19
View File
@@ -1,43 +1,67 @@
import { VectorStoreIndex } from "./BaseIndex";
import { BaseEmbedding, getTopKEmbeddings } from "./Embedding";
import { globalsHelper } from "./GlobalsHelper";
import { NodeWithScore } from "./Node";
import { ServiceContext } from "./ServiceContext";
import { Event } from "./callbacks/CallbackManager";
import { DEFAULT_SIMILARITY_TOP_K } from "./constants";
import {
VectorStoreQuery,
VectorStoreQueryMode,
} from "./storage/vectorStore/types";
export interface BaseRetriever {
aretrieve(query: string): Promise<any>;
aretrieve(query: string, parentEvent?: Event): Promise<NodeWithScore[]>;
getServiceContext(): ServiceContext;
}
export class VectorIndexRetriever implements BaseRetriever {
index: VectorStoreIndex;
similarityTopK = DEFAULT_SIMILARITY_TOP_K;
embeddingService: BaseEmbedding;
private serviceContext: ServiceContext;
constructor(index: VectorStoreIndex, embeddingService: BaseEmbedding) {
constructor(index: VectorStoreIndex) {
this.index = index;
this.embeddingService = embeddingService;
this.serviceContext = this.index.serviceContext;
}
async aretrieve(query: string): Promise<NodeWithScore[]> {
const queryEmbedding = await this.embeddingService.aGetQueryEmbedding(
query
);
const [similarities, ids] = getTopKEmbeddings(
queryEmbedding,
this.index.nodes.map((node) => node.getEmbedding()),
undefined,
this.index.nodes.map((node) => node.id_)
);
async aretrieve(
query: string,
parentEvent?: Event
): Promise<NodeWithScore[]> {
const queryEmbedding =
await this.serviceContext.embedModel.aGetQueryEmbedding(query);
const q: VectorStoreQuery = {
queryEmbedding: queryEmbedding,
mode: VectorStoreQueryMode.DEFAULT,
similarityTopK: this.similarityTopK,
};
const result = this.index.vectorStore.query(q);
let nodesWithScores: NodeWithScore[] = [];
for (let i = 0; i < ids.length; i++) {
const node = this.index.indexStruct.nodesDict[ids[i]];
for (let i = 0; i < result.ids.length; i++) {
const node = this.index.indexStruct.nodesDict[result.ids[i]];
nodesWithScores.push({
node: node,
score: similarities[i],
score: result.similarities[i],
});
}
if (this.serviceContext.callbackManager.onRetrieve) {
this.serviceContext.callbackManager.onRetrieve({
query,
nodes: nodesWithScores,
event: globalsHelper.createEvent({
parentEvent,
type: "retrieve",
}),
});
}
return nodesWithScores;
}
getServiceContext(): ServiceContext {
return this.serviceContext;
}
}
+23 -9
View File
@@ -1,35 +1,46 @@
import { BaseEmbedding, OpenAIEmbedding } from "./Embedding";
import { BaseLLMPredictor, ChatGPTLLMPredictor } from "./LLMPredictor";
import { BaseLanguageModel } from "./LanguageModel";
import { ChatOpenAI } from "./LanguageModel";
import { NodeParser, SimpleNodeParser } from "./NodeParser";
import { PromptHelper } from "./PromptHelper";
import { CallbackManager } from "./callbacks/CallbackManager";
export interface ServiceContext {
llmPredictor: BaseLLMPredictor;
promptHelper: PromptHelper;
embedModel: BaseEmbedding;
nodeParser: NodeParser;
callbackManager: CallbackManager;
// llamaLogger: any;
// callbackManager: any;
}
export interface ServiceContextOptions {
llmPredictor?: BaseLLMPredictor;
llm?: BaseLanguageModel;
llm?: ChatOpenAI;
promptHelper?: PromptHelper;
embedModel?: BaseEmbedding;
nodeParser?: NodeParser;
callbackManager?: CallbackManager;
// NodeParser arguments
chunkSize?: number;
chunkOverlap: number;
chunkOverlap?: number;
}
export function serviceContextFromDefaults(options: ServiceContextOptions) {
export function serviceContextFromDefaults(options?: ServiceContextOptions) {
const callbackManager = options?.callbackManager ?? new CallbackManager();
const serviceContext: ServiceContext = {
llmPredictor: options.llmPredictor ?? new ChatGPTLLMPredictor(),
embedModel: options.embedModel ?? new OpenAIEmbedding(),
nodeParser: options.nodeParser ?? new SimpleNodeParser(),
promptHelper: options.promptHelper ?? new PromptHelper(),
llmPredictor:
options?.llmPredictor ??
new ChatGPTLLMPredictor({ callbackManager, languageModel: options?.llm }),
embedModel: options?.embedModel ?? new OpenAIEmbedding(),
nodeParser:
options?.nodeParser ??
new SimpleNodeParser({
chunkSize: options?.chunkSize,
chunkOverlap: options?.chunkOverlap,
}),
promptHelper: options?.promptHelper ?? new PromptHelper(),
callbackManager,
};
return serviceContext;
@@ -52,5 +63,8 @@ export function serviceContextFromServiceContext(
if (options.nodeParser) {
newServiceContext.nodeParser = options.nodeParser;
}
if (options.callbackManager) {
newServiceContext.callbackManager = options.callbackManager;
}
return newServiceContext;
}
+14
View File
@@ -0,0 +1,14 @@
import { BaseQueryEngine } from "./QueryEngine";
export interface ToolMetadata {
description: string;
name: string;
}
export interface BaseTool {
metadata: ToolMetadata;
}
export interface QueryEngineTool extends BaseTool {
queryEngine: BaseQueryEngine;
}
@@ -0,0 +1,72 @@
import { ChatCompletionResponseMessageRoleEnum } from "openai";
import { NodeWithScore } from "../Node";
/*
An event is a wrapper that groups related operations.
For example, during retrieve and synthesize,
a parent event wraps both operations, and each operation has it's own
event. In this case, both sub-events will share a parentId.
*/
export type EventTag = "intermediate" | "final";
export type EventType = "retrieve" | "llmPredict" | "wrapper";
export interface Event {
id: string;
type: EventType;
tags?: EventTag[];
parentId?: string;
}
interface BaseCallbackResponse {
event: Event;
}
export interface StreamToken {
id: string;
object: string;
created: number;
model: string;
choices: {
index: number;
delta: {
content?: string;
role?: ChatCompletionResponseMessageRoleEnum;
};
finish_reason: string | null;
}[];
}
export interface StreamCallbackResponse extends BaseCallbackResponse {
index: number;
isDone?: boolean;
token?: StreamToken;
}
export interface RetrievalCallbackResponse extends BaseCallbackResponse {
query: string;
nodes: NodeWithScore[];
}
interface CallbackManagerMethods {
/*
onLLMStream is called when a token is streamed from the LLM. Defining this
callback auto sets the stream = True flag on the openAI createChatCompletion request.
*/
onLLMStream?: (params: StreamCallbackResponse) => Promise<void> | void;
/*
onRetrieve is called as soon as the retriever finishes fetching relevant nodes.
This callback allows you to handle the retrieved nodes even if the synthesizer
is still running.
*/
onRetrieve?: (params: RetrievalCallbackResponse) => Promise<void> | void;
}
export class CallbackManager implements CallbackManagerMethods {
onLLMStream?: (params: StreamCallbackResponse) => Promise<void> | void;
onRetrieve?: (params: RetrievalCallbackResponse) => Promise<void> | void;
constructor(handlers?: CallbackManagerMethods) {
this.onLLMStream = handlers?.onLLMStream;
this.onRetrieve = handlers?.onRetrieve;
}
}
@@ -0,0 +1,64 @@
import { globalsHelper } from "../../GlobalsHelper";
import { StreamCallbackResponse, Event } from "../CallbackManager";
import { StreamToken } from "../CallbackManager";
export async function aHandleOpenAIStream({
response,
onLLMStream,
parentEvent,
}: {
response: any;
onLLMStream: (data: StreamCallbackResponse) => void;
parentEvent?: Event;
}): Promise<string> {
const event = globalsHelper.createEvent({
parentEvent,
type: "llmPredict",
});
const stream = __astreamCompletion(response.data as any);
let index = 0;
let cumulativeText = "";
for await (const message of stream) {
const token: StreamToken = JSON.parse(message);
const { content = "", role = "assistant" } = token?.choices[0]?.delta ?? {};
// ignore the first token
if (!content && role === "assistant" && index === 0) {
continue;
}
cumulativeText += content;
onLLMStream?.({ event, index, token });
index++;
}
onLLMStream?.({ event, index, isDone: true });
return cumulativeText;
}
/*
sources:
- https://github.com/openai/openai-node/issues/18#issuecomment-1372047643
- https://github.com/openai/openai-node/issues/18#issuecomment-1595805163
*/
async function* __astreamCompletion(data: string[]) {
yield* __alinesToText(__achunksToLines(data));
}
async function* __alinesToText(linesAsync: string | void | any) {
for await (const line of linesAsync) {
yield line.substring("data :".length);
}
}
async function* __achunksToLines(chunksAsync: string[]) {
let previous = "";
for await (const chunk of chunksAsync) {
const bufferChunk = Buffer.isBuffer(chunk) ? chunk : Buffer.from(chunk);
previous += bufferChunk;
let eolIndex;
while ((eolIndex = previous.indexOf("\n")) >= 0) {
const line = previous.slice(0, eolIndex + 1).trimEnd();
if (line === "data: [DONE]") break;
if (line.startsWith("data: ")) yield line;
previous = previous.slice(eolIndex + 1);
}
}
}
-35
View File
@@ -1,35 +0,0 @@
import _ from "lodash";
import { VectorStoreQueryMode } from "./storage/vectorStore/types";
export function getTopKEmbeddings(
queryEmbedding: number[],
embeddings: number[][],
similarityFn?: (queryEmbedding: number[], emb: number[]) => number,
similarityTopK?: number,
embeddingIds?: number[],
similarityCutoff?: number
): [number[], number[]] {
throw new Error("Not implemented");
}
export function getTopKEmbeddingsLearner(
queryEmbedding: number[],
embeddings: number[][],
similarityTopK?: number,
embeddingIds?: number[],
queryMode: VectorStoreQueryMode = VectorStoreQueryMode.SVM
): [number[], number[]] {
throw new Error("Not implemented");
}
export function getTopKMMREmbeddings(
queryEmbedding: number[],
embeddings: number[][],
similarityFn?: (queryEmbedding: number[], emb: number[]) => number,
similarityTopK?: number,
embeddingIds?: number[],
similarityCutoff?: number,
mmrThreshold?: number
): [number[], number[]] {
throw new Error("Not implemented");
}
+166
View File
@@ -0,0 +1,166 @@
import { BaseNode, Document } from "../../Node";
import { BaseIndex, BaseIndexInit, IndexList } from "../../BaseIndex";
import { BaseQueryEngine, RetrieverQueryEngine } from "../../QueryEngine";
import {
StorageContext,
storageContextFromDefaults,
} from "../../storage/StorageContext";
import { BaseRetriever } from "../../Retriever";
import { ListIndexRetriever } from "./ListIndexRetriever";
import {
ServiceContext,
serviceContextFromDefaults,
} from "../../ServiceContext";
import { BaseDocumentStore, RefDocInfo } from "../../storage/docStore/types";
import _ from "lodash";
export enum ListRetrieverMode {
DEFAULT = "default",
// EMBEDDING = "embedding",
LLM = "llm",
}
export interface ListIndexOptions {
nodes?: BaseNode[];
indexStruct?: IndexList;
serviceContext?: ServiceContext;
storageContext?: StorageContext;
}
export class ListIndex extends BaseIndex<IndexList> {
constructor(init: BaseIndexInit<IndexList>) {
super(init);
}
static async init(options: ListIndexOptions): Promise<ListIndex> {
const storageContext =
options.storageContext ?? (await storageContextFromDefaults({}));
const serviceContext =
options.serviceContext ?? serviceContextFromDefaults({});
const { docStore, indexStore } = storageContext;
let indexStruct: IndexList;
if (options.indexStruct) {
if (options.nodes) {
throw new Error(
"Cannot initialize VectorStoreIndex with both nodes and indexStruct"
);
}
indexStruct = options.indexStruct;
} else {
if (!options.nodes) {
throw new Error(
"Cannot initialize VectorStoreIndex without nodes or indexStruct"
);
}
indexStruct = ListIndex._buildIndexFromNodes(
options.nodes,
storageContext.docStore
);
}
return new ListIndex({
storageContext,
serviceContext,
docStore,
indexStore,
indexStruct,
});
}
static async fromDocuments(
documents: Document[],
storageContext?: StorageContext,
serviceContext?: ServiceContext
): Promise<ListIndex> {
storageContext = storageContext ?? (await storageContextFromDefaults({}));
serviceContext = serviceContext ?? serviceContextFromDefaults({});
const docStore = storageContext.docStore;
docStore.addDocuments(documents, true);
for (const doc of documents) {
docStore.setDocumentHash(doc.id_, doc.hash);
}
const nodes = serviceContext.nodeParser.getNodesFromDocuments(documents);
const index = await ListIndex.init({
nodes,
storageContext,
serviceContext,
});
return index;
}
asRetriever(
mode: ListRetrieverMode = ListRetrieverMode.DEFAULT
): BaseRetriever {
switch (mode) {
case ListRetrieverMode.DEFAULT:
return new ListIndexRetriever(this);
case ListRetrieverMode.LLM:
throw new Error(`Support for LLM retriever mode is not implemented`);
default:
throw new Error(`Unknown retriever mode: ${mode}`);
}
}
asQueryEngine(
mode: ListRetrieverMode = ListRetrieverMode.DEFAULT
): BaseQueryEngine {
return new RetrieverQueryEngine(this.asRetriever());
}
static _buildIndexFromNodes(
nodes: BaseNode[],
docStore: BaseDocumentStore,
indexStruct?: IndexList
): IndexList {
indexStruct = indexStruct || new IndexList();
docStore.addDocuments(nodes, true);
for (const node of nodes) {
indexStruct.addNode(node);
}
return indexStruct;
}
protected _insert(nodes: BaseNode[]): void {
for (const node of nodes) {
this.indexStruct.addNode(node);
}
}
protected _deleteNode(nodeId: string): void {
this.indexStruct.nodes = this.indexStruct.nodes.filter(
(existingNodeId: string) => existingNodeId !== nodeId
);
}
async getRefDocInfo(): Promise<Record<string, RefDocInfo>> {
const nodeDocIds = this.indexStruct.nodes;
const nodes = await this.docStore.getNodes(nodeDocIds);
const refDocInfoMap: Record<string, RefDocInfo> = {};
for (const node of nodes) {
const refNode = node.sourceNode;
if (_.isNil(refNode)) {
continue;
}
const refDocInfo = await this.docStore.getRefDocInfo(refNode.nodeId);
if (_.isNil(refDocInfo)) {
continue;
}
refDocInfoMap[refNode.nodeId] = refDocInfo;
}
return refDocInfoMap;
}
}
// Legacy
export type GPTListIndex = ListIndex;
@@ -0,0 +1,137 @@
import { BaseRetriever } from "../../Retriever";
import { NodeWithScore } from "../../Node";
import { ListIndex } from "./ListIndex";
import { ServiceContext } from "../../ServiceContext";
import {
NodeFormatterFunction,
ChoiceSelectParserFunction,
defaultFormatNodeBatchFn,
defaultParseChoiceSelectAnswerFn,
} from "./utils";
import { SimplePrompt, defaultChoiceSelectPrompt } from "../../Prompt";
import _ from "lodash";
import { globalsHelper } from "../../GlobalsHelper";
import { Event } from "../../callbacks/CallbackManager";
/**
* Simple retriever for ListIndex that returns all nodes
*/
export class ListIndexRetriever implements BaseRetriever {
index: ListIndex;
constructor(index: ListIndex) {
this.index = index;
}
async aretrieve(
query: string,
parentEvent?: Event
): Promise<NodeWithScore[]> {
const nodeIds = this.index.indexStruct.nodes;
const nodes = await this.index.docStore.getNodes(nodeIds);
const result = nodes.map((node) => ({
node: node,
score: 1,
}));
if (this.index.serviceContext.callbackManager.onRetrieve) {
this.index.serviceContext.callbackManager.onRetrieve({
query,
nodes: result,
event: globalsHelper.createEvent({
parentEvent,
type: "retrieve",
}),
});
}
return result;
}
getServiceContext(): ServiceContext {
return this.index.serviceContext;
}
}
/**
* LLM retriever for ListIndex.
*/
export class ListIndexLLMRetriever implements BaseRetriever {
index: ListIndex;
choiceSelectPrompt: SimplePrompt;
choiceBatchSize: number;
formatNodeBatchFn: NodeFormatterFunction;
parseChoiceSelectAnswerFn: ChoiceSelectParserFunction;
serviceContext: ServiceContext;
constructor(
index: ListIndex,
choiceSelectPrompt?: SimplePrompt,
choiceBatchSize: number = 10,
formatNodeBatchFn?: NodeFormatterFunction,
parseChoiceSelectAnswerFn?: ChoiceSelectParserFunction,
serviceContext?: ServiceContext
) {
this.index = index;
this.choiceSelectPrompt = choiceSelectPrompt || defaultChoiceSelectPrompt;
this.choiceBatchSize = choiceBatchSize;
this.formatNodeBatchFn = formatNodeBatchFn || defaultFormatNodeBatchFn;
this.parseChoiceSelectAnswerFn =
parseChoiceSelectAnswerFn || defaultParseChoiceSelectAnswerFn;
this.serviceContext = serviceContext || index.serviceContext;
}
async aretrieve(
query: string,
parentEvent?: Event
): Promise<NodeWithScore[]> {
const nodeIds = this.index.indexStruct.nodes;
const results: NodeWithScore[] = [];
for (let idx = 0; idx < nodeIds.length; idx += this.choiceBatchSize) {
const nodeIdsBatch = nodeIds.slice(idx, idx + this.choiceBatchSize);
const nodesBatch = await this.index.docStore.getNodes(nodeIdsBatch);
const fmtBatchStr = this.formatNodeBatchFn(nodesBatch);
const input = { context: fmtBatchStr, query: query };
const rawResponse = await this.serviceContext.llmPredictor.apredict(
this.choiceSelectPrompt,
input
);
// parseResult is a map from doc number to relevance score
const parseResult = this.parseChoiceSelectAnswerFn(
rawResponse,
nodesBatch.length
);
const choiceNodeIds = nodeIdsBatch.filter((nodeId, idx) => {
return `${idx}` in parseResult;
});
const choiceNodes = await this.index.docStore.getNodes(choiceNodeIds);
const nodeWithScores = choiceNodes.map((node, i) => ({
node: node,
score: _.get(parseResult, `${i + 1}`, 1),
}));
results.push(...nodeWithScores);
}
if (this.serviceContext.callbackManager.onRetrieve) {
this.serviceContext.callbackManager.onRetrieve({
query,
nodes: results,
event: globalsHelper.createEvent({
parentEvent,
type: "retrieve",
}),
});
}
return results;
}
getServiceContext(): ServiceContext {
return this.serviceContext;
}
}
+5
View File
@@ -0,0 +1,5 @@
export { ListIndex, ListRetrieverMode } from "./ListIndex";
export {
ListIndexRetriever,
ListIndexLLMRetriever,
} from "./ListIndexRetriever";
+73
View File
@@ -0,0 +1,73 @@
import { BaseNode, MetadataMode } from "../../Node";
import _ from "lodash";
export type NodeFormatterFunction = (summaryNodes: BaseNode[]) => string;
export const defaultFormatNodeBatchFn: NodeFormatterFunction = (
summaryNodes: BaseNode[]
): string => {
return summaryNodes
.map((node, idx) => {
return `
Document ${idx + 1}:
${node.getContent(MetadataMode.LLM)}
`.trim();
})
.join("\n\n");
};
// map from document number to its relevance score
export type ChoiceSelectParseResult = { [docNumber: number]: number };
export type ChoiceSelectParserFunction = (
answer: string,
numChoices: number,
raiseErr?: boolean
) => ChoiceSelectParseResult;
export const defaultParseChoiceSelectAnswerFn: ChoiceSelectParserFunction = (
answer: string,
numChoices: number,
raiseErr: boolean = false
): ChoiceSelectParseResult => {
// split the line into the answer number and relevance score portions
const lineTokens: string[][] = answer
.split("\n")
.map((line: string) => {
let lineTokens = line.split(",");
if (lineTokens.length !== 2) {
if (raiseErr) {
throw new Error(
`Invalid answer line: ${line}. Answer line must be of the form: answer_num: <int>, answer_relevance: <float>`
);
} else {
return null;
}
}
return lineTokens;
})
.filter((lineTokens) => !_.isNil(lineTokens)) as string[][];
// parse the answer number and relevance score
return lineTokens.reduce(
(parseResult: ChoiceSelectParseResult, lineToken: string[]) => {
try {
let docNum = parseInt(lineToken[0].split(":")[1].trim());
let answerRelevance = parseFloat(lineToken[1].split(":")[1].trim());
if (docNum < 1 || docNum > numChoices) {
if (raiseErr) {
throw new Error(
`Invalid answer number: ${docNum}. Answer number must be between 1 and ${numChoices}`
);
} else {
parseResult[docNum] = answerRelevance;
}
}
} catch (e) {
if (raiseErr) {
throw e;
}
}
return parseResult;
},
{}
);
};
+21 -18
View File
@@ -12,9 +12,9 @@ import {
} from "./constants";
export interface StorageContext {
docStore?: BaseDocumentStore;
indexStore?: BaseIndexStore;
vectorStore?: VectorStore;
docStore: BaseDocumentStore;
indexStore: BaseIndexStore;
vectorStore: VectorStore;
}
type BuilderParams = {
@@ -32,21 +32,24 @@ export async function storageContextFromDefaults({
persistDir,
fs,
}: BuilderParams): Promise<StorageContext> {
persistDir = persistDir || DEFAULT_PERSIST_DIR;
fs = fs || DEFAULT_FS;
docStore =
docStore ||
(await SimpleDocumentStore.fromPersistDir(
persistDir,
DEFAULT_NAMESPACE,
fs
));
indexStore =
indexStore || (await SimpleIndexStore.fromPersistDir(persistDir, fs));
vectorStore =
vectorStore || (await SimpleVectorStore.fromPersistDir(persistDir, fs));
if (!persistDir) {
docStore = docStore || new SimpleDocumentStore();
indexStore = indexStore || new SimpleIndexStore();
vectorStore = vectorStore || new SimpleVectorStore();
} else {
fs = fs || DEFAULT_FS;
docStore =
docStore ||
(await SimpleDocumentStore.fromPersistDir(
persistDir,
DEFAULT_NAMESPACE,
fs
));
indexStore =
indexStore || (await SimpleIndexStore.fromPersistDir(persistDir, fs));
vectorStore =
vectorStore || (await SimpleVectorStore.fromPersistDir(persistDir, fs));
}
return {
docStore,
@@ -77,7 +77,7 @@ export class KVDocumentStore extends BaseDocumentStore {
let json = await this.kvstore.get(docId, this.nodeCollection);
if (_.isNil(json)) {
if (raiseError) {
throw new Error(`doc_id ${docId} not found.`);
throw new Error(`docId ${docId} not found.`);
} else {
return;
}
+3 -5
View File
@@ -23,12 +23,10 @@ export function jsonToDoc(docDict: Record<string, any>): BaseNode {
hash: dataDict.hash,
});
} else if (docType === ObjectType.TEXT) {
const relationships = dataDict.relationships;
doc = new TextNode({
text: relationships.text,
id_: relationships.id_,
embedding: relationships.embedding,
hash: relationships.hash,
text: dataDict.text,
id_: dataDict.id_,
hash: dataDict.hash,
});
} else {
throw new Error(`Unknown doc type: ${docType}`);
@@ -12,7 +12,7 @@ import {
getTopKMMREmbeddings,
} from "../../Embedding";
import { DEFAULT_PERSIST_DIR, DEFAULT_FS } from "../constants";
import { TextNode } from "../../Node";
import { NodeWithEmbedding } from "../../Node";
const LEARNER_MODES = new Set<VectorStoreQueryMode>([
VectorStoreQueryMode.SVM,
@@ -53,18 +53,19 @@ export class SimpleVectorStore implements VectorStore {
return this.data.embeddingDict[textId];
}
add(embeddingResults: TextNode[]): string[] {
add(embeddingResults: NodeWithEmbedding[]): string[] {
for (let result of embeddingResults) {
this.data.embeddingDict[result.id_] = result.getEmbedding();
this.data.embeddingDict[result.node.id_] = result.embedding;
if (!result.sourceNode) {
if (!result.node.sourceNode) {
console.error("Missing source node from TextNode.");
continue;
}
this.data.textIdToRefDocId[result.id_] = result.sourceNode?.nodeId;
this.data.textIdToRefDocId[result.node.id_] =
result.node.sourceNode?.nodeId;
}
return embeddingResults.map((result) => result.id_);
return embeddingResults.map((result) => result.node.id_);
}
delete(refDocId: string): void {
+6 -13
View File
@@ -1,18 +1,11 @@
import { TextNode } from "../../Node";
import { BaseNode } from "../../Node";
import { GenericFileSystem } from "../FileSystem";
export interface NodeWithEmbedding {
node: TextNode;
embedding: number[];
id(): string;
refDocId(): string;
}
import { NodeWithEmbedding } from "../../Node";
export interface VectorStoreQueryResult {
nodes?: TextNode[];
similarities?: number[];
ids?: string[];
nodes?: BaseNode[];
similarities: number[];
ids: string[];
}
export enum VectorStoreQueryMode {
@@ -68,7 +61,7 @@ export interface VectorStore {
storesText: boolean;
isEmbeddingQuery?: boolean;
client(): any;
add(embeddingResults: TextNode[]): string[];
add(embeddingResults: NodeWithEmbedding[]): string[];
delete(refDocId: string, deleteKwargs?: any): void;
query(query: VectorStoreQuery, kwargs?: any): VectorStoreQueryResult;
persist(persistPath: string, fs?: GenericFileSystem): void;
@@ -0,0 +1,208 @@
import { VectorStoreIndex } from "../BaseIndex";
import { OpenAIEmbedding } from "../Embedding";
import { ChatOpenAI } from "../LanguageModel";
import { Document } from "../Node";
import { ServiceContext, serviceContextFromDefaults } from "../ServiceContext";
import {
CallbackManager,
RetrievalCallbackResponse,
StreamCallbackResponse,
} from "../callbacks/CallbackManager";
import { ListIndex } from "../index/list";
import { mockEmbeddingModel, mockLlmGeneration } from "./utility/mockOpenAI";
// Mock the OpenAI getOpenAISession function during testing
jest.mock("../openai", () => {
return {
getOpenAISession: jest.fn().mockImplementation(() => null),
};
});
describe("CallbackManager: onLLMStream and onRetrieve", () => {
let serviceContext: ServiceContext;
let streamCallbackData: StreamCallbackResponse[] = [];
let retrieveCallbackData: RetrievalCallbackResponse[] = [];
let document: Document;
beforeAll(async () => {
document = new Document({ text: "Author: My name is Paul Graham" });
const callbackManager = new CallbackManager({
onLLMStream: (data) => {
streamCallbackData.push(data);
},
onRetrieve: (data) => {
retrieveCallbackData.push(data);
},
});
const languageModel = new ChatOpenAI({
model: "gpt-3.5-turbo",
callbackManager,
});
mockLlmGeneration({ languageModel, callbackManager });
const embedModel = new OpenAIEmbedding();
mockEmbeddingModel(embedModel);
serviceContext = serviceContextFromDefaults({
callbackManager,
llm: languageModel,
embedModel,
});
});
beforeEach(() => {
streamCallbackData = [];
retrieveCallbackData = [];
});
afterAll(() => {
jest.clearAllMocks();
});
test("For VectorStoreIndex w/ a SimpleResponseBuilder", async () => {
const vectorStoreIndex = await VectorStoreIndex.fromDocuments(
[document],
undefined,
serviceContext
);
const queryEngine = vectorStoreIndex.asQueryEngine();
const query = "What is the author's name?";
const response = await queryEngine.aquery(query);
expect(response.toString()).toBe("MOCK_TOKEN_1-MOCK_TOKEN_2");
expect(streamCallbackData).toEqual([
{
event: {
id: expect.any(String),
parentId: expect.any(String),
type: "llmPredict",
tags: ["final"],
},
index: 0,
token: {
id: "id",
object: "object",
created: 1,
model: "model",
choices: expect.any(Array),
},
},
{
event: {
id: expect.any(String),
parentId: expect.any(String),
type: "llmPredict",
tags: ["final"],
},
index: 1,
token: {
id: "id",
object: "object",
created: 1,
model: "model",
choices: expect.any(Array),
},
},
{
event: {
id: expect.any(String),
parentId: expect.any(String),
type: "llmPredict",
tags: ["final"],
},
index: 2,
isDone: true,
},
]);
expect(retrieveCallbackData).toEqual([
{
query: query,
nodes: expect.any(Array),
event: {
id: expect.any(String),
parentId: expect.any(String),
type: "retrieve",
tags: ["final"],
},
},
]);
// both retrieval and streaming should have
// the same parent event
expect(streamCallbackData[0].event.parentId).toBe(
retrieveCallbackData[0].event.parentId
);
});
test("For ListIndex w/ a ListIndexRetriever", async () => {
const listIndex = await ListIndex.fromDocuments(
[document],
undefined,
serviceContext
);
const queryEngine = listIndex.asQueryEngine();
const query = "What is the author's name?";
const response = await queryEngine.aquery(query);
expect(response.toString()).toBe("MOCK_TOKEN_1-MOCK_TOKEN_2");
expect(streamCallbackData).toEqual([
{
event: {
id: expect.any(String),
parentId: expect.any(String),
type: "llmPredict",
tags: ["final"],
},
index: 0,
token: {
id: "id",
object: "object",
created: 1,
model: "model",
choices: expect.any(Array),
},
},
{
event: {
id: expect.any(String),
parentId: expect.any(String),
type: "llmPredict",
tags: ["final"],
},
index: 1,
token: {
id: "id",
object: "object",
created: 1,
model: "model",
choices: expect.any(Array),
},
},
{
event: {
id: expect.any(String),
parentId: expect.any(String),
type: "llmPredict",
tags: ["final"],
},
index: 2,
isDone: true,
},
]);
expect(retrieveCallbackData).toEqual([
{
query: query,
nodes: expect.any(Array),
event: {
id: expect.any(String),
parentId: expect.any(String),
type: "retrieve",
tags: ["final"],
},
},
]);
// both retrieval and streaming should have
// the same parent event
expect(streamCallbackData[0].event.parentId).toBe(
retrieveCallbackData[0].event.parentId
);
});
});
@@ -0,0 +1,72 @@
import { OpenAIEmbedding } from "../../Embedding";
import { globalsHelper } from "../../GlobalsHelper";
import { BaseMessage, ChatOpenAI } from "../../LanguageModel";
import { CallbackManager, Event } from "../../callbacks/CallbackManager";
export function mockLlmGeneration({
languageModel,
callbackManager,
}: {
languageModel: ChatOpenAI;
callbackManager: CallbackManager;
}) {
jest
.spyOn(languageModel, "agenerate")
.mockImplementation(
async (messages: BaseMessage[], parentEvent?: Event) => {
const text = "MOCK_TOKEN_1-MOCK_TOKEN_2";
const event = globalsHelper.createEvent({
parentEvent,
type: "llmPredict",
});
if (callbackManager?.onLLMStream) {
const chunks = text.split("-");
for (let i = 0; i < chunks.length; i++) {
const chunk = chunks[i];
callbackManager?.onLLMStream({
event,
index: i,
token: {
id: "id",
object: "object",
created: 1,
model: "model",
choices: [
{
index: 0,
delta: {
content: chunk,
},
finish_reason: null,
},
],
},
});
}
callbackManager?.onLLMStream({
event,
index: chunks.length,
isDone: true,
});
}
return new Promise((resolve) => {
resolve({
generations: [[{ text }]],
});
});
}
);
}
export function mockEmbeddingModel(embedModel: OpenAIEmbedding) {
jest.spyOn(embedModel, "aGetTextEmbedding").mockImplementation(async (x) => {
return new Promise((resolve) => {
resolve([1, 0, 0, 0, 0, 0]);
});
});
jest.spyOn(embedModel, "aGetQueryEmbedding").mockImplementation(async (x) => {
return new Promise((resolve) => {
resolve([0, 1, 0, 0, 0, 0]);
});
});
}
+22 -22
View File
@@ -36,8 +36,8 @@ importers:
specifier: ^29.1.0
version: 29.1.0(@babel/core@7.22.5)(jest@29.5.0)(typescript@4.9.5)
turbo:
specifier: ^1.10.5
version: 1.10.5
specifier: ^1.10.6
version: 1.10.7
apps/docs:
dependencies:
@@ -5444,65 +5444,65 @@ packages:
resolution: {integrity: sha512-C3TaO7K81YvjCgQH9Q1S3R3P3BtN3RIM8n+OvX4il1K1zgE8ZhI0op7kClgkxtutIE8hQrcrHBXvIheqKUUCxw==}
dev: true
/turbo-darwin-64@1.10.5:
resolution: {integrity: sha512-fIHu+fcW7upaZEfeneoRbZjdrcsj/NxUg7IjZZmlCjgbS9Ofl8RhRid5A1L31AUK3kkqRxzagHc4WZ5x4quBgg==}
/turbo-darwin-64@1.10.7:
resolution: {integrity: sha512-N2MNuhwrl6g7vGuz4y3fFG2aR1oCs0UZ5HKl8KSTn/VC2y2YIuLGedQ3OVbo0TfEvygAlF3QGAAKKtOCmGPNKA==}
cpu: [x64]
os: [darwin]
requiresBuild: true
dev: true
optional: true
/turbo-darwin-arm64@1.10.5:
resolution: {integrity: sha512-uv0sDWizuxVvdSjaKvWdPdX4aZ8IZeYJwTJRZwLNRxZV56/1LZD65gyQIqsSNVRHuXI199yahmB+7PMJNpZFdw==}
/turbo-darwin-arm64@1.10.7:
resolution: {integrity: sha512-WbJkvjU+6qkngp7K4EsswOriO3xrNQag7YEGRtfLoDdMTk4O4QTeU6sfg2dKfDsBpTidTvEDwgIYJhYVGzrz9Q==}
cpu: [arm64]
os: [darwin]
requiresBuild: true
dev: true
optional: true
/turbo-linux-64@1.10.5:
resolution: {integrity: sha512-hI0rErgwxNmuBCNGldhJkjSbb+mT+vjfmBVKcMI/bnBmu/KU7irCrKMe5Vas280teqBrC33GgVfXndJo2cJ1DA==}
/turbo-linux-64@1.10.7:
resolution: {integrity: sha512-x1CF2CDP1pDz/J8/B2T0hnmmOQI2+y11JGIzNP0KtwxDM7rmeg3DDTtDM/9PwGqfPotN9iVGgMiMvBuMFbsLhg==}
cpu: [x64]
os: [linux]
requiresBuild: true
dev: true
optional: true
/turbo-linux-arm64@1.10.5:
resolution: {integrity: sha512-JAygWZjTuD6e7w0KSGzy7UxYqeLIpGfZDne+4MGRc8I5VeWZ6i0HWTqhhIu2/A8AuklYcoj8LkOZxCnMOF3odQ==}
/turbo-linux-arm64@1.10.7:
resolution: {integrity: sha512-JtnBmaBSYbs7peJPkXzXxsRGSGBmBEIb6/kC8RRmyvPAMyqF8wIex0pttsI+9plghREiGPtRWv/lfQEPRlXnNQ==}
cpu: [arm64]
os: [linux]
requiresBuild: true
dev: true
optional: true
/turbo-windows-64@1.10.5:
resolution: {integrity: sha512-6w2GOKmlWEAl6QkC4c2j2ZLTwB+RK6oIDRT2KqF1m07KkY6pebEzbPZLHuP08QV+SE0t+prAn+kn7hkHYkwM+Q==}
/turbo-windows-64@1.10.7:
resolution: {integrity: sha512-7A/4CByoHdolWS8dg3DPm99owfu1aY/W0V0+KxFd0o2JQMTQtoBgIMSvZesXaWM57z3OLsietFivDLQPuzE75w==}
cpu: [x64]
os: [win32]
requiresBuild: true
dev: true
optional: true
/turbo-windows-arm64@1.10.5:
resolution: {integrity: sha512-3eeHRJPU+5zWa/iiikoBoPlNd74Y+L9lrG6ZsDZdzUYxNRTMrZbto1Bu1UF77t10TXeT9BsZRXjquKqrA7R7tg==}
/turbo-windows-arm64@1.10.7:
resolution: {integrity: sha512-D36K/3b6+hqm9IBAymnuVgyePktwQ+F0lSXr2B9JfAdFPBktSqGmp50JNC7pahxhnuCLj0Vdpe9RqfnJw5zATA==}
cpu: [arm64]
os: [win32]
requiresBuild: true
dev: true
optional: true
/turbo@1.10.5:
resolution: {integrity: sha512-4yxHTrlugJhef4eXuyrPJtrgUZWlbcwmSb8iZL/5UzNjCmx+anOm1nfW2XFrZFKy4v0+/fUlqw8LkTgGVsOKaQ==}
/turbo@1.10.7:
resolution: {integrity: sha512-xm0MPM28TWx1e6TNC3wokfE5eaDqlfi0G24kmeHupDUZt5Wd0OzHFENEHMPqEaNKJ0I+AMObL6nbSZonZBV2HA==}
hasBin: true
requiresBuild: true
optionalDependencies:
turbo-darwin-64: 1.10.5
turbo-darwin-arm64: 1.10.5
turbo-linux-64: 1.10.5
turbo-linux-arm64: 1.10.5
turbo-windows-64: 1.10.5
turbo-windows-arm64: 1.10.5
turbo-darwin-64: 1.10.7
turbo-darwin-arm64: 1.10.7
turbo-linux-64: 1.10.7
turbo-linux-arm64: 1.10.7
turbo-windows-64: 1.10.7
turbo-windows-arm64: 1.10.7
dev: true
/type-check@0.4.0: