mirror of
https://github.com/run-llama/LlamaIndexTS.git
synced 2026-06-30 22:17:54 -04:00
feat: improve callback manager (#675)
This commit is contained in:
@@ -0,0 +1,7 @@
|
||||
---
|
||||
"llamaindex": patch
|
||||
"@llamaindex/env": patch
|
||||
"@llamaindex/edge": patch
|
||||
---
|
||||
|
||||
feat: improve CallbackManager
|
||||
+6
-2
@@ -1,9 +1,13 @@
|
||||
{
|
||||
"jsc": {
|
||||
"parser": {
|
||||
"syntax": "typescript"
|
||||
"syntax": "typescript",
|
||||
"decorators": true
|
||||
},
|
||||
"target": "esnext"
|
||||
"target": "esnext",
|
||||
"transform": {
|
||||
"decoratorVersion": "2022-03"
|
||||
}
|
||||
},
|
||||
"module": {
|
||||
"type": "commonjs",
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
{
|
||||
"jsc": {
|
||||
"parser": {
|
||||
"syntax": "typescript"
|
||||
"syntax": "typescript",
|
||||
"decorators": true
|
||||
},
|
||||
"target": "esnext"
|
||||
"target": "esnext",
|
||||
"transform": {
|
||||
"decoratorVersion": "2022-03"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
label: Recipes
|
||||
position: 3
|
||||
@@ -0,0 +1,14 @@
|
||||
# Cost Analysis
|
||||
|
||||
This page shows how to track LLM cost using APIs.
|
||||
|
||||
## Callback Manager
|
||||
|
||||
The callback manager is a class that manages the callback functions.
|
||||
|
||||
You can register `llm-start`, and `llm-end` callbacks to the callback manager for tracking the cost.
|
||||
|
||||
import CodeBlock from "@theme/CodeBlock";
|
||||
import CodeSource from "!raw-loader!../../../../examples/recipes/cost-analysis";
|
||||
|
||||
<CodeBlock language="ts">{CodeSource}</CodeBlock>
|
||||
@@ -15,9 +15,9 @@
|
||||
"typecheck": "tsc"
|
||||
},
|
||||
"dependencies": {
|
||||
"@docusaurus/core": "^3.1.1",
|
||||
"@docusaurus/core": "^3.2.0",
|
||||
"@llamaindex/env": "workspace:*",
|
||||
"@docusaurus/remark-plugin-npm2yarn": "^3.1.1",
|
||||
"@docusaurus/remark-plugin-npm2yarn": "^3.2.0",
|
||||
"@mdx-js/react": "^3.0.0",
|
||||
"clsx": "^2.1.0",
|
||||
"postcss": "^8.4.33",
|
||||
@@ -27,16 +27,16 @@
|
||||
"react-dom": "^18.2.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@docusaurus/module-type-aliases": "3.1.0",
|
||||
"@docusaurus/preset-classic": "^3.1.1",
|
||||
"@docusaurus/theme-classic": "^3.1.1",
|
||||
"@docusaurus/types": "^3.1.1",
|
||||
"@tsconfig/docusaurus": "^2.0.2",
|
||||
"@docusaurus/module-type-aliases": "3.2.0",
|
||||
"@docusaurus/preset-classic": "^3.2.0",
|
||||
"@docusaurus/theme-classic": "^3.2.0",
|
||||
"@docusaurus/types": "^3.2.0",
|
||||
"@tsconfig/docusaurus": "^2.0.3",
|
||||
"@types/node": "^18.19.10",
|
||||
"docusaurus-plugin-typedoc": "^0.22.0",
|
||||
"typedoc": "^0.25.7",
|
||||
"typedoc": "^0.25.12",
|
||||
"typedoc-plugin-markdown": "^3.17.1",
|
||||
"typescript": "^5.3.3"
|
||||
"typescript": "^5.4.3"
|
||||
},
|
||||
"browserslist": {
|
||||
"production": [
|
||||
|
||||
@@ -7,6 +7,7 @@ import {
|
||||
OpenAI,
|
||||
ServiceContext,
|
||||
VectorStoreIndex,
|
||||
runWithCallbackManager,
|
||||
serviceContextFromDefaults,
|
||||
storageContextFromDefaults,
|
||||
} from "llamaindex";
|
||||
@@ -38,7 +39,6 @@ async function main() {
|
||||
llm,
|
||||
chunkSize: 512,
|
||||
chunkOverlap: 20,
|
||||
callbackManager,
|
||||
});
|
||||
const index = await createIndex(serviceContext);
|
||||
|
||||
@@ -46,9 +46,11 @@ async function main() {
|
||||
responseSynthesizer: new MultiModalResponseSynthesizer({ serviceContext }),
|
||||
retriever: index.asRetriever({ similarityTopK: 3, imageSimilarityTopK: 1 }),
|
||||
});
|
||||
const result = await queryEngine.query({
|
||||
query: "Tell me more about Vincent van Gogh's famous paintings",
|
||||
});
|
||||
const result = await runWithCallbackManager(callbackManager, () =>
|
||||
queryEngine.query({
|
||||
query: "Tell me more about Vincent van Gogh's famous paintings",
|
||||
}),
|
||||
);
|
||||
console.log(result.response, "\n");
|
||||
images.forEach((image) =>
|
||||
console.log(`Image retrieved and used in inference: ${image.toString()}`),
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
"devDependencies": {
|
||||
"@types/node": "^18.19.10",
|
||||
"ts-node": "^10.9.2",
|
||||
"typescript": "^5.3.3"
|
||||
"typescript": "^5.4.3"
|
||||
},
|
||||
"scripts": {
|
||||
"lint": "eslint ."
|
||||
|
||||
@@ -5,6 +5,7 @@ import {
|
||||
MetadataMode,
|
||||
QdrantVectorStore,
|
||||
VectorStoreIndex,
|
||||
runWithCallbackManager,
|
||||
serviceContextFromDefaults,
|
||||
storageContextFromDefaults,
|
||||
} from "llamaindex";
|
||||
@@ -36,19 +37,21 @@ async function main() {
|
||||
const ctx = await storageContextFromDefaults({ vectorStore: qdrantVs });
|
||||
|
||||
console.log("Embedding documents and adding to index");
|
||||
const index = await VectorStoreIndex.fromDocuments(docs, {
|
||||
storageContext: ctx,
|
||||
serviceContext: serviceContextFromDefaults({
|
||||
callbackManager: new CallbackManager({
|
||||
onRetrieve: (data) => {
|
||||
console.log(
|
||||
"The retrieved nodes are:",
|
||||
data.nodes.map((node) => node.node.getContent(MetadataMode.NONE)),
|
||||
);
|
||||
},
|
||||
}),
|
||||
const index = await runWithCallbackManager(
|
||||
new CallbackManager({
|
||||
onRetrieve: (data) => {
|
||||
console.log(
|
||||
"The retrieved nodes are:",
|
||||
data.nodes.map((node) => node.node.getContent(MetadataMode.NONE)),
|
||||
);
|
||||
},
|
||||
}),
|
||||
});
|
||||
() =>
|
||||
VectorStoreIndex.fromDocuments(docs, {
|
||||
storageContext: ctx,
|
||||
serviceContext: serviceContextFromDefaults(),
|
||||
}),
|
||||
);
|
||||
|
||||
console.log(
|
||||
"Querying index with no filters: Expected output: Brown probably",
|
||||
|
||||
@@ -17,6 +17,6 @@
|
||||
"devDependencies": {
|
||||
"@types/node": "^20.11.14",
|
||||
"ts-node": "^10.9.2",
|
||||
"typescript": "^5.3.3"
|
||||
"typescript": "^5.4.3"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
import { getCurrentCallbackManager } from "llamaindex/callbacks/CallbackManager";
|
||||
import { OpenAI } from "llamaindex/llm";
|
||||
|
||||
const llm = new OpenAI({
|
||||
model: "gpt-4-0125-preview",
|
||||
});
|
||||
|
||||
let tokenCount = 0;
|
||||
|
||||
// @todo: use GlobalSetting in the future
|
||||
getCurrentCallbackManager().addHandlers("llm-start", (event) => {
|
||||
const { messages } = event.detail.payload;
|
||||
tokenCount += llm.tokens(messages);
|
||||
console.log("Token count:", tokenCount);
|
||||
// https://openai.com/pricing
|
||||
// $10.00 / 1M tokens
|
||||
console.log(`Price: $${(tokenCount / 1000000) * 10}`);
|
||||
});
|
||||
|
||||
const question = "Hello, how are you?";
|
||||
console.log("Question:", question);
|
||||
llm
|
||||
.chat({
|
||||
stream: true,
|
||||
messages: [
|
||||
{
|
||||
content: question,
|
||||
role: "user",
|
||||
},
|
||||
],
|
||||
})
|
||||
.then(async (iter) => {
|
||||
console.log("Response:");
|
||||
for await (const chunk of iter) {
|
||||
process.stdout.write(chunk.delta);
|
||||
}
|
||||
});
|
||||
+1
-1
@@ -27,7 +27,7 @@
|
||||
"prettier": "^3.2.5",
|
||||
"prettier-plugin-organize-imports": "^3.2.4",
|
||||
"turbo": "^1.12.3",
|
||||
"typescript": "^5.3.3"
|
||||
"typescript": "^5.4.3"
|
||||
},
|
||||
"packageManager": "pnpm@8.15.1",
|
||||
"pnpm": {
|
||||
|
||||
@@ -12,7 +12,7 @@ export enum Tokenizers {
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper class singleton
|
||||
* @internal Helper class singleton
|
||||
*/
|
||||
class GlobalsHelper {
|
||||
defaultTokenizer: {
|
||||
@@ -58,6 +58,21 @@ class GlobalsHelper {
|
||||
return this.defaultTokenizer!.decode.bind(this.defaultTokenizer);
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated createEvent will be removed in the future,
|
||||
* please use `new CustomEvent(eventType, { detail: payload })` instead.
|
||||
*
|
||||
* Also, `parentEvent` will not be used in the future,
|
||||
* use `AsyncLocalStorage` to track parent events instead.
|
||||
* @example - Usage of `AsyncLocalStorage`:
|
||||
* let id = 0;
|
||||
* const asyncLocalStorage = new AsyncLocalStorage<number>();
|
||||
* asyncLocalStorage.run(++id, async () => {
|
||||
* setTimeout(() => {
|
||||
* console.log('parent event id:', asyncLocalStorage.getStore()); // 1
|
||||
* }, 1000)
|
||||
* });
|
||||
*/
|
||||
createEvent({
|
||||
parentEvent,
|
||||
type,
|
||||
|
||||
@@ -4,6 +4,9 @@ import type { ServiceContext } from "./ServiceContext.js";
|
||||
|
||||
export type RetrieveParams = {
|
||||
query: string;
|
||||
/**
|
||||
* @deprecated will be removed in the next major version
|
||||
*/
|
||||
parentEvent?: Event;
|
||||
preFilters?: unknown;
|
||||
};
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import { PromptHelper } from "./PromptHelper.js";
|
||||
import { CallbackManager } from "./callbacks/CallbackManager.js";
|
||||
import { OpenAIEmbedding } from "./embeddings/OpenAIEmbedding.js";
|
||||
import type { BaseEmbedding } from "./embeddings/types.js";
|
||||
import type { LLM } from "./llm/index.js";
|
||||
@@ -15,7 +14,6 @@ export interface ServiceContext {
|
||||
promptHelper: PromptHelper;
|
||||
embedModel: BaseEmbedding;
|
||||
nodeParser: NodeParser;
|
||||
callbackManager: CallbackManager;
|
||||
// llamaLogger: any;
|
||||
}
|
||||
|
||||
@@ -24,14 +22,12 @@ export interface ServiceContextOptions {
|
||||
promptHelper?: PromptHelper;
|
||||
embedModel?: BaseEmbedding;
|
||||
nodeParser?: NodeParser;
|
||||
callbackManager?: CallbackManager;
|
||||
// NodeParser arguments
|
||||
chunkSize?: number;
|
||||
chunkOverlap?: number;
|
||||
}
|
||||
|
||||
export function serviceContextFromDefaults(options?: ServiceContextOptions) {
|
||||
const callbackManager = options?.callbackManager ?? new CallbackManager();
|
||||
const serviceContext: ServiceContext = {
|
||||
llm: options?.llm ?? new OpenAI(),
|
||||
embedModel: options?.embedModel ?? new OpenAIEmbedding(),
|
||||
@@ -42,7 +38,6 @@ export function serviceContextFromDefaults(options?: ServiceContextOptions) {
|
||||
chunkOverlap: options?.chunkOverlap,
|
||||
}),
|
||||
promptHelper: options?.promptHelper ?? new PromptHelper(),
|
||||
callbackManager,
|
||||
};
|
||||
|
||||
return serviceContext;
|
||||
@@ -65,8 +60,5 @@ export function serviceContextFromServiceContext(
|
||||
if (options.nodeParser) {
|
||||
newServiceContext.nodeParser = options.nodeParser;
|
||||
}
|
||||
if (options.callbackManager) {
|
||||
newServiceContext.callbackManager = options.callbackManager;
|
||||
}
|
||||
return newServiceContext;
|
||||
}
|
||||
|
||||
@@ -100,7 +100,6 @@ export class SentenceSplitter {
|
||||
}
|
||||
this.chunkSize = chunkSize;
|
||||
this.chunkOverlap = chunkOverlap;
|
||||
// this._callback_manager = callback_manager || new CallbackManager([]);
|
||||
|
||||
this.tokenizer = tokenizer ?? globalsHelper.tokenizer();
|
||||
this.tokenizerDecoder =
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import type { CallbackManager } from "../../callbacks/CallbackManager.js";
|
||||
import type { ChatMessage } from "../../llm/index.js";
|
||||
import { OpenAI } from "../../llm/index.js";
|
||||
import type { ObjectRetriever } from "../../objects/base.js";
|
||||
@@ -14,7 +13,6 @@ type OpenAIAgentParams = {
|
||||
verbose?: boolean;
|
||||
maxFunctionCalls?: number;
|
||||
defaultToolChoice?: string;
|
||||
callbackManager?: CallbackManager;
|
||||
toolRetriever?: ObjectRetriever;
|
||||
systemPrompt?: string;
|
||||
};
|
||||
@@ -33,7 +31,6 @@ export class OpenAIAgent extends AgentRunner {
|
||||
verbose,
|
||||
maxFunctionCalls = 5,
|
||||
defaultToolChoice = "auto",
|
||||
callbackManager,
|
||||
toolRetriever,
|
||||
systemPrompt,
|
||||
}: OpenAIAgentParams) {
|
||||
@@ -58,7 +55,6 @@ export class OpenAIAgent extends AgentRunner {
|
||||
|
||||
const stepEngine = new OpenAIAgentWorker({
|
||||
tools,
|
||||
callbackManager,
|
||||
llm,
|
||||
prefixMessages,
|
||||
maxFunctionCalls,
|
||||
@@ -69,7 +65,6 @@ export class OpenAIAgent extends AgentRunner {
|
||||
super({
|
||||
agentWorker: stepEngine,
|
||||
memory,
|
||||
callbackManager,
|
||||
defaultToolChoice,
|
||||
chatHistory: prefixMessages,
|
||||
});
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { randomUUID } from "@llamaindex/env";
|
||||
import { Response } from "../../Response.js";
|
||||
import type { CallbackManager } from "../../callbacks/CallbackManager.js";
|
||||
import {
|
||||
AgentChatResponse,
|
||||
ChatResponseMode,
|
||||
@@ -79,7 +78,6 @@ type OpenAIAgentWorkerParams = {
|
||||
prefixMessages?: ChatMessage[];
|
||||
verbose?: boolean;
|
||||
maxFunctionCalls?: number;
|
||||
callbackManager?: CallbackManager | undefined;
|
||||
toolRetriever?: ObjectRetriever;
|
||||
};
|
||||
|
||||
@@ -98,7 +96,6 @@ export class OpenAIAgentWorker implements AgentWorker {
|
||||
private maxFunctionCalls: number;
|
||||
|
||||
public prefixMessages: ChatMessage[];
|
||||
public callbackManager: CallbackManager | undefined;
|
||||
|
||||
private _getTools: (input: string) => Promise<BaseTool[]>;
|
||||
|
||||
@@ -111,14 +108,12 @@ export class OpenAIAgentWorker implements AgentWorker {
|
||||
prefixMessages,
|
||||
verbose,
|
||||
maxFunctionCalls = DEFAULT_MAX_FUNCTION_CALLS,
|
||||
callbackManager,
|
||||
toolRetriever,
|
||||
}: OpenAIAgentWorkerParams) {
|
||||
this.llm = llm ?? new OpenAI({ model: "gpt-3.5-turbo-0613" });
|
||||
this.verbose = verbose || false;
|
||||
this.maxFunctionCalls = maxFunctionCalls;
|
||||
this.prefixMessages = prefixMessages || [];
|
||||
this.callbackManager = callbackManager || this.llm.callbackManager;
|
||||
|
||||
if (tools.length > 0 && toolRetriever) {
|
||||
throw new Error("Cannot specify both tools and tool_retriever");
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import type { CallbackManager } from "../../callbacks/CallbackManager.js";
|
||||
import type { ChatMessage, LLM } from "../../llm/index.js";
|
||||
import type { ObjectRetriever } from "../../objects/base.js";
|
||||
import type { BaseTool } from "../../types.js";
|
||||
@@ -13,7 +12,6 @@ type ReActAgentParams = {
|
||||
verbose?: boolean;
|
||||
maxInteractions?: number;
|
||||
defaultToolChoice?: string;
|
||||
callbackManager?: CallbackManager;
|
||||
toolRetriever?: ObjectRetriever;
|
||||
};
|
||||
|
||||
@@ -31,12 +29,10 @@ export class ReActAgent extends AgentRunner {
|
||||
verbose,
|
||||
maxInteractions = 10,
|
||||
defaultToolChoice = "auto",
|
||||
callbackManager,
|
||||
toolRetriever,
|
||||
}: Partial<ReActAgentParams>) {
|
||||
const stepEngine = new ReActAgentWorker({
|
||||
tools: tools ?? [],
|
||||
callbackManager,
|
||||
llm,
|
||||
maxInteractions,
|
||||
toolRetriever,
|
||||
@@ -46,7 +42,6 @@ export class ReActAgent extends AgentRunner {
|
||||
super({
|
||||
agentWorker: stepEngine,
|
||||
memory,
|
||||
callbackManager,
|
||||
defaultToolChoice,
|
||||
chatHistory: prefixMessages,
|
||||
});
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import { randomUUID } from "@llamaindex/env";
|
||||
import { CallbackManager } from "../../callbacks/CallbackManager.js";
|
||||
import { AgentChatResponse } from "../../engines/chat/index.js";
|
||||
import type { ChatResponse, LLM } from "../../llm/index.js";
|
||||
import { OpenAI } from "../../llm/index.js";
|
||||
@@ -23,7 +22,6 @@ type ReActAgentWorkerParams = {
|
||||
maxInteractions?: number;
|
||||
reactChatFormatter?: ReActChatFormatter | undefined;
|
||||
outputParser?: ReActOutputParser | undefined;
|
||||
callbackManager?: CallbackManager | undefined;
|
||||
verbose?: boolean | undefined;
|
||||
toolRetriever?: ObjectRetriever | undefined;
|
||||
};
|
||||
@@ -69,8 +67,6 @@ export class ReActAgentWorker implements AgentWorker {
|
||||
reactChatFormatter: ReActChatFormatter;
|
||||
outputParser: ReActOutputParser;
|
||||
|
||||
callbackManager: CallbackManager;
|
||||
|
||||
_getTools: (message: string) => Promise<BaseTool[]>;
|
||||
|
||||
constructor({
|
||||
@@ -79,12 +75,10 @@ export class ReActAgentWorker implements AgentWorker {
|
||||
maxInteractions,
|
||||
reactChatFormatter,
|
||||
outputParser,
|
||||
callbackManager,
|
||||
verbose,
|
||||
toolRetriever,
|
||||
}: ReActAgentWorkerParams) {
|
||||
this.llm = llm ?? new OpenAI({ model: "gpt-3.5-turbo-0613" });
|
||||
this.callbackManager = callbackManager || new CallbackManager();
|
||||
|
||||
this.maxInteractions = maxInteractions ?? 10;
|
||||
this.reactChatFormatter = reactChatFormatter ?? new ReActChatFormatter();
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import { randomUUID } from "@llamaindex/env";
|
||||
import { CallbackManager } from "../../callbacks/CallbackManager.js";
|
||||
import type { ChatEngineAgentParams } from "../../engines/chat/index.js";
|
||||
import {
|
||||
AgentChatResponse,
|
||||
@@ -35,7 +34,6 @@ type AgentRunnerParams = {
|
||||
state?: AgentState;
|
||||
memory?: BaseMemory;
|
||||
llm?: LLM;
|
||||
callbackManager?: CallbackManager;
|
||||
initTaskStateKwargs?: Record<string, any>;
|
||||
deleteTaskOnFinish?: boolean;
|
||||
defaultToolChoice?: string;
|
||||
@@ -45,7 +43,6 @@ export class AgentRunner extends BaseAgentRunner {
|
||||
agentWorker: AgentWorker;
|
||||
state: AgentState;
|
||||
memory: BaseMemory;
|
||||
callbackManager: CallbackManager;
|
||||
initTaskStateKwargs: Record<string, any>;
|
||||
deleteTaskOnFinish: boolean;
|
||||
defaultToolChoice: string;
|
||||
@@ -63,7 +60,6 @@ export class AgentRunner extends BaseAgentRunner {
|
||||
new ChatMemoryBuffer({
|
||||
chatHistory: params.chatHistory,
|
||||
});
|
||||
this.callbackManager = params.callbackManager ?? new CallbackManager();
|
||||
this.initTaskStateKwargs = params.initTaskStateKwargs ?? {};
|
||||
this.deleteTaskOnFinish = params.deleteTaskOnFinish ?? false;
|
||||
this.defaultToolChoice = params.defaultToolChoice ?? "auto";
|
||||
|
||||
@@ -1,6 +1,26 @@
|
||||
import type { Anthropic } from "@anthropic-ai/sdk";
|
||||
import { AsyncLocalStorage, CustomEvent } from "@llamaindex/env";
|
||||
import type { NodeWithScore } from "../Node.js";
|
||||
|
||||
/**
|
||||
* This type is used to define the event maps for the Llamaindex package.
|
||||
*/
|
||||
export interface LlamaIndexEventMaps {}
|
||||
|
||||
declare module "llamaindex" {
|
||||
interface LlamaIndexEventMaps {
|
||||
/**
|
||||
* @deprecated
|
||||
*/
|
||||
retrieve: CustomEvent<RetrievalCallbackResponse>;
|
||||
/**
|
||||
* @deprecated
|
||||
*/
|
||||
stream: CustomEvent<StreamCallbackResponse>;
|
||||
}
|
||||
}
|
||||
|
||||
//#region @deprecated remove in the next major version
|
||||
/*
|
||||
An event is a wrapper that groups related operations.
|
||||
For example, during retrieve and synthesize,
|
||||
@@ -60,25 +80,136 @@ export interface RetrievalCallbackResponse extends BaseCallbackResponse {
|
||||
}
|
||||
|
||||
interface CallbackManagerMethods {
|
||||
/*
|
||||
onLLMStream is called when a token is streamed from the LLM. Defining this
|
||||
callback auto sets the stream = True flag on the openAI createChatCompletion request.
|
||||
*/
|
||||
onLLMStream?: (params: StreamCallbackResponse) => Promise<void> | void;
|
||||
/*
|
||||
onRetrieve is called as soon as the retriever finishes fetching relevant nodes.
|
||||
This callback allows you to handle the retrieved nodes even if the synthesizer
|
||||
is still running.
|
||||
*/
|
||||
onRetrieve?: (params: RetrievalCallbackResponse) => Promise<void> | void;
|
||||
/**
|
||||
* onLLMStream is called when a token is streamed from the LLM. Defining this
|
||||
* callback auto sets the stream = True flag on the openAI createChatCompletion request.
|
||||
* @deprecated will be removed in the next major version
|
||||
*/
|
||||
onLLMStream: (params: StreamCallbackResponse) => Promise<void> | void;
|
||||
/**
|
||||
* onRetrieve is called as soon as the retriever finishes fetching relevant nodes.
|
||||
* This callback allows you to handle the retrieved nodes even if the synthesizer
|
||||
* is still running.
|
||||
* @deprecated will be removed in the next major version
|
||||
*/
|
||||
onRetrieve: (params: RetrievalCallbackResponse) => Promise<void> | void;
|
||||
}
|
||||
//#endregion
|
||||
|
||||
const noop: (...args: any[]) => any = () => void 0;
|
||||
|
||||
type EventHandler<Event extends CustomEvent> = (event: Event) => void;
|
||||
|
||||
export class CallbackManager implements CallbackManagerMethods {
|
||||
onLLMStream?: (params: StreamCallbackResponse) => Promise<void> | void;
|
||||
onRetrieve?: (params: RetrievalCallbackResponse) => Promise<void> | void;
|
||||
/**
|
||||
* @deprecated will be removed in the next major version
|
||||
*/
|
||||
get onLLMStream(): CallbackManagerMethods["onLLMStream"] {
|
||||
return async (response) => {
|
||||
await Promise.all(
|
||||
this.#handlers
|
||||
.get("stream")!
|
||||
.map((handler) =>
|
||||
handler(new CustomEvent("stream", { detail: response })),
|
||||
),
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
constructor(handlers?: CallbackManagerMethods) {
|
||||
this.onLLMStream = handlers?.onLLMStream;
|
||||
this.onRetrieve = handlers?.onRetrieve;
|
||||
/**
|
||||
* @deprecated will be removed in the next major version
|
||||
*/
|
||||
get onRetrieve(): CallbackManagerMethods["onRetrieve"] {
|
||||
return async (response) => {
|
||||
await Promise.all(
|
||||
this.#handlers
|
||||
.get("retrieve")!
|
||||
.map((handler) =>
|
||||
handler(new CustomEvent("retrieve", { detail: response })),
|
||||
),
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated will be removed in the next major version
|
||||
*/
|
||||
set onLLMStream(_: never) {
|
||||
throw new Error(
|
||||
"onLLMStream is deprecated. Use addHandlers('stream') instead",
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated will be removed in the next major version
|
||||
*/
|
||||
set onRetrieve(_: never) {
|
||||
throw new Error(
|
||||
"onRetrieve is deprecated. Use `addHandlers('retrieve')` instead",
|
||||
);
|
||||
}
|
||||
|
||||
#handlers = new Map<keyof LlamaIndexEventMaps, EventHandler<CustomEvent>[]>();
|
||||
|
||||
constructor(handlers?: Partial<CallbackManagerMethods>) {
|
||||
const onLLMStream = handlers?.onLLMStream ?? noop;
|
||||
this.addHandlers("stream", (event) => onLLMStream(event.detail));
|
||||
const onRetrieve = handlers?.onRetrieve ?? noop;
|
||||
this.addHandlers("retrieve", (event) => onRetrieve(event.detail));
|
||||
}
|
||||
|
||||
addHandlers<
|
||||
K extends keyof LlamaIndexEventMaps,
|
||||
H extends EventHandler<LlamaIndexEventMaps[K]>,
|
||||
>(event: K, handler: H) {
|
||||
if (!this.#handlers.has(event)) {
|
||||
this.#handlers.set(event, []);
|
||||
}
|
||||
this.#handlers.get(event)!.push(handler);
|
||||
return this;
|
||||
}
|
||||
|
||||
removeHandlers<
|
||||
K extends keyof LlamaIndexEventMaps,
|
||||
H extends EventHandler<LlamaIndexEventMaps[K]>,
|
||||
>(event: K, handler: H) {
|
||||
if (!this.#handlers.has(event)) {
|
||||
return;
|
||||
}
|
||||
const handlers = this.#handlers.get(event)!;
|
||||
const index = handlers.indexOf(handler);
|
||||
if (index > -1) {
|
||||
handlers.splice(index, 1);
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
dispatchEvent<K extends keyof LlamaIndexEventMaps>(
|
||||
event: K,
|
||||
detail: LlamaIndexEventMaps[K]["detail"],
|
||||
) {
|
||||
const handlers = this.#handlers.get(event);
|
||||
if (!handlers) {
|
||||
return;
|
||||
}
|
||||
handlers.forEach((handler) => handler(new CustomEvent(event, { detail })));
|
||||
}
|
||||
}
|
||||
|
||||
const defaultCallbackManager = new CallbackManager();
|
||||
const callbackAsyncLocalStorage = new AsyncLocalStorage<CallbackManager>();
|
||||
|
||||
/**
|
||||
* Get the current callback manager
|
||||
* @default defaultCallbackManager if no callback manager is set
|
||||
*/
|
||||
export function getCurrentCallbackManager() {
|
||||
return callbackAsyncLocalStorage.getStore() ?? defaultCallbackManager;
|
||||
}
|
||||
|
||||
export function runWithCallbackManager<Result>(
|
||||
callbackManager: CallbackManager,
|
||||
fn: () => Result,
|
||||
): Result {
|
||||
return callbackAsyncLocalStorage.run(callbackManager, fn);
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import { ObjectType, jsonToNode } from "../Node.js";
|
||||
import type { BaseRetriever, RetrieveParams } from "../Retriever.js";
|
||||
import type { ServiceContext } from "../ServiceContext.js";
|
||||
import { serviceContextFromDefaults } from "../ServiceContext.js";
|
||||
import { getCurrentCallbackManager } from "../callbacks/CallbackManager.js";
|
||||
import type { ClientParams, CloudConstructorParams } from "./types.js";
|
||||
import { DEFAULT_PROJECT_NAME } from "./types.js";
|
||||
import { getClient } from "./utils.js";
|
||||
@@ -80,16 +81,15 @@ export class LlamaCloudRetriever implements BaseRetriever {
|
||||
|
||||
const nodes = this.resultNodesToNodeWithScore(results.retrievalNodes);
|
||||
|
||||
if (this.serviceContext.callbackManager.onRetrieve) {
|
||||
this.serviceContext.callbackManager.onRetrieve({
|
||||
query,
|
||||
nodes,
|
||||
event: globalsHelper.createEvent({
|
||||
parentEvent,
|
||||
type: "retrieve",
|
||||
}),
|
||||
});
|
||||
}
|
||||
getCurrentCallbackManager().onRetrieve({
|
||||
query,
|
||||
nodes,
|
||||
event: globalsHelper.createEvent({
|
||||
parentEvent,
|
||||
type: "retrieve",
|
||||
}),
|
||||
});
|
||||
|
||||
return nodes;
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import { defaultChoiceSelectPrompt } from "../../Prompt.js";
|
||||
import type { BaseRetriever, RetrieveParams } from "../../Retriever.js";
|
||||
import type { ServiceContext } from "../../ServiceContext.js";
|
||||
import { serviceContextFromDefaults } from "../../ServiceContext.js";
|
||||
import { getCurrentCallbackManager } from "../../callbacks/CallbackManager.js";
|
||||
import { RetrieverQueryEngine } from "../../engines/query/index.js";
|
||||
import type { BaseNodePostprocessor } from "../../postprocessors/index.js";
|
||||
import type { StorageContext } from "../../storage/StorageContext.js";
|
||||
@@ -291,16 +292,14 @@ export class SummaryIndexRetriever implements BaseRetriever {
|
||||
score: 1,
|
||||
}));
|
||||
|
||||
if (this.index.serviceContext.callbackManager.onRetrieve) {
|
||||
this.index.serviceContext.callbackManager.onRetrieve({
|
||||
query,
|
||||
nodes: result,
|
||||
event: globalsHelper.createEvent({
|
||||
parentEvent,
|
||||
type: "retrieve",
|
||||
}),
|
||||
});
|
||||
}
|
||||
getCurrentCallbackManager().onRetrieve({
|
||||
query,
|
||||
nodes: result,
|
||||
event: globalsHelper.createEvent({
|
||||
parentEvent,
|
||||
type: "retrieve",
|
||||
}),
|
||||
});
|
||||
|
||||
return result;
|
||||
}
|
||||
@@ -376,16 +375,14 @@ export class SummaryIndexLLMRetriever implements BaseRetriever {
|
||||
results.push(...nodeWithScores);
|
||||
}
|
||||
|
||||
if (this.serviceContext.callbackManager.onRetrieve) {
|
||||
this.serviceContext.callbackManager.onRetrieve({
|
||||
query,
|
||||
nodes: results,
|
||||
event: globalsHelper.createEvent({
|
||||
parentEvent,
|
||||
type: "retrieve",
|
||||
}),
|
||||
});
|
||||
}
|
||||
getCurrentCallbackManager().onRetrieve({
|
||||
query,
|
||||
nodes: results,
|
||||
event: globalsHelper.createEvent({
|
||||
parentEvent,
|
||||
type: "retrieve",
|
||||
}),
|
||||
});
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
@@ -14,7 +14,10 @@ import {
|
||||
import type { BaseRetriever, RetrieveParams } from "../../Retriever.js";
|
||||
import type { ServiceContext } from "../../ServiceContext.js";
|
||||
import { serviceContextFromDefaults } from "../../ServiceContext.js";
|
||||
import type { Event } from "../../callbacks/CallbackManager.js";
|
||||
import {
|
||||
getCurrentCallbackManager,
|
||||
type Event,
|
||||
} from "../../callbacks/CallbackManager.js";
|
||||
import { DEFAULT_SIMILARITY_TOP_K } from "../../constants.js";
|
||||
import type {
|
||||
BaseEmbedding,
|
||||
@@ -480,16 +483,14 @@ export class VectorIndexRetriever implements BaseRetriever {
|
||||
nodesWithScores: NodeWithScore<Metadata>[],
|
||||
parentEvent: Event | undefined,
|
||||
) {
|
||||
if (this.serviceContext.callbackManager.onRetrieve) {
|
||||
this.serviceContext.callbackManager.onRetrieve({
|
||||
query,
|
||||
nodes: nodesWithScores,
|
||||
event: globalsHelper.createEvent({
|
||||
parentEvent,
|
||||
type: "retrieve",
|
||||
}),
|
||||
});
|
||||
}
|
||||
getCurrentCallbackManager().onRetrieve({
|
||||
query,
|
||||
nodes: nodesWithScores,
|
||||
event: globalsHelper.createEvent({
|
||||
parentEvent,
|
||||
type: "retrieve",
|
||||
}),
|
||||
});
|
||||
}
|
||||
|
||||
protected async buildVectorStoreQuery(
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import type OpenAILLM from "openai";
|
||||
import type { ClientOptions as OpenAIClientOptions } from "openai";
|
||||
import type {
|
||||
CallbackManager,
|
||||
Event,
|
||||
EventType,
|
||||
OpenAIStreamToken,
|
||||
StreamCallbackResponse,
|
||||
import {
|
||||
getCurrentCallbackManager,
|
||||
type Event,
|
||||
type EventType,
|
||||
type OpenAIStreamToken,
|
||||
type StreamCallbackResponse,
|
||||
} from "../callbacks/CallbackManager.js";
|
||||
|
||||
import llamaTokenizer from "llama-tokenizer-js";
|
||||
@@ -36,6 +36,7 @@ import type {
|
||||
LLMMetadata,
|
||||
MessageType,
|
||||
} from "./types.js";
|
||||
import { llmEvent } from "./utils.js";
|
||||
|
||||
export const GPT4_MODELS = {
|
||||
"gpt-4": { contextWindow: 8192 },
|
||||
@@ -102,8 +103,6 @@ export class OpenAI extends BaseLLM {
|
||||
"apiKey" | "maxRetries" | "timeout"
|
||||
>;
|
||||
|
||||
callbackManager?: CallbackManager;
|
||||
|
||||
constructor(
|
||||
init?: Partial<OpenAI> & {
|
||||
azure?: AzureOpenAIConfig;
|
||||
@@ -155,8 +154,6 @@ export class OpenAI extends BaseLLM {
|
||||
...this.additionalSessionOptions,
|
||||
});
|
||||
}
|
||||
|
||||
this.callbackManager = init?.callbackManager;
|
||||
}
|
||||
|
||||
get metadata() {
|
||||
@@ -230,6 +227,7 @@ export class OpenAI extends BaseLLM {
|
||||
params: LLMChatParamsStreaming,
|
||||
): Promise<AsyncIterable<ChatResponseChunk>>;
|
||||
chat(params: LLMChatParamsNonStreaming): Promise<ChatResponse>;
|
||||
@llmEvent
|
||||
async chat(
|
||||
params: LLMChatParamsNonStreaming | LLMChatParamsStreaming,
|
||||
): Promise<ChatResponse | AsyncIterable<ChatResponseChunk>> {
|
||||
@@ -293,9 +291,7 @@ export class OpenAI extends BaseLLM {
|
||||
};
|
||||
|
||||
//Now let's wrap our stream in a callback
|
||||
const onLLMStream = this.callbackManager?.onLLMStream
|
||||
? this.callbackManager.onLLMStream
|
||||
: () => {};
|
||||
const onLLMStream = getCurrentCallbackManager().onLLMStream;
|
||||
|
||||
const chunk_stream: AsyncIterable<OpenAIStreamToken> =
|
||||
await this.session.openai.chat.completions.create({
|
||||
@@ -566,6 +562,7 @@ If a question does not make any sense, or is not factually coherent, explain why
|
||||
params: LLMChatParamsStreaming,
|
||||
): Promise<AsyncIterable<ChatResponseChunk>>;
|
||||
chat(params: LLMChatParamsNonStreaming): Promise<ChatResponse>;
|
||||
@llmEvent
|
||||
async chat(
|
||||
params: LLMChatParamsNonStreaming | LLMChatParamsStreaming,
|
||||
): Promise<ChatResponse | AsyncIterable<ChatResponseChunk>> {
|
||||
@@ -653,8 +650,6 @@ export class Anthropic extends BaseLLM {
|
||||
timeout?: number;
|
||||
session: AnthropicSession;
|
||||
|
||||
callbackManager?: CallbackManager;
|
||||
|
||||
constructor(init?: Partial<Anthropic>) {
|
||||
super();
|
||||
this.model = init?.model ?? "claude-3-opus";
|
||||
@@ -672,8 +667,6 @@ export class Anthropic extends BaseLLM {
|
||||
maxRetries: this.maxRetries,
|
||||
timeout: this.timeout,
|
||||
});
|
||||
|
||||
this.callbackManager = init?.callbackManager;
|
||||
}
|
||||
|
||||
tokens(messages: ChatMessage[]): number {
|
||||
@@ -715,6 +708,7 @@ export class Anthropic extends BaseLLM {
|
||||
params: LLMChatParamsStreaming,
|
||||
): Promise<AsyncIterable<ChatResponseChunk>>;
|
||||
chat(params: LLMChatParamsNonStreaming): Promise<ChatResponse>;
|
||||
@llmEvent
|
||||
async chat(
|
||||
params: LLMChatParamsNonStreaming | LLMChatParamsStreaming,
|
||||
): Promise<ChatResponse | AsyncIterable<ChatResponseChunk>> {
|
||||
@@ -790,7 +784,6 @@ export class Portkey extends BaseLLM {
|
||||
mode?: string = undefined;
|
||||
llms?: [LLMOptions] | null = undefined;
|
||||
session: PortkeySession;
|
||||
callbackManager?: CallbackManager;
|
||||
|
||||
constructor(init?: Partial<Portkey>) {
|
||||
super();
|
||||
@@ -804,7 +797,6 @@ export class Portkey extends BaseLLM {
|
||||
llms: this.llms,
|
||||
mode: this.mode,
|
||||
});
|
||||
this.callbackManager = init?.callbackManager;
|
||||
}
|
||||
|
||||
tokens(messages: ChatMessage[]): number {
|
||||
@@ -819,6 +811,7 @@ export class Portkey extends BaseLLM {
|
||||
params: LLMChatParamsStreaming,
|
||||
): Promise<AsyncIterable<ChatResponseChunk>>;
|
||||
chat(params: LLMChatParamsNonStreaming): Promise<ChatResponse>;
|
||||
@llmEvent
|
||||
async chat(
|
||||
params: LLMChatParamsNonStreaming | LLMChatParamsStreaming,
|
||||
): Promise<ChatResponse | AsyncIterable<ChatResponseChunk>> {
|
||||
@@ -844,9 +837,7 @@ export class Portkey extends BaseLLM {
|
||||
params?: Record<string, any>,
|
||||
): AsyncIterable<ChatResponseChunk> {
|
||||
// Wrapping the stream in a callback.
|
||||
const onLLMStream = this.callbackManager?.onLLMStream
|
||||
? this.callbackManager.onLLMStream
|
||||
: () => {};
|
||||
const onLLMStream = getCurrentCallbackManager().onLLMStream;
|
||||
|
||||
const chunkStream = await this.session.portkey.chatCompletions.create({
|
||||
messages,
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import { getEnv } from "@llamaindex/env";
|
||||
import type {
|
||||
CallbackManager,
|
||||
Event,
|
||||
EventType,
|
||||
StreamCallbackResponse,
|
||||
import {
|
||||
getCurrentCallbackManager,
|
||||
type Event,
|
||||
type EventType,
|
||||
type StreamCallbackResponse,
|
||||
} from "../callbacks/CallbackManager.js";
|
||||
import { BaseLLM } from "./base.js";
|
||||
import type {
|
||||
@@ -54,7 +54,6 @@ export class MistralAI extends BaseLLM {
|
||||
topP: number;
|
||||
maxTokens?: number;
|
||||
apiKey?: string;
|
||||
callbackManager?: CallbackManager;
|
||||
safeMode: boolean;
|
||||
randomSeed?: number;
|
||||
|
||||
@@ -66,7 +65,6 @@ export class MistralAI extends BaseLLM {
|
||||
this.temperature = init?.temperature ?? 0.1;
|
||||
this.topP = init?.topP ?? 1;
|
||||
this.maxTokens = init?.maxTokens ?? undefined;
|
||||
this.callbackManager = init?.callbackManager;
|
||||
this.safeMode = init?.safeMode ?? false;
|
||||
this.randomSeed = init?.randomSeed ?? undefined;
|
||||
this.session = new MistralAISession(init);
|
||||
@@ -125,9 +123,7 @@ export class MistralAI extends BaseLLM {
|
||||
parentEvent,
|
||||
}: LLMChatParamsStreaming): AsyncIterable<ChatResponseChunk> {
|
||||
//Now let's wrap our stream in a callback
|
||||
const onLLMStream = this.callbackManager?.onLLMStream
|
||||
? this.callbackManager.onLLMStream
|
||||
: () => {};
|
||||
const onLLMStream = getCurrentCallbackManager().onLLMStream;
|
||||
|
||||
const client = await this.session.getClient();
|
||||
const chunkStream = await client.chatStream(this.buildParams(messages));
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { ok } from "@llamaindex/env";
|
||||
import type { CallbackManager, Event } from "../callbacks/CallbackManager.js";
|
||||
import type { Event } from "../callbacks/CallbackManager.js";
|
||||
import { BaseEmbedding } from "../embeddings/types.js";
|
||||
import type {
|
||||
ChatMessage,
|
||||
@@ -35,7 +35,6 @@ export class Ollama extends BaseEmbedding implements LLM {
|
||||
contextWindow: number = 4096;
|
||||
requestTimeout: number = 60 * 1000; // Default is 60 seconds
|
||||
additionalChatOptions?: Record<string, unknown>;
|
||||
callbackManager?: CallbackManager;
|
||||
|
||||
protected modelMetadata: Partial<LLMMetadata>;
|
||||
|
||||
|
||||
@@ -1,15 +1,49 @@
|
||||
import type { Tokenizers } from "../GlobalsHelper.js";
|
||||
import type { Event } from "../callbacks/CallbackManager.js";
|
||||
import { type Event } from "../callbacks/CallbackManager.js";
|
||||
|
||||
type LLMBaseEvent<
|
||||
Type extends string,
|
||||
Payload extends Record<string, unknown>,
|
||||
> = CustomEvent<{
|
||||
payload: Payload;
|
||||
}>;
|
||||
|
||||
export type LLMStartEvent = LLMBaseEvent<
|
||||
"llm-start",
|
||||
{
|
||||
messages: ChatMessage[];
|
||||
}
|
||||
>;
|
||||
export type LLMEndEvent = LLMBaseEvent<
|
||||
"llm-end",
|
||||
{
|
||||
response: ChatResponse;
|
||||
}
|
||||
>;
|
||||
|
||||
declare module "llamaindex" {
|
||||
interface LlamaIndexEventMaps {
|
||||
"llm-start": LLMStartEvent;
|
||||
"llm-end": LLMEndEvent;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @internal
|
||||
*/
|
||||
export interface LLMChat {
|
||||
chat(
|
||||
params: LLMChatParamsStreaming | LLMChatParamsNonStreaming,
|
||||
): Promise<ChatResponse | AsyncIterable<ChatResponseChunk>>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Unified language model interface
|
||||
*/
|
||||
export interface LLM {
|
||||
export interface LLM extends LLMChat {
|
||||
metadata: LLMMetadata;
|
||||
/**
|
||||
* Get a chat response from the LLM
|
||||
*
|
||||
* @param params
|
||||
*/
|
||||
chat(
|
||||
params: LLMChatParamsStreaming,
|
||||
@@ -18,7 +52,6 @@ export interface LLM {
|
||||
|
||||
/**
|
||||
* Get a prompt completion from the LLM
|
||||
* @param params
|
||||
*/
|
||||
complete(
|
||||
params: LLMCompletionParamsStreaming,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import type { MessageContent } from "./types.js";
|
||||
import { getCurrentCallbackManager } from "../callbacks/CallbackManager.js";
|
||||
import type { ChatResponse, LLM, LLMChat, MessageContent } from "./types.js";
|
||||
|
||||
export async function* streamConverter<S, D>(
|
||||
stream: AsyncIterable<S>,
|
||||
@@ -42,3 +43,58 @@ export function extractText(message: MessageContent): string {
|
||||
}
|
||||
return message;
|
||||
}
|
||||
|
||||
/**
|
||||
* @internal
|
||||
*/
|
||||
export function llmEvent(
|
||||
originalMethod: LLMChat["chat"],
|
||||
_context: ClassMethodDecoratorContext,
|
||||
) {
|
||||
return async function withLLMEvent(
|
||||
this: LLM,
|
||||
...params: Parameters<LLMChat["chat"]>
|
||||
): ReturnType<LLMChat["chat"]> {
|
||||
getCurrentCallbackManager().dispatchEvent("llm-start", {
|
||||
payload: {
|
||||
messages: params[0].messages,
|
||||
},
|
||||
});
|
||||
const response = await originalMethod.call(this, ...params);
|
||||
if (Symbol.asyncIterator in response) {
|
||||
const originalAsyncIterator = {
|
||||
[Symbol.asyncIterator]: response[Symbol.asyncIterator].bind(response),
|
||||
};
|
||||
response[Symbol.asyncIterator] = async function* () {
|
||||
const finalResponse: ChatResponse = {
|
||||
message: {
|
||||
content: "",
|
||||
role: "assistant",
|
||||
},
|
||||
};
|
||||
let firstOne = false;
|
||||
for await (const chunk of originalAsyncIterator) {
|
||||
if (!firstOne) {
|
||||
firstOne = true;
|
||||
finalResponse.message.content = chunk.delta;
|
||||
} else {
|
||||
finalResponse.message.content += chunk.delta;
|
||||
}
|
||||
yield chunk;
|
||||
}
|
||||
getCurrentCallbackManager().dispatchEvent("llm-end", {
|
||||
payload: {
|
||||
response: finalResponse,
|
||||
},
|
||||
});
|
||||
};
|
||||
} else {
|
||||
getCurrentCallbackManager().dispatchEvent("llm-end", {
|
||||
payload: {
|
||||
response,
|
||||
},
|
||||
});
|
||||
}
|
||||
return response;
|
||||
};
|
||||
}
|
||||
|
||||
@@ -15,7 +15,10 @@ import type {
|
||||
RetrievalCallbackResponse,
|
||||
StreamCallbackResponse,
|
||||
} from "llamaindex/callbacks/CallbackManager";
|
||||
import { CallbackManager } from "llamaindex/callbacks/CallbackManager";
|
||||
import {
|
||||
CallbackManager,
|
||||
runWithCallbackManager,
|
||||
} from "llamaindex/callbacks/CallbackManager";
|
||||
import { OpenAIEmbedding } from "llamaindex/embeddings/index";
|
||||
import { SummaryIndex } from "llamaindex/indices/summary/index";
|
||||
import { VectorStoreIndex } from "llamaindex/indices/vectorStore/index";
|
||||
@@ -38,10 +41,11 @@ describe("CallbackManager: onLLMStream and onRetrieve", () => {
|
||||
let streamCallbackData: StreamCallbackResponse[] = [];
|
||||
let retrieveCallbackData: RetrievalCallbackResponse[] = [];
|
||||
let document: Document;
|
||||
let callbackManager: CallbackManager;
|
||||
|
||||
beforeAll(async () => {
|
||||
document = new Document({ text: "Author: My name is Paul Graham" });
|
||||
const callbackManager = new CallbackManager({
|
||||
callbackManager = new CallbackManager({
|
||||
onLLMStream: (data) => {
|
||||
streamCallbackData.push(data);
|
||||
},
|
||||
@@ -52,7 +56,6 @@ describe("CallbackManager: onLLMStream and onRetrieve", () => {
|
||||
|
||||
const languageModel = new OpenAI({
|
||||
model: "gpt-3.5-turbo",
|
||||
callbackManager,
|
||||
});
|
||||
mockLlmGeneration({ languageModel, callbackManager });
|
||||
|
||||
@@ -60,7 +63,6 @@ describe("CallbackManager: onLLMStream and onRetrieve", () => {
|
||||
mockEmbeddingModel(embedModel);
|
||||
|
||||
serviceContext = serviceContextFromDefaults({
|
||||
callbackManager,
|
||||
llm: languageModel,
|
||||
embedModel,
|
||||
});
|
||||
@@ -81,7 +83,10 @@ describe("CallbackManager: onLLMStream and onRetrieve", () => {
|
||||
});
|
||||
const queryEngine = vectorStoreIndex.asQueryEngine();
|
||||
const query = "What is the author's name?";
|
||||
const response = await queryEngine.query({ query });
|
||||
const response = await runWithCallbackManager(callbackManager, () => {
|
||||
return queryEngine.query({ query });
|
||||
});
|
||||
|
||||
expect(response.toString()).toBe("MOCK_TOKEN_1-MOCK_TOKEN_2");
|
||||
expect(streamCallbackData).toEqual([
|
||||
{
|
||||
@@ -159,7 +164,9 @@ describe("CallbackManager: onLLMStream and onRetrieve", () => {
|
||||
responseSynthesizer,
|
||||
});
|
||||
const query = "What is the author's name?";
|
||||
const response = await queryEngine.query({ query });
|
||||
const response = await runWithCallbackManager(callbackManager, async () =>
|
||||
queryEngine.query({ query }),
|
||||
);
|
||||
expect(response.toString()).toBe("MOCK_TOKEN_1-MOCK_TOKEN_2");
|
||||
expect(streamCallbackData).toEqual([
|
||||
{
|
||||
|
||||
@@ -1,11 +1,6 @@
|
||||
import { Document } from "llamaindex/Node";
|
||||
import type { ServiceContext } from "llamaindex/ServiceContext";
|
||||
import { serviceContextFromDefaults } from "llamaindex/ServiceContext";
|
||||
import type {
|
||||
RetrievalCallbackResponse,
|
||||
StreamCallbackResponse,
|
||||
} from "llamaindex/callbacks/CallbackManager";
|
||||
import { CallbackManager } from "llamaindex/callbacks/CallbackManager";
|
||||
import { OpenAIEmbedding } from "llamaindex/embeddings/index";
|
||||
import {
|
||||
KeywordExtractor,
|
||||
@@ -15,15 +10,7 @@ import {
|
||||
} from "llamaindex/extractors/index";
|
||||
import { OpenAI } from "llamaindex/llm/LLM";
|
||||
import { SimpleNodeParser } from "llamaindex/nodeParsers/index";
|
||||
import {
|
||||
afterAll,
|
||||
beforeAll,
|
||||
beforeEach,
|
||||
describe,
|
||||
expect,
|
||||
test,
|
||||
vi,
|
||||
} from "vitest";
|
||||
import { afterAll, beforeAll, describe, expect, test, vi } from "vitest";
|
||||
import {
|
||||
DEFAULT_LLM_TEXT_OUTPUT,
|
||||
mockEmbeddingModel,
|
||||
@@ -39,42 +26,24 @@ vi.mock("llamaindex/llm/open_ai", () => {
|
||||
|
||||
describe("[MetadataExtractor]: Extractors should populate the metadata", () => {
|
||||
let serviceContext: ServiceContext;
|
||||
let streamCallbackData: StreamCallbackResponse[] = [];
|
||||
let retrieveCallbackData: RetrievalCallbackResponse[] = [];
|
||||
|
||||
beforeAll(async () => {
|
||||
const callbackManager = new CallbackManager({
|
||||
onLLMStream: (data) => {
|
||||
streamCallbackData.push(data);
|
||||
},
|
||||
onRetrieve: (data) => {
|
||||
retrieveCallbackData.push(data);
|
||||
},
|
||||
});
|
||||
|
||||
const languageModel = new OpenAI({
|
||||
model: "gpt-3.5-turbo",
|
||||
callbackManager,
|
||||
});
|
||||
|
||||
mockLlmGeneration({ languageModel, callbackManager });
|
||||
mockLlmGeneration({ languageModel });
|
||||
|
||||
const embedModel = new OpenAIEmbedding();
|
||||
|
||||
mockEmbeddingModel(embedModel);
|
||||
|
||||
serviceContext = serviceContextFromDefaults({
|
||||
callbackManager,
|
||||
llm: languageModel,
|
||||
embedModel,
|
||||
});
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
streamCallbackData = [];
|
||||
retrieveCallbackData = [];
|
||||
});
|
||||
|
||||
afterAll(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
@@ -22,7 +22,6 @@ describe("LLMSelector", () => {
|
||||
|
||||
mocStructuredkLlmGeneration({
|
||||
languageModel,
|
||||
callbackManager: serviceContext.callbackManager,
|
||||
});
|
||||
|
||||
const selector = new LLMSingleSelector({
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import { OpenAIAgent } from "llamaindex/agent/index";
|
||||
import { CallbackManager } from "llamaindex/callbacks/CallbackManager";
|
||||
import { OpenAI } from "llamaindex/llm/index";
|
||||
import { FunctionTool } from "llamaindex/tools/index";
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
@@ -35,16 +34,12 @@ describe("OpenAIAgent", () => {
|
||||
let openaiAgent: OpenAIAgent;
|
||||
|
||||
beforeEach(() => {
|
||||
const callbackManager = new CallbackManager({});
|
||||
|
||||
const languageModel = new OpenAI({
|
||||
model: "gpt-3.5-turbo",
|
||||
callbackManager,
|
||||
});
|
||||
|
||||
mockLlmToolCallGeneration({
|
||||
languageModel,
|
||||
callbackManager,
|
||||
});
|
||||
|
||||
const sumFunctionTool = new FunctionTool(sumNumbers, {
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { OpenAIAgentWorker } from "llamaindex/agent/index";
|
||||
import { AgentRunner } from "llamaindex/agent/runner/base";
|
||||
import { CallbackManager } from "llamaindex/callbacks/CallbackManager";
|
||||
import { OpenAI } from "llamaindex/llm/LLM";
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
|
||||
@@ -19,16 +18,12 @@ describe("Agent Runner", () => {
|
||||
let agentRunner: AgentRunner;
|
||||
|
||||
beforeEach(() => {
|
||||
const callbackManager = new CallbackManager({});
|
||||
|
||||
const languageModel = new OpenAI({
|
||||
model: "gpt-3.5-turbo",
|
||||
callbackManager,
|
||||
});
|
||||
|
||||
mockLlmGeneration({
|
||||
languageModel,
|
||||
callbackManager,
|
||||
});
|
||||
|
||||
agentRunner = new AgentRunner({
|
||||
|
||||
@@ -68,7 +68,7 @@ export function mockLlmToolCallGeneration({
|
||||
callbackManager,
|
||||
}: {
|
||||
languageModel: OpenAI;
|
||||
callbackManager: CallbackManager;
|
||||
callbackManager?: CallbackManager;
|
||||
}) {
|
||||
vi.spyOn(languageModel, "chat").mockImplementation(
|
||||
() =>
|
||||
@@ -119,7 +119,7 @@ export function mocStructuredkLlmGeneration({
|
||||
callbackManager,
|
||||
}: {
|
||||
languageModel: OpenAI;
|
||||
callbackManager: CallbackManager;
|
||||
callbackManager?: CallbackManager;
|
||||
}) {
|
||||
vi.spyOn(languageModel, "chat").mockImplementation(
|
||||
async ({ messages, parentEvent }: LLMChatParamsBase) => {
|
||||
|
||||
Vendored
+1
-1
@@ -39,4 +39,4 @@ export function randomUUID(): string {
|
||||
return crypto.randomUUID();
|
||||
}
|
||||
export * from "./type.js";
|
||||
export { getEnv } from "./utils.js";
|
||||
export { AsyncLocalStorage, CustomEvent, getEnv } from "./utils.js";
|
||||
|
||||
Vendored
+1
-1
@@ -35,5 +35,5 @@ export const defaultFS: CompleteFileSystem = {
|
||||
};
|
||||
|
||||
export type * from "./type.js";
|
||||
export { getEnv } from "./utils.js";
|
||||
export { AsyncLocalStorage, CustomEvent, getEnv } from "./utils.js";
|
||||
export { EOL, ok, path, randomUUID };
|
||||
|
||||
Vendored
+20
@@ -10,3 +10,23 @@ export function getEnv(name: string): string | undefined {
|
||||
}
|
||||
return process.env[name];
|
||||
}
|
||||
|
||||
// Browser doesn't support AsyncLocalStorage
|
||||
export { AsyncLocalStorage } from "node:async_hooks";
|
||||
|
||||
class CustomEvent<T = any> extends Event {
|
||||
readonly #detail: T;
|
||||
get detail(): T {
|
||||
return this.#detail;
|
||||
}
|
||||
constructor(event: string, options?: CustomEventInit) {
|
||||
super(event, options);
|
||||
this.#detail = options?.detail;
|
||||
}
|
||||
}
|
||||
|
||||
// Node.js doesn't have CustomEvent by default
|
||||
// Refs: https://github.com/nodejs/node/issues/40678
|
||||
const defaultCustomEvent = globalThis.CustomEvent || CustomEvent;
|
||||
|
||||
export { defaultCustomEvent as CustomEvent };
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
"devDependencies": {
|
||||
"@types/node": "^18.19.10",
|
||||
"ts-node": "^10.9.2",
|
||||
"typescript": "^5.3.3"
|
||||
"typescript": "^5.4.3"
|
||||
},
|
||||
"scripts": {
|
||||
"lint": "eslint ."
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
"assemblyscript": "^0.19.9",
|
||||
"@swc/cli": "^0.3.9",
|
||||
"@swc/core": "^1.4.2",
|
||||
"typescript": "^5.3.3"
|
||||
"typescript": "^5.4.3"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=18.0.0"
|
||||
|
||||
Generated
+471
-735
File diff suppressed because it is too large
Load Diff
@@ -8,6 +8,7 @@
|
||||
"forceConsistentCasingInFileNames": true,
|
||||
"strict": true,
|
||||
"skipLibCheck": true,
|
||||
"stripInternal": true,
|
||||
"outDir": "./lib",
|
||||
"tsBuildInfoFile": "./lib/.tsbuildinfo",
|
||||
"incremental": true,
|
||||
|
||||
Reference in New Issue
Block a user