mirror of
https://github.com/cloudstack-llc/mlx-knife.git
synced 2026-07-01 20:44:14 -04:00
7f10187bee
Core: - Run preflight passes probe/framework to audio_runtime_compatibility - STT model_type gate extended (vibevoice, audio) - MLX 0.30.x compat: catch Exception in whisper_tokenizer - Embedding-gate unit tests (3 tests) - Removed get_encoding duplication (-45 LOC) Benchmark: - GPU analysis section in reports
438 lines
16 KiB
Python
438 lines
16 KiB
Python
"""
|
|
Audio runner wrapping mlx-audio for STT transcription (ADR-020).
|
|
|
|
Dedicated AudioRunner for speech-to-text models (Whisper, Voxtral, VibeVoice).
|
|
Multimodal audio models (Gemma-3n, Qwen3-Omni) use VisionRunner instead.
|
|
|
|
Backend routing: config-based detection determines MLX_AUDIO vs MLX_VLM.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import tempfile
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Sequence, Tuple
|
|
|
|
from ..operations.workspace import is_workspace_path
|
|
|
|
|
|
# ============================================================================
|
|
# CRITICAL: Monkey-patch mlx-audio tokenizer BEFORE any imports
|
|
# ============================================================================
|
|
# Workaround for mlx-audio Issue #479: tiktoken assets were removed in commit
|
|
# f7328a4 (Jan 29, 2026), but code still tries to load them.
|
|
#
|
|
# We bundle the assets from commit 9349644 in mlxk2/assets/whisper/ and patch
|
|
# get_encoding() to use them instead.
|
|
#
|
|
# This MUST happen at module import time, before any mlx-audio code runs!
|
|
# ============================================================================
|
|
|
|
def _apply_tiktoken_patch():
|
|
"""Apply tiktoken asset patch globally at module import time.
|
|
|
|
Patches mlx-audio's get_encoding() to use our bundled tiktoken files
|
|
from mlxk2/assets/whisper/ instead of the removed upstream assets.
|
|
|
|
The actual implementation lives in mlxk2.audio.whisper_tokenizer to avoid
|
|
code duplication. This function just wires it up as a monkey-patch.
|
|
"""
|
|
try:
|
|
# Import mlx-audio's tokenizer module (target of the patch)
|
|
import mlx_audio.stt.models.whisper.tokenizer as mlx_whisper_tokenizer
|
|
|
|
# Import our bundled implementation
|
|
from mlxk2.audio.whisper_tokenizer import get_encoding
|
|
|
|
# Patch the module globally
|
|
mlx_whisper_tokenizer.get_encoding = get_encoding
|
|
|
|
except ImportError:
|
|
# mlx-audio not installed - skip patching
|
|
pass
|
|
|
|
|
|
# Apply patch immediately at module import
|
|
_apply_tiktoken_patch()
|
|
|
|
|
|
# ============================================================================
|
|
# CRITICAL: Patch Model.get_tokenizer() to use our tiktoken tokenizer
|
|
# ============================================================================
|
|
# mlx-audio 0.3.1 (PyPI) removed the get_tokenizer() function that creates
|
|
# tiktoken-based tokenizers. The Model.get_tokenizer() method now throws
|
|
# ValueError if no HuggingFace processor is available.
|
|
#
|
|
# We patch Model.get_tokenizer() to fall back to our bundled tokenizer
|
|
# (mlxk2.audio.whisper_tokenizer) when no processor is available.
|
|
#
|
|
# This MUST happen at module import time, before any Whisper model is loaded!
|
|
# ============================================================================
|
|
|
|
def _apply_whisper_tokenizer_patch():
|
|
"""Patch Model.get_tokenizer to fall back to our tiktoken tokenizer."""
|
|
try:
|
|
from mlx_audio.stt.models.whisper.whisper import Model as WhisperModel
|
|
from mlxk2.audio.whisper_tokenizer import get_tokenizer
|
|
|
|
# Store original method (if it exists and isn't already patched)
|
|
if hasattr(WhisperModel, "_mlxk_original_get_tokenizer"):
|
|
# Already patched
|
|
return
|
|
|
|
original_get_tokenizer = WhisperModel.get_tokenizer
|
|
|
|
def patched_get_tokenizer(self, language=None, task="transcribe"):
|
|
"""Patched get_tokenizer with tiktoken fallback.
|
|
|
|
First tries the original method (uses HF Processor if available).
|
|
Falls back to our bundled tiktoken-based tokenizer on failure.
|
|
"""
|
|
# Try original first (uses HF Processor if available)
|
|
if hasattr(self, "_processor") and self._processor is not None:
|
|
try:
|
|
return original_get_tokenizer(self, language, task)
|
|
except Exception:
|
|
# HF Processor failed, fall through to tiktoken
|
|
pass
|
|
|
|
# Fallback to our tiktoken-based tokenizer
|
|
return get_tokenizer(
|
|
self.is_multilingual,
|
|
num_languages=getattr(self, "num_languages", 99),
|
|
language=language,
|
|
task=task,
|
|
)
|
|
|
|
# Apply patch
|
|
WhisperModel.get_tokenizer = patched_get_tokenizer
|
|
WhisperModel._mlxk_original_get_tokenizer = original_get_tokenizer
|
|
|
|
except ImportError:
|
|
# mlx-audio not installed - skip patching
|
|
pass
|
|
|
|
|
|
# Apply Whisper tokenizer patch immediately at module import
|
|
_apply_whisper_tokenizer_patch()
|
|
|
|
|
|
class AudioRunner:
|
|
"""Wrapper around mlx-audio STT API for dedicated transcription models.
|
|
|
|
Supports:
|
|
- Whisper variants (large-v3-turbo, base, small, etc.)
|
|
- Voxtral (mini, small)
|
|
- VibeVoice-ASR
|
|
|
|
Usage:
|
|
with AudioRunner(model_path, model_name, verbose) as runner:
|
|
result = runner.transcribe(audio=[("file.wav", audio_bytes)])
|
|
"""
|
|
|
|
def __init__(self, model_path: Path, model_name: str, verbose: bool = False):
|
|
self.model_path = Path(model_path)
|
|
self.model_name = model_name # HF repo_id or workspace path
|
|
self.verbose = verbose
|
|
self.model = None
|
|
self.processor = None
|
|
self._generate_fn = None
|
|
self._load_fn = None
|
|
self._temp_files: List[str] = [] # Track created temp files for cleanup
|
|
|
|
def __enter__(self):
|
|
self.load_model()
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
self._cleanup_temp_files()
|
|
return False
|
|
|
|
def _cleanup_temp_files(self):
|
|
"""Remove all temporary audio files created during transcription."""
|
|
for path in self._temp_files:
|
|
try:
|
|
if os.path.exists(path):
|
|
os.unlink(path)
|
|
except Exception:
|
|
# Ignore cleanup errors (best effort)
|
|
pass
|
|
self._temp_files.clear()
|
|
|
|
def load_model(self):
|
|
"""Load the audio model and processor.
|
|
|
|
Supports both HF cache models and workspace paths.
|
|
"""
|
|
# Suppress HF progress bars during loading (pull shows them)
|
|
prev_pbar = os.environ.get("HF_HUB_DISABLE_PROGRESS_BARS")
|
|
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
|
try:
|
|
self._load_model_impl()
|
|
finally:
|
|
if prev_pbar is None:
|
|
os.environ.pop("HF_HUB_DISABLE_PROGRESS_BARS", None)
|
|
else:
|
|
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = prev_pbar
|
|
|
|
def _load_model_impl(self):
|
|
"""Internal model loading - called with progress bars suppressed."""
|
|
try:
|
|
# Import mlx-audio STT module (0.3.0 API)
|
|
# Note: tiktoken patch was already applied at module import time
|
|
from mlx_audio.stt import load_model
|
|
from mlx_audio.stt.generate import generate_transcription
|
|
except ImportError as e:
|
|
raise RuntimeError(
|
|
f"Failed to import mlx-audio (audio backend): {e}\n"
|
|
"Install with: pip install mlx-knife[audio]"
|
|
) from e
|
|
|
|
self._generate_fn = generate_transcription
|
|
self._load_fn = load_model
|
|
|
|
# Check if model_path is a workspace directory
|
|
if is_workspace_path(self.model_path):
|
|
# Workspace path - load model directly
|
|
model_ref = str(self.model_path)
|
|
try:
|
|
self.model = self._load_fn(model_ref)
|
|
self.processor = None # Processor handled internally
|
|
except Exception as e:
|
|
# Extract error details (some exceptions have empty messages)
|
|
error_type = type(e).__name__
|
|
error_msg = str(e) if str(e) else f"{error_type} (no details)"
|
|
raise RuntimeError(f"Failed to load audio model from workspace: {error_msg}") from e
|
|
else:
|
|
# HF repo_id - defer loading to transcribe() (high-level API)
|
|
self.model = None
|
|
self.processor = None
|
|
|
|
def _write_temp_audio(self, filename: str, audio_bytes: bytes) -> str:
|
|
"""Write audio bytes to a temporary file.
|
|
|
|
mlx-audio expects file paths, not bytes. We write to temp files
|
|
and track them for cleanup.
|
|
|
|
Args:
|
|
filename: Original filename (for extension detection)
|
|
audio_bytes: Raw audio data
|
|
|
|
Returns:
|
|
Path to temporary file
|
|
"""
|
|
suffix = Path(filename).suffix or ".wav"
|
|
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
|
|
tmp.write(audio_bytes)
|
|
tmp.flush()
|
|
tmp.close()
|
|
self._temp_files.append(tmp.name)
|
|
return tmp.name
|
|
|
|
def transcribe(
|
|
self,
|
|
audio: Sequence[Tuple[str, bytes]],
|
|
prompt: Optional[str] = None,
|
|
max_tokens: int = 4096, # Ignored (Whisper generates full transcription)
|
|
temperature: float = 0.0,
|
|
language: Optional[str] = None,
|
|
) -> str:
|
|
"""Transcribe audio files to text.
|
|
|
|
Args:
|
|
audio: List of (filename, bytes) tuples for audio files
|
|
prompt: Optional context for transcription (improves domain-specific accuracy)
|
|
max_tokens: Ignored (Whisper generates full transcription automatically)
|
|
temperature: Sampling temperature (0.0 = deterministic, best for accuracy)
|
|
language: Language code (e.g., 'en', 'de'). Auto-detect if None.
|
|
|
|
Returns:
|
|
Transcription text. If MLXK2_AUDIO_SEGMENTS=1, includes segment table.
|
|
"""
|
|
if not audio:
|
|
return ""
|
|
|
|
# Prepare audio file paths
|
|
audio_paths = []
|
|
for filename, audio_bytes in audio:
|
|
path = self._write_temp_audio(filename, audio_bytes)
|
|
audio_paths.append(path)
|
|
|
|
try:
|
|
all_transcriptions = []
|
|
|
|
for audio_path in audio_paths:
|
|
result = self._transcribe_single(
|
|
audio_path=audio_path,
|
|
prompt=prompt,
|
|
max_tokens=max_tokens,
|
|
temperature=temperature,
|
|
language=language,
|
|
)
|
|
all_transcriptions.append(result)
|
|
|
|
# Combine results (newline-separated for multiple files)
|
|
combined = "\n\n".join(all_transcriptions)
|
|
return combined.strip()
|
|
|
|
except Exception as e:
|
|
error_type = type(e).__name__
|
|
error_msg = str(e) if str(e) else f"{error_type} (no details)"
|
|
raise RuntimeError(f"mlx-audio transcribe() failed: {error_msg}") from e
|
|
finally:
|
|
# Clean up temp files after transcription
|
|
self._cleanup_temp_files()
|
|
|
|
def _transcribe_single(
|
|
self,
|
|
audio_path: str,
|
|
prompt: Optional[str] = None,
|
|
max_tokens: int = 4096, # Ignored
|
|
temperature: float = 0.0,
|
|
language: Optional[str] = None,
|
|
) -> str:
|
|
"""Transcribe a single audio file.
|
|
|
|
Uses generate_transcription() with either pre-loaded model (workspace)
|
|
or model name (HF cache).
|
|
"""
|
|
try:
|
|
# Build kwargs for generate_transcription
|
|
gen_kwargs = {
|
|
"audio": audio_path,
|
|
"verbose": self.verbose,
|
|
}
|
|
|
|
if is_workspace_path(self.model_path):
|
|
# Workspace path - use pre-loaded model
|
|
if self.model is None:
|
|
raise RuntimeError("Model not loaded. Call load_model() first.")
|
|
gen_kwargs["model"] = self.model
|
|
else:
|
|
# HF repo_id - pass model name (handles loading internally)
|
|
gen_kwargs["model"] = self.model_name
|
|
|
|
# Add Whisper generation parameters (via **kwargs → model.generate())
|
|
# These are filtered by generate_transcription() to match model.generate() signature
|
|
if prompt:
|
|
gen_kwargs["initial_prompt"] = prompt
|
|
if temperature is not None:
|
|
gen_kwargs["temperature"] = temperature
|
|
if language:
|
|
gen_kwargs["language"] = language
|
|
|
|
# Optimize for batch STT (not streaming)
|
|
# chunk_duration=30.0: Process 30s chunks (Whisper's max context window)
|
|
# For 60min podcasts, this provides best accuracy vs latency balance
|
|
gen_kwargs["chunk_duration"] = 30.0
|
|
|
|
# Call generate_transcription
|
|
result = self._generate_fn(**gen_kwargs)
|
|
|
|
# Extract transcription text
|
|
text = self._extract_text(result)
|
|
|
|
# Optionally add segment metadata (MLXK2_AUDIO_SEGMENTS=1)
|
|
if os.environ.get("MLXK2_AUDIO_SEGMENTS") == "1":
|
|
segments = self._extract_segments(result)
|
|
if segments:
|
|
text = self._add_segment_metadata(text, segments)
|
|
|
|
return text
|
|
|
|
except Exception as e:
|
|
error_type = type(e).__name__
|
|
error_msg = str(e) if str(e) else f"{error_type} (no details)"
|
|
raise RuntimeError(f"Transcription failed for {audio_path}: {error_msg}") from e
|
|
|
|
def _extract_text(self, result) -> str:
|
|
"""Extract transcription text from result object.
|
|
|
|
mlx-audio returns various formats depending on model/version.
|
|
"""
|
|
if result is None:
|
|
return ""
|
|
|
|
# String result
|
|
if isinstance(result, str):
|
|
return result
|
|
|
|
# Dict with 'text' key
|
|
if isinstance(result, dict):
|
|
text = result.get("text", "")
|
|
if isinstance(text, str):
|
|
return text
|
|
|
|
# Object with 'text' attribute
|
|
if hasattr(result, "text"):
|
|
text = result.text
|
|
if isinstance(text, str):
|
|
return text
|
|
|
|
# Fallback: string conversion
|
|
return str(result)
|
|
|
|
def _extract_segments(self, result) -> Optional[List[Dict]]:
|
|
"""Extract segment data from result (if available).
|
|
|
|
Whisper models provide segments with timestamps:
|
|
[{"start": 0.0, "end": 2.34, "text": "..."}, ...]
|
|
|
|
VibeVoice-ASR provides speaker diarization:
|
|
[{"start_time": 0.0, "end_time": 2.5, "text": "...", "speaker_id": 0}, ...]
|
|
"""
|
|
if result is None:
|
|
return None
|
|
|
|
segments = None
|
|
|
|
# Dict with 'segments' key
|
|
if isinstance(result, dict):
|
|
segments = result.get("segments")
|
|
|
|
# Object with 'segments' attribute
|
|
elif hasattr(result, "segments"):
|
|
segments = result.segments
|
|
|
|
# Validate segments format
|
|
if segments and isinstance(segments, list) and len(segments) > 0:
|
|
# Check if first segment has expected keys
|
|
first = segments[0]
|
|
if isinstance(first, dict) and ("start" in first or "start_time" in first):
|
|
return segments
|
|
|
|
return None
|
|
|
|
def _add_segment_metadata(self, text: str, segments: List[Dict]) -> str:
|
|
"""Add segment timestamps as collapsible HTML table.
|
|
|
|
Format matches VisionRunner's image metadata table (collapsible).
|
|
"""
|
|
count = len(segments)
|
|
lines = [
|
|
"<details>",
|
|
f"<summary>Audio Segments ({count} segment{'s' if count != 1 else ''})</summary>",
|
|
"",
|
|
"| Start | End | Text |",
|
|
"|-------|-----|------|",
|
|
]
|
|
|
|
for seg in segments:
|
|
# Handle both Whisper format (start/end) and VibeVoice format (start_time/end_time)
|
|
start = seg.get("start") or seg.get("start_time", 0)
|
|
end = seg.get("end") or seg.get("end_time", 0)
|
|
seg_text = seg.get("text", "").strip()
|
|
|
|
# Escape pipe characters in text
|
|
seg_text = seg_text.replace("|", "\\|")
|
|
|
|
lines.append(f"| {start:.2f}s | {end:.2f}s | {seg_text} |")
|
|
|
|
lines.append("")
|
|
lines.append("</details>")
|
|
lines.append("")
|
|
|
|
# Segments go after the transcription (metadata is supplementary)
|
|
return text + "\n\n" + "\n".join(lines)
|