Files
mlx-knife/tests_2.0/test_capabilities.py
T
The BROKE Cluster Team 7f10187bee fix: Runtime gates + unit tests; benchmark GPU analysis
Core:
- Run preflight passes probe/framework to audio_runtime_compatibility
- STT model_type gate extended (vibevoice, audio)
- MLX 0.30.x compat: catch Exception in whisper_tokenizer
- Embedding-gate unit tests (3 tests)
- Removed get_encoding duplication (-45 LOC)

Benchmark:
- GPU analysis section in reports
2026-02-07 23:38:34 +01:00

641 lines
24 KiB
Python

"""Tests for probe/policy architecture (capabilities.py).
Tests the unified capability detection and backend selection logic.
"""
import json
import pytest
import sys
from pathlib import Path
from unittest.mock import patch, MagicMock
from mlxk2.core.capabilities import (
Backend,
PolicyDecision,
MEMORY_THRESHOLD_PERCENT,
VISION_MODEL_TYPES,
AUDIO_MODEL_TYPES,
ModelCapabilities,
BackendPolicy,
probe_model_capabilities,
select_backend_policy,
probe_and_select,
_get_system_memory_bytes,
_get_model_size_bytes,
_format_bytes_gb,
_detect_vision_from_config,
_detect_vision_from_files,
)
from mlxk2.operations.common import detect_audio_capability
class TestVisionModelTypes:
"""Tests for VISION_MODEL_TYPES constant."""
def test_contains_known_types(self):
"""Should contain all known vision model types."""
expected = {"llava", "llava_next", "pixtral", "qwen2_vl", "phi3_v", "mllama", "paligemma", "idefics", "smolvlm"}
assert expected.issubset(VISION_MODEL_TYPES)
def test_is_frozenset(self):
"""Should be immutable."""
assert isinstance(VISION_MODEL_TYPES, frozenset)
class TestAudioModelTypes:
"""Tests for AUDIO_MODEL_TYPES constant (ADR-019)."""
def test_contains_gemma3n(self):
"""Should contain Gemma-3n model types."""
assert "gemma3n" in AUDIO_MODEL_TYPES
assert "gemma3n_audio" in AUDIO_MODEL_TYPES
def test_is_frozenset(self):
"""Should be immutable."""
assert isinstance(AUDIO_MODEL_TYPES, frozenset)
class TestDetectAudioCapability:
"""Tests for detect_audio_capability (ADR-019)."""
def test_detects_audio_config_in_config_json(self, tmp_path):
"""Should detect audio capability via audio_config key."""
config = {"model_type": "gemma3n", "audio_config": {"some": "config"}}
(tmp_path / "config.json").write_text(json.dumps(config))
assert detect_audio_capability(tmp_path, config) is True
def test_detects_gemma3n_model_type(self, tmp_path):
"""Should detect audio capability via model_type."""
config = {"model_type": "gemma3n"}
(tmp_path / "config.json").write_text(json.dumps(config))
assert detect_audio_capability(tmp_path, config) is True
def test_detects_audio_seq_length_in_processor_config(self, tmp_path):
"""Should detect audio capability via processor_config.json."""
config = {"model_type": "llama"} # Not an audio type
(tmp_path / "config.json").write_text(json.dumps(config))
processor_config = {"audio_seq_length": 188, "processor_class": "Gemma3nProcessor"}
(tmp_path / "processor_config.json").write_text(json.dumps(processor_config))
assert detect_audio_capability(tmp_path, config) is True
def test_non_audio_model(self, tmp_path):
"""Should return False for non-audio models."""
config = {"model_type": "llama"}
(tmp_path / "config.json").write_text(json.dumps(config))
assert detect_audio_capability(tmp_path, config) is False
def test_none_config(self, tmp_path):
"""Should return False for None config (checks processor_config.json)."""
assert detect_audio_capability(tmp_path, None) is False
def test_empty_config(self, tmp_path):
"""Should return False for empty config."""
(tmp_path / "config.json").write_text("{}")
assert detect_audio_capability(tmp_path, {}) is False
class TestHelperFunctions:
"""Tests for helper functions."""
def test_format_bytes_gb(self):
"""Should format bytes as GB correctly."""
assert _format_bytes_gb(64 * 1024**3) == "64.0 GB"
assert _format_bytes_gb(1 * 1024**3) == "1.0 GB"
assert _format_bytes_gb(int(24.5 * 1024**3)) == "24.5 GB"
def test_get_system_memory_bytes_returns_integer(self):
"""On macOS, should return positive integer."""
result = _get_system_memory_bytes()
if result is not None:
assert isinstance(result, int)
assert result > 0
def test_get_system_memory_bytes_handles_error(self):
"""Should return None on error."""
with patch("mlxk2.core.capabilities.subprocess.run") as mock_run:
mock_run.side_effect = FileNotFoundError()
result = _get_system_memory_bytes()
assert result is None
def test_get_model_size_bytes(self, tmp_path):
"""Should calculate total model size."""
# Create some files
(tmp_path / "model.safetensors").write_bytes(b"x" * 1000)
(tmp_path / "config.json").write_text("{}")
result = _get_model_size_bytes(tmp_path)
assert result >= 1000 # At least the model file
def test_get_model_size_bytes_empty_dir(self, tmp_path):
"""Should return 0 for empty directory."""
result = _get_model_size_bytes(tmp_path)
assert result == 0
class TestDetectVisionFromConfig:
"""Tests for _detect_vision_from_config."""
def test_detects_mllama_type(self):
"""Should detect mllama as vision model."""
config = {"model_type": "mllama"}
assert _detect_vision_from_config(config) is True
def test_detects_llava_type(self):
"""Should detect llava as vision model."""
config = {"model_type": "llava"}
assert _detect_vision_from_config(config) is True
def test_detects_pixtral_type(self):
"""Should detect pixtral as vision model."""
config = {"model_type": "pixtral"}
assert _detect_vision_from_config(config) is True
def test_detects_image_processor(self):
"""Should detect vision via image_processor field."""
config = {"model_type": "llama", "image_processor": "CLIPImageProcessor"}
assert _detect_vision_from_config(config) is True
def test_detects_preprocessor_config(self):
"""Should detect vision via preprocessor_config field."""
config = {"model_type": "llama", "preprocessor_config": {"image_size": 224}}
assert _detect_vision_from_config(config) is True
def test_non_vision_model(self):
"""Should not detect plain llama as vision."""
config = {"model_type": "llama"}
assert _detect_vision_from_config(config) is False
def test_none_config(self):
"""Should return False for None config."""
assert _detect_vision_from_config(None) is False
def test_empty_config(self):
"""Should return False for empty config."""
assert _detect_vision_from_config({}) is False
class TestDetectVisionFromFiles:
"""Tests for _detect_vision_from_files."""
def test_detects_preprocessor_config_json(self, tmp_path):
"""Should detect vision via preprocessor_config.json presence."""
(tmp_path / "preprocessor_config.json").write_text("{}")
assert _detect_vision_from_files(tmp_path) is True
def test_detects_processor_config_json(self, tmp_path):
"""Should detect vision via processor_config.json presence."""
(tmp_path / "processor_config.json").write_text("{}")
assert _detect_vision_from_files(tmp_path) is True
def test_detects_image_processor_config_json(self, tmp_path):
"""Should detect vision via image_processor_config.json presence."""
(tmp_path / "image_processor_config.json").write_text("{}")
assert _detect_vision_from_files(tmp_path) is True
def test_no_vision_files(self, tmp_path):
"""Should return False when no vision files present."""
(tmp_path / "config.json").write_text("{}")
assert _detect_vision_from_files(tmp_path) is False
class TestProbeModelCapabilities:
"""Tests for probe_model_capabilities."""
def test_probes_text_model(self, tmp_path):
"""Should correctly probe a text-only model."""
config = {"model_type": "llama"}
(tmp_path / "config.json").write_text(json.dumps(config))
(tmp_path / "model.safetensors").write_bytes(b"x" * 1000)
caps = probe_model_capabilities(tmp_path, "test/model")
assert caps.model_path == tmp_path
assert caps.model_name == "test/model"
assert caps.is_vision is False
assert caps.config_valid is True
assert caps.model_type == "llama"
assert "text-generation" in caps.capabilities_list
assert "vision" not in caps.capabilities_list
def test_probes_vision_model_by_config(self, tmp_path):
"""Should detect vision model from config.json."""
config = {"model_type": "mllama"}
(tmp_path / "config.json").write_text(json.dumps(config))
caps = probe_model_capabilities(tmp_path, "test/vision-model")
assert caps.is_vision is True
assert "vision" in caps.capabilities_list
def test_probes_vision_model_by_files(self, tmp_path):
"""Should detect vision model from file presence."""
config = {"model_type": "llama"} # Not a vision type
(tmp_path / "config.json").write_text(json.dumps(config))
(tmp_path / "preprocessor_config.json").write_text("{}")
caps = probe_model_capabilities(tmp_path, "test/vision-model")
assert caps.is_vision is True
assert "vision" in caps.capabilities_list
def test_probes_chat_model_by_type(self, tmp_path):
"""Should detect chat capability from model_type."""
config = {"model_type": "chat"}
(tmp_path / "config.json").write_text(json.dumps(config))
caps = probe_model_capabilities(tmp_path, "test/chat-model")
assert caps.is_chat is True
assert "chat" in caps.capabilities_list
def test_probes_chat_model_by_template(self, tmp_path):
"""Should detect chat capability from chat_template."""
config = {"model_type": "llama"}
(tmp_path / "config.json").write_text(json.dumps(config))
tokenizer = {"chat_template": "{% for message in messages %}...{% endfor %}"}
(tmp_path / "tokenizer_config.json").write_text(json.dumps(tokenizer))
caps = probe_model_capabilities(tmp_path, "test/chat-model")
assert caps.is_chat is True
assert "chat" in caps.capabilities_list
def test_probes_chat_model_by_name(self, tmp_path):
"""Should detect chat capability from model name."""
config = {"model_type": "llama"}
(tmp_path / "config.json").write_text(json.dumps(config))
caps = probe_model_capabilities(tmp_path, "test/Llama-3-Instruct")
assert caps.is_chat is True
assert "chat" in caps.capabilities_list
def test_probes_embedding_model(self, tmp_path):
"""Should detect embedding model from name."""
config = {"model_type": "bert"}
(tmp_path / "config.json").write_text(json.dumps(config))
caps = probe_model_capabilities(tmp_path, "test/text-embedding-model")
assert caps.is_embedding is True
assert "embeddings" in caps.capabilities_list
assert "text-generation" not in caps.capabilities_list
def test_probes_memory_information(self, tmp_path):
"""Should probe memory information."""
config = {"model_type": "llama"}
(tmp_path / "config.json").write_text(json.dumps(config))
(tmp_path / "model.safetensors").write_bytes(b"x" * 10000)
with patch("mlxk2.core.capabilities._get_system_memory_bytes") as mock_mem:
mock_mem.return_value = 64 * 1024**3 # 64GB
caps = probe_model_capabilities(tmp_path, "test/model")
assert caps.system_memory_bytes == 64 * 1024**3
assert caps.model_size_bytes >= 10000
assert caps.memory_ratio > 0
def test_probes_python_version(self, tmp_path):
"""Should record Python version."""
config = {"model_type": "llama"}
(tmp_path / "config.json").write_text(json.dumps(config))
caps = probe_model_capabilities(tmp_path, "test/model")
assert caps.python_version == sys.version_info[:3]
def test_handles_missing_config(self, tmp_path):
"""Should handle missing config.json gracefully."""
caps = probe_model_capabilities(tmp_path, "test/model")
assert caps.config_valid is False
assert caps.config is None
def test_handles_invalid_config(self, tmp_path):
"""Should handle invalid JSON in config.json."""
(tmp_path / "config.json").write_text("not valid json")
caps = probe_model_capabilities(tmp_path, "test/model")
assert caps.config_valid is False
def test_accepts_preloaded_config(self, tmp_path):
"""Should use provided config instead of loading from file."""
config = {"model_type": "pixtral"}
caps = probe_model_capabilities(tmp_path, "test/model", config=config)
assert caps.config_valid is True
assert caps.model_type == "pixtral"
assert caps.is_vision is True
class TestSelectBackendPolicy:
"""Tests for select_backend_policy."""
def test_text_model_cli_allows(self, tmp_path):
"""Text model in CLI context should be allowed with MLX_LM backend."""
caps = ModelCapabilities(
model_path=tmp_path,
model_name="test/model",
is_vision=False,
mlx_lm_available=True,
)
policy = select_backend_policy(caps, context="cli")
assert policy.backend == Backend.MLX_LM
assert policy.decision == PolicyDecision.ALLOW
assert policy.message is None
def test_vision_model_cli_allows(self, tmp_path):
"""Vision model in CLI context should be allowed with MLX_VLM backend."""
caps = ModelCapabilities(
model_path=tmp_path,
model_name="test/vision-model",
is_vision=True,
python_version=(3, 10, 0),
mlx_vlm_available=True,
memory_ratio=0.5, # Under threshold
)
policy = select_backend_policy(caps, context="cli")
assert policy.backend == Backend.MLX_VLM
assert policy.decision == PolicyDecision.ALLOW
def test_vision_model_server_allows(self, tmp_path):
"""Vision model in server context should be allowed (ADR-012 Phase 3)."""
caps = ModelCapabilities(
model_path=tmp_path,
model_name="test/vision-model",
is_vision=True,
python_version=(3, 10, 0),
mlx_vlm_available=True,
)
policy = select_backend_policy(caps, context="server")
# Vision is now supported in server (2.0.4-beta.1)
assert policy.backend == Backend.MLX_VLM
assert policy.decision == PolicyDecision.ALLOW
def test_vision_model_python_39_blocks(self, tmp_path):
"""Vision model with Python 3.9 should be blocked."""
caps = ModelCapabilities(
model_path=tmp_path,
model_name="test/vision-model",
is_vision=True,
python_version=(3, 9, 0),
mlx_vlm_available=True,
)
policy = select_backend_policy(caps, context="cli")
assert policy.backend == Backend.UNSUPPORTED
assert policy.decision == PolicyDecision.BLOCK
assert policy.http_status == 501
assert "3.10" in policy.message
assert policy.error_type == "python_version_error"
def test_vision_model_no_mlx_vlm_blocks(self, tmp_path):
"""Vision model without mlx-vlm should be blocked."""
caps = ModelCapabilities(
model_path=tmp_path,
model_name="test/vision-model",
is_vision=True,
python_version=(3, 10, 0),
mlx_vlm_available=False,
)
policy = select_backend_policy(caps, context="cli")
assert policy.backend == Backend.UNSUPPORTED
assert policy.decision == PolicyDecision.BLOCK
assert "mlx-vlm" in policy.message
assert policy.error_type == "missing_dependency"
def test_vision_model_memory_over_threshold_blocks(self, tmp_path):
"""Vision model over 70% memory should be blocked."""
caps = ModelCapabilities(
model_path=tmp_path,
model_name="test/vision-model",
is_vision=True,
python_version=(3, 10, 0),
mlx_vlm_available=True,
system_memory_bytes=64 * 1024**3,
model_size_bytes=50 * 1024**3, # 78% of 64GB
memory_ratio=0.78,
)
policy = select_backend_policy(caps, context="cli")
assert policy.backend == Backend.UNSUPPORTED
assert policy.decision == PolicyDecision.BLOCK
assert policy.http_status == 507
assert "70%" in policy.message
assert "Metal OOM" in policy.message
assert policy.error_type == "insufficient_memory"
def test_text_model_memory_over_threshold_warns(self, tmp_path):
"""Text model over 70% memory should warn but allow."""
caps = ModelCapabilities(
model_path=tmp_path,
model_name="test/text-model",
is_vision=False,
mlx_lm_available=True,
system_memory_bytes=64 * 1024**3,
model_size_bytes=50 * 1024**3,
memory_ratio=0.78,
)
policy = select_backend_policy(caps, context="cli")
assert policy.backend == Backend.MLX_LM
assert policy.decision == PolicyDecision.WARN
assert "70%" in policy.message
assert "swapping" in policy.message
def test_images_on_non_vision_model_blocks(self, tmp_path):
"""Passing images to non-vision model should be blocked."""
caps = ModelCapabilities(
model_path=tmp_path,
model_name="test/text-model",
is_vision=False,
python_version=(3, 10, 0),
mlx_vlm_available=True,
)
policy = select_backend_policy(caps, context="cli", has_images=True)
assert policy.backend == Backend.UNSUPPORTED
assert policy.decision == PolicyDecision.BLOCK
assert policy.http_status == 400
assert "vision" in policy.message.lower()
assert policy.error_type == "capability_mismatch"
def test_text_model_no_mlx_lm_blocks(self, tmp_path):
"""Text model without mlx-lm should be blocked."""
caps = ModelCapabilities(
model_path=tmp_path,
model_name="test/text-model",
is_vision=False,
mlx_lm_available=False,
)
policy = select_backend_policy(caps, context="cli")
assert policy.backend == Backend.UNSUPPORTED
assert policy.decision == PolicyDecision.BLOCK
assert "mlx-lm" in policy.message
assert policy.error_type == "missing_dependency"
class TestProbeAndSelect:
"""Tests for probe_and_select convenience function."""
def test_combines_probe_and_select(self, tmp_path):
"""Should run both probe and select in one call."""
config = {"model_type": "llama"}
(tmp_path / "config.json").write_text(json.dumps(config))
# Create valid MLX weight file for runtime compatibility
(tmp_path / "model.safetensors").write_bytes(b"weights")
with patch("mlxk2.core.capabilities._check_mlx_lm_available") as mock_lm, \
patch("mlxk2.core.capabilities._check_text_runtime_compatibility") as mock_runtime:
mock_lm.return_value = True
mock_runtime.return_value = (True, None) # Compatible
caps, policy = probe_and_select(tmp_path, "mlx-community/test-model", context="cli")
assert caps.model_type == "llama"
assert caps.is_vision is False
assert policy.backend == Backend.MLX_LM
assert policy.decision == PolicyDecision.ALLOW
def test_context_parameter_forwarded_to_policy(self, tmp_path):
"""Context parameter should be forwarded to select_backend_policy (text model)."""
config = {"model_type": "llama"}
(tmp_path / "config.json").write_text(json.dumps(config))
with patch("mlxk2.core.capabilities._check_mlx_lm_available") as mock_lm, \
patch("mlxk2.core.capabilities._check_text_runtime_compatibility") as mock_runtime:
mock_lm.return_value = True
mock_runtime.return_value = (True, None)
_, policy = probe_and_select(tmp_path, "test/model", context="server")
# Text model should work in server context
assert policy.decision == PolicyDecision.ALLOW
assert policy.backend == Backend.MLX_LM
@pytest.mark.skipif(sys.version_info < (3, 10), reason="Vision requires Python 3.10+")
def test_vision_server_context_allowed_on_python_310_plus(self, tmp_path):
"""Vision models should be allowed in server context on Python 3.10+ (ADR-012 Phase 3)."""
config = {"model_type": "mllama"}
(tmp_path / "config.json").write_text(json.dumps(config))
with patch("mlxk2.core.capabilities._check_mlx_vlm_available") as mock_vlm:
mock_vlm.return_value = True
_, policy = probe_and_select(tmp_path, "test/vision", context="server")
# Server allows vision on Python 3.10+ (ADR-012 Phase 3)
assert policy.decision == PolicyDecision.ALLOW
assert policy.backend == Backend.MLX_VLM
def test_respects_has_images_parameter(self, tmp_path):
"""Should pass has_images through to select_backend_policy."""
config = {"model_type": "llama"} # Not vision
(tmp_path / "config.json").write_text(json.dumps(config))
with patch("mlxk2.core.capabilities._check_mlx_vlm_available") as mock_vlm:
mock_vlm.return_value = True
_, policy = probe_and_select(tmp_path, "test/model", has_images=True)
# Images on non-vision should block
assert policy.decision == PolicyDecision.BLOCK
assert policy.error_type == "capability_mismatch"
class TestMemoryThreshold:
"""Tests for MEMORY_THRESHOLD_PERCENT constant."""
def test_threshold_is_70_percent(self):
"""Threshold should be 70% as per ADR-016."""
assert MEMORY_THRESHOLD_PERCENT == 0.70
def test_threshold_calculation(self):
"""Verify threshold calculation for 64GB system."""
system_memory = 64 * 1024**3
threshold = int(system_memory * MEMORY_THRESHOLD_PERCENT)
expected = int(44.8 * 1024**3) # 70% of 64GB
# Allow small rounding difference
assert abs(threshold - expected) < 1024**2 # Within 1MB
class TestEmbeddingGate:
"""Tests for embedding model runtime compatibility gate (common.py:636-639).
Embedding models should be detected but blocked from `mlxk run` with a
helpful message pointing to `mlxk embed`.
"""
def test_detect_capabilities_embedding_model_type(self, tmp_path):
"""model_type='embedding' should return only EMBEDDINGS capability."""
from mlxk2.operations.common import detect_capabilities
caps = detect_capabilities(
model_type="embedding",
hf_name="test/embedding-model",
tok_hints={},
config={},
probe=tmp_path,
)
assert caps == ["embeddings"]
assert "text-generation" not in caps
def test_embedding_model_runtime_incompatible(self, tmp_path):
"""Embedding models should have runtime_compatible=False with helpful reason."""
from mlxk2.operations.common import build_model_object
# Create minimal embedding model structure with workspace sentinel
config = {"model_type": "embedding"}
(tmp_path / "config.json").write_text(json.dumps(config))
(tmp_path / "model.safetensors").write_bytes(b"x" * 100)
# Add workspace sentinel so it's treated as workspace path (ADR-018)
sentinel = {"managed_by": "mlxk", "source": "test"}
(tmp_path / ".mlxk-workspace.json").write_text(json.dumps(sentinel))
# Add README with MLX library tag so framework is detected as MLX
readme = """---
library_name: mlx
---
# Test Embedding Model
"""
(tmp_path / "README.md").write_text(readme)
# Use absolute path so it's treated as workspace path
result = build_model_object(
hf_name=str(tmp_path), # Absolute path = workspace
model_root=tmp_path,
selected_path=tmp_path,
)
assert result["runtime_compatible"] is False
assert "mlxk embed" in result["reason"]
assert "embeddings" in result["capabilities"]
def test_embedding_model_detected_by_name_heuristic(self, tmp_path):
"""Models with 'embedding' in name should be detected as embedding models."""
config = {"model_type": "bert"} # Not explicitly "embedding" type
(tmp_path / "config.json").write_text(json.dumps(config))
caps = probe_model_capabilities(tmp_path, "test/text-embedding-3-small")
assert caps.is_embedding is True
assert "embeddings" in caps.capabilities_list