implement RAG for manual mode

This commit is contained in:
Marcus Schiesser
2024-10-04 16:01:22 +07:00
parent 415b8d5eeb
commit ec20f7a439
12 changed files with 6611 additions and 54 deletions
+11
View File
@@ -0,0 +1,11 @@
# The provider for the AI models to use.
MODEL_PROVIDER=openai
# The name of LLM model to use.
MODEL=gpt-4o-mini
# Name of the embedding model to use.
EMBEDDING_MODEL=text-embedding-3-large
# Dimension of the embedding model to use.
EMBEDDING_DIM=1024
+3
View File
@@ -41,3 +41,6 @@ yarn-error.log*
# typescript
*.tsbuildinfo
next-env.d.ts
# llama index
cache
BIN
View File
Binary file not shown.
+6209 -46
View File
File diff suppressed because it is too large Load Diff
+3
View File
@@ -7,6 +7,7 @@
"@openai/realtime-api-beta": "github:openai/openai-realtime-api-beta",
"dotenv": "^16.4.5",
"leaflet": "^1.9.4",
"llamaindex": "^0.6.15",
"next": "14.2.14",
"react": "^18",
"react-dom": "^18",
@@ -25,6 +26,7 @@
"lint": "next lint",
"zip": "zip -r realtime-api-console.zip . -x 'node_modules' 'node_modules/*' 'node_modules/**' '.git' '.git/*' '.git/**' '.DS_Store' '*/.DS_Store' 'package-lock.json' '*.zip' '*.tar.gz' '*.tar' '.env'",
"relay": "nodemon ./relay-server/index.js",
"generate": "tsx src/pages/engine/generate.ts",
"format": "prettier --ignore-unknown --cache --check .",
"format:write": "prettier --ignore-unknown --write ."
},
@@ -61,6 +63,7 @@
"postcss": "^8",
"prettier": "^3.3.3",
"tailwindcss": "^3.4.1",
"tsx": "^4.19.1",
"typescript": "^5"
}
}
+54 -8
View File
@@ -174,12 +174,6 @@ export function ConsolePage() {
const wavRecorder = wavRecorderRef.current;
const wavStreamPlayer = wavStreamPlayerRef.current;
// Set state variables
startTimeRef.current = new Date().toISOString();
setIsConnected(true);
setRealtimeEvents([]);
setItems(client.conversation.getItems());
// Connect to microphone
await wavRecorder.begin();
@@ -196,6 +190,12 @@ export function ConsolePage() {
},
]);
// Set state variables
startTimeRef.current = new Date().toISOString();
setIsConnected(true);
setRealtimeEvents([]);
setItems(client.conversation.getItems());
if (client.getTurnDetectionType() === 'server_vad') {
await wavRecorder.record((data) => client.appendInputAudio(data.mono));
}
@@ -261,7 +261,12 @@ export function ConsolePage() {
if (!client) throw new Error('RealtimeClient is not initialized');
const wavRecorder = wavRecorderRef.current;
await wavRecorder.pause();
client.createResponse();
if (client.inputAudioBuffer.byteLength > 0) {
// commit the input audio buffer to the server
client.realtime.send('input_audio_buffer.commit', {});
client.conversation.queueInputAudio(client.inputAudioBuffer);
client.inputAudioBuffer = new Int16Array(0);
}
};
/**
@@ -283,6 +288,36 @@ export function ConsolePage() {
setCanPushToTalk(value === 'none');
};
const injectContext = async (transcript: string) => {
const client = clientRef.current;
if (!client) throw new Error('RealtimeClient is not initialized');
transcript = transcript.trim();
if (transcript.length === 0) {
console.log(`Empty transcript - can't generate context`);
return;
}
console.log(`Triggering context API for ${transcript}`);
const response = await fetch(
`/api/context?query=${encodeURIComponent(transcript)}`,
);
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
const data = await response.json();
console.log(`Received context API response: ${data.message}`);
client.sendUserMessageContent([
{
type: 'input_text',
text: data.message,
},
]);
if (client.getTurnDetectionType() === null) {
// if we are not in push-to-talk mode, create a response
client.createResponse();
}
};
/**
* Auto-scroll the event logs
*/
@@ -471,7 +506,18 @@ export function ConsolePage() {
);
// handle realtime events from client + server for event logging
client.on('realtime.event', (realtimeEvent: RealtimeEvent) => {
client.on('realtime.event', async (realtimeEvent: RealtimeEvent) => {
if (
realtimeEvent.event.type ===
'conversation.item.input_audio_transcription.completed'
) {
console.log(
'conversation.item.input_audio_transcription.completed',
realtimeEvent,
);
// transcript of a user message is available
await injectContext(realtimeEvent.event.transcript);
}
setRealtimeEvents((realtimeEvents) => {
const lastEvent = realtimeEvents[realtimeEvents.length - 1];
if (lastEvent?.event.type === realtimeEvent.event.type) {
+68
View File
@@ -0,0 +1,68 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { MetadataMode } from 'llamaindex';
import { getDataSource } from '../engine';
import { extractText } from '@llamaindex/core/utils';
import {
PromptTemplate,
type ContextSystemPrompt,
} from '@llamaindex/core/prompts';
import { createMessageContent } from '@llamaindex/core/response-synthesizers';
import { initSettings } from '../engine/settings';
type ResponseData = {
message: string;
};
initSettings();
export default async function handler(
req: NextApiRequest,
res: NextApiResponse<ResponseData>,
) {
try {
const { query } = req.query;
if (typeof query !== 'string' || query.trim() === '') {
console.log('[context] Invalid query parameter');
return res.status(400).json({
message: "A valid 'query' string parameter is required in the URL",
});
}
console.log(`[context] Processing query: "${query}"`);
const index = await getDataSource();
if (!index) {
throw new Error(
`StorageContext is empty - call 'npm run generate' to generate the storage first`,
);
}
const retriever = index.asRetriever();
const nodes = await retriever.retrieve({
query: query,
});
console.log(`[context] Retrieved ${nodes.length} nodes`);
const contextSystemPrompt: ContextSystemPrompt = new PromptTemplate({
templateVars: ['context'],
template: `For improving the answer to my last question use the following context:
---------------------
{context}
---------------------`,
});
const content = await createMessageContent(
contextSystemPrompt as any,
nodes.map((r) => r.node),
undefined,
MetadataMode.LLM,
);
res.status(200).json({ message: extractText(content) });
} catch (error) {
console.error('[context] Error:', error);
return res.status(500).json({
message: (error as Error).message,
});
}
}
+40
View File
@@ -0,0 +1,40 @@
import { VectorStoreIndex } from 'llamaindex';
import { storageContextFromDefaults } from 'llamaindex/storage/StorageContext';
import * as dotenv from 'dotenv';
import { getDocuments } from './loader';
import { initSettings } from './settings';
import { STORAGE_CACHE_DIR } from './shared';
// Load environment variables from local .env file
dotenv.config();
async function getRuntime(func: any) {
const start = Date.now();
await func();
const end = Date.now();
return end - start;
}
async function generateDatasource() {
console.log(`Generating storage context...`);
// Split documents, create embeddings and store them in the storage context
const ms = await getRuntime(async () => {
const storageContext = await storageContextFromDefaults({
persistDir: STORAGE_CACHE_DIR,
});
const documents = await getDocuments();
await VectorStoreIndex.fromDocuments(documents, {
storageContext,
});
});
console.log(`Storage context successfully generated in ${ms / 1000}s.`);
}
(async () => {
initSettings();
await generateDatasource();
console.log('Finished generating storage.');
})();
+19
View File
@@ -0,0 +1,19 @@
import { SimpleDocumentStore, VectorStoreIndex } from 'llamaindex';
import { storageContextFromDefaults } from 'llamaindex/storage/StorageContext';
import { STORAGE_CACHE_DIR } from './shared';
export async function getDataSource(params?: any) {
const storageContext = await storageContextFromDefaults({
persistDir: `${STORAGE_CACHE_DIR}`,
});
const numberOfDocs = Object.keys(
(storageContext.docStore as SimpleDocumentStore).toDict(),
).length;
if (numberOfDocs === 0) {
return null;
}
return await VectorStoreIndex.init({
storageContext,
});
}
+24
View File
@@ -0,0 +1,24 @@
import {
FILE_EXT_TO_READER,
SimpleDirectoryReader,
} from 'llamaindex/readers/SimpleDirectoryReader';
export const DATA_DIR = './data';
export function getExtractors() {
return FILE_EXT_TO_READER;
}
export async function getDocuments() {
const documents = await new SimpleDirectoryReader().loadData({
directoryPath: DATA_DIR,
});
// Set private=false to mark the document as public (required for filtering)
for (const document of documents) {
document.metadata = {
...document.metadata,
private: 'false',
};
}
return documents;
}
+179
View File
@@ -0,0 +1,179 @@
import {
ALL_AVAILABLE_MISTRAL_MODELS,
Anthropic,
GEMINI_EMBEDDING_MODEL,
GEMINI_MODEL,
Gemini,
GeminiEmbedding,
Groq,
MistralAI,
MistralAIEmbedding,
MistralAIEmbeddingModelType,
OpenAI,
OpenAIEmbedding,
Settings,
} from 'llamaindex';
import { HuggingFaceEmbedding } from 'llamaindex/embeddings/HuggingFaceEmbedding';
import { OllamaEmbedding } from 'llamaindex/embeddings/OllamaEmbedding';
import { ALL_AVAILABLE_ANTHROPIC_MODELS } from 'llamaindex/llm/anthropic';
import { Ollama } from 'llamaindex/llm/ollama';
const CHUNK_SIZE = 512;
const CHUNK_OVERLAP = 20;
export const initSettings = async () => {
// HINT: you can delete the initialization code for unused model providers
console.log(`Using '${process.env.MODEL_PROVIDER}' model provider`);
if (!process.env.MODEL || !process.env.EMBEDDING_MODEL) {
throw new Error("'MODEL' and 'EMBEDDING_MODEL' env variables must be set.");
}
switch (process.env.MODEL_PROVIDER) {
case 'ollama':
initOllama();
break;
case 'groq':
initGroq();
break;
case 'anthropic':
initAnthropic();
break;
case 'gemini':
initGemini();
break;
case 'mistral':
initMistralAI();
break;
case 'azure-openai':
initAzureOpenAI();
break;
default:
initOpenAI();
break;
}
Settings.chunkSize = CHUNK_SIZE;
Settings.chunkOverlap = CHUNK_OVERLAP;
};
function initOpenAI() {
Settings.llm = new OpenAI({
model: process.env.MODEL ?? 'gpt-4o-mini',
maxTokens: process.env.LLM_MAX_TOKENS
? Number(process.env.LLM_MAX_TOKENS)
: undefined,
});
Settings.embedModel = new OpenAIEmbedding({
model: process.env.EMBEDDING_MODEL,
dimensions: process.env.EMBEDDING_DIM
? parseInt(process.env.EMBEDDING_DIM)
: undefined,
});
}
function initAzureOpenAI() {
// Map Azure OpenAI model names to OpenAI model names (only for TS)
const AZURE_OPENAI_MODEL_MAP: Record<string, string> = {
'gpt-35-turbo': 'gpt-3.5-turbo',
'gpt-35-turbo-16k': 'gpt-3.5-turbo-16k',
'gpt-4o': 'gpt-4o',
'gpt-4': 'gpt-4',
'gpt-4-32k': 'gpt-4-32k',
'gpt-4-turbo': 'gpt-4-turbo',
'gpt-4-turbo-2024-04-09': 'gpt-4-turbo',
'gpt-4-vision-preview': 'gpt-4-vision-preview',
'gpt-4-1106-preview': 'gpt-4-1106-preview',
'gpt-4o-2024-05-13': 'gpt-4o-2024-05-13',
};
const azureConfig = {
apiKey: process.env.AZURE_OPENAI_KEY,
endpoint: process.env.AZURE_OPENAI_ENDPOINT,
apiVersion:
process.env.AZURE_OPENAI_API_VERSION || process.env.OPENAI_API_VERSION,
};
Settings.llm = new OpenAI({
model:
AZURE_OPENAI_MODEL_MAP[process.env.MODEL ?? 'gpt-35-turbo'] ??
'gpt-3.5-turbo',
maxTokens: process.env.LLM_MAX_TOKENS
? Number(process.env.LLM_MAX_TOKENS)
: undefined,
azure: {
...azureConfig,
deployment: process.env.AZURE_OPENAI_LLM_DEPLOYMENT,
},
});
Settings.embedModel = new OpenAIEmbedding({
model: process.env.EMBEDDING_MODEL,
dimensions: process.env.EMBEDDING_DIM
? parseInt(process.env.EMBEDDING_DIM)
: undefined,
azure: {
...azureConfig,
deployment: process.env.AZURE_OPENAI_EMBEDDING_DEPLOYMENT,
},
});
}
function initOllama() {
const config = {
host: process.env.OLLAMA_BASE_URL ?? 'http://127.0.0.1:11434',
};
Settings.llm = new Ollama({
model: process.env.MODEL ?? '',
config,
});
Settings.embedModel = new OllamaEmbedding({
model: process.env.EMBEDDING_MODEL ?? '',
config,
});
}
function initGroq() {
const embedModelMap: Record<string, string> = {
'all-MiniLM-L6-v2': 'Xenova/all-MiniLM-L6-v2',
'all-mpnet-base-v2': 'Xenova/all-mpnet-base-v2',
};
Settings.llm = new Groq({
model: process.env.MODEL!,
});
Settings.embedModel = new HuggingFaceEmbedding({
modelType: embedModelMap[process.env.EMBEDDING_MODEL!],
});
}
function initAnthropic() {
const embedModelMap: Record<string, string> = {
'all-MiniLM-L6-v2': 'Xenova/all-MiniLM-L6-v2',
'all-mpnet-base-v2': 'Xenova/all-mpnet-base-v2',
};
Settings.llm = new Anthropic({
model: process.env.MODEL as keyof typeof ALL_AVAILABLE_ANTHROPIC_MODELS,
});
Settings.embedModel = new HuggingFaceEmbedding({
modelType: embedModelMap[process.env.EMBEDDING_MODEL!],
});
}
function initGemini() {
Settings.llm = new Gemini({
model: process.env.MODEL as GEMINI_MODEL,
});
Settings.embedModel = new GeminiEmbedding({
model: process.env.EMBEDDING_MODEL as GEMINI_EMBEDDING_MODEL,
});
}
function initMistralAI() {
Settings.llm = new MistralAI({
model: process.env.MODEL as keyof typeof ALL_AVAILABLE_MISTRAL_MODELS,
});
Settings.embedModel = new MistralAIEmbedding({
model: process.env.EMBEDDING_MODEL as MistralAIEmbeddingModelType,
});
}
+1
View File
@@ -0,0 +1 @@
export const STORAGE_CACHE_DIR = './cache';