Compare commits

...

2 Commits

Author SHA1 Message Date
Marcus Schiesser b622f49845 add docs and unit test 2024-09-20 15:01:08 +07:00
Marcus Schiesser f02ecbff14 fix: context not working in contextchatengine 2024-09-20 12:58:26 +07:00
9 changed files with 136 additions and 28 deletions
+6
View File
@@ -0,0 +1,6 @@
---
"llamaindex": patch
"@llamaindex/core": patch
---
Fix context not being sent using ContextChatEngine
+23 -2
View File
@@ -1,5 +1,5 @@
import { Settings } from "../global";
import type { ChatMessage, MessageContent } from "../llms";
import type { ChatMessage } from "../llms";
import { type BaseChatStore, SimpleChatStore } from "../storage/chat-store";
import { extractText } from "../utils";
@@ -12,15 +12,36 @@ export const DEFAULT_CHAT_STORE_KEY = "chat_history";
export abstract class BaseMemory<
AdditionalMessageOptions extends object = object,
> {
/**
* Retrieves messages from the memory, optionally including transient messages.
* Compared to getAllMessages, this method a) allows for transient messages to be included in the retrieval and b) may return a subset of the total messages by applying a token limit.
* @param transientMessages Optional array of temporary messages to be included in the retrieval.
* These messages are not stored in the memory but are considered for the current interaction.
* @returns An array of chat messages, either synchronously or as a Promise.
*/
abstract getMessages(
input?: MessageContent | undefined,
transientMessages?: ChatMessage<AdditionalMessageOptions>[] | undefined,
):
| ChatMessage<AdditionalMessageOptions>[]
| Promise<ChatMessage<AdditionalMessageOptions>[]>;
/**
* Retrieves all messages stored in the memory.
* @returns An array of all chat messages, either synchronously or as a Promise.
*/
abstract getAllMessages():
| ChatMessage<AdditionalMessageOptions>[]
| Promise<ChatMessage<AdditionalMessageOptions>[]>;
/**
* Adds a new message to the memory.
* @param messages The chat message to be added to the memory.
*/
abstract put(messages: ChatMessage<AdditionalMessageOptions>): void;
/**
* Clears all messages from the memory.
*/
abstract reset(): void;
protected _tokenCountForMessages(messages: ChatMessage[]): number {
+14 -8
View File
@@ -1,5 +1,5 @@
import { Settings } from "../global";
import type { ChatMessage, LLM, MessageContent } from "../llms";
import type { ChatMessage, LLM } from "../llms";
import { type BaseChatStore } from "../storage/chat-store";
import { BaseChatStoreMemory, DEFAULT_TOKEN_LIMIT_RATIO } from "./base";
@@ -34,7 +34,7 @@ export class ChatMemoryBuffer<
}
getMessages(
input?: MessageContent | undefined,
transientMessages?: ChatMessage<AdditionalMessageOptions>[] | undefined,
initialTokenCount: number = 0,
) {
const messages = this.getAllMessages();
@@ -43,16 +43,22 @@ export class ChatMemoryBuffer<
throw new Error("Initial token count exceeds token limit");
}
let messageCount = messages.length;
let currentMessages = messages.slice(-messageCount);
let tokenCount = this._tokenCountForMessages(messages) + initialTokenCount;
// Add input messages as transient messages
const messagesWithInput = transientMessages
? [...transientMessages, ...messages]
: messages;
let messageCount = messagesWithInput.length;
let currentMessages = messagesWithInput.slice(-messageCount);
let tokenCount =
this._tokenCountForMessages(messagesWithInput) + initialTokenCount;
while (tokenCount > this.tokenLimit && messageCount > 1) {
messageCount -= 1;
if (messages.at(-messageCount)!.role === "assistant") {
if (messagesWithInput.at(-messageCount)!.role === "assistant") {
messageCount -= 1;
}
currentMessages = messages.slice(-messageCount);
currentMessages = messagesWithInput.slice(-messageCount);
tokenCount =
this._tokenCountForMessages(currentMessages) + initialTokenCount;
}
@@ -60,6 +66,6 @@ export class ChatMemoryBuffer<
if (tokenCount > this.tokenLimit && messageCount <= 0) {
return [];
}
return messages.slice(-messageCount);
return messagesWithInput.slice(-messageCount);
}
}
+10 -6
View File
@@ -114,18 +114,22 @@ export class ChatSummaryMemoryBuffer extends BaseMemory {
}
}
private calcCurrentRequestMessages() {
// TODO: check order: currently, we're sending:
private calcCurrentRequestMessages(transientMessages?: ChatMessage[]) {
// currently, we're sending:
// system messages first, then transient messages and then the messages that describe the conversation so far
return [...this.systemMessages, ...this.calcConversationMessages(true)];
return [
...this.systemMessages,
...(transientMessages ? transientMessages : []),
...this.calcConversationMessages(true),
];
}
reset() {
this.messages = [];
}
async getMessages(): Promise<ChatMessage[]> {
const requestMessages = this.calcCurrentRequestMessages();
async getMessages(transientMessages?: ChatMessage[]): Promise<ChatMessage[]> {
const requestMessages = this.calcCurrentRequestMessages(transientMessages);
// get tokens of current request messages and the transient messages
const tokens = requestMessages.reduce(
@@ -149,7 +153,7 @@ export class ChatSummaryMemoryBuffer extends BaseMemory {
// TODO: we still might have too many tokens
// e.g. too large system messages or transient messages
// how should we deal with that?
return this.calcCurrentRequestMessages();
return this.calcCurrentRequestMessages(transientMessages);
}
return requestMessages;
}
@@ -0,0 +1,74 @@
import { Settings } from "@llamaindex/core/global";
import type { ChatMessage } from "@llamaindex/core/llms";
import { ChatMemoryBuffer } from "@llamaindex/core/memory";
import { beforeEach, describe, expect, test } from "vitest";
describe("ChatMemoryBuffer", () => {
beforeEach(() => {
// Mock the Settings.llm
(Settings.llm as any) = {
metadata: {
contextWindow: 1000,
},
};
});
test("constructor initializes with custom token limit", () => {
const buffer = new ChatMemoryBuffer({ tokenLimit: 500 });
expect(buffer.tokenLimit).toBe(500);
});
test("getMessages returns all messages when under token limit", () => {
const messages: ChatMessage[] = [
{ role: "user", content: "Hello" },
{ role: "assistant", content: "Hi there!" },
{ role: "user", content: "How are you?" },
];
const buffer = new ChatMemoryBuffer({
tokenLimit: 1000,
chatHistory: messages,
});
const result = buffer.getMessages();
expect(result).toEqual(messages);
});
test("getMessages truncates messages when over token limit", () => {
const messages: ChatMessage[] = [
{ role: "user", content: "This is a long message" },
{ role: "assistant", content: "This is also a long reply" },
{ role: "user", content: "Short" },
];
const buffer = new ChatMemoryBuffer({
tokenLimit: 5, // limit to only allow the last message
chatHistory: messages,
});
const result = buffer.getMessages();
expect(result).toEqual([{ role: "user", content: "Short" }]);
});
test("getMessages handles input messages", () => {
const storedMessages: ChatMessage[] = [
{ role: "user", content: "Hello" },
{ role: "assistant", content: "Hi there!" },
];
const buffer = new ChatMemoryBuffer({
tokenLimit: 50,
chatHistory: storedMessages,
});
const inputMessages: ChatMessage[] = [
{ role: "user", content: "New message" },
];
const result = buffer.getMessages(inputMessages);
expect(result).toEqual([...inputMessages, ...storedMessages]);
});
test("getMessages throws error when initial token count exceeds limit", () => {
const buffer = new ChatMemoryBuffer({ tokenLimit: 10 });
expect(() => buffer.getMessages(undefined, 20)).toThrow(
"Initial token count exceeds token limit",
);
});
});
+2 -3
View File
@@ -356,9 +356,8 @@ export abstract class AgentRunner<
let chatHistory: ChatMessage<AdditionalMessageOptions>[] = [];
if (params.chatHistory instanceof BaseMemory) {
chatHistory = (await params.chatHistory.getMessages(
params.message,
)) as ChatMessage<AdditionalMessageOptions>[];
chatHistory =
(await params.chatHistory.getMessages()) as ChatMessage<AdditionalMessageOptions>[];
} else {
chatHistory =
params.chatHistory as ChatMessage<AdditionalMessageOptions>[];
@@ -78,9 +78,7 @@ export class CondenseQuestionChatEngine
}
private async condenseQuestion(chatHistory: BaseMemory, question: string) {
const chatHistoryStr = messagesToHistory(
await chatHistory.getMessages(question),
);
const chatHistoryStr = messagesToHistory(await chatHistory.getMessages());
return this.llm.complete({
prompt: this.condenseMessagePrompt.format({
@@ -103,7 +101,7 @@ export class CondenseQuestionChatEngine
? new ChatMemoryBuffer({
chatHistory:
params.chatHistory instanceof BaseMemory
? await params.chatHistory.getMessages(message)
? await params.chatHistory.getMessages()
: params.chatHistory,
})
: this.chatHistory;
@@ -92,7 +92,7 @@ export class ContextChatEngine extends PromptMixin implements ChatEngine {
? new ChatMemoryBuffer({
chatHistory:
params.chatHistory instanceof BaseMemory
? await params.chatHistory.getMessages(message)
? await params.chatHistory.getMessages()
: params.chatHistory,
})
: this.chatHistory;
@@ -139,7 +139,7 @@ export class ContextChatEngine extends PromptMixin implements ChatEngine {
const textOnly = extractText(message);
const context = await this.contextGenerator.generate(textOnly);
const systemMessage = this.prependSystemPrompt(context.message);
const messages = await chatHistory.getMessages(systemMessage.content);
const messages = await chatHistory.getMessages([systemMessage]);
return { nodes: context.nodes, messages };
}
@@ -40,7 +40,7 @@ export class SimpleChatEngine implements ChatEngine {
? new ChatMemoryBuffer({
chatHistory:
params.chatHistory instanceof BaseMemory
? await params.chatHistory.getMessages(message)
? await params.chatHistory.getMessages()
: params.chatHistory,
})
: this.chatHistory;
@@ -48,7 +48,7 @@ export class SimpleChatEngine implements ChatEngine {
if (stream) {
const stream = await this.llm.chat({
messages: await chatHistory.getMessages(params.message),
messages: await chatHistory.getMessages(),
stream: true,
});
return streamConverter(
@@ -66,7 +66,7 @@ export class SimpleChatEngine implements ChatEngine {
const response = await this.llm.chat({
stream: false,
messages: await chatHistory.getMessages(params.message),
messages: await chatHistory.getMessages(),
});
chatHistory.put(response.message);
return EngineResponse.fromChatResponse(response);