mirror of
https://github.com/cloudstack-llc/mlx-knife.git
synced 2026-07-01 20:44:14 -04:00
bf7480d042
Major Features: - Audio transcription via mlx-audio backend (Whisper, >10min duration) - OpenAI /v1/audio/transcriptions endpoint - Memory Gate System (Vision: 8GB, Audio: 4GB) - Config-based backend routing (ADR-020) - Benchmark toolchain (memmon/memplot, Schema v0.2.2) Key Fixes: - EuroLLM tokenizer decoding - Vision-model text-only routing regression - Multimodal model context length detection - Memory cleanup bug (mx.metal.clear_cache) - Orphan process bug Test Results: - Unit tests: 647 passed, 11 skipped (Python 3.10-3.12) - wet-umbrella: 171 passed total See CHANGELOG.md for complete details and known issues.
336 lines
12 KiB
Python
336 lines
12 KiB
Python
"""
|
|
Audio runner wrapping mlx-audio for STT transcription (ADR-020).
|
|
|
|
Dedicated AudioRunner for speech-to-text models (Whisper, Voxtral, VibeVoice).
|
|
Multimodal audio models (Gemma-3n, Qwen3-Omni) use VisionRunner instead.
|
|
|
|
Backend routing: config-based detection determines MLX_AUDIO vs MLX_VLM.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import tempfile
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Sequence, Tuple
|
|
|
|
from ..operations.workspace import is_workspace_path
|
|
|
|
|
|
class AudioRunner:
|
|
"""Wrapper around mlx-audio STT API for dedicated transcription models.
|
|
|
|
Supports:
|
|
- Whisper variants (large-v3-turbo, base, small, etc.)
|
|
- Voxtral (mini, small)
|
|
- VibeVoice-ASR
|
|
|
|
Usage:
|
|
with AudioRunner(model_path, model_name, verbose) as runner:
|
|
result = runner.transcribe(audio=[("file.wav", audio_bytes)])
|
|
"""
|
|
|
|
def __init__(self, model_path: Path, model_name: str, verbose: bool = False):
|
|
self.model_path = Path(model_path)
|
|
self.model_name = model_name # HF repo_id or workspace path
|
|
self.verbose = verbose
|
|
self.model = None
|
|
self.processor = None
|
|
self._generate_fn = None
|
|
self._load_fn = None
|
|
self._temp_files: List[str] = [] # Track created temp files for cleanup
|
|
|
|
def __enter__(self):
|
|
self.load_model()
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
self._cleanup_temp_files()
|
|
return False
|
|
|
|
def _cleanup_temp_files(self):
|
|
"""Remove all temporary audio files created during transcription."""
|
|
for path in self._temp_files:
|
|
try:
|
|
if os.path.exists(path):
|
|
os.unlink(path)
|
|
except Exception:
|
|
# Ignore cleanup errors (best effort)
|
|
pass
|
|
self._temp_files.clear()
|
|
|
|
def load_model(self):
|
|
"""Load the audio model and processor.
|
|
|
|
Supports both HF cache models and workspace paths.
|
|
"""
|
|
# Suppress HF progress bars during loading (pull shows them)
|
|
prev_pbar = os.environ.get("HF_HUB_DISABLE_PROGRESS_BARS")
|
|
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
|
try:
|
|
self._load_model_impl()
|
|
finally:
|
|
if prev_pbar is None:
|
|
os.environ.pop("HF_HUB_DISABLE_PROGRESS_BARS", None)
|
|
else:
|
|
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = prev_pbar
|
|
|
|
def _load_model_impl(self):
|
|
"""Internal model loading - called with progress bars suppressed."""
|
|
try:
|
|
# Import mlx-audio STT module (0.3.0 API)
|
|
from mlx_audio.stt import load_model
|
|
from mlx_audio.stt.generate import generate_transcription
|
|
except ImportError as e:
|
|
raise RuntimeError(
|
|
f"Failed to import mlx-audio (audio backend): {e}\n"
|
|
"Install with: pip install mlx-knife[audio]"
|
|
) from e
|
|
|
|
self._generate_fn = generate_transcription
|
|
self._load_fn = load_model
|
|
|
|
# Check if model_path is a workspace directory
|
|
if is_workspace_path(self.model_path):
|
|
# Workspace path - load model directly
|
|
model_ref = str(self.model_path)
|
|
try:
|
|
self.model = self._load_fn(model_ref)
|
|
self.processor = None # Processor handled internally
|
|
except Exception as e:
|
|
# Extract error details (some exceptions have empty messages)
|
|
error_type = type(e).__name__
|
|
error_msg = str(e) if str(e) else f"{error_type} (no details)"
|
|
raise RuntimeError(f"Failed to load audio model from workspace: {error_msg}") from e
|
|
else:
|
|
# HF repo_id - defer loading to transcribe() (high-level API)
|
|
self.model = None
|
|
self.processor = None
|
|
|
|
def _write_temp_audio(self, filename: str, audio_bytes: bytes) -> str:
|
|
"""Write audio bytes to a temporary file.
|
|
|
|
mlx-audio expects file paths, not bytes. We write to temp files
|
|
and track them for cleanup.
|
|
|
|
Args:
|
|
filename: Original filename (for extension detection)
|
|
audio_bytes: Raw audio data
|
|
|
|
Returns:
|
|
Path to temporary file
|
|
"""
|
|
suffix = Path(filename).suffix or ".wav"
|
|
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
|
|
tmp.write(audio_bytes)
|
|
tmp.flush()
|
|
tmp.close()
|
|
self._temp_files.append(tmp.name)
|
|
return tmp.name
|
|
|
|
def transcribe(
|
|
self,
|
|
audio: Sequence[Tuple[str, bytes]],
|
|
prompt: Optional[str] = None,
|
|
max_tokens: int = 4096, # Ignored (Whisper generates full transcription)
|
|
temperature: float = 0.0,
|
|
language: Optional[str] = None,
|
|
) -> str:
|
|
"""Transcribe audio files to text.
|
|
|
|
Args:
|
|
audio: List of (filename, bytes) tuples for audio files
|
|
prompt: Optional context for transcription (improves domain-specific accuracy)
|
|
max_tokens: Ignored (Whisper generates full transcription automatically)
|
|
temperature: Sampling temperature (0.0 = deterministic, best for accuracy)
|
|
language: Language code (e.g., 'en', 'de'). Auto-detect if None.
|
|
|
|
Returns:
|
|
Transcription text. If MLXK2_AUDIO_SEGMENTS=1, includes segment table.
|
|
"""
|
|
if not audio:
|
|
return ""
|
|
|
|
# Prepare audio file paths
|
|
audio_paths = []
|
|
for filename, audio_bytes in audio:
|
|
path = self._write_temp_audio(filename, audio_bytes)
|
|
audio_paths.append(path)
|
|
|
|
try:
|
|
all_transcriptions = []
|
|
|
|
for audio_path in audio_paths:
|
|
result = self._transcribe_single(
|
|
audio_path=audio_path,
|
|
prompt=prompt,
|
|
max_tokens=max_tokens,
|
|
temperature=temperature,
|
|
language=language,
|
|
)
|
|
all_transcriptions.append(result)
|
|
|
|
# Combine results (newline-separated for multiple files)
|
|
combined = "\n\n".join(all_transcriptions)
|
|
return combined.strip()
|
|
|
|
except Exception as e:
|
|
error_type = type(e).__name__
|
|
error_msg = str(e) if str(e) else f"{error_type} (no details)"
|
|
raise RuntimeError(f"mlx-audio transcribe() failed: {error_msg}") from e
|
|
finally:
|
|
# Clean up temp files after transcription
|
|
self._cleanup_temp_files()
|
|
|
|
def _transcribe_single(
|
|
self,
|
|
audio_path: str,
|
|
prompt: Optional[str] = None,
|
|
max_tokens: int = 4096, # Ignored
|
|
temperature: float = 0.0,
|
|
language: Optional[str] = None,
|
|
) -> str:
|
|
"""Transcribe a single audio file.
|
|
|
|
Uses generate_transcription() with either pre-loaded model (workspace)
|
|
or model name (HF cache).
|
|
"""
|
|
try:
|
|
# Build kwargs for generate_transcription
|
|
gen_kwargs = {
|
|
"audio": audio_path,
|
|
"verbose": self.verbose,
|
|
}
|
|
|
|
if is_workspace_path(self.model_path):
|
|
# Workspace path - use pre-loaded model
|
|
if self.model is None:
|
|
raise RuntimeError("Model not loaded. Call load_model() first.")
|
|
gen_kwargs["model"] = self.model
|
|
else:
|
|
# HF repo_id - pass model name (handles loading internally)
|
|
gen_kwargs["model"] = self.model_name
|
|
|
|
# Add Whisper generation parameters (via **kwargs → model.generate())
|
|
# These are filtered by generate_transcription() to match model.generate() signature
|
|
if prompt:
|
|
gen_kwargs["initial_prompt"] = prompt
|
|
if temperature is not None:
|
|
gen_kwargs["temperature"] = temperature
|
|
if language:
|
|
gen_kwargs["language"] = language
|
|
|
|
# Optimize for batch STT (not streaming)
|
|
# chunk_duration=30.0: Process 30s chunks (Whisper's max context window)
|
|
# For 60min podcasts, this provides best accuracy vs latency balance
|
|
gen_kwargs["chunk_duration"] = 30.0
|
|
|
|
# Call generate_transcription
|
|
result = self._generate_fn(**gen_kwargs)
|
|
|
|
# Extract transcription text
|
|
text = self._extract_text(result)
|
|
|
|
# Optionally add segment metadata (MLXK2_AUDIO_SEGMENTS=1)
|
|
if os.environ.get("MLXK2_AUDIO_SEGMENTS") == "1":
|
|
segments = self._extract_segments(result)
|
|
if segments:
|
|
text = self._add_segment_metadata(text, segments)
|
|
|
|
return text
|
|
|
|
except Exception as e:
|
|
error_type = type(e).__name__
|
|
error_msg = str(e) if str(e) else f"{error_type} (no details)"
|
|
raise RuntimeError(f"Transcription failed for {audio_path}: {error_msg}") from e
|
|
|
|
def _extract_text(self, result) -> str:
|
|
"""Extract transcription text from result object.
|
|
|
|
mlx-audio returns various formats depending on model/version.
|
|
"""
|
|
if result is None:
|
|
return ""
|
|
|
|
# String result
|
|
if isinstance(result, str):
|
|
return result
|
|
|
|
# Dict with 'text' key
|
|
if isinstance(result, dict):
|
|
text = result.get("text", "")
|
|
if isinstance(text, str):
|
|
return text
|
|
|
|
# Object with 'text' attribute
|
|
if hasattr(result, "text"):
|
|
text = result.text
|
|
if isinstance(text, str):
|
|
return text
|
|
|
|
# Fallback: string conversion
|
|
return str(result)
|
|
|
|
def _extract_segments(self, result) -> Optional[List[Dict]]:
|
|
"""Extract segment data from result (if available).
|
|
|
|
Whisper models provide segments with timestamps:
|
|
[{"start": 0.0, "end": 2.34, "text": "..."}, ...]
|
|
|
|
VibeVoice-ASR provides speaker diarization:
|
|
[{"start_time": 0.0, "end_time": 2.5, "text": "...", "speaker_id": 0}, ...]
|
|
"""
|
|
if result is None:
|
|
return None
|
|
|
|
segments = None
|
|
|
|
# Dict with 'segments' key
|
|
if isinstance(result, dict):
|
|
segments = result.get("segments")
|
|
|
|
# Object with 'segments' attribute
|
|
elif hasattr(result, "segments"):
|
|
segments = result.segments
|
|
|
|
# Validate segments format
|
|
if segments and isinstance(segments, list) and len(segments) > 0:
|
|
# Check if first segment has expected keys
|
|
first = segments[0]
|
|
if isinstance(first, dict) and ("start" in first or "start_time" in first):
|
|
return segments
|
|
|
|
return None
|
|
|
|
def _add_segment_metadata(self, text: str, segments: List[Dict]) -> str:
|
|
"""Add segment timestamps as collapsible HTML table.
|
|
|
|
Format matches VisionRunner's image metadata table (collapsible).
|
|
"""
|
|
count = len(segments)
|
|
lines = [
|
|
"<details>",
|
|
f"<summary>Audio Segments ({count} segment{'s' if count != 1 else ''})</summary>",
|
|
"",
|
|
"| Start | End | Text |",
|
|
"|-------|-----|------|",
|
|
]
|
|
|
|
for seg in segments:
|
|
# Handle both Whisper format (start/end) and VibeVoice format (start_time/end_time)
|
|
start = seg.get("start") or seg.get("start_time", 0)
|
|
end = seg.get("end") or seg.get("end_time", 0)
|
|
seg_text = seg.get("text", "").strip()
|
|
|
|
# Escape pipe characters in text
|
|
seg_text = seg_text.replace("|", "\\|")
|
|
|
|
lines.append(f"| {start:.2f}s | {end:.2f}s | {seg_text} |")
|
|
|
|
lines.append("")
|
|
lines.append("</details>")
|
|
lines.append("")
|
|
|
|
# Segments go after the transcription (metadata is supplementary)
|
|
return text + "\n\n" + "\n".join(lines)
|