feat: improve callback manager (#675)

This commit is contained in:
Alex Yang
2024-03-31 15:34:48 -05:00
committed by GitHub
parent 041acd11fe
commit 7a23cc6c84
43 changed files with 938 additions and 955 deletions
+7
View File
@@ -0,0 +1,7 @@
---
"llamaindex": patch
"@llamaindex/env": patch
"@llamaindex/edge": patch
---
feat: improve CallbackManager
+6 -2
View File
@@ -1,9 +1,13 @@
{
"jsc": {
"parser": {
"syntax": "typescript"
"syntax": "typescript",
"decorators": true
},
"target": "esnext"
"target": "esnext",
"transform": {
"decoratorVersion": "2022-03"
}
},
"module": {
"type": "commonjs",
+6 -2
View File
@@ -1,8 +1,12 @@
{
"jsc": {
"parser": {
"syntax": "typescript"
"syntax": "typescript",
"decorators": true
},
"target": "esnext"
"target": "esnext",
"transform": {
"decoratorVersion": "2022-03"
}
}
}
+2
View File
@@ -0,0 +1,2 @@
label: Recipes
position: 3
+14
View File
@@ -0,0 +1,14 @@
# Cost Analysis
This page shows how to track LLM cost using APIs.
## Callback Manager
The callback manager is a class that manages the callback functions.
You can register `llm-start`, and `llm-end` callbacks to the callback manager for tracking the cost.
import CodeBlock from "@theme/CodeBlock";
import CodeSource from "!raw-loader!../../../../examples/recipes/cost-analysis";
<CodeBlock language="ts">{CodeSource}</CodeBlock>
+9 -9
View File
@@ -15,9 +15,9 @@
"typecheck": "tsc"
},
"dependencies": {
"@docusaurus/core": "^3.1.1",
"@docusaurus/core": "^3.2.0",
"@llamaindex/env": "workspace:*",
"@docusaurus/remark-plugin-npm2yarn": "^3.1.1",
"@docusaurus/remark-plugin-npm2yarn": "^3.2.0",
"@mdx-js/react": "^3.0.0",
"clsx": "^2.1.0",
"postcss": "^8.4.33",
@@ -27,16 +27,16 @@
"react-dom": "^18.2.0"
},
"devDependencies": {
"@docusaurus/module-type-aliases": "3.1.0",
"@docusaurus/preset-classic": "^3.1.1",
"@docusaurus/theme-classic": "^3.1.1",
"@docusaurus/types": "^3.1.1",
"@tsconfig/docusaurus": "^2.0.2",
"@docusaurus/module-type-aliases": "3.2.0",
"@docusaurus/preset-classic": "^3.2.0",
"@docusaurus/theme-classic": "^3.2.0",
"@docusaurus/types": "^3.2.0",
"@tsconfig/docusaurus": "^2.0.3",
"@types/node": "^18.19.10",
"docusaurus-plugin-typedoc": "^0.22.0",
"typedoc": "^0.25.7",
"typedoc": "^0.25.12",
"typedoc-plugin-markdown": "^3.17.1",
"typescript": "^5.3.3"
"typescript": "^5.4.3"
},
"browserslist": {
"production": [
+6 -4
View File
@@ -7,6 +7,7 @@ import {
OpenAI,
ServiceContext,
VectorStoreIndex,
runWithCallbackManager,
serviceContextFromDefaults,
storageContextFromDefaults,
} from "llamaindex";
@@ -38,7 +39,6 @@ async function main() {
llm,
chunkSize: 512,
chunkOverlap: 20,
callbackManager,
});
const index = await createIndex(serviceContext);
@@ -46,9 +46,11 @@ async function main() {
responseSynthesizer: new MultiModalResponseSynthesizer({ serviceContext }),
retriever: index.asRetriever({ similarityTopK: 3, imageSimilarityTopK: 1 }),
});
const result = await queryEngine.query({
query: "Tell me more about Vincent van Gogh's famous paintings",
});
const result = await runWithCallbackManager(callbackManager, () =>
queryEngine.query({
query: "Tell me more about Vincent van Gogh's famous paintings",
}),
);
console.log(result.response, "\n");
images.forEach((image) =>
console.log(`Image retrieved and used in inference: ${image.toString()}`),
+1 -1
View File
@@ -18,7 +18,7 @@
"devDependencies": {
"@types/node": "^18.19.10",
"ts-node": "^10.9.2",
"typescript": "^5.3.3"
"typescript": "^5.4.3"
},
"scripts": {
"lint": "eslint ."
+15 -12
View File
@@ -5,6 +5,7 @@ import {
MetadataMode,
QdrantVectorStore,
VectorStoreIndex,
runWithCallbackManager,
serviceContextFromDefaults,
storageContextFromDefaults,
} from "llamaindex";
@@ -36,19 +37,21 @@ async function main() {
const ctx = await storageContextFromDefaults({ vectorStore: qdrantVs });
console.log("Embedding documents and adding to index");
const index = await VectorStoreIndex.fromDocuments(docs, {
storageContext: ctx,
serviceContext: serviceContextFromDefaults({
callbackManager: new CallbackManager({
onRetrieve: (data) => {
console.log(
"The retrieved nodes are:",
data.nodes.map((node) => node.node.getContent(MetadataMode.NONE)),
);
},
}),
const index = await runWithCallbackManager(
new CallbackManager({
onRetrieve: (data) => {
console.log(
"The retrieved nodes are:",
data.nodes.map((node) => node.node.getContent(MetadataMode.NONE)),
);
},
}),
});
() =>
VectorStoreIndex.fromDocuments(docs, {
storageContext: ctx,
serviceContext: serviceContextFromDefaults(),
}),
);
console.log(
"Querying index with no filters: Expected output: Brown probably",
+1 -1
View File
@@ -17,6 +17,6 @@
"devDependencies": {
"@types/node": "^20.11.14",
"ts-node": "^10.9.2",
"typescript": "^5.3.3"
"typescript": "^5.4.3"
}
}
+37
View File
@@ -0,0 +1,37 @@
import { getCurrentCallbackManager } from "llamaindex/callbacks/CallbackManager";
import { OpenAI } from "llamaindex/llm";
const llm = new OpenAI({
model: "gpt-4-0125-preview",
});
let tokenCount = 0;
// @todo: use GlobalSetting in the future
getCurrentCallbackManager().addHandlers("llm-start", (event) => {
const { messages } = event.detail.payload;
tokenCount += llm.tokens(messages);
console.log("Token count:", tokenCount);
// https://openai.com/pricing
// $10.00 / 1M tokens
console.log(`Price: $${(tokenCount / 1000000) * 10}`);
});
const question = "Hello, how are you?";
console.log("Question:", question);
llm
.chat({
stream: true,
messages: [
{
content: question,
role: "user",
},
],
})
.then(async (iter) => {
console.log("Response:");
for await (const chunk of iter) {
process.stdout.write(chunk.delta);
}
});
+1 -1
View File
@@ -27,7 +27,7 @@
"prettier": "^3.2.5",
"prettier-plugin-organize-imports": "^3.2.4",
"turbo": "^1.12.3",
"typescript": "^5.3.3"
"typescript": "^5.4.3"
},
"packageManager": "pnpm@8.15.1",
"pnpm": {
+16 -1
View File
@@ -12,7 +12,7 @@ export enum Tokenizers {
}
/**
* Helper class singleton
* @internal Helper class singleton
*/
class GlobalsHelper {
defaultTokenizer: {
@@ -58,6 +58,21 @@ class GlobalsHelper {
return this.defaultTokenizer!.decode.bind(this.defaultTokenizer);
}
/**
* @deprecated createEvent will be removed in the future,
* please use `new CustomEvent(eventType, { detail: payload })` instead.
*
* Also, `parentEvent` will not be used in the future,
* use `AsyncLocalStorage` to track parent events instead.
* @example - Usage of `AsyncLocalStorage`:
* let id = 0;
* const asyncLocalStorage = new AsyncLocalStorage<number>();
* asyncLocalStorage.run(++id, async () => {
* setTimeout(() => {
* console.log('parent event id:', asyncLocalStorage.getStore()); // 1
* }, 1000)
* });
*/
createEvent({
parentEvent,
type,
+3
View File
@@ -4,6 +4,9 @@ import type { ServiceContext } from "./ServiceContext.js";
export type RetrieveParams = {
query: string;
/**
* @deprecated will be removed in the next major version
*/
parentEvent?: Event;
preFilters?: unknown;
};
-8
View File
@@ -1,5 +1,4 @@
import { PromptHelper } from "./PromptHelper.js";
import { CallbackManager } from "./callbacks/CallbackManager.js";
import { OpenAIEmbedding } from "./embeddings/OpenAIEmbedding.js";
import type { BaseEmbedding } from "./embeddings/types.js";
import type { LLM } from "./llm/index.js";
@@ -15,7 +14,6 @@ export interface ServiceContext {
promptHelper: PromptHelper;
embedModel: BaseEmbedding;
nodeParser: NodeParser;
callbackManager: CallbackManager;
// llamaLogger: any;
}
@@ -24,14 +22,12 @@ export interface ServiceContextOptions {
promptHelper?: PromptHelper;
embedModel?: BaseEmbedding;
nodeParser?: NodeParser;
callbackManager?: CallbackManager;
// NodeParser arguments
chunkSize?: number;
chunkOverlap?: number;
}
export function serviceContextFromDefaults(options?: ServiceContextOptions) {
const callbackManager = options?.callbackManager ?? new CallbackManager();
const serviceContext: ServiceContext = {
llm: options?.llm ?? new OpenAI(),
embedModel: options?.embedModel ?? new OpenAIEmbedding(),
@@ -42,7 +38,6 @@ export function serviceContextFromDefaults(options?: ServiceContextOptions) {
chunkOverlap: options?.chunkOverlap,
}),
promptHelper: options?.promptHelper ?? new PromptHelper(),
callbackManager,
};
return serviceContext;
@@ -65,8 +60,5 @@ export function serviceContextFromServiceContext(
if (options.nodeParser) {
newServiceContext.nodeParser = options.nodeParser;
}
if (options.callbackManager) {
newServiceContext.callbackManager = options.callbackManager;
}
return newServiceContext;
}
-1
View File
@@ -100,7 +100,6 @@ export class SentenceSplitter {
}
this.chunkSize = chunkSize;
this.chunkOverlap = chunkOverlap;
// this._callback_manager = callback_manager || new CallbackManager([]);
this.tokenizer = tokenizer ?? globalsHelper.tokenizer();
this.tokenizerDecoder =
-5
View File
@@ -1,4 +1,3 @@
import type { CallbackManager } from "../../callbacks/CallbackManager.js";
import type { ChatMessage } from "../../llm/index.js";
import { OpenAI } from "../../llm/index.js";
import type { ObjectRetriever } from "../../objects/base.js";
@@ -14,7 +13,6 @@ type OpenAIAgentParams = {
verbose?: boolean;
maxFunctionCalls?: number;
defaultToolChoice?: string;
callbackManager?: CallbackManager;
toolRetriever?: ObjectRetriever;
systemPrompt?: string;
};
@@ -33,7 +31,6 @@ export class OpenAIAgent extends AgentRunner {
verbose,
maxFunctionCalls = 5,
defaultToolChoice = "auto",
callbackManager,
toolRetriever,
systemPrompt,
}: OpenAIAgentParams) {
@@ -58,7 +55,6 @@ export class OpenAIAgent extends AgentRunner {
const stepEngine = new OpenAIAgentWorker({
tools,
callbackManager,
llm,
prefixMessages,
maxFunctionCalls,
@@ -69,7 +65,6 @@ export class OpenAIAgent extends AgentRunner {
super({
agentWorker: stepEngine,
memory,
callbackManager,
defaultToolChoice,
chatHistory: prefixMessages,
});
-5
View File
@@ -1,6 +1,5 @@
import { randomUUID } from "@llamaindex/env";
import { Response } from "../../Response.js";
import type { CallbackManager } from "../../callbacks/CallbackManager.js";
import {
AgentChatResponse,
ChatResponseMode,
@@ -79,7 +78,6 @@ type OpenAIAgentWorkerParams = {
prefixMessages?: ChatMessage[];
verbose?: boolean;
maxFunctionCalls?: number;
callbackManager?: CallbackManager | undefined;
toolRetriever?: ObjectRetriever;
};
@@ -98,7 +96,6 @@ export class OpenAIAgentWorker implements AgentWorker {
private maxFunctionCalls: number;
public prefixMessages: ChatMessage[];
public callbackManager: CallbackManager | undefined;
private _getTools: (input: string) => Promise<BaseTool[]>;
@@ -111,14 +108,12 @@ export class OpenAIAgentWorker implements AgentWorker {
prefixMessages,
verbose,
maxFunctionCalls = DEFAULT_MAX_FUNCTION_CALLS,
callbackManager,
toolRetriever,
}: OpenAIAgentWorkerParams) {
this.llm = llm ?? new OpenAI({ model: "gpt-3.5-turbo-0613" });
this.verbose = verbose || false;
this.maxFunctionCalls = maxFunctionCalls;
this.prefixMessages = prefixMessages || [];
this.callbackManager = callbackManager || this.llm.callbackManager;
if (tools.length > 0 && toolRetriever) {
throw new Error("Cannot specify both tools and tool_retriever");
-5
View File
@@ -1,4 +1,3 @@
import type { CallbackManager } from "../../callbacks/CallbackManager.js";
import type { ChatMessage, LLM } from "../../llm/index.js";
import type { ObjectRetriever } from "../../objects/base.js";
import type { BaseTool } from "../../types.js";
@@ -13,7 +12,6 @@ type ReActAgentParams = {
verbose?: boolean;
maxInteractions?: number;
defaultToolChoice?: string;
callbackManager?: CallbackManager;
toolRetriever?: ObjectRetriever;
};
@@ -31,12 +29,10 @@ export class ReActAgent extends AgentRunner {
verbose,
maxInteractions = 10,
defaultToolChoice = "auto",
callbackManager,
toolRetriever,
}: Partial<ReActAgentParams>) {
const stepEngine = new ReActAgentWorker({
tools: tools ?? [],
callbackManager,
llm,
maxInteractions,
toolRetriever,
@@ -46,7 +42,6 @@ export class ReActAgent extends AgentRunner {
super({
agentWorker: stepEngine,
memory,
callbackManager,
defaultToolChoice,
chatHistory: prefixMessages,
});
-6
View File
@@ -1,5 +1,4 @@
import { randomUUID } from "@llamaindex/env";
import { CallbackManager } from "../../callbacks/CallbackManager.js";
import { AgentChatResponse } from "../../engines/chat/index.js";
import type { ChatResponse, LLM } from "../../llm/index.js";
import { OpenAI } from "../../llm/index.js";
@@ -23,7 +22,6 @@ type ReActAgentWorkerParams = {
maxInteractions?: number;
reactChatFormatter?: ReActChatFormatter | undefined;
outputParser?: ReActOutputParser | undefined;
callbackManager?: CallbackManager | undefined;
verbose?: boolean | undefined;
toolRetriever?: ObjectRetriever | undefined;
};
@@ -69,8 +67,6 @@ export class ReActAgentWorker implements AgentWorker {
reactChatFormatter: ReActChatFormatter;
outputParser: ReActOutputParser;
callbackManager: CallbackManager;
_getTools: (message: string) => Promise<BaseTool[]>;
constructor({
@@ -79,12 +75,10 @@ export class ReActAgentWorker implements AgentWorker {
maxInteractions,
reactChatFormatter,
outputParser,
callbackManager,
verbose,
toolRetriever,
}: ReActAgentWorkerParams) {
this.llm = llm ?? new OpenAI({ model: "gpt-3.5-turbo-0613" });
this.callbackManager = callbackManager || new CallbackManager();
this.maxInteractions = maxInteractions ?? 10;
this.reactChatFormatter = reactChatFormatter ?? new ReActChatFormatter();
-4
View File
@@ -1,5 +1,4 @@
import { randomUUID } from "@llamaindex/env";
import { CallbackManager } from "../../callbacks/CallbackManager.js";
import type { ChatEngineAgentParams } from "../../engines/chat/index.js";
import {
AgentChatResponse,
@@ -35,7 +34,6 @@ type AgentRunnerParams = {
state?: AgentState;
memory?: BaseMemory;
llm?: LLM;
callbackManager?: CallbackManager;
initTaskStateKwargs?: Record<string, any>;
deleteTaskOnFinish?: boolean;
defaultToolChoice?: string;
@@ -45,7 +43,6 @@ export class AgentRunner extends BaseAgentRunner {
agentWorker: AgentWorker;
state: AgentState;
memory: BaseMemory;
callbackManager: CallbackManager;
initTaskStateKwargs: Record<string, any>;
deleteTaskOnFinish: boolean;
defaultToolChoice: string;
@@ -63,7 +60,6 @@ export class AgentRunner extends BaseAgentRunner {
new ChatMemoryBuffer({
chatHistory: params.chatHistory,
});
this.callbackManager = params.callbackManager ?? new CallbackManager();
this.initTaskStateKwargs = params.initTaskStateKwargs ?? {};
this.deleteTaskOnFinish = params.deleteTaskOnFinish ?? false;
this.defaultToolChoice = params.defaultToolChoice ?? "auto";
+147 -16
View File
@@ -1,6 +1,26 @@
import type { Anthropic } from "@anthropic-ai/sdk";
import { AsyncLocalStorage, CustomEvent } from "@llamaindex/env";
import type { NodeWithScore } from "../Node.js";
/**
* This type is used to define the event maps for the Llamaindex package.
*/
export interface LlamaIndexEventMaps {}
declare module "llamaindex" {
interface LlamaIndexEventMaps {
/**
* @deprecated
*/
retrieve: CustomEvent<RetrievalCallbackResponse>;
/**
* @deprecated
*/
stream: CustomEvent<StreamCallbackResponse>;
}
}
//#region @deprecated remove in the next major version
/*
An event is a wrapper that groups related operations.
For example, during retrieve and synthesize,
@@ -60,25 +80,136 @@ export interface RetrievalCallbackResponse extends BaseCallbackResponse {
}
interface CallbackManagerMethods {
/*
onLLMStream is called when a token is streamed from the LLM. Defining this
callback auto sets the stream = True flag on the openAI createChatCompletion request.
*/
onLLMStream?: (params: StreamCallbackResponse) => Promise<void> | void;
/*
onRetrieve is called as soon as the retriever finishes fetching relevant nodes.
This callback allows you to handle the retrieved nodes even if the synthesizer
is still running.
*/
onRetrieve?: (params: RetrievalCallbackResponse) => Promise<void> | void;
/**
* onLLMStream is called when a token is streamed from the LLM. Defining this
* callback auto sets the stream = True flag on the openAI createChatCompletion request.
* @deprecated will be removed in the next major version
*/
onLLMStream: (params: StreamCallbackResponse) => Promise<void> | void;
/**
* onRetrieve is called as soon as the retriever finishes fetching relevant nodes.
* This callback allows you to handle the retrieved nodes even if the synthesizer
* is still running.
* @deprecated will be removed in the next major version
*/
onRetrieve: (params: RetrievalCallbackResponse) => Promise<void> | void;
}
//#endregion
const noop: (...args: any[]) => any = () => void 0;
type EventHandler<Event extends CustomEvent> = (event: Event) => void;
export class CallbackManager implements CallbackManagerMethods {
onLLMStream?: (params: StreamCallbackResponse) => Promise<void> | void;
onRetrieve?: (params: RetrievalCallbackResponse) => Promise<void> | void;
/**
* @deprecated will be removed in the next major version
*/
get onLLMStream(): CallbackManagerMethods["onLLMStream"] {
return async (response) => {
await Promise.all(
this.#handlers
.get("stream")!
.map((handler) =>
handler(new CustomEvent("stream", { detail: response })),
),
);
};
}
constructor(handlers?: CallbackManagerMethods) {
this.onLLMStream = handlers?.onLLMStream;
this.onRetrieve = handlers?.onRetrieve;
/**
* @deprecated will be removed in the next major version
*/
get onRetrieve(): CallbackManagerMethods["onRetrieve"] {
return async (response) => {
await Promise.all(
this.#handlers
.get("retrieve")!
.map((handler) =>
handler(new CustomEvent("retrieve", { detail: response })),
),
);
};
}
/**
* @deprecated will be removed in the next major version
*/
set onLLMStream(_: never) {
throw new Error(
"onLLMStream is deprecated. Use addHandlers('stream') instead",
);
}
/**
* @deprecated will be removed in the next major version
*/
set onRetrieve(_: never) {
throw new Error(
"onRetrieve is deprecated. Use `addHandlers('retrieve')` instead",
);
}
#handlers = new Map<keyof LlamaIndexEventMaps, EventHandler<CustomEvent>[]>();
constructor(handlers?: Partial<CallbackManagerMethods>) {
const onLLMStream = handlers?.onLLMStream ?? noop;
this.addHandlers("stream", (event) => onLLMStream(event.detail));
const onRetrieve = handlers?.onRetrieve ?? noop;
this.addHandlers("retrieve", (event) => onRetrieve(event.detail));
}
addHandlers<
K extends keyof LlamaIndexEventMaps,
H extends EventHandler<LlamaIndexEventMaps[K]>,
>(event: K, handler: H) {
if (!this.#handlers.has(event)) {
this.#handlers.set(event, []);
}
this.#handlers.get(event)!.push(handler);
return this;
}
removeHandlers<
K extends keyof LlamaIndexEventMaps,
H extends EventHandler<LlamaIndexEventMaps[K]>,
>(event: K, handler: H) {
if (!this.#handlers.has(event)) {
return;
}
const handlers = this.#handlers.get(event)!;
const index = handlers.indexOf(handler);
if (index > -1) {
handlers.splice(index, 1);
}
return this;
}
dispatchEvent<K extends keyof LlamaIndexEventMaps>(
event: K,
detail: LlamaIndexEventMaps[K]["detail"],
) {
const handlers = this.#handlers.get(event);
if (!handlers) {
return;
}
handlers.forEach((handler) => handler(new CustomEvent(event, { detail })));
}
}
const defaultCallbackManager = new CallbackManager();
const callbackAsyncLocalStorage = new AsyncLocalStorage<CallbackManager>();
/**
* Get the current callback manager
* @default defaultCallbackManager if no callback manager is set
*/
export function getCurrentCallbackManager() {
return callbackAsyncLocalStorage.getStore() ?? defaultCallbackManager;
}
export function runWithCallbackManager<Result>(
callbackManager: CallbackManager,
fn: () => Result,
): Result {
return callbackAsyncLocalStorage.run(callbackManager, fn);
}
+10 -10
View File
@@ -5,6 +5,7 @@ import { ObjectType, jsonToNode } from "../Node.js";
import type { BaseRetriever, RetrieveParams } from "../Retriever.js";
import type { ServiceContext } from "../ServiceContext.js";
import { serviceContextFromDefaults } from "../ServiceContext.js";
import { getCurrentCallbackManager } from "../callbacks/CallbackManager.js";
import type { ClientParams, CloudConstructorParams } from "./types.js";
import { DEFAULT_PROJECT_NAME } from "./types.js";
import { getClient } from "./utils.js";
@@ -80,16 +81,15 @@ export class LlamaCloudRetriever implements BaseRetriever {
const nodes = this.resultNodesToNodeWithScore(results.retrievalNodes);
if (this.serviceContext.callbackManager.onRetrieve) {
this.serviceContext.callbackManager.onRetrieve({
query,
nodes,
event: globalsHelper.createEvent({
parentEvent,
type: "retrieve",
}),
});
}
getCurrentCallbackManager().onRetrieve({
query,
nodes,
event: globalsHelper.createEvent({
parentEvent,
type: "retrieve",
}),
});
return nodes;
}
+17 -20
View File
@@ -6,6 +6,7 @@ import { defaultChoiceSelectPrompt } from "../../Prompt.js";
import type { BaseRetriever, RetrieveParams } from "../../Retriever.js";
import type { ServiceContext } from "../../ServiceContext.js";
import { serviceContextFromDefaults } from "../../ServiceContext.js";
import { getCurrentCallbackManager } from "../../callbacks/CallbackManager.js";
import { RetrieverQueryEngine } from "../../engines/query/index.js";
import type { BaseNodePostprocessor } from "../../postprocessors/index.js";
import type { StorageContext } from "../../storage/StorageContext.js";
@@ -291,16 +292,14 @@ export class SummaryIndexRetriever implements BaseRetriever {
score: 1,
}));
if (this.index.serviceContext.callbackManager.onRetrieve) {
this.index.serviceContext.callbackManager.onRetrieve({
query,
nodes: result,
event: globalsHelper.createEvent({
parentEvent,
type: "retrieve",
}),
});
}
getCurrentCallbackManager().onRetrieve({
query,
nodes: result,
event: globalsHelper.createEvent({
parentEvent,
type: "retrieve",
}),
});
return result;
}
@@ -376,16 +375,14 @@ export class SummaryIndexLLMRetriever implements BaseRetriever {
results.push(...nodeWithScores);
}
if (this.serviceContext.callbackManager.onRetrieve) {
this.serviceContext.callbackManager.onRetrieve({
query,
nodes: results,
event: globalsHelper.createEvent({
parentEvent,
type: "retrieve",
}),
});
}
getCurrentCallbackManager().onRetrieve({
query,
nodes: results,
event: globalsHelper.createEvent({
parentEvent,
type: "retrieve",
}),
});
return results;
}
+12 -11
View File
@@ -14,7 +14,10 @@ import {
import type { BaseRetriever, RetrieveParams } from "../../Retriever.js";
import type { ServiceContext } from "../../ServiceContext.js";
import { serviceContextFromDefaults } from "../../ServiceContext.js";
import type { Event } from "../../callbacks/CallbackManager.js";
import {
getCurrentCallbackManager,
type Event,
} from "../../callbacks/CallbackManager.js";
import { DEFAULT_SIMILARITY_TOP_K } from "../../constants.js";
import type {
BaseEmbedding,
@@ -480,16 +483,14 @@ export class VectorIndexRetriever implements BaseRetriever {
nodesWithScores: NodeWithScore<Metadata>[],
parentEvent: Event | undefined,
) {
if (this.serviceContext.callbackManager.onRetrieve) {
this.serviceContext.callbackManager.onRetrieve({
query,
nodes: nodesWithScores,
event: globalsHelper.createEvent({
parentEvent,
type: "retrieve",
}),
});
}
getCurrentCallbackManager().onRetrieve({
query,
nodes: nodesWithScores,
event: globalsHelper.createEvent({
parentEvent,
type: "retrieve",
}),
});
}
protected async buildVectorStoreQuery(
+13 -22
View File
@@ -1,11 +1,11 @@
import type OpenAILLM from "openai";
import type { ClientOptions as OpenAIClientOptions } from "openai";
import type {
CallbackManager,
Event,
EventType,
OpenAIStreamToken,
StreamCallbackResponse,
import {
getCurrentCallbackManager,
type Event,
type EventType,
type OpenAIStreamToken,
type StreamCallbackResponse,
} from "../callbacks/CallbackManager.js";
import llamaTokenizer from "llama-tokenizer-js";
@@ -36,6 +36,7 @@ import type {
LLMMetadata,
MessageType,
} from "./types.js";
import { llmEvent } from "./utils.js";
export const GPT4_MODELS = {
"gpt-4": { contextWindow: 8192 },
@@ -102,8 +103,6 @@ export class OpenAI extends BaseLLM {
"apiKey" | "maxRetries" | "timeout"
>;
callbackManager?: CallbackManager;
constructor(
init?: Partial<OpenAI> & {
azure?: AzureOpenAIConfig;
@@ -155,8 +154,6 @@ export class OpenAI extends BaseLLM {
...this.additionalSessionOptions,
});
}
this.callbackManager = init?.callbackManager;
}
get metadata() {
@@ -230,6 +227,7 @@ export class OpenAI extends BaseLLM {
params: LLMChatParamsStreaming,
): Promise<AsyncIterable<ChatResponseChunk>>;
chat(params: LLMChatParamsNonStreaming): Promise<ChatResponse>;
@llmEvent
async chat(
params: LLMChatParamsNonStreaming | LLMChatParamsStreaming,
): Promise<ChatResponse | AsyncIterable<ChatResponseChunk>> {
@@ -293,9 +291,7 @@ export class OpenAI extends BaseLLM {
};
//Now let's wrap our stream in a callback
const onLLMStream = this.callbackManager?.onLLMStream
? this.callbackManager.onLLMStream
: () => {};
const onLLMStream = getCurrentCallbackManager().onLLMStream;
const chunk_stream: AsyncIterable<OpenAIStreamToken> =
await this.session.openai.chat.completions.create({
@@ -566,6 +562,7 @@ If a question does not make any sense, or is not factually coherent, explain why
params: LLMChatParamsStreaming,
): Promise<AsyncIterable<ChatResponseChunk>>;
chat(params: LLMChatParamsNonStreaming): Promise<ChatResponse>;
@llmEvent
async chat(
params: LLMChatParamsNonStreaming | LLMChatParamsStreaming,
): Promise<ChatResponse | AsyncIterable<ChatResponseChunk>> {
@@ -653,8 +650,6 @@ export class Anthropic extends BaseLLM {
timeout?: number;
session: AnthropicSession;
callbackManager?: CallbackManager;
constructor(init?: Partial<Anthropic>) {
super();
this.model = init?.model ?? "claude-3-opus";
@@ -672,8 +667,6 @@ export class Anthropic extends BaseLLM {
maxRetries: this.maxRetries,
timeout: this.timeout,
});
this.callbackManager = init?.callbackManager;
}
tokens(messages: ChatMessage[]): number {
@@ -715,6 +708,7 @@ export class Anthropic extends BaseLLM {
params: LLMChatParamsStreaming,
): Promise<AsyncIterable<ChatResponseChunk>>;
chat(params: LLMChatParamsNonStreaming): Promise<ChatResponse>;
@llmEvent
async chat(
params: LLMChatParamsNonStreaming | LLMChatParamsStreaming,
): Promise<ChatResponse | AsyncIterable<ChatResponseChunk>> {
@@ -790,7 +784,6 @@ export class Portkey extends BaseLLM {
mode?: string = undefined;
llms?: [LLMOptions] | null = undefined;
session: PortkeySession;
callbackManager?: CallbackManager;
constructor(init?: Partial<Portkey>) {
super();
@@ -804,7 +797,6 @@ export class Portkey extends BaseLLM {
llms: this.llms,
mode: this.mode,
});
this.callbackManager = init?.callbackManager;
}
tokens(messages: ChatMessage[]): number {
@@ -819,6 +811,7 @@ export class Portkey extends BaseLLM {
params: LLMChatParamsStreaming,
): Promise<AsyncIterable<ChatResponseChunk>>;
chat(params: LLMChatParamsNonStreaming): Promise<ChatResponse>;
@llmEvent
async chat(
params: LLMChatParamsNonStreaming | LLMChatParamsStreaming,
): Promise<ChatResponse | AsyncIterable<ChatResponseChunk>> {
@@ -844,9 +837,7 @@ export class Portkey extends BaseLLM {
params?: Record<string, any>,
): AsyncIterable<ChatResponseChunk> {
// Wrapping the stream in a callback.
const onLLMStream = this.callbackManager?.onLLMStream
? this.callbackManager.onLLMStream
: () => {};
const onLLMStream = getCurrentCallbackManager().onLLMStream;
const chunkStream = await this.session.portkey.chatCompletions.create({
messages,
+6 -10
View File
@@ -1,9 +1,9 @@
import { getEnv } from "@llamaindex/env";
import type {
CallbackManager,
Event,
EventType,
StreamCallbackResponse,
import {
getCurrentCallbackManager,
type Event,
type EventType,
type StreamCallbackResponse,
} from "../callbacks/CallbackManager.js";
import { BaseLLM } from "./base.js";
import type {
@@ -54,7 +54,6 @@ export class MistralAI extends BaseLLM {
topP: number;
maxTokens?: number;
apiKey?: string;
callbackManager?: CallbackManager;
safeMode: boolean;
randomSeed?: number;
@@ -66,7 +65,6 @@ export class MistralAI extends BaseLLM {
this.temperature = init?.temperature ?? 0.1;
this.topP = init?.topP ?? 1;
this.maxTokens = init?.maxTokens ?? undefined;
this.callbackManager = init?.callbackManager;
this.safeMode = init?.safeMode ?? false;
this.randomSeed = init?.randomSeed ?? undefined;
this.session = new MistralAISession(init);
@@ -125,9 +123,7 @@ export class MistralAI extends BaseLLM {
parentEvent,
}: LLMChatParamsStreaming): AsyncIterable<ChatResponseChunk> {
//Now let's wrap our stream in a callback
const onLLMStream = this.callbackManager?.onLLMStream
? this.callbackManager.onLLMStream
: () => {};
const onLLMStream = getCurrentCallbackManager().onLLMStream;
const client = await this.session.getClient();
const chunkStream = await client.chatStream(this.buildParams(messages));
+1 -2
View File
@@ -1,5 +1,5 @@
import { ok } from "@llamaindex/env";
import type { CallbackManager, Event } from "../callbacks/CallbackManager.js";
import type { Event } from "../callbacks/CallbackManager.js";
import { BaseEmbedding } from "../embeddings/types.js";
import type {
ChatMessage,
@@ -35,7 +35,6 @@ export class Ollama extends BaseEmbedding implements LLM {
contextWindow: number = 4096;
requestTimeout: number = 60 * 1000; // Default is 60 seconds
additionalChatOptions?: Record<string, unknown>;
callbackManager?: CallbackManager;
protected modelMetadata: Partial<LLMMetadata>;
+38 -5
View File
@@ -1,15 +1,49 @@
import type { Tokenizers } from "../GlobalsHelper.js";
import type { Event } from "../callbacks/CallbackManager.js";
import { type Event } from "../callbacks/CallbackManager.js";
type LLMBaseEvent<
Type extends string,
Payload extends Record<string, unknown>,
> = CustomEvent<{
payload: Payload;
}>;
export type LLMStartEvent = LLMBaseEvent<
"llm-start",
{
messages: ChatMessage[];
}
>;
export type LLMEndEvent = LLMBaseEvent<
"llm-end",
{
response: ChatResponse;
}
>;
declare module "llamaindex" {
interface LlamaIndexEventMaps {
"llm-start": LLMStartEvent;
"llm-end": LLMEndEvent;
}
}
/**
* @internal
*/
export interface LLMChat {
chat(
params: LLMChatParamsStreaming | LLMChatParamsNonStreaming,
): Promise<ChatResponse | AsyncIterable<ChatResponseChunk>>;
}
/**
* Unified language model interface
*/
export interface LLM {
export interface LLM extends LLMChat {
metadata: LLMMetadata;
/**
* Get a chat response from the LLM
*
* @param params
*/
chat(
params: LLMChatParamsStreaming,
@@ -18,7 +52,6 @@ export interface LLM {
/**
* Get a prompt completion from the LLM
* @param params
*/
complete(
params: LLMCompletionParamsStreaming,
+57 -1
View File
@@ -1,4 +1,5 @@
import type { MessageContent } from "./types.js";
import { getCurrentCallbackManager } from "../callbacks/CallbackManager.js";
import type { ChatResponse, LLM, LLMChat, MessageContent } from "./types.js";
export async function* streamConverter<S, D>(
stream: AsyncIterable<S>,
@@ -42,3 +43,58 @@ export function extractText(message: MessageContent): string {
}
return message;
}
/**
* @internal
*/
export function llmEvent(
originalMethod: LLMChat["chat"],
_context: ClassMethodDecoratorContext,
) {
return async function withLLMEvent(
this: LLM,
...params: Parameters<LLMChat["chat"]>
): ReturnType<LLMChat["chat"]> {
getCurrentCallbackManager().dispatchEvent("llm-start", {
payload: {
messages: params[0].messages,
},
});
const response = await originalMethod.call(this, ...params);
if (Symbol.asyncIterator in response) {
const originalAsyncIterator = {
[Symbol.asyncIterator]: response[Symbol.asyncIterator].bind(response),
};
response[Symbol.asyncIterator] = async function* () {
const finalResponse: ChatResponse = {
message: {
content: "",
role: "assistant",
},
};
let firstOne = false;
for await (const chunk of originalAsyncIterator) {
if (!firstOne) {
firstOne = true;
finalResponse.message.content = chunk.delta;
} else {
finalResponse.message.content += chunk.delta;
}
yield chunk;
}
getCurrentCallbackManager().dispatchEvent("llm-end", {
payload: {
response: finalResponse,
},
});
};
} else {
getCurrentCallbackManager().dispatchEvent("llm-end", {
payload: {
response,
},
});
}
return response;
};
}
+13 -6
View File
@@ -15,7 +15,10 @@ import type {
RetrievalCallbackResponse,
StreamCallbackResponse,
} from "llamaindex/callbacks/CallbackManager";
import { CallbackManager } from "llamaindex/callbacks/CallbackManager";
import {
CallbackManager,
runWithCallbackManager,
} from "llamaindex/callbacks/CallbackManager";
import { OpenAIEmbedding } from "llamaindex/embeddings/index";
import { SummaryIndex } from "llamaindex/indices/summary/index";
import { VectorStoreIndex } from "llamaindex/indices/vectorStore/index";
@@ -38,10 +41,11 @@ describe("CallbackManager: onLLMStream and onRetrieve", () => {
let streamCallbackData: StreamCallbackResponse[] = [];
let retrieveCallbackData: RetrievalCallbackResponse[] = [];
let document: Document;
let callbackManager: CallbackManager;
beforeAll(async () => {
document = new Document({ text: "Author: My name is Paul Graham" });
const callbackManager = new CallbackManager({
callbackManager = new CallbackManager({
onLLMStream: (data) => {
streamCallbackData.push(data);
},
@@ -52,7 +56,6 @@ describe("CallbackManager: onLLMStream and onRetrieve", () => {
const languageModel = new OpenAI({
model: "gpt-3.5-turbo",
callbackManager,
});
mockLlmGeneration({ languageModel, callbackManager });
@@ -60,7 +63,6 @@ describe("CallbackManager: onLLMStream and onRetrieve", () => {
mockEmbeddingModel(embedModel);
serviceContext = serviceContextFromDefaults({
callbackManager,
llm: languageModel,
embedModel,
});
@@ -81,7 +83,10 @@ describe("CallbackManager: onLLMStream and onRetrieve", () => {
});
const queryEngine = vectorStoreIndex.asQueryEngine();
const query = "What is the author's name?";
const response = await queryEngine.query({ query });
const response = await runWithCallbackManager(callbackManager, () => {
return queryEngine.query({ query });
});
expect(response.toString()).toBe("MOCK_TOKEN_1-MOCK_TOKEN_2");
expect(streamCallbackData).toEqual([
{
@@ -159,7 +164,9 @@ describe("CallbackManager: onLLMStream and onRetrieve", () => {
responseSynthesizer,
});
const query = "What is the author's name?";
const response = await queryEngine.query({ query });
const response = await runWithCallbackManager(callbackManager, async () =>
queryEngine.query({ query }),
);
expect(response.toString()).toBe("MOCK_TOKEN_1-MOCK_TOKEN_2");
expect(streamCallbackData).toEqual([
{
+2 -33
View File
@@ -1,11 +1,6 @@
import { Document } from "llamaindex/Node";
import type { ServiceContext } from "llamaindex/ServiceContext";
import { serviceContextFromDefaults } from "llamaindex/ServiceContext";
import type {
RetrievalCallbackResponse,
StreamCallbackResponse,
} from "llamaindex/callbacks/CallbackManager";
import { CallbackManager } from "llamaindex/callbacks/CallbackManager";
import { OpenAIEmbedding } from "llamaindex/embeddings/index";
import {
KeywordExtractor,
@@ -15,15 +10,7 @@ import {
} from "llamaindex/extractors/index";
import { OpenAI } from "llamaindex/llm/LLM";
import { SimpleNodeParser } from "llamaindex/nodeParsers/index";
import {
afterAll,
beforeAll,
beforeEach,
describe,
expect,
test,
vi,
} from "vitest";
import { afterAll, beforeAll, describe, expect, test, vi } from "vitest";
import {
DEFAULT_LLM_TEXT_OUTPUT,
mockEmbeddingModel,
@@ -39,42 +26,24 @@ vi.mock("llamaindex/llm/open_ai", () => {
describe("[MetadataExtractor]: Extractors should populate the metadata", () => {
let serviceContext: ServiceContext;
let streamCallbackData: StreamCallbackResponse[] = [];
let retrieveCallbackData: RetrievalCallbackResponse[] = [];
beforeAll(async () => {
const callbackManager = new CallbackManager({
onLLMStream: (data) => {
streamCallbackData.push(data);
},
onRetrieve: (data) => {
retrieveCallbackData.push(data);
},
});
const languageModel = new OpenAI({
model: "gpt-3.5-turbo",
callbackManager,
});
mockLlmGeneration({ languageModel, callbackManager });
mockLlmGeneration({ languageModel });
const embedModel = new OpenAIEmbedding();
mockEmbeddingModel(embedModel);
serviceContext = serviceContextFromDefaults({
callbackManager,
llm: languageModel,
embedModel,
});
});
beforeEach(() => {
streamCallbackData = [];
retrieveCallbackData = [];
});
afterAll(() => {
vi.clearAllMocks();
});
-1
View File
@@ -22,7 +22,6 @@ describe("LLMSelector", () => {
mocStructuredkLlmGeneration({
languageModel,
callbackManager: serviceContext.callbackManager,
});
const selector = new LLMSingleSelector({
@@ -1,5 +1,4 @@
import { OpenAIAgent } from "llamaindex/agent/index";
import { CallbackManager } from "llamaindex/callbacks/CallbackManager";
import { OpenAI } from "llamaindex/llm/index";
import { FunctionTool } from "llamaindex/tools/index";
import { beforeEach, describe, expect, it, vi } from "vitest";
@@ -35,16 +34,12 @@ describe("OpenAIAgent", () => {
let openaiAgent: OpenAIAgent;
beforeEach(() => {
const callbackManager = new CallbackManager({});
const languageModel = new OpenAI({
model: "gpt-3.5-turbo",
callbackManager,
});
mockLlmToolCallGeneration({
languageModel,
callbackManager,
});
const sumFunctionTool = new FunctionTool(sumNumbers, {
@@ -1,6 +1,5 @@
import { OpenAIAgentWorker } from "llamaindex/agent/index";
import { AgentRunner } from "llamaindex/agent/runner/base";
import { CallbackManager } from "llamaindex/callbacks/CallbackManager";
import { OpenAI } from "llamaindex/llm/LLM";
import { beforeEach, describe, expect, it, vi } from "vitest";
@@ -19,16 +18,12 @@ describe("Agent Runner", () => {
let agentRunner: AgentRunner;
beforeEach(() => {
const callbackManager = new CallbackManager({});
const languageModel = new OpenAI({
model: "gpt-3.5-turbo",
callbackManager,
});
mockLlmGeneration({
languageModel,
callbackManager,
});
agentRunner = new AgentRunner({
+2 -2
View File
@@ -68,7 +68,7 @@ export function mockLlmToolCallGeneration({
callbackManager,
}: {
languageModel: OpenAI;
callbackManager: CallbackManager;
callbackManager?: CallbackManager;
}) {
vi.spyOn(languageModel, "chat").mockImplementation(
() =>
@@ -119,7 +119,7 @@ export function mocStructuredkLlmGeneration({
callbackManager,
}: {
languageModel: OpenAI;
callbackManager: CallbackManager;
callbackManager?: CallbackManager;
}) {
vi.spyOn(languageModel, "chat").mockImplementation(
async ({ messages, parentEvent }: LLMChatParamsBase) => {
+1 -1
View File
@@ -39,4 +39,4 @@ export function randomUUID(): string {
return crypto.randomUUID();
}
export * from "./type.js";
export { getEnv } from "./utils.js";
export { AsyncLocalStorage, CustomEvent, getEnv } from "./utils.js";
+1 -1
View File
@@ -35,5 +35,5 @@ export const defaultFS: CompleteFileSystem = {
};
export type * from "./type.js";
export { getEnv } from "./utils.js";
export { AsyncLocalStorage, CustomEvent, getEnv } from "./utils.js";
export { EOL, ok, path, randomUUID };
+20
View File
@@ -10,3 +10,23 @@ export function getEnv(name: string): string | undefined {
}
return process.env[name];
}
// Browser doesn't support AsyncLocalStorage
export { AsyncLocalStorage } from "node:async_hooks";
class CustomEvent<T = any> extends Event {
readonly #detail: T;
get detail(): T {
return this.#detail;
}
constructor(event: string, options?: CustomEventInit) {
super(event, options);
this.#detail = options?.detail;
}
}
// Node.js doesn't have CustomEvent by default
// Refs: https://github.com/nodejs/node/issues/40678
const defaultCustomEvent = globalThis.CustomEvent || CustomEvent;
export { defaultCustomEvent as CustomEvent };
+1 -1
View File
@@ -9,7 +9,7 @@
"devDependencies": {
"@types/node": "^18.19.10",
"ts-node": "^10.9.2",
"typescript": "^5.3.3"
"typescript": "^5.4.3"
},
"scripts": {
"lint": "eslint ."
+1 -1
View File
@@ -11,7 +11,7 @@
"assemblyscript": "^0.19.9",
"@swc/cli": "^0.3.9",
"@swc/core": "^1.4.2",
"typescript": "^5.3.3"
"typescript": "^5.4.3"
},
"engines": {
"node": ">=18.0.0"
+471 -735
View File
File diff suppressed because it is too large Load Diff
+1
View File
@@ -8,6 +8,7 @@
"forceConsistentCasingInFileNames": true,
"strict": true,
"skipLibCheck": true,
"stripInternal": true,
"outDir": "./lib",
"tsBuildInfoFile": "./lib/.tsbuildinfo",
"incremental": true,