mirror of
https://github.com/cloudstack-llc/mlx-knife.git
synced 2026-07-01 20:44:14 -04:00
fix: P0 bugfixes + test infrastructure + benchmark metadata sync
P0 Bugfixes: - cache.py: Handle empty HF_HOME strings in get_current_cache_root() - clone.py: Remove obsolete _validate_same_volume() check - common.py: Use importlib.metadata instead of importing transformers Test Infrastructure: - runner/__init__.py: Replace "mock" fallback with clear RuntimeError - Fix mock paths in test_runner_core, test_token_limits, etc. - Add VISION_TEST_MODELS + AUDIO_TEST_MODELS fallbacks - Portfolio fixtures work with and without HF_HOME Benchmark Fixes: - Sort models/tests alphabetically instead of by regression % - Fix vision metadata drift: pixtral-12b-8bit → pixtral-12b-4bit Documentation: - ADR-022: Workspace-First Paradigm (draft) - ADR-018: Phase 2 details expanded - TESTING.md/TESTING-DETAILS.md: Fallback docs updated
This commit is contained in:
+39
-56
@@ -67,67 +67,27 @@ Total: 171 passed across all phases
|
||||
| Live list | `pytest -m live_list -v` | `live_list` (subset of `wet`) + Env: `HF_HOME` (user cache with models) | Tests list/health against user cache models | No (uses local cache) |
|
||||
| Clone offline | `pytest -k clone -v` | — | Clone offline tests (APFS validation, temp cache, CoW workflow); no network needed | No |
|
||||
| Live clone (ADR-007) | `pytest -m live_clone -v` | `live_clone` + Env: `MLXK2_LIVE_CLONE=1`, `HF_TOKEN`, `MLXK2_LIVE_CLONE_MODEL`, `MLXK2_LIVE_CLONE_WORKSPACE` | Real clone workflow: pull→temp cache→APFS same-volume clone→workspace (ADR-007 Phase 1 constraints: same volume + APFS required) | Yes |
|
||||
| Live stop tokens (ADR-009) | `pytest -m live_stop_tokens -v` | `live_stop_tokens` (required); Optional: `HF_HOME` (enables portfolio discovery) | Issue #32: Validates stop token behavior with real models. **With HF_HOME:** Portfolio Discovery auto-discovers all MLX chat models (filter: MLX+healthy+runtime+chat), RAM-aware skip, empirical report. **Without HF_HOME:** Uses 3 predefined models (see "Optional Setup" section for model requirements). | No (uses local cache) |
|
||||
| Live run | `pytest -m live_run -v` | `live_run` + Env: `MLXK2_USER_HF_HOME` or `HF_HOME` (user cache with `mlx-community/Phi-3-mini-4k-instruct-4bit`) | Regression tests for Issue #37: Validates private/org MLX model framework detection in run command (renames Phi-3 to simulate private-org model) | No (uses local cache) |
|
||||
| Live E2E (ADR-011) | `HF_HOME=/path/to/cache pytest -m live_e2e -v` | `live_e2e` (required) + Env: `HF_HOME` (optional, enables Portfolio Discovery); Requires: `httpx` installed | **✅ Working:** Server/HTTP/CLI validation with real models. Portfolio Discovery auto-discovers all MLX chat models via `mlxk list --json` (filter: MLX+healthy+runtime+chat), parametrized tests (one server per model), RAM-aware skip. | No (uses local cache) |
|
||||
| Vision CLI E2E (ADR-012) | `HF_HOME=/path/to/cache pytest -m live_e2e tests_2.0/live/test_vision_e2e_live.py -v` | `live_e2e` (required) + Env: `HF_HOME` (vision model in cache, e.g., pixtral-12b-8bit or Llama-3.2-Vision); Requires: `mlx-vlm` installed (Python 3.10+) | **✅ Working:** Deterministic vision queries validate actual image understanding (not hallucination). Tests: chess position reading (e6=black king), OCR text extraction (contract name), color recognition (blue mug), chart label reading (Y-axis), large image support (2.7MB). | No (uses local cache) |
|
||||
| Vision Server E2E (ADR-012 Phase 3) | `HF_HOME=/path/to/cache pytest -m live_e2e tests_2.0/live/test_vision_server_e2e.py -v` | `live_e2e` (required) + Env: `HF_HOME` (vision model in cache); Requires: `mlx-vlm` installed (Python 3.10+), `httpx` | **✅ Working:** Vision API over HTTP. Tests: Base64 image chat completion, streaming graceful degradation (SSE emulation), text request on vision model server. | No (uses local cache) |
|
||||
| Audio CLI E2E (ADR-020) | `HF_HOME=/path/to/cache pytest -m live_e2e tests_2.0/live/test_audio_e2e_live.py -v` | `live_e2e` (required) + Env: `HF_HOME` (audio model in cache, e.g., whisper-large-v3-turbo-4bit); Requires: `mlx-audio` installed (Python 3.10+) | **✅ Working:** Audio transcription with Whisper models (mlx-audio backend). Portfolio Discovery auto-discovers audio-capable models (`model_type: audio`). Tests: WAV/MP3 transcription, Server `/v1/audio/transcriptions` endpoint. **Note:** Gemma-3n requires workspace repair (not in portfolio). | No (uses local cache) |
|
||||
| Live stop tokens (ADR-009) | `pytest -m live_stop_tokens -v` | `live_stop_tokens`; Optional: `HF_HOME` | Issue #32: Stop token behavior. Uses Portfolio Discovery or fallback models (see below). | No |
|
||||
| Live run | `pytest -m live_run -v` | `live_run` + `HF_HOME` (needs Phi-3-mini) | Issue #37: Private/org MLX model framework detection. | No |
|
||||
| Live E2E (ADR-011) | `pytest -m live_e2e -v` | `live_e2e`; Optional: `HF_HOME`; Requires: `httpx` | Server/HTTP/CLI validation. Uses Portfolio Discovery or fallback models. | No |
|
||||
| Vision E2E (ADR-012) | `pytest -m live_e2e tests_2.0/live/test_vision*.py -v` | `live_e2e`; Optional: `HF_HOME`; Requires: `mlx-vlm` | Vision CLI + Server. Uses Portfolio Discovery or `pixtral-12b-4bit` fallback. | No |
|
||||
| Audio E2E (ADR-020) | `pytest -m live_e2e tests_2.0/live/test_audio*.py -v` | `live_e2e`; Optional: `HF_HOME`; Requires: `mlx-audio` | Audio transcription + Server. Uses Portfolio Discovery or `whisper` fallback. | No |
|
||||
| Resumable Pull | `MLXK2_TEST_RESUMABLE_DOWNLOAD=1 pytest -m live_pull tests_2.0/test_resumable_pull.py -v` | `live_pull` (required) + Env: `MLXK2_TEST_RESUMABLE_DOWNLOAD=1` (opt-in for network test) | **✅ Working:** Real network download with controlled interruption (45s timer). Tests unhealthy detection → `requires_confirmation` status → resume with `force_resume=True` → final health check. Validates resumable pull feature (interrupted downloads can be resumed). Uses isolated cache (no impact on user cache). | Yes (HuggingFace download) |
|
||||
| Show E2E portfolios | `HF_HOME=/path/to/cache python tests_2.0/show_portfolios.py` OR `pytest -m show_model_portfolio -s` | Env: `HF_HOME` | Displays TEXT and VISION portfolios separately. Shows model keys (text_XX, vision_XX), RAM requirements, and test/skip status. Diagnostic tool for understanding portfolio separation. Use script for detailed output, or pytest marker for quick check. | No (uses local cache) |
|
||||
| Manual debug mode | `mlxk run <model> "test prompt" --verbose` | Manual CLI usage with `--verbose` flag | Shows token generation details including multiple EOS token warnings. Use this for manual debugging of model quality issues. Output includes `[DEBUG] Token generation analysis` and `⚠️ WARNING: Multiple EOS tokens detected` for broken models. | No (uses local cache) |
|
||||
| Issue #27 real-model | `pytest -m issue27 tests_2.0/test_issue_27.py -v` | Marker: `issue27`; Env (required): `MLXK2_USER_HF_HOME` or `HF_HOME` (user cache, read-only). Env (optional): `MLXK2_ISSUE27_MODEL`, `MLXK2_ISSUE27_INDEX_MODEL`, `MLXK2_SUBSET_COUNT=0`. | Copies real models from user cache into isolated test cache; validates strict health policy on index-based models (no network) | No (uses local cache) |
|
||||
| Server tests | `pytest -k server -v` | — | Basic server API tests (minimal, uses MLX stubs) | No |
|
||||
|
||||
**Useful commands:**
|
||||
**Quick reference (not in table above):**
|
||||
```bash
|
||||
# Only Spec
|
||||
pytest -m spec -v
|
||||
# All live tests (umbrella marker)
|
||||
pytest -m wet -v
|
||||
|
||||
# Push tests (offline)
|
||||
pytest -k "push and not live" -v
|
||||
# Show which models will be tested
|
||||
pytest -m live_e2e --collect-only -q
|
||||
|
||||
# Clone tests (offline)
|
||||
pytest -k "clone and not live" -v
|
||||
|
||||
# Exclude Spec
|
||||
pytest -m "not spec" -v
|
||||
|
||||
# Live Push only
|
||||
MLXK2_LIVE_PUSH=1 HF_TOKEN=... MLXK2_LIVE_REPO=... MLXK2_LIVE_WORKSPACE=... pytest -m live_push -v
|
||||
|
||||
# Live Clone only
|
||||
MLXK2_LIVE_CLONE=1 HF_TOKEN=... MLXK2_LIVE_CLONE_MODEL=... MLXK2_LIVE_CLONE_WORKSPACE=... pytest -m live_clone -v
|
||||
|
||||
# Live List only
|
||||
HF_HOME=/path/to/user/cache pytest -m live_list -v
|
||||
|
||||
# Live Stop Tokens only (ADR-009)
|
||||
pytest -m live_stop_tokens -v # Optional: HF_HOME=/path/to/cache for portfolio discovery
|
||||
|
||||
# Live Run only
|
||||
HF_HOME=/path/to/user/cache pytest -m live_run -v
|
||||
|
||||
# Live E2E only (ADR-011)
|
||||
HF_HOME=/path/to/user/cache pytest -m live_e2e -v # See model list: pytest tests_2.0/live/test_server_e2e.py::TestChatCompletionsBatch --collect-only -q
|
||||
|
||||
# Resumable Pull only (separate run - uses isolated cache)
|
||||
MLXK2_TEST_RESUMABLE_DOWNLOAD=1 pytest -m live_pull tests_2.0/test_resumable_pull.py -v
|
||||
|
||||
# Empirical Mapping only (model benchmarking - excluded from wet due to RAM)
|
||||
# Empirical Mapping (heavy, excluded from wet)
|
||||
pytest -m live_stop_tokens tests_2.0/test_stop_tokens_live.py::TestStopTokensEmpiricalMapping -v
|
||||
|
||||
# Issue #27 only
|
||||
MLXK2_USER_HF_HOME=/path/to/user/cache pytest -m issue27 tests_2.0/test_issue_27.py -v
|
||||
|
||||
# All live tests (umbrella)
|
||||
MLXK2_ENABLE_ALPHA_FEATURES=1 pytest -m wet -v
|
||||
|
||||
# Vision→Geo pipe only (ADR-012 Phase 1c + Pipe integration)
|
||||
MLXK2_ENABLE_PIPES=1 pytest -m live_vision_pipe -v
|
||||
|
||||
# With custom batch size (optional)
|
||||
MLXK2_ENABLE_PIPES=1 MLXK2_VISION_BATCH_SIZE=3 pytest -m live_vision_pipe -v
|
||||
```
|
||||
|
||||
---
|
||||
@@ -1175,15 +1135,38 @@ pytest -m live_stop_tokens -v # → Runs if models present, else fails
|
||||
- ✅ **Portfolio Discovery:** Uses `mlxk list --json` to discover all qualifying models (refactored: production command, ~70 LOC eliminated)
|
||||
- ✅ **RAM-Aware:** Progressive budgets prevent OOM (40%-70% of system RAM)
|
||||
- ✅ **Empirical Report:** Generates `stop_token_config_report.json` with findings
|
||||
- ✅ **Fallback:** Uses 3 predefined models (MXFP4, Qwen, Llama) if HF_HOME not set - models must exist in HF cache
|
||||
- ✅ **Fallback:** Uses predefined models when no qualifying models discovered (regardless of HF_HOME setting)
|
||||
|
||||
**Required Models for Live Tests:**
|
||||
|
||||
Live tests use **either** Portfolio Discovery **or** these fallback models:
|
||||
|
||||
| Scenario | Models tested |
|
||||
|----------|---------------|
|
||||
| Portfolio Discovery finds models | Only discovered models (dynamic) |
|
||||
| Portfolio Discovery finds nothing | Only fallback models (this list) |
|
||||
|
||||
**Fallback models** (only needed when Discovery finds nothing — any qualifying MLX model in cache replaces these):
|
||||
|
||||
| Type | Model | RAM | Fallback for |
|
||||
|------|-------|-----|--------------|
|
||||
| Text | `mlx-community/gpt-oss-20b-MXFP4-Q8` | ~12 GB | Text tests |
|
||||
| Text | `mlx-community/Qwen2.5-0.5B-Instruct-4bit` | ~1 GB | Text tests |
|
||||
| Text | `mlx-community/Llama-3.2-3B-Instruct-4bit` | ~4 GB | Text tests |
|
||||
| Vision | `mlx-community/pixtral-12b-4bit` | ~7 GB | Vision tests (or any vision model) |
|
||||
| Audio | `mlx-community/whisper-large-v3-turbo-4bit` | ~1.5 GB | Audio tests (or any audio model) |
|
||||
|
||||
**Required models for fallback (without HF_HOME):**
|
||||
```bash
|
||||
mlxk pull mlx-community/gpt-oss-20b-MXFP4-Q8 # ~12GB RAM
|
||||
mlxk pull mlx-community/Qwen2.5-0.5B-Instruct-4bit # ~1GB RAM
|
||||
mlxk pull mlx-community/Llama-3.2-3B-Instruct-4bit # ~4GB RAM
|
||||
# Pull all minimum required models (~25 GB total)
|
||||
mlxk pull mlx-community/gpt-oss-20b-MXFP4-Q8
|
||||
mlxk pull mlx-community/Qwen2.5-0.5B-Instruct-4bit
|
||||
mlxk pull mlx-community/Llama-3.2-3B-Instruct-4bit
|
||||
mlxk pull mlx-community/pixtral-12b-4bit
|
||||
mlxk pull mlx-community/whisper-large-v3-turbo-4bit
|
||||
```
|
||||
|
||||
**Note:** These models are defined in `tests_2.0/live/test_utils.py` (`TEST_MODELS`, `VISION_TEST_MODELS`, `AUDIO_TEST_MODELS`) and `tests_2.0/test_stop_tokens_live.py` (`TEST_MODELS`).
|
||||
|
||||
### E2E Tests with Portfolio Separation (ADR-011 + Portfolio Separation)
|
||||
|
||||
**Status:** ✅ Working (Portfolio Separation complete)
|
||||
|
||||
+2
-5
@@ -271,12 +271,9 @@ HF_HOME=/path/to/cache pytest -m live_e2e -v
|
||||
|
||||
**Stop token validation** (ADR-009):
|
||||
```bash
|
||||
# Option A: Portfolio Discovery (recommended)
|
||||
export HF_HOME=/path/to/cache
|
||||
pytest -m live_stop_tokens -v
|
||||
|
||||
# Option B: Hardcoded models (requires 3 specific models in cache)
|
||||
# See TESTING-DETAILS.md for model list
|
||||
# Uses Portfolio Discovery if models found, else fallback models
|
||||
# See TESTING-DETAILS.md "Required Models for Live Tests"
|
||||
```
|
||||
|
||||
**Push/Clone tests** (alpha features):
|
||||
|
||||
@@ -597,20 +597,13 @@ Quality Flags (Thresholds: RAM <5 GB free, zombies >0):
|
||||
|
||||
"""
|
||||
|
||||
# Sort models by total time (descending), or by change if comparing
|
||||
sorted_models = sorted(stats['models'].values(), key=lambda m: m['total_time'], reverse=True)
|
||||
# Sort models alphabetically (stable ordering across reports)
|
||||
sorted_models = sorted(stats['models'].values(), key=lambda m: m['id'].lower())
|
||||
|
||||
# Build comparison lookup if available
|
||||
compare_models = {}
|
||||
if compare_stats:
|
||||
compare_models = {m['id']: m for m in compare_stats['models'].values()}
|
||||
# Re-sort by change percentage (biggest regression first)
|
||||
def get_change_pct(model):
|
||||
old = compare_models.get(model['id'])
|
||||
if old and old['total_time'] > 0:
|
||||
return (model['total_time'] - old['total_time']) / old['total_time'] * 100
|
||||
return 0
|
||||
sorted_models = sorted(stats['models'].values(), key=get_change_pct, reverse=True)
|
||||
|
||||
if compare_stats:
|
||||
md += f"""```
|
||||
@@ -889,8 +882,8 @@ Quality Flags (Thresholds: RAM <5 GB free, zombies >0):
|
||||
md += "## Per-Test Statistics\n\n"
|
||||
md += "Shows performance range across models for each test.\n\n"
|
||||
|
||||
# Sort tests by model count (descending) - most representative tests first
|
||||
sorted_tests = sorted(stats['tests'].values(), key=lambda t: t['model_count'], reverse=True)
|
||||
# Sort tests alphabetically (stable ordering across reports)
|
||||
sorted_tests = sorted(stats['tests'].values(), key=lambda t: t['name'].lower())
|
||||
|
||||
# Build comparison lookup for tests (key: (name, modality))
|
||||
compare_tests = {}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# ADR-018: Convert Operation
|
||||
|
||||
**Status:** Implemented (Phases 0a-0c + 1 complete in 2.0.4-beta.6)
|
||||
**Status:** Implemented (Phases 0a-0c + 1 complete in 2.0.4-beta.6), Phase 2 planned for 2.0.5
|
||||
**Created:** 2025-12-18
|
||||
**Updated:** 2026-02-01 (Added: Known Model Defects & Repair Strategies survey)
|
||||
**Updated:** 2026-02-08 (Added: Phase 2 details for 2.0.5)
|
||||
**Context:** Users need to (a) quantize MLX workspaces locally without polluting the HF cache and (b) repair MLX/HF compliance issues (notably safetensors index/shard mismatches) in a deterministic way.
|
||||
|
||||
**Phase Status:**
|
||||
@@ -10,11 +10,12 @@
|
||||
- **Phase 0b:** Resumable clone — ✅ Implemented (2.0.4-beta.6)
|
||||
- **Phase 0c:** Workspace run/show/server support — ✅ Implemented (2.0.4-beta.6)
|
||||
- **Phase 1:** `--repair-index` — ✅ Implemented (2.0.4-beta.5)
|
||||
- **Phase 2:** `--quantize` + content_hash — 🚧 Planned (2.0.5)
|
||||
|
||||
**Feature Gates (2.0.4-beta.7+):**
|
||||
**Feature Gates:**
|
||||
- `clone`, `push`: **Production** (no gate required)
|
||||
- `convert`: **Experimental** (requires `MLXK2_ENABLE_ALPHA_FEATURES=1`)
|
||||
- Rationale: `--quantize` not yet implemented, only `--repair-index` available
|
||||
- `convert`: **Experimental** until 2.0.5 (requires `MLXK2_ENABLE_ALPHA_FEATURES=1`)
|
||||
- Gate removed in 2.0.5 when `--quantize` ships
|
||||
|
||||
**Note:** Complete workspace infrastructure shipped in 2.0.4-beta.6. Full `clone → convert → run/show/server` workflow with resume support, no HF push requirement.
|
||||
|
||||
@@ -681,9 +682,138 @@ mlxk convert ./ws ./ws-fixed --repair-all # Apply all safe repairs
|
||||
- **Files:** `mlxk2/operations/convert.py` (NEW), `cli.py` (convert subparser), `output/human.py` (render_convert)
|
||||
- **Tests:** 11 new tests, all passing
|
||||
|
||||
- [ ] **Phase 2 (future):** `--quantize <bits>` for text models (mlx-lm)
|
||||
- [ ] **Phase 2 (2.0.5):** `--quantize <bits>` for text models + content_hash
|
||||
- [ ] **Phase 3 (future):** Mixed recipes / advanced quant options
|
||||
- [ ] **Phase 4 (future):** Vision model support (mlx-vlm) once stable and dependency policy allows
|
||||
- [ ] **Phase 4 (2.0.5 or 2.0.6):** Vision quantize (mlx-vlm) — timing depends on upstream stability and 2.0.6 focus
|
||||
|
||||
---
|
||||
|
||||
### Phase 2: Quantize & Content Hash (2.0.5)
|
||||
|
||||
**Goal:** Production-ready convert with quantization and workspace integrity tracking.
|
||||
|
||||
#### 2a: Remove ALPHA Gate
|
||||
|
||||
**Current:** `convert` requires `MLXK2_ENABLE_ALPHA_FEATURES=1`
|
||||
|
||||
**Change:** Remove gate — `--repair-index` has proven stable since 2.0.4-beta.5
|
||||
|
||||
**Files:**
|
||||
- `mlxk2/operations/convert.py` — Remove alpha check
|
||||
- `mlxk2/cli.py` — Remove gate from convert subparser help
|
||||
|
||||
#### 2b: Quantize Implementation
|
||||
|
||||
**CLI:**
|
||||
```bash
|
||||
mlxk convert ./ws-bf16 ./ws-4bit --quantize 4
|
||||
mlxk convert ./ws-bf16 ./ws-4bit --quantize 4 --q-group-size 128
|
||||
```
|
||||
|
||||
**Implementation:**
|
||||
```python
|
||||
# mlxk2/operations/convert.py
|
||||
|
||||
def _quantize_text_model(source: Path, target: Path, bits: int, group_size: int = 64):
|
||||
"""Quantize text model using mlx-lm."""
|
||||
from mlx_lm import convert as mlx_lm_convert
|
||||
|
||||
# mlx-lm expects these parameters
|
||||
mlx_lm_convert(
|
||||
hf_path=str(source),
|
||||
mlx_path=str(target),
|
||||
quantize=True,
|
||||
q_bits=bits,
|
||||
q_group_size=group_size,
|
||||
)
|
||||
|
||||
# Always rebuild index for consistency (safety measure)
|
||||
rebuild_safetensors_index(target)
|
||||
```
|
||||
|
||||
**Supported bit depths:** 2, 3, 4, 6, 8 (same as mlx-lm)
|
||||
|
||||
**Files:**
|
||||
- `mlxk2/operations/convert.py` — `_quantize_text_model()`, CLI integration
|
||||
- Tests: 5-8 new tests
|
||||
|
||||
#### 2c: Content Hash
|
||||
|
||||
**Purpose:** Detect modifications after clone/convert for integrity tracking.
|
||||
|
||||
**Algorithm:** (from ADR-022)
|
||||
```python
|
||||
HASH_EXCLUDE = [
|
||||
".mlxk_workspace.json", # contains the hash itself
|
||||
".hf_cache/", # runtime artifacts
|
||||
".DS_Store",
|
||||
".git/",
|
||||
"__pycache__/",
|
||||
"*.log",
|
||||
"*.tmp",
|
||||
]
|
||||
|
||||
def compute_workspace_hash(workspace_path: Path) -> str:
|
||||
hasher = hashlib.sha256()
|
||||
for file in sorted(workspace_path.rglob("*")):
|
||||
if should_exclude(file):
|
||||
continue
|
||||
if file.is_file():
|
||||
hasher.update(file.relative_to(workspace_path).encode())
|
||||
hasher.update(file.read_bytes())
|
||||
return f"sha256:{hasher.hexdigest()}"
|
||||
```
|
||||
|
||||
**When computed:**
|
||||
- After `clone` (before declaring success)
|
||||
- After `convert` (before declaring success)
|
||||
|
||||
**Stored in:** `.mlxk_workspace.json`
|
||||
|
||||
#### 2d: Extended Sentinel Schema
|
||||
|
||||
```json
|
||||
{
|
||||
"mlxk_version": "2.0.5",
|
||||
"created_at": "2026-02-08T10:30:00Z",
|
||||
"source_repo": "mlx-community/whisper-large-v3-mlx",
|
||||
"source_revision": "abc123def456",
|
||||
"managed": true,
|
||||
"operation": "convert",
|
||||
"content_hash": "sha256:a1b2c3d4e5f6...",
|
||||
"hash_computed_at": "2026-02-08T10:30:05Z",
|
||||
"hash_excludes": [".mlxk_workspace.json", ".hf_cache/"],
|
||||
"convert_options": {
|
||||
"mode": "quantize",
|
||||
"bits": 4,
|
||||
"group_size": 64
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**New fields:**
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `content_hash` | `string` | SHA256 of workspace content |
|
||||
| `hash_computed_at` | `string` | ISO timestamp |
|
||||
| `hash_excludes` | `string[]` | Patterns excluded from hash |
|
||||
| `convert_options` | `object` | Quantization parameters (if convert) |
|
||||
|
||||
**Files:**
|
||||
- `mlxk2/operations/workspace.py` — `compute_workspace_hash()`, extended sentinel
|
||||
- `mlxk2/operations/clone.py` — Hash after clone
|
||||
- `mlxk2/operations/convert.py` — Hash after convert
|
||||
- Tests: 5-8 new tests
|
||||
|
||||
#### Phase 2 Effort
|
||||
|
||||
| Component | LOC | Tests |
|
||||
|-----------|-----|-------|
|
||||
| ALPHA gate removal | ~10 | 1 |
|
||||
| Quantize | ~80 | 5-8 |
|
||||
| Content hash | ~50 | 5-8 |
|
||||
| Sentinel extension | ~30 | 3-5 |
|
||||
| **Total** | ~170 | 14-22 |
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -0,0 +1,462 @@
|
||||
# ADR-022: Workspace-First Paradigm
|
||||
|
||||
**Status:** Draft (Discussion)
|
||||
**Created:** 2026-02-06
|
||||
**Related:** ADR-018 (Convert Operation), SECURITY.md
|
||||
**Target:** 2.0.5
|
||||
|
||||
---
|
||||
|
||||
## Context
|
||||
|
||||
### The HuggingFace Cache Problem
|
||||
|
||||
The HF cache (`$HF_HOME/hub/`) is a **shared mutable namespace** used by multiple uncoordinated actors:
|
||||
|
||||
```
|
||||
$HF_HOME/hub/
|
||||
├── models--mlx-community--whisper-large-v3-mlx/ ← mlx-knife pull
|
||||
├── models--Qwen--Qwen2.5-7B/ ← mlx-audio runtime (!!)
|
||||
└── ...
|
||||
```
|
||||
|
||||
**Actors writing to the cache:**
|
||||
- `transformers` (AutoTokenizer, AutoModel)
|
||||
- `mlx-lm` (model loading)
|
||||
- `mlx-vlm` (vision model loading)
|
||||
- `mlx-audio` (audio model loading, **including undeclared dependencies**)
|
||||
- `huggingface_hub` (downloads)
|
||||
|
||||
**This creates classic shared-state problems:**
|
||||
|
||||
| Problem | Description | Example |
|
||||
|---------|-------------|---------|
|
||||
| Undeclared dependencies | Runtime downloads not visible at pull time | VibeVoice needs Qwen2.5-7B tokenizer |
|
||||
| Write pollution | Upstream libs modify cache during inference | mlx-audio downloads during `run` |
|
||||
| No isolation | All libs see and write same namespace | Cross-model interference possible |
|
||||
| Implicit state | "Works after first run" syndrome | Cache state determines behavior |
|
||||
|
||||
### The Broken Promise
|
||||
|
||||
SECURITY.md currently states:
|
||||
> "Network activity is limited to explicit interactions with Hugging Face: downloading models (pull)"
|
||||
|
||||
This promise is **broken** when upstream libraries download during `run`:
|
||||
|
||||
```bash
|
||||
mlxk pull VibeVoice-ASR-4bit # ✓ Model downloaded
|
||||
# Network disabled
|
||||
mlxk run VibeVoice --audio x.wav # ✗ Fails - needs Qwen2.5-7B
|
||||
```
|
||||
|
||||
### What mlx-knife Controls
|
||||
|
||||
| Layer | Control | Can Guarantee |
|
||||
|-------|---------|---------------|
|
||||
| mlx-knife CLI | Full | Own behavior |
|
||||
| mlx-lm / mlx-vlm / mlx-audio | None | Nothing |
|
||||
| HuggingFace Hub | None | Nothing |
|
||||
| Model repositories | None | Nothing |
|
||||
|
||||
**Reality:** mlx-knife is an integration layer. It can recommend models but cannot guarantee their behavior remains constant.
|
||||
|
||||
---
|
||||
|
||||
## Decision
|
||||
|
||||
### Workspace as Primary Paradigm
|
||||
|
||||
Shift from HF-cache-centric to workspace-centric model management:
|
||||
|
||||
**Current (2.0.4):**
|
||||
```bash
|
||||
mlxk pull Model → $HF_HOME (shared, uncontrolled)
|
||||
mlxk run Model → reads from shared cache
|
||||
→ upstream may write to cache (hidden)
|
||||
```
|
||||
|
||||
**New (2.0.5):**
|
||||
```bash
|
||||
mlxk clone Model ./models/Model → local workspace (controlled)
|
||||
mlxk run ./models/Model → reads from workspace
|
||||
→ side effects visible in .hf_cache/
|
||||
```
|
||||
|
||||
### Workspace-Local Cache
|
||||
|
||||
Each workspace gets an isolated HF cache for runtime artifacts:
|
||||
|
||||
```
|
||||
./models/
|
||||
├── whisper-large-v3-mlx/ # cloned model
|
||||
├── VibeVoice-ASR-4bit/ # cloned model
|
||||
└── .hf_cache/ # workspace-local cache
|
||||
└── Qwen--Qwen2.5-7B/ # runtime artifact (VISIBLE!)
|
||||
```
|
||||
|
||||
**Implementation:** When running from workspace path (`./`), set:
|
||||
```bash
|
||||
HF_HOME=<workspace>/.hf_cache
|
||||
```
|
||||
|
||||
### Isolation Guarantees
|
||||
|
||||
| Guarantee | HF Cache | Workspace |
|
||||
|-----------|----------|-----------|
|
||||
| Model isolation | No | Yes (per-workspace) |
|
||||
| Side effects visible | No (hidden in ~/.cache) | Yes (.hf_cache/) |
|
||||
| Reproducible | No | Yes (tar/zip/archive) |
|
||||
| Auditable | Difficult | Trivial (`ls -la`) |
|
||||
| Offline after first run | Unknown | Yes (everything local) |
|
||||
|
||||
### What mlx-knife CAN and CANNOT Guarantee
|
||||
|
||||
**CAN guarantee (workspace mode):**
|
||||
- Models are isolated from each other
|
||||
- Runtime artifacts are visible in `.hf_cache/`
|
||||
- After successful first run, all dependencies are local
|
||||
- Workspace can be archived/transferred
|
||||
|
||||
**CANNOT guarantee:**
|
||||
- Upstream libraries won't attempt network access
|
||||
- First run won't download additional artifacts
|
||||
- Model behavior remains constant over time
|
||||
|
||||
### Revised Security Promise
|
||||
|
||||
Update SECURITY.md to reflect reality:
|
||||
|
||||
> **Network Activity**
|
||||
>
|
||||
> mlx-knife itself performs network activity only during explicit commands (`pull`, `clone`, `push`).
|
||||
>
|
||||
> **Important:** mlx-knife integrates upstream libraries (mlx-lm, mlx-vlm, mlx-audio) whose behavior is outside our control. These libraries may perform their own network requests during model loading or inference.
|
||||
>
|
||||
> **For offline/air-gapped environments:**
|
||||
> 1. Use `mlxk clone` to create isolated workspaces
|
||||
> 2. Run the model once (online) to capture all runtime dependencies
|
||||
> 3. Verify `.hf_cache/` contains all artifacts
|
||||
> 4. Subsequent runs will be fully offline
|
||||
>
|
||||
> We recommend tested models from `mlx-community/*` but cannot guarantee third-party code behavior.
|
||||
|
||||
---
|
||||
|
||||
## UX Changes
|
||||
|
||||
### Command Prominence
|
||||
|
||||
| Command | 2.0.4 Role | 2.0.5 Role |
|
||||
|---------|------------|------------|
|
||||
| `pull` | Primary download | Caching/convenience |
|
||||
| `clone` | Secondary | **Primary** for managed workflows |
|
||||
| `run Model` | Default | Legacy/quick testing |
|
||||
| `run ./path` | Supported | **Recommended** |
|
||||
|
||||
### Documentation Shift
|
||||
|
||||
**Before:** "Download models with `mlxk pull`"
|
||||
|
||||
**After:** "For reproducible workflows, use `mlxk clone` to create managed workspaces"
|
||||
|
||||
### New Flags/Behavior
|
||||
|
||||
```bash
|
||||
# Automatic workspace-local cache when path starts with ./
|
||||
mlxk run ./models/whisper "transcribe"
|
||||
# Internally: HF_HOME=./models/.hf_cache
|
||||
|
||||
# Explicit flag (optional, for cache models)
|
||||
mlxk run Model --workspace-cache ./cache
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Relationship to ADR-018
|
||||
|
||||
ADR-018 defines workspace operations (clone, convert, push) and the workspace sentinel concept.
|
||||
|
||||
**ADR-022 extends this by:**
|
||||
1. Making workspace the **primary** paradigm, not secondary
|
||||
2. Adding workspace-local HF cache isolation
|
||||
3. Defining security/offline guarantees
|
||||
4. Driving UX changes (clone > pull)
|
||||
|
||||
**ADR-018 provides:** Infrastructure (sentinel, convert, workspace paths)
|
||||
**ADR-022 provides:** Philosophy and user-facing paradigm shift
|
||||
|
||||
---
|
||||
|
||||
## Implementation Phases
|
||||
|
||||
### Phase 1: Workspace-Local Cache (2.0.5-beta.1)
|
||||
|
||||
**Goal:** Isolate runtime artifacts per workspace
|
||||
|
||||
**Changes:**
|
||||
- `run ./path` sets `HF_HOME=<workspace>/.hf_cache` before loading
|
||||
- `.hf_cache/` added to workspace structure
|
||||
- `.hf_cache/` documented in workspace sentinel
|
||||
|
||||
**Files:**
|
||||
- `mlxk2/core/runner/__init__.py` — HF_HOME redirect
|
||||
- `mlxk2/core/vision_runner.py` — HF_HOME redirect
|
||||
- `mlxk2/core/audio_runner.py` — HF_HOME redirect
|
||||
- `mlxk2/operations/workspace.py` — .hf_cache handling
|
||||
|
||||
**Tests:** ~10-15 new tests
|
||||
|
||||
### Phase 2: Testsuite Migration (2.0.5-beta.2)
|
||||
|
||||
**Goal:** Tests support both paradigms
|
||||
|
||||
**Changes:**
|
||||
- Fixtures for `cached_model` and `workspace_model`
|
||||
- E2E tests for workspace isolation
|
||||
- Tests for .hf_cache artifact capture
|
||||
|
||||
**Effort:** High (many fixtures affected)
|
||||
|
||||
### Phase 3: Documentation & UX (2.0.5-beta.3)
|
||||
|
||||
**Goal:** Shift user guidance to workspace-first
|
||||
|
||||
**Changes:**
|
||||
- README: clone as primary workflow
|
||||
- SECURITY.md: revised guarantees
|
||||
- Tutorials: workspace-based examples
|
||||
- `mlxk pull` help text: "For caching; use clone for managed workflows"
|
||||
|
||||
### Phase 4: SECURITY.md Update (2.0.5 stable)
|
||||
|
||||
**Goal:** Honest, defensible security claims
|
||||
|
||||
**Changes:**
|
||||
- Clear separation: mlx-knife behavior vs upstream behavior
|
||||
- Workspace-based offline workflow documented
|
||||
- Disclaimer for third-party library behavior
|
||||
|
||||
---
|
||||
|
||||
## Risks and Mitigations
|
||||
|
||||
| Risk | Mitigation |
|
||||
|------|------------|
|
||||
| Breaking change for pull-centric users | pull still works, just de-emphasized |
|
||||
| Testsuite complexity | Phased migration, both modes supported |
|
||||
| Disk space (workspace + cache duplication) | Document, user choice |
|
||||
| User confusion (two paradigms) | Clear docs, gradual deprecation of pull-first |
|
||||
|
||||
---
|
||||
|
||||
## Open Questions
|
||||
|
||||
1. **Should `pull` warn about workspace-first?** → No, just document
|
||||
2. **Auto-create .hf_cache/?** → Yes, automatic
|
||||
3. **Workspace health include .hf_cache scan?** → Yes, with `--verbose`
|
||||
4. **Archive format?** → Deferred to 2.0.6+
|
||||
|
||||
---
|
||||
|
||||
## MLXK_WORKSPACE_HOME
|
||||
|
||||
Single workspace path (like `HF_HOME`):
|
||||
|
||||
```bash
|
||||
export MLXK_WORKSPACE_HOME=~/mlx-models
|
||||
|
||||
mlxk clone whisper-large-v3
|
||||
# → ~/mlx-models/whisper-large-v3/
|
||||
|
||||
mlxk list
|
||||
# Shows: HF cache + MLXK_WORKSPACE_HOME
|
||||
|
||||
mlxk run whisper-large-v3
|
||||
# Search order: 1. MLXK_WORKSPACE_HOME 2. HF cache
|
||||
```
|
||||
|
||||
**Implementation:**
|
||||
- `mlxk2/core/cache.py` — new `get_workspace_home()` function
|
||||
- `mlxk2/operations/clone.py` — default target if no path given
|
||||
- `mlxk2/operations/list.py` — include MLXK_WORKSPACE_HOME in scan
|
||||
- `mlxk2/core/model_resolution.py` — search MLXK_WORKSPACE_HOME first
|
||||
|
||||
**Future:** `MLXK_MODEL_PATH` for multi-path search (2.0.6+)
|
||||
|
||||
---
|
||||
|
||||
## UX Details
|
||||
|
||||
### list: Source Column
|
||||
|
||||
```
|
||||
Name | Source | Size | Type
|
||||
whisper-large-v3 | ws | 400MB | audio
|
||||
phi-3-mini | cache | 2.1GB | chat
|
||||
```
|
||||
|
||||
### list --full-paths
|
||||
|
||||
```
|
||||
Name | Source | Size
|
||||
/Users/.../models/whisper-large-v3| ws | 400MB
|
||||
```
|
||||
|
||||
### list --origin
|
||||
|
||||
```
|
||||
Name | Source | Origin | Size
|
||||
whisper-large-v3 | ws | mlx-community/whisper-large-v3 | 400MB
|
||||
```
|
||||
|
||||
### show: Workspace Metadata
|
||||
|
||||
```
|
||||
Model: whisper-large-v3
|
||||
Framework: MLX
|
||||
...
|
||||
Workspace:
|
||||
Source: mlx-community/whisper-large-v3-mlx
|
||||
Operation: clone
|
||||
Created: 2026-02-08
|
||||
Content Hash: sha256:a1b2c3...
|
||||
Modified: no
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## JSON API Schema 0.2.0
|
||||
|
||||
New fields in `modelObject`:
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "whisper-large-v3",
|
||||
"source": "workspace",
|
||||
"origin": "mlx-community/whisper-large-v3-mlx",
|
||||
"content_hash": "sha256:a1b2c3...",
|
||||
"hash_modified": false,
|
||||
"cached": false
|
||||
}
|
||||
```
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `source` | `"cache" \| "workspace"` | Where model lives |
|
||||
| `origin` | `string \| null` | HF origin (from sentinel) |
|
||||
| `content_hash` | `string \| null` | SHA256 of workspace content |
|
||||
| `hash_modified` | `boolean` | True if hash changed since clone/convert |
|
||||
|
||||
**Breaking Changes:** None (additive)
|
||||
|
||||
---
|
||||
|
||||
## Content Hash
|
||||
|
||||
### Exclude List
|
||||
|
||||
```python
|
||||
HASH_EXCLUDE = [
|
||||
".mlxk_workspace.json", # contains the hash itself
|
||||
".hf_cache/", # runtime artifacts
|
||||
".DS_Store",
|
||||
".git/",
|
||||
"__pycache__/",
|
||||
"*.log",
|
||||
"*.tmp",
|
||||
]
|
||||
```
|
||||
|
||||
### Algorithm
|
||||
|
||||
```python
|
||||
def compute_workspace_hash(workspace_path: Path) -> str:
|
||||
hasher = hashlib.sha256()
|
||||
for file in sorted(workspace_path.rglob("*")):
|
||||
if should_exclude(file):
|
||||
continue
|
||||
if file.is_file():
|
||||
# Hash: relative path + content
|
||||
hasher.update(file.relative_to(workspace_path).encode())
|
||||
hasher.update(file.read_bytes())
|
||||
return f"sha256:{hasher.hexdigest()}"
|
||||
```
|
||||
|
||||
### When Computed
|
||||
|
||||
- After `clone` (before declaring success)
|
||||
- After `convert` (before declaring success)
|
||||
- Stored in `.mlxk_workspace.json`
|
||||
|
||||
---
|
||||
|
||||
## Sentinel Schema (Extended)
|
||||
|
||||
```json
|
||||
{
|
||||
"mlxk_version": "2.0.5",
|
||||
"created_at": "2026-02-08T10:30:00Z",
|
||||
"source_repo": "mlx-community/whisper-large-v3-mlx",
|
||||
"source_revision": "abc123def456",
|
||||
"managed": true,
|
||||
"operation": "clone",
|
||||
"content_hash": "sha256:a1b2c3d4e5f6...",
|
||||
"hash_computed_at": "2026-02-08T10:30:05Z",
|
||||
"hash_excludes": [".mlxk_workspace.json", ".hf_cache/"]
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Code-Findings (Session 2026-02-08)
|
||||
|
||||
### Bug 1: PyTorch Warning bei Workspace-Pfaden
|
||||
|
||||
**Symptom:** `mlxk list ./path` zeigt "PyTorch was not found" Warnung
|
||||
|
||||
**Root Cause:** `vision_runtime_compatibility()` (common.py:456) importiert `transformers` als erstes bei healthy Vision-Modellen. Bei HF-Cache wird `mlx_lm` vorher importiert (unterdrückt Warnung).
|
||||
|
||||
**Betroffene Befehle:** `list`, `show` (nicht `run`, `health`)
|
||||
|
||||
**Fix:**
|
||||
```python
|
||||
# ALT (common.py:456)
|
||||
import transformers
|
||||
tf_version = getattr(transformers, "__version__", "0.0.0")
|
||||
|
||||
# NEU
|
||||
from importlib.metadata import version
|
||||
tf_version = version("transformers")
|
||||
```
|
||||
|
||||
### Bug 2: Clone ohne HF_HOME
|
||||
|
||||
**Symptom:** `clone` schlägt fehl wenn `HF_HOME=""` (unset)
|
||||
|
||||
**Root Cause:** `_validate_same_volume()` (clone.py:100) prüft `volume(workspace) == volume(HF_HOME)`. Aber temp_cache wird sowieso auf Workspace-Volume erstellt (Zeile 439).
|
||||
|
||||
**Fix:** Check entfernen — ist überflüssig.
|
||||
|
||||
### Bug 3: Empty HF_HOME String
|
||||
|
||||
**Symptom:** `get_current_cache_root()` gibt `Path("")` → `PosixPath(".")` zurück
|
||||
|
||||
**Root Cause:** `os.environ.get("HF_HOME", DEFAULT)` gibt `""` zurück wenn Key existiert aber leer ist.
|
||||
|
||||
**Fix:**
|
||||
```python
|
||||
def get_current_cache_root() -> Path:
|
||||
hf_home = os.environ.get("HF_HOME")
|
||||
if not hf_home: # None or ""
|
||||
return DEFAULT_CACHE_ROOT
|
||||
return Path(hf_home)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## References
|
||||
|
||||
- ADR-018: Convert Operation (workspace infrastructure)
|
||||
- SECURITY.md (current promises)
|
||||
- VibeVoice tokenizer issue (docs/ISSUES/vibevoice-missing-tokenizer.md)
|
||||
- HuggingFace Hub caching behavior
|
||||
+30
-2
@@ -2,14 +2,42 @@
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
# Cache path constants - copied from mlx_knife/cache_utils.py
|
||||
DEFAULT_CACHE_ROOT = Path.home() / ".cache/huggingface"
|
||||
|
||||
|
||||
def get_workspace_home() -> Optional[Path]:
|
||||
"""Get workspace home directory from MLXK_WORKSPACE_HOME env var.
|
||||
|
||||
Returns:
|
||||
Path to workspace home if set and valid, None otherwise.
|
||||
|
||||
Example:
|
||||
export MLXK_WORKSPACE_HOME=~/mlx-models
|
||||
→ Path("/Users/me/mlx-models")
|
||||
"""
|
||||
workspace_home = os.environ.get("MLXK_WORKSPACE_HOME")
|
||||
if not workspace_home:
|
||||
return None
|
||||
path = Path(workspace_home).expanduser()
|
||||
# Only return if directory exists (don't auto-create)
|
||||
if path.is_dir():
|
||||
return path
|
||||
return None
|
||||
|
||||
|
||||
def get_current_cache_root() -> Path:
|
||||
"""Get current cache root (respects runtime HF_HOME changes)."""
|
||||
return Path(os.environ.get("HF_HOME", DEFAULT_CACHE_ROOT))
|
||||
"""Get current cache root (respects runtime HF_HOME changes).
|
||||
|
||||
Note: Returns DEFAULT_CACHE_ROOT if HF_HOME is unset OR empty string.
|
||||
This handles `export HF_HOME=""` edge case gracefully.
|
||||
"""
|
||||
hf_home = os.environ.get("HF_HOME")
|
||||
if not hf_home: # None or ""
|
||||
return DEFAULT_CACHE_ROOT
|
||||
return Path(hf_home)
|
||||
|
||||
|
||||
def get_current_model_cache() -> Path:
|
||||
|
||||
@@ -187,13 +187,25 @@ class MLXRunner:
|
||||
if commit_hash:
|
||||
model_path = model_cache_dir / "snapshots" / commit_hash
|
||||
else:
|
||||
# Try to find a snapshot directory; tolerate missing during tests
|
||||
# Find a snapshot directory
|
||||
snapshots_dir = model_cache_dir / "snapshots"
|
||||
if snapshots_dir.exists():
|
||||
snapshots = [d for d in snapshots_dir.iterdir() if d.is_dir()]
|
||||
model_path = snapshots[0] if snapshots else snapshots_dir / "mock"
|
||||
if snapshots:
|
||||
# Prefer most recently modified snapshot
|
||||
model_path = max(snapshots, key=lambda x: x.stat().st_mtime)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Model '{resolved_name}' has no snapshots in cache. "
|
||||
f"The model directory exists at {model_cache_dir} but contains no "
|
||||
f"downloaded snapshots. Try running: mlxk pull {resolved_name}"
|
||||
)
|
||||
else:
|
||||
model_path = snapshots_dir / "mock"
|
||||
raise RuntimeError(
|
||||
f"Model '{resolved_name}' not found in cache. "
|
||||
f"Expected at: {model_cache_dir}. "
|
||||
f"Try running: mlxk pull {resolved_name}"
|
||||
)
|
||||
else:
|
||||
# Non path-like cache (likely a Mock in unit tests) → pass a synthetic path to load()
|
||||
model_path = Path("/mock") / hf_to_cache_dir(resolved_name) / "snapshots" / (commit_hash or "mock")
|
||||
|
||||
@@ -95,17 +95,9 @@ def clone_operation(model_spec: str, target_dir: str, health_check: bool = True,
|
||||
result["data"]["clone_status"] = "filesystem_error"
|
||||
return result
|
||||
|
||||
# Phase 1b: Validate same-volume requirement (ADR-007)
|
||||
try:
|
||||
_validate_same_volume(target_path.parent)
|
||||
except FilesystemError as e:
|
||||
result["status"] = "error"
|
||||
result["error"] = {
|
||||
"type": "FilesystemError",
|
||||
"message": str(e)
|
||||
}
|
||||
result["data"]["clone_status"] = "filesystem_error"
|
||||
return result
|
||||
# Phase 1b: Removed - same-volume validation obsolete (ADR-022)
|
||||
# Temp cache is always created on workspace volume (line ~440), so cross-volume
|
||||
# HF_HOME doesn't matter. Removed _validate_same_volume check.
|
||||
|
||||
# Phase 2: Create or resume temp cache on same volume as workspace (ADR-018 Phase 0b)
|
||||
result["data"]["clone_status"] = "preparing"
|
||||
@@ -326,24 +318,6 @@ def _validate_apfs_filesystem(path: Path) -> None:
|
||||
)
|
||||
|
||||
|
||||
def _validate_same_volume(workspace_path: Path) -> None:
|
||||
"""Validate that workspace and HF_HOME cache are on same volume (ADR-007 Phase 1)."""
|
||||
cache_root = get_current_cache_root()
|
||||
|
||||
# Get volume mount points for both paths
|
||||
workspace_volume = _get_volume_mount_point(workspace_path)
|
||||
cache_volume = _get_volume_mount_point(cache_root)
|
||||
|
||||
if workspace_volume != cache_volume:
|
||||
raise FilesystemError(
|
||||
f"Phase 1 requires workspace and cache on same volume.\n"
|
||||
f"Workspace volume: {workspace_volume}\n"
|
||||
f"Cache volume (HF_HOME): {cache_volume}\n"
|
||||
f"Solution: Set HF_HOME to same volume as workspace:\n"
|
||||
f" export HF_HOME={workspace_volume}/huggingface/cache"
|
||||
)
|
||||
|
||||
|
||||
def _is_apfs_filesystem(path: Path) -> bool:
|
||||
"""Simple APFS check - returns True/False only.
|
||||
|
||||
|
||||
@@ -453,8 +453,8 @@ def vision_runtime_compatibility(probe: Optional[Path] = None) -> tuple[bool, Op
|
||||
# with temporal_patch_size (video-capable models like Qwen2-VL)
|
||||
if probe is not None:
|
||||
try:
|
||||
import transformers
|
||||
tf_version = getattr(transformers, "__version__", "0.0.0")
|
||||
from importlib.metadata import version
|
||||
tf_version = version("transformers")
|
||||
# Check if transformers 5.x (RC or early release with potential bugs)
|
||||
if tf_version.startswith("5."):
|
||||
preprocessor_path = probe / "preprocessor_config.json"
|
||||
|
||||
+27
-17
@@ -25,6 +25,8 @@ from .test_utils import (
|
||||
discover_audio_models,
|
||||
parse_vm_stat_page_size,
|
||||
TEST_MODELS,
|
||||
VISION_TEST_MODELS,
|
||||
AUDIO_TEST_MODELS,
|
||||
)
|
||||
|
||||
# Import the real MLX modules fixture from parent test module
|
||||
@@ -132,10 +134,10 @@ def pytest_generate_tests(metafunc):
|
||||
if vision_models:
|
||||
model_keys = [f"vision_{i:02d}" for i in range(len(vision_models))]
|
||||
else:
|
||||
# No fallback for vision (needs real models)
|
||||
model_keys = []
|
||||
# Fallback to hardcoded VISION_TEST_MODELS (pixtral)
|
||||
model_keys = list(VISION_TEST_MODELS.keys())
|
||||
|
||||
# If no vision models, parametrize with skip marker
|
||||
# If still no vision models, parametrize with skip marker
|
||||
if not model_keys:
|
||||
model_keys = ["_no_vision_models"]
|
||||
|
||||
@@ -153,10 +155,10 @@ def pytest_generate_tests(metafunc):
|
||||
if audio_models:
|
||||
model_keys = [f"audio_{i:02d}" for i in range(len(audio_models))]
|
||||
else:
|
||||
# No fallback for audio (needs real models)
|
||||
model_keys = []
|
||||
# Fallback to hardcoded AUDIO_TEST_MODELS (whisper)
|
||||
model_keys = list(AUDIO_TEST_MODELS.keys())
|
||||
|
||||
# If no audio models, parametrize with skip marker
|
||||
# If still no audio models, parametrize with skip marker
|
||||
if not model_keys:
|
||||
model_keys = ["_no_audio_models"]
|
||||
|
||||
@@ -306,9 +308,9 @@ def vision_portfolio():
|
||||
print(f"\n👁️ Vision Portfolio: Found {len(result)} vision-capable models")
|
||||
return result
|
||||
else:
|
||||
# No fallback for vision - requires real models
|
||||
print(f"\n⚠️ Vision Portfolio: No vision models found in cache")
|
||||
return {}
|
||||
# Fallback to hardcoded vision test model (pixtral)
|
||||
print(f"\n📋 Vision Portfolio: Using fallback VISION_TEST_MODELS (1 model)")
|
||||
return VISION_TEST_MODELS
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@@ -348,9 +350,9 @@ def audio_portfolio():
|
||||
print(f"\n🔊 Audio Portfolio: Found {len(result)} audio-capable models")
|
||||
return result
|
||||
else:
|
||||
# No fallback for audio - requires real models
|
||||
print(f"\n⚠️ Audio Portfolio: No audio models found in cache")
|
||||
return {}
|
||||
# Fallback to hardcoded audio test model (whisper)
|
||||
print(f"\n📋 Audio Portfolio: Using fallback AUDIO_TEST_MODELS (1 model)")
|
||||
return AUDIO_TEST_MODELS
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -399,7 +401,11 @@ def text_model_info(text_portfolio, text_model_key):
|
||||
- ram_needed_gb: Estimated RAM requirement (1.2x text formula)
|
||||
- expected_issue: Known issue or None
|
||||
- description: Human-readable description
|
||||
|
||||
Returns None for skip markers (_skipped, _no_text_models).
|
||||
"""
|
||||
if text_model_key.startswith("_"):
|
||||
return None
|
||||
return text_portfolio[text_model_key]
|
||||
|
||||
|
||||
@@ -423,7 +429,11 @@ def vision_model_info(vision_portfolio, vision_model_key):
|
||||
- ram_needed_gb: Estimated RAM requirement (0.70 threshold vision formula)
|
||||
- expected_issue: Known issue or None
|
||||
- description: Human-readable description
|
||||
|
||||
Returns None for skip markers (_skipped, _no_vision_models).
|
||||
"""
|
||||
if vision_model_key.startswith("_"):
|
||||
return None
|
||||
return vision_portfolio[vision_model_key]
|
||||
|
||||
|
||||
@@ -498,14 +508,14 @@ def _auto_report_vision_model(request):
|
||||
|
||||
# Type 2: CLI vision tests (test_vision_e2e_live.py)
|
||||
# These tests use subprocess.run(["mlxk", "run", VISION_MODEL, ...])
|
||||
# VISION_MODEL is explicitly set to "pixtral-12b-8bit" to avoid ambiguity
|
||||
# VISION_MODEL is "pixtral-12b-4bit" (matches VISION_TEST_MODELS fallback)
|
||||
if 'test_vision_e2e_live.py' in request.node.nodeid:
|
||||
# All CLI vision tests use explicit pixtral-12b-8bit
|
||||
# All CLI vision tests use pixtral-12b-4bit
|
||||
request.node.user_properties.append(("model", {
|
||||
"id": "pixtral-12b-8bit", # Explicit model (not shorthand)
|
||||
"size_gb": 13.5, # Actual disk size of 8bit variant
|
||||
"id": "mlx-community/pixtral-12b-4bit",
|
||||
"size_gb": 7.0, # 12B 4-bit (~7GB empirical)
|
||||
"family": "pixtral",
|
||||
"variant": "12b-8bit",
|
||||
"variant": "12b-4bit",
|
||||
}))
|
||||
# Explicit inference_modality for CLI vision tests (v0.2.1)
|
||||
# Required because these tests don't use vision_model_key fixture
|
||||
|
||||
@@ -15,9 +15,9 @@ def test_text_portfolio_contains_only_text_models(text_portfolio):
|
||||
if not text_portfolio:
|
||||
pytest.skip("No text models found (HF_HOME not set or no models in cache)")
|
||||
|
||||
# All models should have text_ prefix
|
||||
for key in text_portfolio.keys():
|
||||
assert key.startswith("text_"), f"Expected text_XX key, got: {key}"
|
||||
# Keys are either text_XX (discovered) or fallback names (mxfp4, qwen25, etc.)
|
||||
# Just verify we have keys, not their format
|
||||
assert len(text_portfolio) > 0, "Portfolio should not be empty"
|
||||
|
||||
# All models should have required fields
|
||||
for key, model_info in text_portfolio.items():
|
||||
@@ -34,9 +34,9 @@ def test_vision_portfolio_contains_only_vision_models(vision_portfolio):
|
||||
if not vision_portfolio:
|
||||
pytest.skip("No vision models found in cache")
|
||||
|
||||
# All models should have vision_ prefix
|
||||
for key in vision_portfolio.keys():
|
||||
assert key.startswith("vision_"), f"Expected vision_XX key, got: {key}"
|
||||
# Keys are either vision_XX (discovered) or fallback names (pixtral, etc.)
|
||||
# Just verify we have keys, not their format
|
||||
assert len(vision_portfolio) > 0, "Portfolio should not be empty"
|
||||
|
||||
# All models should have required fields
|
||||
for key, model_info in vision_portfolio.items():
|
||||
|
||||
@@ -206,8 +206,6 @@ def discover_text_models() -> list[Dict[str, Any]]:
|
||||
|
||||
# Get capabilities from mlxk list --json
|
||||
env = os.environ.copy()
|
||||
if not env.get("HF_HOME"):
|
||||
return all_models # Fall back to all models if HF_HOME not set
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
@@ -265,8 +263,6 @@ def discover_vision_models() -> list[Dict[str, Any]]:
|
||||
|
||||
# Get capabilities and size_bytes from mlxk list --json
|
||||
env = os.environ.copy()
|
||||
if not env.get("HF_HOME"):
|
||||
return [] # Vision models need HF_HOME
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
@@ -341,7 +337,7 @@ def discover_audio_models() -> list[Dict[str, Any]]:
|
||||
|
||||
env = os.environ.copy()
|
||||
if not env.get("HF_HOME"):
|
||||
return []
|
||||
return [] # Audio discovery requires HF_HOME (see TESTING.md)
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
@@ -393,6 +389,42 @@ def discover_audio_models() -> list[Dict[str, Any]]:
|
||||
return []
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# FALLBACK TEST MODELS - Minimum Required Models for Testing Without HF_HOME
|
||||
# =============================================================================
|
||||
# When HF_HOME is not set, Portfolio Discovery returns []. These fallback models
|
||||
# provide a baseline for testing when the user has these specific models in
|
||||
# their default cache (~/.cache/huggingface).
|
||||
#
|
||||
# These models must be downloaded manually if testing without HF_HOME:
|
||||
# mlxk pull mlx-community/gpt-oss-20b-MXFP4-Q8
|
||||
# mlxk pull mlx-community/Qwen2.5-0.5B-Instruct-4bit
|
||||
# mlxk pull mlx-community/Llama-3.2-3B-Instruct-4bit
|
||||
# mlxk pull mlx-community/pixtral-12b-4bit
|
||||
# mlxk pull mlx-community/whisper-large-v3-turbo-4bit
|
||||
# =============================================================================
|
||||
|
||||
# Vision fallback model (for tests without HF_HOME)
|
||||
VISION_TEST_MODELS = {
|
||||
"pixtral": {
|
||||
"id": "mlx-community/pixtral-12b-4bit",
|
||||
"expected_issue": None,
|
||||
"description": "Pixtral 12B - general-purpose vision model",
|
||||
"ram_needed_gb": 7.0 # 12B 4-bit (~7GB empirical)
|
||||
}
|
||||
}
|
||||
|
||||
# Audio fallback model (for tests without HF_HOME)
|
||||
AUDIO_TEST_MODELS = {
|
||||
"whisper": {
|
||||
"id": "mlx-community/whisper-large-v3-turbo-4bit",
|
||||
"expected_issue": None,
|
||||
"description": "Whisper large-v3-turbo - STT baseline",
|
||||
"ram_needed_gb": 1.5 # Large-v3 4-bit (~1.5GB)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# Re-export for convenience
|
||||
__all__ = [
|
||||
"discover_mlx_models_in_user_cache",
|
||||
@@ -406,6 +438,8 @@ __all__ = [
|
||||
"get_system_ram_gb",
|
||||
"should_skip_model",
|
||||
"TEST_MODELS",
|
||||
"VISION_TEST_MODELS",
|
||||
"AUDIO_TEST_MODELS",
|
||||
"TEST_PROMPT",
|
||||
"MAX_TOKENS",
|
||||
"TEST_TEMPERATURE",
|
||||
|
||||
@@ -6,9 +6,9 @@ to validate actual image understanding (not just hallucination).
|
||||
|
||||
Requires:
|
||||
- Python 3.10+ (mlx-vlm requirement)
|
||||
- Vision model in cache (e.g., pixtral-12b-4bit or pixtral-12b-8bit)
|
||||
- Vision model in cache (default: pixtral-12b-4bit, see VISION_TEST_MODELS)
|
||||
- Test assets in tests_2.0/assets/
|
||||
- HF_HOME set to model cache location
|
||||
- HF_HOME optional (uses default cache if not set)
|
||||
|
||||
Run with:
|
||||
HF_HOME=/path/to/cache pytest -m live_e2e tests_2.0/live/test_vision_e2e_live.py
|
||||
@@ -19,8 +19,8 @@ import pytest
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
# Explicit model name to avoid ambiguity when multiple pixtral variants in cache
|
||||
VISION_MODEL = "pixtral-12b-8bit"
|
||||
# Must match VISION_TEST_MODELS fallback (see tests_2.0/live/test_utils.py)
|
||||
VISION_MODEL = "pixtral-12b-4bit"
|
||||
|
||||
# Vision support requires Python 3.10+ (mlx-vlm requirement)
|
||||
pytestmark = [
|
||||
|
||||
@@ -471,7 +471,6 @@ class TestCloneOperationIntegration:
|
||||
sentinel.write_text("mlxk2_temp_cache_created_test")
|
||||
|
||||
with patch('mlxk2.operations.clone._validate_apfs_filesystem'), \
|
||||
patch('mlxk2.operations.clone._validate_same_volume'), \
|
||||
patch('mlxk2.operations.clone._create_temp_cache_same_volume') as mock_create_cache, \
|
||||
patch('mlxk2.operations.clone.pull_to_cache') as mock_pull, \
|
||||
patch('mlxk2.operations.clone._resolve_latest_snapshot') as mock_resolve, \
|
||||
@@ -544,7 +543,6 @@ class TestCloneOperationIntegration:
|
||||
sentinel.write_text("mlxk2_temp_cache_created_test")
|
||||
|
||||
with patch('mlxk2.operations.clone._validate_apfs_filesystem'), \
|
||||
patch('mlxk2.operations.clone._validate_same_volume'), \
|
||||
patch('mlxk2.operations.clone._create_temp_cache_same_volume') as mock_create_cache, \
|
||||
patch('mlxk2.operations.clone.pull_to_cache') as mock_pull:
|
||||
|
||||
@@ -579,7 +577,6 @@ class TestCloneOperationIntegration:
|
||||
sentinel.write_text("mlxk2_temp_cache_created_test")
|
||||
|
||||
with patch('mlxk2.operations.clone._validate_apfs_filesystem'), \
|
||||
patch('mlxk2.operations.clone._validate_same_volume'), \
|
||||
patch('mlxk2.operations.clone._create_temp_cache_same_volume') as mock_create_cache, \
|
||||
patch('mlxk2.operations.clone.pull_to_cache') as mock_pull, \
|
||||
patch('mlxk2.operations.clone._resolve_latest_snapshot') as mock_resolve, \
|
||||
@@ -649,7 +646,6 @@ class TestCloneOperationIntegration:
|
||||
sentinel.write_text("mlxk2_temp_cache_created_test")
|
||||
|
||||
with patch('mlxk2.operations.clone._validate_apfs_filesystem'), \
|
||||
patch('mlxk2.operations.clone._validate_same_volume'), \
|
||||
patch('mlxk2.operations.clone._create_temp_cache_same_volume') as mock_create_cache, \
|
||||
patch('mlxk2.operations.clone.pull_to_cache') as mock_pull, \
|
||||
patch('mlxk2.operations.clone._resolve_latest_snapshot') as mock_resolve, \
|
||||
@@ -695,7 +691,6 @@ class TestCloneJSONAPICompliance:
|
||||
sentinel.write_text("mlxk2_temp_cache_created_test")
|
||||
|
||||
with patch('mlxk2.operations.clone._validate_apfs_filesystem'), \
|
||||
patch('mlxk2.operations.clone._validate_same_volume'), \
|
||||
patch('mlxk2.operations.clone._create_temp_cache_same_volume') as mock_create_cache, \
|
||||
patch('mlxk2.operations.clone.pull_to_cache') as mock_pull, \
|
||||
patch('mlxk2.operations.clone._resolve_latest_snapshot') as mock_resolve, \
|
||||
@@ -804,7 +799,6 @@ class TestCloneCoreFeatures:
|
||||
model_spec = "org/model"
|
||||
|
||||
with patch('mlxk2.operations.clone._validate_apfs_filesystem'), \
|
||||
patch('mlxk2.operations.clone._validate_same_volume'), \
|
||||
patch('mlxk2.operations.clone._create_temp_cache_same_volume') as mock_create_cache, \
|
||||
patch('mlxk2.operations.clone.pull_to_cache') as mock_pull, \
|
||||
patch('mlxk2.operations.clone._resolve_latest_snapshot') as mock_resolve, \
|
||||
@@ -865,7 +859,6 @@ class TestCloneCoreFeatures:
|
||||
user_cache.mkdir()
|
||||
|
||||
with patch('mlxk2.operations.clone._validate_apfs_filesystem'), \
|
||||
patch('mlxk2.operations.clone._validate_same_volume'), \
|
||||
patch('mlxk2.operations.clone._create_temp_cache_same_volume') as mock_create_cache, \
|
||||
patch('mlxk2.operations.clone.pull_to_cache') as mock_pull, \
|
||||
patch('mlxk2.operations.clone._resolve_latest_snapshot') as mock_resolve, \
|
||||
@@ -911,7 +904,6 @@ class TestCloneEdgeCases:
|
||||
temp_cache.mkdir()
|
||||
|
||||
with patch('mlxk2.operations.clone._validate_apfs_filesystem'), \
|
||||
patch('mlxk2.operations.clone._validate_same_volume'), \
|
||||
patch('mlxk2.operations.clone._create_temp_cache_same_volume') as mock_create_cache, \
|
||||
patch('mlxk2.operations.clone.pull_to_cache') as mock_pull, \
|
||||
patch('mlxk2.operations.clone._resolve_latest_snapshot') as mock_resolve, \
|
||||
@@ -949,7 +941,6 @@ class TestCloneEdgeCases:
|
||||
sentinel.write_text("mlxk2_temp_cache_created_test")
|
||||
|
||||
with patch('mlxk2.operations.clone._validate_apfs_filesystem'), \
|
||||
patch('mlxk2.operations.clone._validate_same_volume'), \
|
||||
patch('mlxk2.operations.clone._create_temp_cache_same_volume') as mock_create_cache, \
|
||||
patch('mlxk2.operations.clone.pull_to_cache') as mock_pull, \
|
||||
patch('mlxk2.operations.clone._resolve_latest_snapshot') as mock_resolve:
|
||||
@@ -979,7 +970,6 @@ class TestCloneEdgeCases:
|
||||
temp_cache.mkdir()
|
||||
|
||||
with patch('mlxk2.operations.clone._validate_apfs_filesystem'), \
|
||||
patch('mlxk2.operations.clone._validate_same_volume'), \
|
||||
patch('mlxk2.operations.clone._create_temp_cache_same_volume') as mock_create_cache, \
|
||||
patch('mlxk2.operations.clone.pull_to_cache') as mock_pull, \
|
||||
patch('mlxk2.operations.clone._resolve_latest_snapshot') as mock_resolve, \
|
||||
@@ -1029,7 +1019,6 @@ class TestUnhealthyModelClone:
|
||||
"""
|
||||
|
||||
@patch('mlxk2.operations.clone._validate_apfs_filesystem')
|
||||
@patch('mlxk2.operations.clone._validate_same_volume')
|
||||
@patch('mlxk2.operations.clone._create_temp_cache_same_volume')
|
||||
@patch('mlxk2.operations.clone.pull_to_cache')
|
||||
@patch('mlxk2.operations.clone._resolve_latest_snapshot')
|
||||
@@ -1039,7 +1028,7 @@ class TestUnhealthyModelClone:
|
||||
@patch('mlxk2.operations.clone.write_workspace_sentinel')
|
||||
def test_unhealthy_model_clone_succeeds(
|
||||
self, mock_sentinel, mock_cleanup, mock_clone, mock_health,
|
||||
mock_snapshot, mock_pull, mock_temp_cache, mock_validate_vol, mock_validate_apfs,
|
||||
mock_snapshot, mock_pull, mock_temp_cache, mock_validate_apfs,
|
||||
tmp_path
|
||||
):
|
||||
"""Test that unhealthy models are still cloned successfully."""
|
||||
@@ -1083,7 +1072,6 @@ class TestUnhealthyModelClone:
|
||||
mock_sentinel.assert_called_once()
|
||||
|
||||
@patch('mlxk2.operations.clone._validate_apfs_filesystem')
|
||||
@patch('mlxk2.operations.clone._validate_same_volume')
|
||||
@patch('mlxk2.operations.clone._create_temp_cache_same_volume')
|
||||
@patch('mlxk2.operations.clone.pull_to_cache')
|
||||
@patch('mlxk2.operations.clone._resolve_latest_snapshot')
|
||||
@@ -1093,7 +1081,7 @@ class TestUnhealthyModelClone:
|
||||
@patch('mlxk2.operations.clone.write_workspace_sentinel')
|
||||
def test_healthy_model_clone_records_status(
|
||||
self, mock_sentinel, mock_cleanup, mock_clone, mock_health,
|
||||
mock_snapshot, mock_pull, mock_temp_cache, mock_validate_vol, mock_validate_apfs,
|
||||
mock_snapshot, mock_pull, mock_temp_cache, mock_validate_apfs,
|
||||
tmp_path
|
||||
):
|
||||
"""Test that healthy models record health status correctly."""
|
||||
@@ -1131,7 +1119,6 @@ class TestUnhealthyModelClone:
|
||||
assert result["data"]["health_reason"] == "Multi-file model complete"
|
||||
|
||||
@patch('mlxk2.operations.clone._validate_apfs_filesystem')
|
||||
@patch('mlxk2.operations.clone._validate_same_volume')
|
||||
@patch('mlxk2.operations.clone._create_temp_cache_same_volume')
|
||||
@patch('mlxk2.operations.clone.pull_to_cache')
|
||||
@patch('mlxk2.operations.clone._resolve_latest_snapshot')
|
||||
@@ -1140,7 +1127,7 @@ class TestUnhealthyModelClone:
|
||||
@patch('mlxk2.operations.clone.write_workspace_sentinel')
|
||||
def test_no_health_check_skips_health_status(
|
||||
self, mock_sentinel, mock_cleanup, mock_clone,
|
||||
mock_snapshot, mock_pull, mock_temp_cache, mock_validate_vol, mock_validate_apfs,
|
||||
mock_snapshot, mock_pull, mock_temp_cache, mock_validate_apfs,
|
||||
tmp_path
|
||||
):
|
||||
"""Test that --no-health-check skips health status entirely."""
|
||||
@@ -1355,9 +1342,8 @@ class TestResumableClone:
|
||||
model_spec = "test/model"
|
||||
|
||||
with patch('mlxk2.operations.clone._validate_apfs_filesystem'):
|
||||
with patch('mlxk2.operations.clone._validate_same_volume'):
|
||||
with patch('mlxk2.operations.clone._get_volume_mount_point', return_value=tmp_path):
|
||||
# Mock pull_to_cache to raise KeyboardInterrupt
|
||||
with patch('mlxk2.operations.clone._get_volume_mount_point', return_value=tmp_path):
|
||||
# Mock pull_to_cache to raise KeyboardInterrupt
|
||||
with patch('mlxk2.operations.clone.pull_to_cache', side_effect=KeyboardInterrupt()):
|
||||
result = clone_operation(model_spec, str(target))
|
||||
|
||||
@@ -1384,9 +1370,8 @@ class TestResumableClone:
|
||||
model_spec = "test/model"
|
||||
|
||||
with patch('mlxk2.operations.clone._validate_apfs_filesystem'):
|
||||
with patch('mlxk2.operations.clone._validate_same_volume'):
|
||||
# Raise KeyboardInterrupt during volume mount check
|
||||
with patch('mlxk2.operations.clone._get_volume_mount_point', side_effect=KeyboardInterrupt()):
|
||||
# Raise KeyboardInterrupt during volume mount check
|
||||
with patch('mlxk2.operations.clone._get_volume_mount_point', side_effect=KeyboardInterrupt()):
|
||||
result = clone_operation(model_spec, str(target))
|
||||
|
||||
# Should handle gracefully even without temp cache
|
||||
|
||||
@@ -74,7 +74,7 @@ class TestMLXRunnerInterruption:
|
||||
|
||||
@patch('mlxk2.core.runner.load')
|
||||
@patch('mlxk2.core.runner.resolve_model_for_operation')
|
||||
@patch('mlxk2.core.cache.get_current_model_cache')
|
||||
@patch('mlxk2.core.runner.get_current_model_cache')
|
||||
def test_signal_handler_setup(self, mock_cache, mock_resolve, mock_load):
|
||||
"""Test that signal handler is properly set up"""
|
||||
mock_resolve.return_value = ("test-model", None, None)
|
||||
@@ -88,7 +88,7 @@ class TestMLXRunnerInterruption:
|
||||
|
||||
@patch('mlxk2.core.runner.load')
|
||||
@patch('mlxk2.core.runner.resolve_model_for_operation')
|
||||
@patch('mlxk2.core.cache.get_current_model_cache')
|
||||
@patch('mlxk2.core.runner.get_current_model_cache')
|
||||
def test_interrupt_flag_setting(self, mock_cache, mock_resolve, mock_load):
|
||||
"""Test that interrupt handler sets the flag correctly"""
|
||||
mock_resolve.return_value = ("test-model", None, None)
|
||||
@@ -107,7 +107,7 @@ class TestMLXRunnerInterruption:
|
||||
|
||||
@patch('mlxk2.core.runner.load')
|
||||
@patch('mlxk2.core.runner.resolve_model_for_operation')
|
||||
@patch('mlxk2.core.cache.get_current_model_cache')
|
||||
@patch('mlxk2.core.runner.get_current_model_cache')
|
||||
@patch('mlxk2.core.runner.generate_step')
|
||||
def test_streaming_interruption_detection(self, mock_gen, mock_cache, mock_resolve, mock_load):
|
||||
"""Test that streaming generation checks for interruption"""
|
||||
@@ -159,7 +159,7 @@ class TestMLXRunnerInterruption:
|
||||
|
||||
@patch('mlxk2.core.runner.load')
|
||||
@patch('mlxk2.core.runner.resolve_model_for_operation')
|
||||
@patch('mlxk2.core.cache.get_current_model_cache')
|
||||
@patch('mlxk2.core.runner.get_current_model_cache')
|
||||
@patch('mlxk2.core.runner.generate_step')
|
||||
def test_batch_interruption_detection(self, mock_gen, mock_cache, mock_resolve, mock_load):
|
||||
"""Test that batch generation also checks for interruption"""
|
||||
|
||||
@@ -48,7 +48,7 @@ class TestInterruptionRecovery:
|
||||
|
||||
@patch('mlxk2.core.runner.load')
|
||||
@patch('mlxk2.core.runner.resolve_model_for_operation')
|
||||
@patch('mlxk2.core.cache.get_current_model_cache')
|
||||
@patch('mlxk2.core.runner.get_current_model_cache')
|
||||
def test_interruption_flag_reset_streaming(self, mock_cache, mock_resolve, mock_load):
|
||||
"""Test that interruption flag is reset for new streaming generation"""
|
||||
mock_resolve.return_value = ("test-model", None, None)
|
||||
@@ -94,7 +94,7 @@ class TestInterruptionRecovery:
|
||||
|
||||
@patch('mlxk2.core.runner.load')
|
||||
@patch('mlxk2.core.runner.resolve_model_for_operation')
|
||||
@patch('mlxk2.core.cache.get_current_model_cache')
|
||||
@patch('mlxk2.core.runner.get_current_model_cache')
|
||||
def test_interruption_flag_reset_batch(self, mock_cache, mock_resolve, mock_load):
|
||||
"""Test that interruption flag is reset for new batch generation"""
|
||||
mock_resolve.return_value = ("test-model", None, None)
|
||||
|
||||
@@ -63,22 +63,23 @@ class TestTextModelsDiscovery:
|
||||
assert "mlx-community/Phi-3-mini-4k-instruct-4bit" in model_ids
|
||||
assert "mlx-community/Llama-3.2-11B-Vision-Instruct-4bit" not in model_ids
|
||||
|
||||
def test_discover_text_models_returns_all_when_no_hf_home(self, monkeypatch):
|
||||
"""Verify fallback behavior when HF_HOME not set."""
|
||||
mock_all_models = [
|
||||
{"model_id": "mlx-community/Qwen2.5-0.5B-Instruct-4bit", "ram_needed_gb": 1.0, "snapshot_path": None, "weight_count": None},
|
||||
]
|
||||
def test_discover_text_models_returns_empty_when_no_hf_home(self, monkeypatch):
|
||||
"""Verify fallback behavior when HF_HOME not set.
|
||||
|
||||
with patch("live.test_utils.discover_mlx_models_in_user_cache", return_value=mock_all_models):
|
||||
Without HF_HOME, discover_mlx_models_in_user_cache returns [] (by design).
|
||||
This ensures tests fall back to TEST_MODELS hardcoded models.
|
||||
See TESTING.md for Portfolio Discovery requirements.
|
||||
"""
|
||||
# Mock discover_mlx_models_in_user_cache to return [] (simulates no HF_HOME)
|
||||
with patch("live.test_utils.discover_mlx_models_in_user_cache", return_value=[]):
|
||||
# Unset HF_HOME
|
||||
monkeypatch.delenv("HF_HOME", raising=False)
|
||||
|
||||
from live.test_utils import discover_text_models
|
||||
result = discover_text_models()
|
||||
|
||||
# Should return all models (fallback)
|
||||
assert len(result) == 1
|
||||
assert result == mock_all_models
|
||||
# Should return empty (triggers fallback to TEST_MODELS in portfolio fixture)
|
||||
assert result == []
|
||||
|
||||
def test_discover_text_models_handles_empty_portfolio(self):
|
||||
"""Verify behavior when no models discovered."""
|
||||
@@ -149,20 +150,39 @@ class TestVisionModelsDiscovery:
|
||||
assert "mlx-community/pixtral-12b-8bit" in model_ids
|
||||
assert "mlx-community/Qwen2.5-0.5B-Instruct-4bit" not in model_ids
|
||||
|
||||
def test_discover_vision_models_returns_empty_when_no_hf_home(self, monkeypatch):
|
||||
"""Verify that vision models require HF_HOME."""
|
||||
def test_discover_vision_models_uses_default_cache_when_no_hf_home(self, monkeypatch):
|
||||
"""Verify that vision models use default cache when HF_HOME not set."""
|
||||
mock_all_models = [
|
||||
{"model_id": "mlx-community/Llama-3.2-11B-Vision-Instruct-4bit", "ram_needed_gb": 24.0, "snapshot_path": None, "weight_count": None},
|
||||
]
|
||||
|
||||
mock_list_output = {
|
||||
"status": "success",
|
||||
"command": "list",
|
||||
"data": {
|
||||
"models": [
|
||||
{"name": "mlx-community/Llama-3.2-11B-Vision-Instruct-4bit", "capabilities": ["text-generation", "chat", "vision"], "size_bytes": 12000000000},
|
||||
],
|
||||
"count": 1
|
||||
},
|
||||
"error": None
|
||||
}
|
||||
|
||||
with patch("live.test_utils.discover_mlx_models_in_user_cache", return_value=mock_all_models):
|
||||
monkeypatch.delenv("HF_HOME", raising=False)
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(
|
||||
returncode=0,
|
||||
stdout=json.dumps(mock_list_output)
|
||||
)
|
||||
# No HF_HOME set - should still work with default cache
|
||||
monkeypatch.delenv("HF_HOME", raising=False)
|
||||
|
||||
from live.test_utils import discover_vision_models
|
||||
result = discover_vision_models()
|
||||
from live.test_utils import discover_vision_models
|
||||
result = discover_vision_models()
|
||||
|
||||
# Should return empty (vision needs HF_HOME)
|
||||
assert result == []
|
||||
# Should return vision models (using default cache)
|
||||
assert len(result) == 1
|
||||
assert result[0]["model_id"] == "mlx-community/Llama-3.2-11B-Vision-Instruct-4bit"
|
||||
|
||||
def test_discover_vision_models_handles_empty_portfolio(self):
|
||||
"""Verify behavior when no models discovered."""
|
||||
|
||||
@@ -48,10 +48,13 @@ class MockDetokenizer:
|
||||
@contextmanager
|
||||
def mock_runner_environment(temp_cache_dir, model_name="test-model"):
|
||||
"""Mock the environment needed for MLXRunner tests."""
|
||||
# IMPORTANT: Patch in the runner module where the functions are imported,
|
||||
# not in the cache module where they're defined. This ensures the patched
|
||||
# references are used by MLXRunner.
|
||||
with patch('mlxk2.core.runner.load') as mock_load, \
|
||||
patch('mlxk2.core.runner.resolve_model_for_operation') as mock_resolve, \
|
||||
patch('mlxk2.core.cache.get_current_model_cache') as mock_cache, \
|
||||
patch('mlxk2.core.cache.hf_to_cache_dir') as mock_hf_to_cache, \
|
||||
patch('mlxk2.core.runner.get_current_model_cache') as mock_cache, \
|
||||
patch('mlxk2.core.runner.hf_to_cache_dir') as mock_hf_to_cache, \
|
||||
patch('mlxk2.core.runner.get_model_context_length') as mock_context:
|
||||
|
||||
# Mock successful model resolution
|
||||
@@ -296,128 +299,99 @@ class TestMLXRunnerStopTokens:
|
||||
|
||||
class TestMLXRunnerMemorySafety:
|
||||
"""Test memory management and cleanup"""
|
||||
|
||||
|
||||
def test_model_cleanup_on_context_exit(self, temp_cache_dir):
|
||||
"""Test that model is properly cleaned up"""
|
||||
model_name = "test-model"
|
||||
|
||||
with patch('mlxk2.core.runner.load') as mock_load:
|
||||
mock_model = Mock()
|
||||
mock_tokenizer = Mock()
|
||||
mock_load.return_value = (mock_model, mock_tokenizer)
|
||||
|
||||
|
||||
with mock_runner_environment(temp_cache_dir, model_name) as mocks:
|
||||
runner = None
|
||||
with MLXRunner(model_name) as r:
|
||||
runner = r
|
||||
assert runner.model is not None
|
||||
assert runner.tokenizer is not None
|
||||
|
||||
|
||||
# After context exit, model should be cleaned up
|
||||
assert runner.model is None
|
||||
assert runner.tokenizer is None
|
||||
|
||||
|
||||
def test_multiple_context_managers(self, temp_cache_dir):
|
||||
"""Test that multiple runners can be used sequentially"""
|
||||
model_name = "test-model"
|
||||
|
||||
with patch('mlxk2.core.runner.load') as mock_load:
|
||||
mock_model = Mock()
|
||||
mock_tokenizer = Mock()
|
||||
mock_tokenizer.encode.return_value = [1]
|
||||
mock_tokenizer.decode.return_value = "ok"
|
||||
mock_tokenizer.eos_token_id = 2
|
||||
mock_tokenizer.eos_token_ids = {mock_tokenizer.eos_token_id}
|
||||
mock_tokenizer.additional_special_tokens = []
|
||||
mock_tokenizer.added_tokens_decoder = {}
|
||||
mock_load.return_value = (mock_model, mock_tokenizer)
|
||||
|
||||
|
||||
with mock_runner_environment(temp_cache_dir, model_name) as mocks:
|
||||
# First runner
|
||||
with MLXRunner(model_name) as runner1:
|
||||
assert runner1 is not None
|
||||
|
||||
|
||||
# Second runner should work independently
|
||||
with MLXRunner(model_name) as runner2:
|
||||
assert runner2 is not None
|
||||
|
||||
|
||||
# Should have loaded model twice
|
||||
assert mock_load.call_count == 2
|
||||
assert mocks['mock_load'].call_count == 2
|
||||
|
||||
|
||||
class TestMLXRunnerDynamicTokens:
|
||||
"""Test dynamic token limit functionality"""
|
||||
|
||||
|
||||
def test_no_max_tokens_uses_dynamic(self, temp_cache_dir):
|
||||
"""Test that None max_tokens uses dynamic limit based on model context"""
|
||||
model_name = "test-model"
|
||||
|
||||
with patch('mlxk2.core.runner.load') as mock_load:
|
||||
mock_load.return_value = (Mock(), Mock())
|
||||
|
||||
# Mock config reading for context length
|
||||
with patch('mlxk2.core.runner.get_model_context_length', return_value=8192):
|
||||
with MLXRunner(model_name) as runner:
|
||||
# Should calculate dynamic limit from context length
|
||||
dynamic_limit = runner._calculate_dynamic_max_tokens()
|
||||
|
||||
# Should be a reasonable fraction of context (server-mode default)
|
||||
# Accept half-context on 8K models as reasonable
|
||||
assert 1000 <= dynamic_limit <= 4096
|
||||
|
||||
|
||||
with mock_runner_environment(temp_cache_dir, model_name) as mocks:
|
||||
with MLXRunner(model_name) as runner:
|
||||
# Should calculate dynamic limit from context length (8192 from mock)
|
||||
dynamic_limit = runner._calculate_dynamic_max_tokens()
|
||||
|
||||
# Should be a reasonable fraction of context (server-mode default)
|
||||
# Accept half-context on 8K models as reasonable
|
||||
assert 1000 <= dynamic_limit <= 4096
|
||||
|
||||
def test_respects_explicit_max_tokens(self, temp_cache_dir):
|
||||
"""Test that explicit max_tokens is respected"""
|
||||
model_name = "test-model"
|
||||
|
||||
with patch('mlxk2.core.runner.load') as mock_load:
|
||||
mock_model = Mock()
|
||||
mock_tokenizer = Mock()
|
||||
mock_tokenizer.encode.return_value = [1]
|
||||
mock_tokenizer.decode.return_value = "ok"
|
||||
mock_tokenizer.eos_token_id = 2
|
||||
mock_tokenizer.eos_token_ids = {mock_tokenizer.eos_token_id}
|
||||
mock_tokenizer.additional_special_tokens = []
|
||||
mock_tokenizer.added_tokens_decoder = {}
|
||||
mock_load.return_value = (mock_model, mock_tokenizer)
|
||||
|
||||
|
||||
with mock_runner_environment(temp_cache_dir, model_name) as mocks:
|
||||
# Update mock tokenizer with extra methods needed for generation
|
||||
mocks['mock_tokenizer'].encode.return_value = [1]
|
||||
mocks['mock_tokenizer'].decode.return_value = "ok"
|
||||
|
||||
with MLXRunner(model_name) as runner:
|
||||
# When max_tokens is explicitly set, should respect it
|
||||
with patch('mlxk2.core.runner.generate_step') as mock_gen:
|
||||
mock_gen.return_value = iter([(mx.array([1]), mx.zeros(1))])
|
||||
|
||||
|
||||
# Mock to check that max_tokens is passed through
|
||||
result = runner.generate_batch("test", max_tokens=100)
|
||||
|
||||
|
||||
# Should have respected the explicit limit
|
||||
# (Details depend on implementation)
|
||||
|
||||
|
||||
class TestMLXRunnerErrorHandling:
|
||||
"""Test error handling and edge cases"""
|
||||
|
||||
|
||||
def test_model_loading_failure(self, temp_cache_dir):
|
||||
"""Test handling of model loading failures"""
|
||||
model_path = "nonexistent-model"
|
||||
|
||||
with patch('mlxk2.core.runner.load') as mock_load:
|
||||
mock_load.side_effect = FileNotFoundError("Model not found")
|
||||
|
||||
model_name = "test-model"
|
||||
|
||||
# Create the mock environment but configure load to raise an error
|
||||
with mock_runner_environment(temp_cache_dir, model_name) as mocks:
|
||||
mocks['mock_load'].side_effect = FileNotFoundError("Model not found")
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
with MLXRunner(model_path):
|
||||
with MLXRunner(model_name):
|
||||
pass
|
||||
|
||||
|
||||
def test_generation_interruption(self, temp_cache_dir):
|
||||
"""Test Ctrl-C interruption handling"""
|
||||
model_name = "test-model"
|
||||
|
||||
with patch('mlxk2.core.runner.load') as mock_load:
|
||||
mock_model, mock_tokenizer = Mock(), Mock()
|
||||
# Minimal tokenizer stubs to satisfy runner
|
||||
mock_tokenizer.encode.return_value = [1]
|
||||
mock_tokenizer.decode.return_value = "ok"
|
||||
mock_tokenizer.eos_token_id = 2
|
||||
mock_tokenizer.eos_token_ids = {mock_tokenizer.eos_token_id}
|
||||
mock_tokenizer.additional_special_tokens = []
|
||||
mock_tokenizer.added_tokens_decoder = {}
|
||||
mock_load.return_value = (mock_model, mock_tokenizer)
|
||||
|
||||
with mock_runner_environment(temp_cache_dir, model_name) as mocks:
|
||||
# Update mock tokenizer with extra methods needed for generation
|
||||
mocks['mock_tokenizer'].encode.return_value = [1]
|
||||
mocks['mock_tokenizer'].decode.return_value = "ok"
|
||||
|
||||
# With new recovery semantics, a pre-existing interruption flag
|
||||
# is cleared at the start of a new generation.
|
||||
|
||||
@@ -224,7 +224,8 @@ def discover_mlx_models_in_user_cache() -> List[Dict[str, Any]]:
|
||||
except ImportError:
|
||||
KNOWN_BROKEN_MODELS = set() # Fallback if import fails
|
||||
|
||||
# Check HF_HOME is set (required for mlxk list)
|
||||
# Check HF_HOME is set (required for Portfolio Discovery - see TESTING.md)
|
||||
# Without HF_HOME: tests fall back to TEST_MODELS/VISION_TEST_MODELS/AUDIO_TEST_MODELS
|
||||
env = os.environ.copy()
|
||||
if not env.get("HF_HOME"):
|
||||
return []
|
||||
|
||||
@@ -64,7 +64,7 @@ class TestDynamicTokenLimits:
|
||||
with patch('mlxk2.core.runner.resolve_model_for_operation') as mock_resolve:
|
||||
mock_resolve.return_value = ("test-model", None, None)
|
||||
|
||||
with patch('mlxk2.core.cache.get_current_model_cache') as mock_cache:
|
||||
with patch('mlxk2.core.runner.get_current_model_cache') as mock_cache:
|
||||
mock_cache.return_value = Mock()
|
||||
|
||||
# Create runner and test calculation
|
||||
@@ -86,7 +86,7 @@ class TestDynamicTokenLimits:
|
||||
with patch('mlxk2.core.runner.resolve_model_for_operation') as mock_resolve:
|
||||
mock_resolve.return_value = ("test-model", None, None)
|
||||
|
||||
with patch('mlxk2.core.cache.get_current_model_cache') as mock_cache:
|
||||
with patch('mlxk2.core.runner.get_current_model_cache') as mock_cache:
|
||||
mock_cache.return_value = Mock()
|
||||
|
||||
# Create runner and test calculation
|
||||
@@ -105,7 +105,7 @@ class TestDynamicTokenLimits:
|
||||
with patch('mlxk2.core.runner.resolve_model_for_operation') as mock_resolve:
|
||||
mock_resolve.return_value = ("test-model", None, None)
|
||||
|
||||
with patch('mlxk2.core.cache.get_current_model_cache') as mock_cache:
|
||||
with patch('mlxk2.core.runner.get_current_model_cache') as mock_cache:
|
||||
mock_cache.return_value = Mock()
|
||||
|
||||
# Create runner with no context length
|
||||
@@ -125,7 +125,7 @@ class TestTokenLimitApplication:
|
||||
|
||||
@patch('mlxk2.core.runner.load')
|
||||
@patch('mlxk2.core.runner.resolve_model_for_operation')
|
||||
@patch('mlxk2.core.cache.get_current_model_cache')
|
||||
@patch('mlxk2.core.runner.get_current_model_cache')
|
||||
@patch('mlxk2.core.runner.get_model_context_length')
|
||||
def test_generate_streaming_uses_dynamic_limits(self, mock_context, mock_cache, mock_resolve, mock_load):
|
||||
"""Test that generate_streaming uses dynamic limits when max_tokens=None"""
|
||||
@@ -158,7 +158,7 @@ class TestTokenLimitApplication:
|
||||
|
||||
@patch('mlxk2.core.runner.load')
|
||||
@patch('mlxk2.core.runner.resolve_model_for_operation')
|
||||
@patch('mlxk2.core.cache.get_current_model_cache')
|
||||
@patch('mlxk2.core.runner.get_current_model_cache')
|
||||
@patch('mlxk2.core.runner.get_model_context_length')
|
||||
def test_generate_streaming_respects_explicit_limits(self, mock_context, mock_cache, mock_resolve, mock_load):
|
||||
"""Test that explicit max_tokens is respected"""
|
||||
@@ -191,7 +191,7 @@ class TestTokenLimitApplication:
|
||||
|
||||
@patch('mlxk2.core.runner.load')
|
||||
@patch('mlxk2.core.runner.resolve_model_for_operation')
|
||||
@patch('mlxk2.core.cache.get_current_model_cache')
|
||||
@patch('mlxk2.core.runner.get_current_model_cache')
|
||||
@patch('mlxk2.core.runner.get_model_context_length')
|
||||
def test_generate_batch_uses_dynamic_limits(self, mock_context, mock_cache, mock_resolve, mock_load):
|
||||
"""Test that generate_batch also uses dynamic limits"""
|
||||
@@ -238,7 +238,7 @@ class TestLargeContextModels:
|
||||
with patch('mlxk2.core.runner.resolve_model_for_operation') as mock_resolve:
|
||||
mock_resolve.return_value = ("large-model", None, None)
|
||||
|
||||
with patch('mlxk2.core.cache.get_current_model_cache') as mock_cache:
|
||||
with patch('mlxk2.core.runner.get_current_model_cache') as mock_cache:
|
||||
mock_cache.return_value = Mock()
|
||||
|
||||
runner = MLXRunner("large-model")
|
||||
@@ -263,7 +263,7 @@ class TestLargeContextModels:
|
||||
with patch('mlxk2.core.runner.resolve_model_for_operation') as mock_resolve:
|
||||
mock_resolve.return_value = ("huge-model", None, None)
|
||||
|
||||
with patch('mlxk2.core.cache.get_current_model_cache') as mock_cache:
|
||||
with patch('mlxk2.core.runner.get_current_model_cache') as mock_cache:
|
||||
mock_cache.return_value = Mock()
|
||||
|
||||
runner = MLXRunner("huge-model")
|
||||
@@ -288,7 +288,7 @@ class TestTokenLimitEdgeCases:
|
||||
with patch('mlxk2.core.runner.resolve_model_for_operation') as mock_resolve:
|
||||
mock_resolve.return_value = ("test-model", None, None)
|
||||
|
||||
with patch('mlxk2.core.cache.get_current_model_cache') as mock_cache:
|
||||
with patch('mlxk2.core.runner.get_current_model_cache') as mock_cache:
|
||||
mock_cache.return_value = Mock()
|
||||
|
||||
runner = MLXRunner("test-model")
|
||||
@@ -315,7 +315,7 @@ class TestTokenLimitEdgeCases:
|
||||
with patch('mlxk2.core.runner.resolve_model_for_operation') as mock_resolve:
|
||||
mock_resolve.return_value = ("test-model", None, None)
|
||||
|
||||
with patch('mlxk2.core.cache.get_current_model_cache') as mock_cache:
|
||||
with patch('mlxk2.core.runner.get_current_model_cache') as mock_cache:
|
||||
mock_cache.return_value = Mock()
|
||||
|
||||
runner = MLXRunner("test-model")
|
||||
@@ -337,7 +337,7 @@ class TestServerVsRunDifferences:
|
||||
with patch('mlxk2.core.runner.resolve_model_for_operation') as mock_resolve:
|
||||
mock_resolve.return_value = ("test-model", None, None)
|
||||
|
||||
with patch('mlxk2.core.cache.get_current_model_cache') as mock_cache:
|
||||
with patch('mlxk2.core.runner.get_current_model_cache') as mock_cache:
|
||||
mock_cache.return_value = Mock()
|
||||
|
||||
runner = MLXRunner("test-model")
|
||||
@@ -376,7 +376,7 @@ class TestServerVsRunDifferences:
|
||||
with patch('mlxk2.core.runner.resolve_model_for_operation') as mock_resolve:
|
||||
mock_resolve.return_value = ("test-model", None, None)
|
||||
|
||||
with patch('mlxk2.core.cache.get_current_model_cache') as mock_cache:
|
||||
with patch('mlxk2.core.runner.get_current_model_cache') as mock_cache:
|
||||
mock_cache.return_value = Mock()
|
||||
|
||||
runner = MLXRunner("test-model")
|
||||
|
||||
@@ -0,0 +1,562 @@
|
||||
"""Unit tests for mlxk2/audio/whisper_tokenizer.py.
|
||||
|
||||
Tests the bundled Whisper tokenizer implementation (mlx-audio Issue #479 workaround).
|
||||
|
||||
Coverage:
|
||||
- get_encoding(): Load tiktoken encodings from bundled assets
|
||||
- get_tokenizer(): Create Tokenizer instances for various configurations
|
||||
- Tokenizer class: Special tokens, encode/decode, properties
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _mlx_audio_available():
|
||||
"""Check if mlx-audio is available and functional."""
|
||||
try:
|
||||
import mlx_audio.stt.models.whisper.tokenizer # noqa: F401
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
requires_mlx_audio = pytest.mark.skipif(
|
||||
not _mlx_audio_available(),
|
||||
reason="mlx-audio not available or MLX incompatible"
|
||||
)
|
||||
|
||||
|
||||
class TestGetEncoding:
|
||||
"""Tests for get_encoding() function."""
|
||||
|
||||
def test_get_encoding_gpt2(self):
|
||||
"""Load gpt2 encoding from bundled assets."""
|
||||
from mlxk2.audio.whisper_tokenizer import get_encoding
|
||||
|
||||
enc = get_encoding("gpt2")
|
||||
|
||||
assert enc is not None
|
||||
assert enc.name == "gpt2.tiktoken"
|
||||
# Verify it can encode/decode basic text
|
||||
tokens = enc.encode("Hello world")
|
||||
assert len(tokens) > 0
|
||||
decoded = enc.decode(tokens)
|
||||
assert decoded == "Hello world"
|
||||
|
||||
def test_get_encoding_multilingual(self):
|
||||
"""Load multilingual encoding from bundled assets."""
|
||||
from mlxk2.audio.whisper_tokenizer import get_encoding
|
||||
|
||||
enc = get_encoding("multilingual")
|
||||
|
||||
assert enc is not None
|
||||
assert enc.name == "multilingual.tiktoken"
|
||||
# Verify it can encode/decode multilingual text
|
||||
tokens = enc.encode("Guten Tag")
|
||||
assert len(tokens) > 0
|
||||
decoded = enc.decode(tokens)
|
||||
assert decoded == "Guten Tag"
|
||||
|
||||
def test_get_encoding_nonexistent_raises(self):
|
||||
"""Unknown encoding name should raise FileNotFoundError."""
|
||||
from mlxk2.audio.whisper_tokenizer import get_encoding
|
||||
|
||||
with pytest.raises(FileNotFoundError) as exc_info:
|
||||
get_encoding("nonexistent_encoding")
|
||||
|
||||
assert "Tiktoken vocabulary file not found" in str(exc_info.value)
|
||||
assert "mlx-audio Issue #479" in str(exc_info.value)
|
||||
|
||||
def test_get_encoding_is_cached(self):
|
||||
"""get_encoding() should be cached (lru_cache)."""
|
||||
from mlxk2.audio.whisper_tokenizer import get_encoding
|
||||
|
||||
enc1 = get_encoding("gpt2")
|
||||
enc2 = get_encoding("gpt2")
|
||||
|
||||
# Same object due to caching
|
||||
assert enc1 is enc2
|
||||
|
||||
def test_get_encoding_has_special_tokens(self):
|
||||
"""Encoding should have Whisper special tokens."""
|
||||
from mlxk2.audio.whisper_tokenizer import get_encoding
|
||||
|
||||
enc = get_encoding("gpt2")
|
||||
|
||||
# Check Whisper-specific special tokens exist
|
||||
special_tokens = enc.special_tokens_set
|
||||
assert "<|endoftext|>" in special_tokens
|
||||
assert "<|startoftranscript|>" in special_tokens
|
||||
assert "<|transcribe|>" in special_tokens
|
||||
assert "<|translate|>" in special_tokens
|
||||
assert "<|nospeech|>" in special_tokens
|
||||
assert "<|notimestamps|>" in special_tokens
|
||||
|
||||
def test_get_encoding_has_language_tokens(self):
|
||||
"""Encoding should have language tokens (<|en|>, <|de|>, etc.)."""
|
||||
from mlxk2.audio.whisper_tokenizer import get_encoding
|
||||
|
||||
enc = get_encoding("gpt2", num_languages=99)
|
||||
|
||||
special_tokens = enc.special_tokens_set
|
||||
assert "<|en|>" in special_tokens
|
||||
assert "<|de|>" in special_tokens
|
||||
assert "<|fr|>" in special_tokens
|
||||
assert "<|es|>" in special_tokens
|
||||
|
||||
def test_get_encoding_has_timestamp_tokens(self):
|
||||
"""Encoding should have timestamp tokens (<|0.00|> to <|30.00|>)."""
|
||||
from mlxk2.audio.whisper_tokenizer import get_encoding
|
||||
|
||||
enc = get_encoding("gpt2")
|
||||
|
||||
special_tokens = enc.special_tokens_set
|
||||
assert "<|0.00|>" in special_tokens
|
||||
assert "<|0.02|>" in special_tokens
|
||||
assert "<|10.00|>" in special_tokens
|
||||
assert "<|30.00|>" in special_tokens
|
||||
|
||||
|
||||
class TestGetTokenizer:
|
||||
"""Tests for get_tokenizer() function."""
|
||||
|
||||
def test_get_tokenizer_multilingual_default(self):
|
||||
"""Multilingual tokenizer with default settings."""
|
||||
from mlxk2.audio.whisper_tokenizer import get_tokenizer
|
||||
|
||||
tok = get_tokenizer(multilingual=True)
|
||||
|
||||
assert tok is not None
|
||||
assert tok.language == "en" # Default language
|
||||
assert tok.task == "transcribe" # Default task
|
||||
assert tok.encoding.name == "multilingual.tiktoken"
|
||||
|
||||
def test_get_tokenizer_multilingual_german(self):
|
||||
"""Multilingual tokenizer with German language."""
|
||||
from mlxk2.audio.whisper_tokenizer import get_tokenizer
|
||||
|
||||
tok = get_tokenizer(multilingual=True, language="de")
|
||||
|
||||
assert tok.language == "de"
|
||||
assert tok.task == "transcribe"
|
||||
|
||||
def test_get_tokenizer_multilingual_translate_task(self):
|
||||
"""Multilingual tokenizer with translate task."""
|
||||
from mlxk2.audio.whisper_tokenizer import get_tokenizer
|
||||
|
||||
tok = get_tokenizer(multilingual=True, language="de", task="translate")
|
||||
|
||||
assert tok.language == "de"
|
||||
assert tok.task == "translate"
|
||||
|
||||
def test_get_tokenizer_english_only(self):
|
||||
"""English-only (non-multilingual) tokenizer."""
|
||||
from mlxk2.audio.whisper_tokenizer import get_tokenizer
|
||||
|
||||
tok = get_tokenizer(multilingual=False)
|
||||
|
||||
assert tok.language is None # English-only has no language
|
||||
assert tok.task is None # English-only has no task
|
||||
assert tok.encoding.name == "gpt2.tiktoken"
|
||||
|
||||
def test_get_tokenizer_invalid_language_raises(self):
|
||||
"""Invalid language code should raise ValueError."""
|
||||
from mlxk2.audio.whisper_tokenizer import get_tokenizer
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
get_tokenizer(multilingual=True, language="xyz")
|
||||
|
||||
assert "Unsupported language: xyz" in str(exc_info.value)
|
||||
|
||||
def test_get_tokenizer_language_alias(self):
|
||||
"""Language aliases should be resolved (e.g., 'german' -> 'de')."""
|
||||
from mlxk2.audio.whisper_tokenizer import get_tokenizer
|
||||
|
||||
tok = get_tokenizer(multilingual=True, language="german")
|
||||
|
||||
assert tok.language == "de"
|
||||
|
||||
def test_get_tokenizer_language_case_insensitive(self):
|
||||
"""Language codes should be case-insensitive."""
|
||||
from mlxk2.audio.whisper_tokenizer import get_tokenizer
|
||||
|
||||
tok = get_tokenizer(multilingual=True, language="DE")
|
||||
|
||||
assert tok.language == "de"
|
||||
|
||||
def test_get_tokenizer_is_cached(self):
|
||||
"""get_tokenizer() should be cached."""
|
||||
from mlxk2.audio.whisper_tokenizer import get_tokenizer
|
||||
|
||||
tok1 = get_tokenizer(multilingual=True, language="fr")
|
||||
tok2 = get_tokenizer(multilingual=True, language="fr")
|
||||
|
||||
assert tok1 is tok2
|
||||
|
||||
def test_get_tokenizer_various_languages(self):
|
||||
"""Test tokenizer with various supported languages."""
|
||||
from mlxk2.audio.whisper_tokenizer import get_tokenizer
|
||||
|
||||
# Sample of supported languages
|
||||
languages = ["en", "de", "fr", "es", "ja", "zh", "ru", "ar", "ko", "pt"]
|
||||
|
||||
for lang in languages:
|
||||
tok = get_tokenizer(multilingual=True, language=lang)
|
||||
assert tok.language == lang, f"Language {lang} not set correctly"
|
||||
|
||||
|
||||
class TestTokenizerClass:
|
||||
"""Tests for Tokenizer class methods and properties."""
|
||||
|
||||
@pytest.fixture
|
||||
def multilingual_tokenizer(self):
|
||||
"""Create a multilingual tokenizer for testing."""
|
||||
from mlxk2.audio.whisper_tokenizer import get_tokenizer
|
||||
|
||||
return get_tokenizer(multilingual=True, language="en", task="transcribe")
|
||||
|
||||
@pytest.fixture
|
||||
def german_tokenizer(self):
|
||||
"""Create a German tokenizer for testing."""
|
||||
from mlxk2.audio.whisper_tokenizer import get_tokenizer
|
||||
|
||||
return get_tokenizer(multilingual=True, language="de", task="transcribe")
|
||||
|
||||
def test_special_tokens_populated(self, multilingual_tokenizer):
|
||||
"""Special tokens dict should be populated after init."""
|
||||
tok = multilingual_tokenizer
|
||||
|
||||
assert len(tok.special_tokens) > 0
|
||||
assert "<|startoftranscript|>" in tok.special_tokens
|
||||
assert "<|transcribe|>" in tok.special_tokens
|
||||
assert "<|translate|>" in tok.special_tokens
|
||||
assert "<|endoftext|>" in tok.special_tokens
|
||||
|
||||
def test_sot_property(self, multilingual_tokenizer):
|
||||
"""sot property should return start of transcript token."""
|
||||
tok = multilingual_tokenizer
|
||||
|
||||
assert tok.sot == tok.special_tokens["<|startoftranscript|>"]
|
||||
assert isinstance(tok.sot, int)
|
||||
|
||||
def test_eot_property(self, multilingual_tokenizer):
|
||||
"""eot property should return end of text token."""
|
||||
tok = multilingual_tokenizer
|
||||
|
||||
assert tok.eot == tok.encoding.eot_token
|
||||
assert isinstance(tok.eot, int)
|
||||
|
||||
def test_transcribe_property(self, multilingual_tokenizer):
|
||||
"""transcribe property should return transcribe token."""
|
||||
tok = multilingual_tokenizer
|
||||
|
||||
assert tok.transcribe == tok.special_tokens["<|transcribe|>"]
|
||||
assert isinstance(tok.transcribe, int)
|
||||
|
||||
def test_translate_property(self, multilingual_tokenizer):
|
||||
"""translate property should return translate token."""
|
||||
tok = multilingual_tokenizer
|
||||
|
||||
assert tok.translate == tok.special_tokens["<|translate|>"]
|
||||
assert isinstance(tok.translate, int)
|
||||
|
||||
def test_no_timestamps_property(self, multilingual_tokenizer):
|
||||
"""no_timestamps property should return notimestamps token."""
|
||||
tok = multilingual_tokenizer
|
||||
|
||||
assert tok.no_timestamps == tok.special_tokens["<|notimestamps|>"]
|
||||
|
||||
def test_timestamp_begin_property(self, multilingual_tokenizer):
|
||||
"""timestamp_begin property should return first timestamp token."""
|
||||
tok = multilingual_tokenizer
|
||||
|
||||
assert tok.timestamp_begin == tok.special_tokens["<|0.00|>"]
|
||||
|
||||
def test_no_speech_property(self, multilingual_tokenizer):
|
||||
"""no_speech property should return nospeech token."""
|
||||
tok = multilingual_tokenizer
|
||||
|
||||
assert tok.no_speech == tok.special_tokens["<|nospeech|>"]
|
||||
|
||||
def test_sot_sequence_multilingual(self, multilingual_tokenizer):
|
||||
"""sot_sequence should contain sot + language + task tokens."""
|
||||
tok = multilingual_tokenizer
|
||||
|
||||
# Should have: sot, language token, task token
|
||||
assert len(tok.sot_sequence) == 3
|
||||
assert tok.sot_sequence[0] == tok.sot
|
||||
# Last token should be transcribe (for transcribe task)
|
||||
assert tok.sot_sequence[2] == tok.transcribe
|
||||
|
||||
def test_sot_sequence_including_notimestamps(self, multilingual_tokenizer):
|
||||
"""sot_sequence_including_notimestamps should append notimestamps."""
|
||||
tok = multilingual_tokenizer
|
||||
|
||||
seq = tok.sot_sequence_including_notimestamps
|
||||
assert seq[-1] == tok.no_timestamps
|
||||
assert len(seq) == len(tok.sot_sequence) + 1
|
||||
|
||||
def test_encode_decode_roundtrip(self, multilingual_tokenizer):
|
||||
"""encode() and decode() should roundtrip text correctly."""
|
||||
tok = multilingual_tokenizer
|
||||
|
||||
original = "Hello, this is a test."
|
||||
tokens = tok.encode(original)
|
||||
decoded = tok.decode(tokens)
|
||||
|
||||
assert decoded == original
|
||||
|
||||
def test_decode_filters_timestamp_tokens(self, multilingual_tokenizer):
|
||||
"""decode() should filter out timestamp tokens."""
|
||||
tok = multilingual_tokenizer
|
||||
|
||||
# Encode some text and add a timestamp token
|
||||
tokens = tok.encode("Hello")
|
||||
# Add a timestamp token (should be filtered)
|
||||
tokens_with_timestamp = tokens + [tok.timestamp_begin]
|
||||
|
||||
# decode() filters tokens >= timestamp_begin
|
||||
decoded = tok.decode(tokens_with_timestamp)
|
||||
assert decoded == "Hello"
|
||||
|
||||
def test_decode_with_timestamps_preserves_all(self, multilingual_tokenizer):
|
||||
"""decode_with_timestamps() should preserve timestamp tokens."""
|
||||
tok = multilingual_tokenizer
|
||||
|
||||
# Encode text that includes timestamp-like content
|
||||
tokens = tok.encode("Hello")
|
||||
decoded = tok.decode_with_timestamps(tokens)
|
||||
assert decoded == "Hello"
|
||||
|
||||
def test_language_token_property(self, german_tokenizer):
|
||||
"""language_token property should return correct language token."""
|
||||
tok = german_tokenizer
|
||||
|
||||
lang_token = tok.language_token
|
||||
assert lang_token == tok.special_tokens["<|de|>"]
|
||||
|
||||
def test_language_token_raises_when_none(self):
|
||||
"""language_token should raise ValueError when language is None."""
|
||||
from mlxk2.audio.whisper_tokenizer import get_tokenizer
|
||||
|
||||
tok = get_tokenizer(multilingual=False) # English-only has no language
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
_ = tok.language_token
|
||||
|
||||
assert "language token configured" in str(exc_info.value)
|
||||
|
||||
def test_to_language_token(self, multilingual_tokenizer):
|
||||
"""to_language_token() should return token for given language."""
|
||||
tok = multilingual_tokenizer
|
||||
|
||||
de_token = tok.to_language_token("de")
|
||||
assert de_token == tok.special_tokens["<|de|>"]
|
||||
|
||||
def test_to_language_token_invalid_raises(self, multilingual_tokenizer):
|
||||
"""to_language_token() should raise KeyError for invalid language."""
|
||||
tok = multilingual_tokenizer
|
||||
|
||||
with pytest.raises(KeyError) as exc_info:
|
||||
tok.to_language_token("xyz")
|
||||
|
||||
assert "Language xyz not found" in str(exc_info.value)
|
||||
|
||||
def test_all_language_tokens(self, multilingual_tokenizer):
|
||||
"""all_language_tokens should return tuple of language token IDs."""
|
||||
tok = multilingual_tokenizer
|
||||
|
||||
lang_tokens = tok.all_language_tokens
|
||||
|
||||
assert isinstance(lang_tokens, tuple)
|
||||
assert len(lang_tokens) > 0
|
||||
assert all(isinstance(t, int) for t in lang_tokens)
|
||||
|
||||
def test_all_language_codes(self, multilingual_tokenizer):
|
||||
"""all_language_codes should return tuple of language codes."""
|
||||
tok = multilingual_tokenizer
|
||||
|
||||
lang_codes = tok.all_language_codes
|
||||
|
||||
assert isinstance(lang_codes, tuple)
|
||||
assert len(lang_codes) > 0
|
||||
assert "en" in lang_codes
|
||||
assert "de" in lang_codes
|
||||
|
||||
def test_non_speech_tokens(self, multilingual_tokenizer):
|
||||
"""non_speech_tokens should return tokens to suppress."""
|
||||
tok = multilingual_tokenizer
|
||||
|
||||
non_speech = tok.non_speech_tokens
|
||||
|
||||
assert isinstance(non_speech, tuple)
|
||||
assert len(non_speech) > 0
|
||||
assert all(isinstance(t, int) for t in non_speech)
|
||||
|
||||
|
||||
class TestTokenizerWordSplitting:
|
||||
"""Tests for word splitting methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def english_tokenizer(self):
|
||||
"""English tokenizer for word splitting tests."""
|
||||
from mlxk2.audio.whisper_tokenizer import get_tokenizer
|
||||
|
||||
return get_tokenizer(multilingual=True, language="en")
|
||||
|
||||
@pytest.fixture
|
||||
def chinese_tokenizer(self):
|
||||
"""Chinese tokenizer for unicode splitting tests."""
|
||||
from mlxk2.audio.whisper_tokenizer import get_tokenizer
|
||||
|
||||
return get_tokenizer(multilingual=True, language="zh")
|
||||
|
||||
def test_split_to_word_tokens_english(self, english_tokenizer):
|
||||
"""split_to_word_tokens should split on spaces for English."""
|
||||
tok = english_tokenizer
|
||||
|
||||
tokens = tok.encode("Hello world")
|
||||
words, word_tokens = tok.split_to_word_tokens(tokens)
|
||||
|
||||
assert len(words) >= 1
|
||||
assert len(word_tokens) == len(words)
|
||||
# Each word should have associated tokens
|
||||
for word, wtokens in zip(words, word_tokens):
|
||||
assert len(wtokens) > 0
|
||||
|
||||
def test_split_to_word_tokens_chinese(self, chinese_tokenizer):
|
||||
"""split_to_word_tokens should use unicode splitting for Chinese."""
|
||||
tok = chinese_tokenizer
|
||||
|
||||
tokens = tok.encode("Hello") # Just test it doesn't crash
|
||||
words, word_tokens = tok.split_to_word_tokens(tokens)
|
||||
|
||||
assert len(words) >= 1
|
||||
assert len(word_tokens) == len(words)
|
||||
|
||||
def test_split_tokens_on_unicode(self, english_tokenizer):
|
||||
"""split_tokens_on_unicode should handle unicode characters."""
|
||||
tok = english_tokenizer
|
||||
|
||||
tokens = tok.encode("Caf\u00e9")
|
||||
words, word_tokens = tok.split_tokens_on_unicode(tokens)
|
||||
|
||||
assert len(words) >= 1
|
||||
# Reconstructed should match
|
||||
reconstructed = "".join(words)
|
||||
assert "Caf" in reconstructed
|
||||
|
||||
def test_split_tokens_on_spaces(self, english_tokenizer):
|
||||
"""split_tokens_on_spaces should split on whitespace."""
|
||||
tok = english_tokenizer
|
||||
|
||||
tokens = tok.encode("Hello world test")
|
||||
words, word_tokens = tok.split_tokens_on_spaces(tokens)
|
||||
|
||||
assert len(words) >= 1
|
||||
assert len(word_tokens) == len(words)
|
||||
|
||||
|
||||
class TestLanguageConstants:
|
||||
"""Tests for LANGUAGES and TO_LANGUAGE_CODE constants."""
|
||||
|
||||
def test_languages_dict_exists(self):
|
||||
"""LANGUAGES dict should be importable."""
|
||||
from mlxk2.audio.whisper_tokenizer import LANGUAGES
|
||||
|
||||
assert isinstance(LANGUAGES, dict)
|
||||
assert len(LANGUAGES) > 90 # Whisper supports ~99 languages
|
||||
|
||||
def test_languages_contains_common(self):
|
||||
"""LANGUAGES should contain common language codes."""
|
||||
from mlxk2.audio.whisper_tokenizer import LANGUAGES
|
||||
|
||||
assert "en" in LANGUAGES
|
||||
assert LANGUAGES["en"] == "english"
|
||||
assert "de" in LANGUAGES
|
||||
assert LANGUAGES["de"] == "german"
|
||||
assert "fr" in LANGUAGES
|
||||
assert LANGUAGES["fr"] == "french"
|
||||
assert "ja" in LANGUAGES
|
||||
assert LANGUAGES["ja"] == "japanese"
|
||||
assert "zh" in LANGUAGES
|
||||
assert LANGUAGES["zh"] == "chinese"
|
||||
|
||||
def test_to_language_code_dict_exists(self):
|
||||
"""TO_LANGUAGE_CODE dict should be importable."""
|
||||
from mlxk2.audio.whisper_tokenizer import TO_LANGUAGE_CODE
|
||||
|
||||
assert isinstance(TO_LANGUAGE_CODE, dict)
|
||||
|
||||
def test_to_language_code_aliases(self):
|
||||
"""TO_LANGUAGE_CODE should contain language name aliases."""
|
||||
from mlxk2.audio.whisper_tokenizer import TO_LANGUAGE_CODE
|
||||
|
||||
assert TO_LANGUAGE_CODE["english"] == "en"
|
||||
assert TO_LANGUAGE_CODE["german"] == "de"
|
||||
assert TO_LANGUAGE_CODE["french"] == "fr"
|
||||
# Check some special aliases
|
||||
assert TO_LANGUAGE_CODE.get("mandarin") == "zh"
|
||||
assert TO_LANGUAGE_CODE.get("castilian") == "es"
|
||||
|
||||
|
||||
class TestAssetsPaths:
|
||||
"""Tests for bundled tiktoken assets."""
|
||||
|
||||
def test_assets_directory_exists(self):
|
||||
"""Assets directory should exist."""
|
||||
from mlxk2.audio.whisper_tokenizer import _ASSETS_DIR
|
||||
|
||||
assert _ASSETS_DIR.exists(), f"Assets dir not found: {_ASSETS_DIR}"
|
||||
assert _ASSETS_DIR.is_dir()
|
||||
|
||||
def test_gpt2_tiktoken_exists(self):
|
||||
"""gpt2.tiktoken asset should exist."""
|
||||
from mlxk2.audio.whisper_tokenizer import _ASSETS_DIR
|
||||
|
||||
gpt2_path = _ASSETS_DIR / "gpt2.tiktoken"
|
||||
assert gpt2_path.exists(), f"gpt2.tiktoken not found: {gpt2_path}"
|
||||
assert gpt2_path.stat().st_size > 100000 # Should be ~800KB
|
||||
|
||||
def test_multilingual_tiktoken_exists(self):
|
||||
"""multilingual.tiktoken asset should exist."""
|
||||
from mlxk2.audio.whisper_tokenizer import _ASSETS_DIR
|
||||
|
||||
multilingual_path = _ASSETS_DIR / "multilingual.tiktoken"
|
||||
assert multilingual_path.exists(), f"multilingual.tiktoken not found: {multilingual_path}"
|
||||
assert multilingual_path.stat().st_size > 100000 # Should be ~800KB
|
||||
|
||||
|
||||
@requires_mlx_audio
|
||||
class TestPatchIntegration:
|
||||
"""Tests for mlx-audio patch integration."""
|
||||
|
||||
def test_patch_applied_to_mlx_audio(self):
|
||||
"""Verify patch is applied when audio_runner is imported."""
|
||||
# Import audio_runner which applies the patch
|
||||
from mlxk2.core.audio_runner import AudioRunner # noqa: F401
|
||||
from mlxk2.audio.whisper_tokenizer import get_encoding
|
||||
|
||||
# Import the patched module
|
||||
import mlx_audio.stt.models.whisper.tokenizer as mlx_tok
|
||||
|
||||
# Our get_encoding should be installed
|
||||
assert mlx_tok.get_encoding is get_encoding
|
||||
|
||||
def test_patched_get_encoding_works(self):
|
||||
"""Verify patched get_encoding produces valid encodings."""
|
||||
# Import to apply patch
|
||||
from mlxk2.core.audio_runner import AudioRunner # noqa: F401
|
||||
|
||||
# Use the patched version
|
||||
import mlx_audio.stt.models.whisper.tokenizer as mlx_tok
|
||||
|
||||
enc = mlx_tok.get_encoding("gpt2")
|
||||
assert enc.name == "gpt2.tiktoken"
|
||||
|
||||
# Verify encode/decode works
|
||||
tokens = enc.encode("Test patch")
|
||||
assert len(tokens) > 0
|
||||
decoded = enc.decode(tokens)
|
||||
assert decoded == "Test patch"
|
||||
Reference in New Issue
Block a user