mirror of
https://github.com/run-llama/voice-chat-pdf.git
synced 2026-06-30 22:27:54 -04:00
implement RAG for manual mode
This commit is contained in:
@@ -0,0 +1,11 @@
|
||||
# The provider for the AI models to use.
|
||||
MODEL_PROVIDER=openai
|
||||
|
||||
# The name of LLM model to use.
|
||||
MODEL=gpt-4o-mini
|
||||
|
||||
# Name of the embedding model to use.
|
||||
EMBEDDING_MODEL=text-embedding-3-large
|
||||
|
||||
# Dimension of the embedding model to use.
|
||||
EMBEDDING_DIM=1024
|
||||
@@ -41,3 +41,6 @@ yarn-error.log*
|
||||
# typescript
|
||||
*.tsbuildinfo
|
||||
next-env.d.ts
|
||||
|
||||
# llama index
|
||||
cache
|
||||
|
||||
Binary file not shown.
Generated
+6209
-46
File diff suppressed because it is too large
Load Diff
@@ -7,6 +7,7 @@
|
||||
"@openai/realtime-api-beta": "github:openai/openai-realtime-api-beta",
|
||||
"dotenv": "^16.4.5",
|
||||
"leaflet": "^1.9.4",
|
||||
"llamaindex": "^0.6.15",
|
||||
"next": "14.2.14",
|
||||
"react": "^18",
|
||||
"react-dom": "^18",
|
||||
@@ -25,6 +26,7 @@
|
||||
"lint": "next lint",
|
||||
"zip": "zip -r realtime-api-console.zip . -x 'node_modules' 'node_modules/*' 'node_modules/**' '.git' '.git/*' '.git/**' '.DS_Store' '*/.DS_Store' 'package-lock.json' '*.zip' '*.tar.gz' '*.tar' '.env'",
|
||||
"relay": "nodemon ./relay-server/index.js",
|
||||
"generate": "tsx src/pages/engine/generate.ts",
|
||||
"format": "prettier --ignore-unknown --cache --check .",
|
||||
"format:write": "prettier --ignore-unknown --write ."
|
||||
},
|
||||
@@ -61,6 +63,7 @@
|
||||
"postcss": "^8",
|
||||
"prettier": "^3.3.3",
|
||||
"tailwindcss": "^3.4.1",
|
||||
"tsx": "^4.19.1",
|
||||
"typescript": "^5"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -174,12 +174,6 @@ export function ConsolePage() {
|
||||
const wavRecorder = wavRecorderRef.current;
|
||||
const wavStreamPlayer = wavStreamPlayerRef.current;
|
||||
|
||||
// Set state variables
|
||||
startTimeRef.current = new Date().toISOString();
|
||||
setIsConnected(true);
|
||||
setRealtimeEvents([]);
|
||||
setItems(client.conversation.getItems());
|
||||
|
||||
// Connect to microphone
|
||||
await wavRecorder.begin();
|
||||
|
||||
@@ -196,6 +190,12 @@ export function ConsolePage() {
|
||||
},
|
||||
]);
|
||||
|
||||
// Set state variables
|
||||
startTimeRef.current = new Date().toISOString();
|
||||
setIsConnected(true);
|
||||
setRealtimeEvents([]);
|
||||
setItems(client.conversation.getItems());
|
||||
|
||||
if (client.getTurnDetectionType() === 'server_vad') {
|
||||
await wavRecorder.record((data) => client.appendInputAudio(data.mono));
|
||||
}
|
||||
@@ -261,7 +261,12 @@ export function ConsolePage() {
|
||||
if (!client) throw new Error('RealtimeClient is not initialized');
|
||||
const wavRecorder = wavRecorderRef.current;
|
||||
await wavRecorder.pause();
|
||||
client.createResponse();
|
||||
if (client.inputAudioBuffer.byteLength > 0) {
|
||||
// commit the input audio buffer to the server
|
||||
client.realtime.send('input_audio_buffer.commit', {});
|
||||
client.conversation.queueInputAudio(client.inputAudioBuffer);
|
||||
client.inputAudioBuffer = new Int16Array(0);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -283,6 +288,36 @@ export function ConsolePage() {
|
||||
setCanPushToTalk(value === 'none');
|
||||
};
|
||||
|
||||
const injectContext = async (transcript: string) => {
|
||||
const client = clientRef.current;
|
||||
if (!client) throw new Error('RealtimeClient is not initialized');
|
||||
|
||||
transcript = transcript.trim();
|
||||
if (transcript.length === 0) {
|
||||
console.log(`Empty transcript - can't generate context`);
|
||||
return;
|
||||
}
|
||||
console.log(`Triggering context API for ${transcript}`);
|
||||
const response = await fetch(
|
||||
`/api/context?query=${encodeURIComponent(transcript)}`,
|
||||
);
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! status: ${response.status}`);
|
||||
}
|
||||
const data = await response.json();
|
||||
console.log(`Received context API response: ${data.message}`);
|
||||
client.sendUserMessageContent([
|
||||
{
|
||||
type: 'input_text',
|
||||
text: data.message,
|
||||
},
|
||||
]);
|
||||
if (client.getTurnDetectionType() === null) {
|
||||
// if we are not in push-to-talk mode, create a response
|
||||
client.createResponse();
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Auto-scroll the event logs
|
||||
*/
|
||||
@@ -471,7 +506,18 @@ export function ConsolePage() {
|
||||
);
|
||||
|
||||
// handle realtime events from client + server for event logging
|
||||
client.on('realtime.event', (realtimeEvent: RealtimeEvent) => {
|
||||
client.on('realtime.event', async (realtimeEvent: RealtimeEvent) => {
|
||||
if (
|
||||
realtimeEvent.event.type ===
|
||||
'conversation.item.input_audio_transcription.completed'
|
||||
) {
|
||||
console.log(
|
||||
'conversation.item.input_audio_transcription.completed',
|
||||
realtimeEvent,
|
||||
);
|
||||
// transcript of a user message is available
|
||||
await injectContext(realtimeEvent.event.transcript);
|
||||
}
|
||||
setRealtimeEvents((realtimeEvents) => {
|
||||
const lastEvent = realtimeEvents[realtimeEvents.length - 1];
|
||||
if (lastEvent?.event.type === realtimeEvent.event.type) {
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
import type { NextApiRequest, NextApiResponse } from 'next';
|
||||
import { MetadataMode } from 'llamaindex';
|
||||
import { getDataSource } from '../engine';
|
||||
import { extractText } from '@llamaindex/core/utils';
|
||||
import {
|
||||
PromptTemplate,
|
||||
type ContextSystemPrompt,
|
||||
} from '@llamaindex/core/prompts';
|
||||
import { createMessageContent } from '@llamaindex/core/response-synthesizers';
|
||||
import { initSettings } from '../engine/settings';
|
||||
|
||||
type ResponseData = {
|
||||
message: string;
|
||||
};
|
||||
|
||||
initSettings();
|
||||
|
||||
export default async function handler(
|
||||
req: NextApiRequest,
|
||||
res: NextApiResponse<ResponseData>,
|
||||
) {
|
||||
try {
|
||||
const { query } = req.query;
|
||||
if (typeof query !== 'string' || query.trim() === '') {
|
||||
console.log('[context] Invalid query parameter');
|
||||
return res.status(400).json({
|
||||
message: "A valid 'query' string parameter is required in the URL",
|
||||
});
|
||||
}
|
||||
|
||||
console.log(`[context] Processing query: "${query}"`);
|
||||
|
||||
const index = await getDataSource();
|
||||
if (!index) {
|
||||
throw new Error(
|
||||
`StorageContext is empty - call 'npm run generate' to generate the storage first`,
|
||||
);
|
||||
}
|
||||
const retriever = index.asRetriever();
|
||||
|
||||
const nodes = await retriever.retrieve({
|
||||
query: query,
|
||||
});
|
||||
console.log(`[context] Retrieved ${nodes.length} nodes`);
|
||||
|
||||
const contextSystemPrompt: ContextSystemPrompt = new PromptTemplate({
|
||||
templateVars: ['context'],
|
||||
template: `For improving the answer to my last question use the following context:
|
||||
---------------------
|
||||
{context}
|
||||
---------------------`,
|
||||
});
|
||||
|
||||
const content = await createMessageContent(
|
||||
contextSystemPrompt as any,
|
||||
nodes.map((r) => r.node),
|
||||
undefined,
|
||||
MetadataMode.LLM,
|
||||
);
|
||||
|
||||
res.status(200).json({ message: extractText(content) });
|
||||
} catch (error) {
|
||||
console.error('[context] Error:', error);
|
||||
return res.status(500).json({
|
||||
message: (error as Error).message,
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
import { VectorStoreIndex } from 'llamaindex';
|
||||
import { storageContextFromDefaults } from 'llamaindex/storage/StorageContext';
|
||||
|
||||
import * as dotenv from 'dotenv';
|
||||
|
||||
import { getDocuments } from './loader';
|
||||
import { initSettings } from './settings';
|
||||
import { STORAGE_CACHE_DIR } from './shared';
|
||||
|
||||
// Load environment variables from local .env file
|
||||
dotenv.config();
|
||||
|
||||
async function getRuntime(func: any) {
|
||||
const start = Date.now();
|
||||
await func();
|
||||
const end = Date.now();
|
||||
return end - start;
|
||||
}
|
||||
|
||||
async function generateDatasource() {
|
||||
console.log(`Generating storage context...`);
|
||||
// Split documents, create embeddings and store them in the storage context
|
||||
const ms = await getRuntime(async () => {
|
||||
const storageContext = await storageContextFromDefaults({
|
||||
persistDir: STORAGE_CACHE_DIR,
|
||||
});
|
||||
const documents = await getDocuments();
|
||||
|
||||
await VectorStoreIndex.fromDocuments(documents, {
|
||||
storageContext,
|
||||
});
|
||||
});
|
||||
console.log(`Storage context successfully generated in ${ms / 1000}s.`);
|
||||
}
|
||||
|
||||
(async () => {
|
||||
initSettings();
|
||||
await generateDatasource();
|
||||
console.log('Finished generating storage.');
|
||||
})();
|
||||
@@ -0,0 +1,19 @@
|
||||
import { SimpleDocumentStore, VectorStoreIndex } from 'llamaindex';
|
||||
import { storageContextFromDefaults } from 'llamaindex/storage/StorageContext';
|
||||
import { STORAGE_CACHE_DIR } from './shared';
|
||||
|
||||
export async function getDataSource(params?: any) {
|
||||
const storageContext = await storageContextFromDefaults({
|
||||
persistDir: `${STORAGE_CACHE_DIR}`,
|
||||
});
|
||||
|
||||
const numberOfDocs = Object.keys(
|
||||
(storageContext.docStore as SimpleDocumentStore).toDict(),
|
||||
).length;
|
||||
if (numberOfDocs === 0) {
|
||||
return null;
|
||||
}
|
||||
return await VectorStoreIndex.init({
|
||||
storageContext,
|
||||
});
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
import {
|
||||
FILE_EXT_TO_READER,
|
||||
SimpleDirectoryReader,
|
||||
} from 'llamaindex/readers/SimpleDirectoryReader';
|
||||
|
||||
export const DATA_DIR = './data';
|
||||
|
||||
export function getExtractors() {
|
||||
return FILE_EXT_TO_READER;
|
||||
}
|
||||
|
||||
export async function getDocuments() {
|
||||
const documents = await new SimpleDirectoryReader().loadData({
|
||||
directoryPath: DATA_DIR,
|
||||
});
|
||||
// Set private=false to mark the document as public (required for filtering)
|
||||
for (const document of documents) {
|
||||
document.metadata = {
|
||||
...document.metadata,
|
||||
private: 'false',
|
||||
};
|
||||
}
|
||||
return documents;
|
||||
}
|
||||
@@ -0,0 +1,179 @@
|
||||
import {
|
||||
ALL_AVAILABLE_MISTRAL_MODELS,
|
||||
Anthropic,
|
||||
GEMINI_EMBEDDING_MODEL,
|
||||
GEMINI_MODEL,
|
||||
Gemini,
|
||||
GeminiEmbedding,
|
||||
Groq,
|
||||
MistralAI,
|
||||
MistralAIEmbedding,
|
||||
MistralAIEmbeddingModelType,
|
||||
OpenAI,
|
||||
OpenAIEmbedding,
|
||||
Settings,
|
||||
} from 'llamaindex';
|
||||
import { HuggingFaceEmbedding } from 'llamaindex/embeddings/HuggingFaceEmbedding';
|
||||
import { OllamaEmbedding } from 'llamaindex/embeddings/OllamaEmbedding';
|
||||
import { ALL_AVAILABLE_ANTHROPIC_MODELS } from 'llamaindex/llm/anthropic';
|
||||
import { Ollama } from 'llamaindex/llm/ollama';
|
||||
|
||||
const CHUNK_SIZE = 512;
|
||||
const CHUNK_OVERLAP = 20;
|
||||
|
||||
export const initSettings = async () => {
|
||||
// HINT: you can delete the initialization code for unused model providers
|
||||
console.log(`Using '${process.env.MODEL_PROVIDER}' model provider`);
|
||||
|
||||
if (!process.env.MODEL || !process.env.EMBEDDING_MODEL) {
|
||||
throw new Error("'MODEL' and 'EMBEDDING_MODEL' env variables must be set.");
|
||||
}
|
||||
|
||||
switch (process.env.MODEL_PROVIDER) {
|
||||
case 'ollama':
|
||||
initOllama();
|
||||
break;
|
||||
case 'groq':
|
||||
initGroq();
|
||||
break;
|
||||
case 'anthropic':
|
||||
initAnthropic();
|
||||
break;
|
||||
case 'gemini':
|
||||
initGemini();
|
||||
break;
|
||||
case 'mistral':
|
||||
initMistralAI();
|
||||
break;
|
||||
case 'azure-openai':
|
||||
initAzureOpenAI();
|
||||
break;
|
||||
default:
|
||||
initOpenAI();
|
||||
break;
|
||||
}
|
||||
Settings.chunkSize = CHUNK_SIZE;
|
||||
Settings.chunkOverlap = CHUNK_OVERLAP;
|
||||
};
|
||||
|
||||
function initOpenAI() {
|
||||
Settings.llm = new OpenAI({
|
||||
model: process.env.MODEL ?? 'gpt-4o-mini',
|
||||
maxTokens: process.env.LLM_MAX_TOKENS
|
||||
? Number(process.env.LLM_MAX_TOKENS)
|
||||
: undefined,
|
||||
});
|
||||
Settings.embedModel = new OpenAIEmbedding({
|
||||
model: process.env.EMBEDDING_MODEL,
|
||||
dimensions: process.env.EMBEDDING_DIM
|
||||
? parseInt(process.env.EMBEDDING_DIM)
|
||||
: undefined,
|
||||
});
|
||||
}
|
||||
|
||||
function initAzureOpenAI() {
|
||||
// Map Azure OpenAI model names to OpenAI model names (only for TS)
|
||||
const AZURE_OPENAI_MODEL_MAP: Record<string, string> = {
|
||||
'gpt-35-turbo': 'gpt-3.5-turbo',
|
||||
'gpt-35-turbo-16k': 'gpt-3.5-turbo-16k',
|
||||
'gpt-4o': 'gpt-4o',
|
||||
'gpt-4': 'gpt-4',
|
||||
'gpt-4-32k': 'gpt-4-32k',
|
||||
'gpt-4-turbo': 'gpt-4-turbo',
|
||||
'gpt-4-turbo-2024-04-09': 'gpt-4-turbo',
|
||||
'gpt-4-vision-preview': 'gpt-4-vision-preview',
|
||||
'gpt-4-1106-preview': 'gpt-4-1106-preview',
|
||||
'gpt-4o-2024-05-13': 'gpt-4o-2024-05-13',
|
||||
};
|
||||
|
||||
const azureConfig = {
|
||||
apiKey: process.env.AZURE_OPENAI_KEY,
|
||||
endpoint: process.env.AZURE_OPENAI_ENDPOINT,
|
||||
apiVersion:
|
||||
process.env.AZURE_OPENAI_API_VERSION || process.env.OPENAI_API_VERSION,
|
||||
};
|
||||
|
||||
Settings.llm = new OpenAI({
|
||||
model:
|
||||
AZURE_OPENAI_MODEL_MAP[process.env.MODEL ?? 'gpt-35-turbo'] ??
|
||||
'gpt-3.5-turbo',
|
||||
maxTokens: process.env.LLM_MAX_TOKENS
|
||||
? Number(process.env.LLM_MAX_TOKENS)
|
||||
: undefined,
|
||||
azure: {
|
||||
...azureConfig,
|
||||
deployment: process.env.AZURE_OPENAI_LLM_DEPLOYMENT,
|
||||
},
|
||||
});
|
||||
|
||||
Settings.embedModel = new OpenAIEmbedding({
|
||||
model: process.env.EMBEDDING_MODEL,
|
||||
dimensions: process.env.EMBEDDING_DIM
|
||||
? parseInt(process.env.EMBEDDING_DIM)
|
||||
: undefined,
|
||||
azure: {
|
||||
...azureConfig,
|
||||
deployment: process.env.AZURE_OPENAI_EMBEDDING_DEPLOYMENT,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
function initOllama() {
|
||||
const config = {
|
||||
host: process.env.OLLAMA_BASE_URL ?? 'http://127.0.0.1:11434',
|
||||
};
|
||||
Settings.llm = new Ollama({
|
||||
model: process.env.MODEL ?? '',
|
||||
config,
|
||||
});
|
||||
Settings.embedModel = new OllamaEmbedding({
|
||||
model: process.env.EMBEDDING_MODEL ?? '',
|
||||
config,
|
||||
});
|
||||
}
|
||||
|
||||
function initGroq() {
|
||||
const embedModelMap: Record<string, string> = {
|
||||
'all-MiniLM-L6-v2': 'Xenova/all-MiniLM-L6-v2',
|
||||
'all-mpnet-base-v2': 'Xenova/all-mpnet-base-v2',
|
||||
};
|
||||
|
||||
Settings.llm = new Groq({
|
||||
model: process.env.MODEL!,
|
||||
});
|
||||
|
||||
Settings.embedModel = new HuggingFaceEmbedding({
|
||||
modelType: embedModelMap[process.env.EMBEDDING_MODEL!],
|
||||
});
|
||||
}
|
||||
|
||||
function initAnthropic() {
|
||||
const embedModelMap: Record<string, string> = {
|
||||
'all-MiniLM-L6-v2': 'Xenova/all-MiniLM-L6-v2',
|
||||
'all-mpnet-base-v2': 'Xenova/all-mpnet-base-v2',
|
||||
};
|
||||
Settings.llm = new Anthropic({
|
||||
model: process.env.MODEL as keyof typeof ALL_AVAILABLE_ANTHROPIC_MODELS,
|
||||
});
|
||||
Settings.embedModel = new HuggingFaceEmbedding({
|
||||
modelType: embedModelMap[process.env.EMBEDDING_MODEL!],
|
||||
});
|
||||
}
|
||||
|
||||
function initGemini() {
|
||||
Settings.llm = new Gemini({
|
||||
model: process.env.MODEL as GEMINI_MODEL,
|
||||
});
|
||||
Settings.embedModel = new GeminiEmbedding({
|
||||
model: process.env.EMBEDDING_MODEL as GEMINI_EMBEDDING_MODEL,
|
||||
});
|
||||
}
|
||||
|
||||
function initMistralAI() {
|
||||
Settings.llm = new MistralAI({
|
||||
model: process.env.MODEL as keyof typeof ALL_AVAILABLE_MISTRAL_MODELS,
|
||||
});
|
||||
Settings.embedModel = new MistralAIEmbedding({
|
||||
model: process.env.EMBEDDING_MODEL as MistralAIEmbeddingModelType,
|
||||
});
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
export const STORAGE_CACHE_DIR = './cache';
|
||||
Reference in New Issue
Block a user