mirror of
https://github.com/cloudstack-llc/mlx-knife.git
synced 2026-07-01 20:44:14 -04:00
Release MLX Knife 1.1.0-beta1 - Dynamic Token Limits & Enhanced Web Client
Issues Resolved: • Issue #15: Token limits vs natural stop tokens race condition - FIXED • Issue #16: Interactive vs server token limit policies - FIXED Major Improvements: • Automatic optimal token limits - no configuration needed • Manual --max-tokens control still available when desired • Eliminates old hardcoded 500/2000 token restrictions • Performance gains: Up to 524x improvement for large context models • Enhanced web client with model capabilities display and better UX Additional Enhancements: • Enhanced /v1/models API with context_length field • Comprehensive test expansion: 114 → 131 tests (131/131 passing) • Python 3.9-3.13 compatibility verified Known Issues (Beta Status): • Server deadlock possible under extreme concurrent model loading stress • Workaround: Avoid simultaneous heavy model operations
This commit is contained in:
@@ -1,5 +1,68 @@
|
|||||||
# Changelog
|
# Changelog
|
||||||
|
|
||||||
|
## [1.1.0-beta1] - 2025-08-21
|
||||||
|
|
||||||
|
### Major Features 🚀
|
||||||
|
- **Issues #15 & #16**: Dynamic Model-Aware Token Limits
|
||||||
|
- Eliminated hardcoded 500/2000 token defaults with intelligent model-based limits
|
||||||
|
- **Phi-3-mini**: 4096 context → 2048 server tokens, 4096 interactive (8x improvement)
|
||||||
|
- **Qwen2.5-30B**: 262,144 context → 131,072 server tokens, 262,144 interactive (524x improvement!)
|
||||||
|
- Context-aware policies: Interactive mode uses full context, server mode uses context/2 for DoS protection
|
||||||
|
- Automatic adaptation to new models with larger context windows (future-proof)
|
||||||
|
|
||||||
|
### Enhanced Web Client 🌐
|
||||||
|
- **Model Token Capacity Display**: Shows "Ready with Mistral-7B (32,768 tokens)" in header
|
||||||
|
- **Enhanced `/v1/models` API**: Now exposes `context_length` field for model capabilities
|
||||||
|
- **Button State Management**: Clear Chat properly disabled during streaming with CSS styling
|
||||||
|
- **Streaming Status Tracking**: Added `isStreaming` flag with "Generating response..." feedback
|
||||||
|
|
||||||
|
### Interactive Mode Improvements 💡
|
||||||
|
- **Smart CLI Defaults**: `mlxk run <model> "prompt"` automatically uses optimal token limits per model
|
||||||
|
- **No Configuration Needed**: Users benefit immediately without changing usage patterns
|
||||||
|
- **Explicit Control Preserved**: `--max-tokens` arguments still respected and capped at model context
|
||||||
|
- **Clean Type Safety**: Proper `Optional[int]` handling eliminates fragile CLI guessing
|
||||||
|
|
||||||
|
### Technical Architecture 🏗️
|
||||||
|
- **`get_model_context_length()` function**: Extracts context length from model configs with multiple fallback keys
|
||||||
|
- **Enhanced MLXRunner**: `get_effective_max_tokens()` method for context-aware token limiting
|
||||||
|
- **Server API Updates**: All endpoints use model-aware limits with DoS protection
|
||||||
|
- **Unified Token Logic**: Single source of truth through MLXRunner eliminates duplicate code
|
||||||
|
- **Backward Compatible**: All existing CLI arguments and APIs work unchanged
|
||||||
|
|
||||||
|
### Performance Impact 📊
|
||||||
|
- **Modern Models Unleashed**: Large-context models can now use their full capabilities
|
||||||
|
- **Real-World Benefits**: No more artificial 500-token truncation for 100K+ context models
|
||||||
|
- **Smart Server Limits**: Automatic DoS protection while maximizing usable context
|
||||||
|
- **Zero Magic Numbers**: Clean architecture with clear `None` vs explicit value semantics
|
||||||
|
|
||||||
|
### Testing & Quality Assurance ✅
|
||||||
|
- **Comprehensive Coverage**: 131/131 tests passing (expansion from 114 tests)
|
||||||
|
- **20 new unit tests**: Covering CLI None-handling, model context extraction, effective token calculation
|
||||||
|
- **5 server integration tests**: Real-world validation with actual MLX models
|
||||||
|
- **Extreme Model Testing**: Validated with models from 1B to 30B parameters, up to 256K context
|
||||||
|
- **Edge Case Handling**: Unknown models, missing configs, CLI argument combinations
|
||||||
|
|
||||||
|
### Issue #14 Model Compatibility Validation
|
||||||
|
**Chat Self-Conversation Fix tested across model spectrum:**
|
||||||
|
|
||||||
|
| Model | Size | RAM (GB) | Context | Status | Architecture |
|
||||||
|
|-------|------|----------|---------|--------|-------------|
|
||||||
|
| **Llama-3.2-1B-Instruct-4bit** | 1B | 2 | 131,072 | ✅ PASSED | Llama |
|
||||||
|
| **Llama-3.2-3B-Instruct-4bit** | 3B | 4 | 131,072 | ✅ PASSED | Llama |
|
||||||
|
| **Phi-3-mini-4k-instruct-4bit** | 4B | 5 | 4,096 | ✅ PASSED | Phi-3 |
|
||||||
|
| **Mistral-7B-Instruct-v0.2-4bit** | 7B | 8 | 32,768 | ✅ PASSED | Mistral |
|
||||||
|
| **Mixtral-8x7B-Instruct-v0.1-4bit** | 8x7B | 16 | 32,768 | ✅ PASSED | Mixtral MoE |
|
||||||
|
| **Mistral-Small-3.2-24B-Instruct-2506-4bit** | 24B | 20 | 32,768 | ✅ PASSED | Mistral |
|
||||||
|
| **Qwen3-30B-A3B-Instruct-2507-4bit** | 30B | 24 | 262,144 | ✅ PASSED | Qwen |
|
||||||
|
|
||||||
|
**Validation Results**: 7/7 models passed - comprehensive coverage from 1B to 30B parameters across all major MLX architectures ensures robust chat stop token handling.
|
||||||
|
|
||||||
|
### Beta Status Notes ⚠️
|
||||||
|
- **Core Functionality**: Solid foundation with comprehensive test coverage
|
||||||
|
- **Known Limitation**: Server deadlock possible under extreme concurrent model loading stress
|
||||||
|
- **Workaround**: Avoid simultaneous heavy model operations (normal usage unaffected)
|
||||||
|
- **Real-World Ready**: Significant improvements ready for community testing and feedback
|
||||||
|
|
||||||
## [1.0.4] - 2025-08-19
|
## [1.0.4] - 2025-08-19
|
||||||
|
|
||||||
### Fixed
|
### Fixed
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ A lightweight, ollama-like CLI for managing and running MLX models on Apple Sili
|
|||||||
|
|
||||||
> **Note**: MLX Knife is designed as a command-line interface tool only. While some internal functions are accessible via Python imports, only CLI usage is officially supported.
|
> **Note**: MLX Knife is designed as a command-line interface tool only. While some internal functions are accessible via Python imports, only CLI usage is officially supported.
|
||||||
|
|
||||||
**Current Version**: 1.0.4 (August 2025)
|
**Current Version**: 1.1.0-beta1 (August 2025) - Dynamic Token Limits & Web UI Enhancements
|
||||||
|
|
||||||
[](https://github.com/mzau/mlx-knife/releases)
|
[](https://github.com/mzau/mlx-knife/releases)
|
||||||
[](https://opensource.org/licenses/MIT)
|
[](https://opensource.org/licenses/MIT)
|
||||||
@@ -16,7 +16,7 @@ A lightweight, ollama-like CLI for managing and running MLX models on Apple Sili
|
|||||||
[](https://www.python.org/downloads/)
|
[](https://www.python.org/downloads/)
|
||||||
[](https://support.apple.com/en-us/HT211814)
|
[](https://support.apple.com/en-us/HT211814)
|
||||||
[](https://github.com/ml-explore/mlx)
|
[](https://github.com/ml-explore/mlx)
|
||||||
[](#testing)
|
[](#testing)
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
@@ -325,6 +325,6 @@ Copyright (c) 2025 The BROKE team 🦫
|
|||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<b>Made with ❤️ by The BROKE team <img src="broke-logo.png" alt="BROKE Logo" width="30" style="vertical-align: middle;"></b><br>
|
<b>Made with ❤️ by The BROKE team <img src="broke-logo.png" alt="BROKE Logo" width="30" style="vertical-align: middle;"></b><br>
|
||||||
<i>Version 1.0.4 | August 2025</i><br>
|
<i>Version 1.1.0-beta1 | August 2025</i><br>
|
||||||
<a href="https://github.com/mzau/broke-cluster">🔮 Next: BROKE Cluster for multi-node deployments</a>
|
<a href="https://github.com/mzau/broke-cluster">🔮 Next: BROKE Cluster for multi-node deployments</a>
|
||||||
</p>
|
</p>
|
||||||
|
|||||||
+11
-9
@@ -2,10 +2,10 @@
|
|||||||
|
|
||||||
## Current Status
|
## Current Status
|
||||||
|
|
||||||
✅ **114/114 tests passing** (August 2025)
|
✅ **131/131 tests passing** (August 2025)
|
||||||
✅ **Apple Silicon verified** (M1/M2/M3)
|
✅ **Apple Silicon verified** (M1/M2/M3)
|
||||||
✅ **Python 3.9-3.13 compatible**
|
✅ **Python 3.9-3.13 compatible**
|
||||||
✅ **Production ready** - real model execution validated
|
✅ **Beta ready** - comprehensive testing with real model execution
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
@@ -42,13 +42,15 @@ This approach ensures our tests reflect real-world usage, not mocked behavior.
|
|||||||
```
|
```
|
||||||
tests/
|
tests/
|
||||||
├── conftest.py # Shared fixtures and utilities
|
├── conftest.py # Shared fixtures and utilities
|
||||||
├── integration/ # System-level integration tests (62 tests)
|
├── integration/ # System-level integration tests (85+ tests)
|
||||||
│ ├── test_core_functionality.py # Basic CLI operations
|
│ ├── test_core_functionality.py # Basic CLI operations
|
||||||
│ ├── test_health_checks.py # Model corruption detection
|
│ ├── test_health_checks.py # Model corruption detection
|
||||||
|
│ ├── test_issue_14.py # Issue #14: Chat self-conversation fix
|
||||||
|
│ ├── test_issue_15_16.py # Issues #15/#16: Dynamic token limits
|
||||||
│ ├── test_process_lifecycle.py # Process management & cleanup
|
│ ├── test_process_lifecycle.py # Process management & cleanup
|
||||||
│ ├── test_run_command_advanced.py # Run command edge cases
|
│ ├── test_run_command_advanced.py # Run command edge cases
|
||||||
│ └── test_server_functionality.py # OpenAI API server tests
|
│ └── test_server_functionality.py # OpenAI API server tests
|
||||||
└── unit/ # Module-level unit tests (52 tests)
|
└── unit/ # Module-level unit tests (45+ tests)
|
||||||
├── test_cache_utils.py # Cache management functions
|
├── test_cache_utils.py # Cache management functions
|
||||||
├── test_cli.py # CLI argument parsing
|
├── test_cli.py # CLI argument parsing
|
||||||
└── test_mlx_runner_memory.py # Memory management tests
|
└── test_mlx_runner_memory.py # Memory management tests
|
||||||
@@ -158,11 +160,11 @@ pytest -n auto
|
|||||||
|
|
||||||
| Python Version | Status | Tests Passing |
|
| Python Version | Status | Tests Passing |
|
||||||
|----------------|--------|---------------|
|
|----------------|--------|---------------|
|
||||||
| 3.9.6 (macOS) | ✅ Verified | 114/114 |
|
| 3.9.6 (macOS) | ✅ Verified | 131/131 |
|
||||||
| 3.10.x | ✅ Verified | 114/114 |
|
| 3.10.x | ✅ Verified | 131/131 |
|
||||||
| 3.11.x | ✅ Verified | 114/114 |
|
| 3.11.x | ✅ Verified | 131/131 |
|
||||||
| 3.12.x | ✅ Verified | 114/114 |
|
| 3.12.x | ✅ Verified | 131/131 |
|
||||||
| 3.13.x | ✅ Verified | 114/114 |
|
| 3.13.x | ✅ Verified | 131/131 |
|
||||||
|
|
||||||
All versions tested with real MLX model execution (Phi-3-mini-4k-instruct-4bit).
|
All versions tested with real MLX model execution (Phi-3-mini-4k-instruct-4bit).
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ A lightweight, ollama-like CLI for managing and running MLX models on Apple Sili
|
|||||||
Provides native MLX execution with streaming output and interactive chat capabilities.
|
Provides native MLX execution with streaming output and interactive chat capabilities.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__version__ = "1.0.4"
|
__version__ = "1.1.0-beta1"
|
||||||
__author__ = "The BROKE team"
|
__author__ = "The BROKE team"
|
||||||
__email__ = "broke@gmx.eu"
|
__email__ = "broke@gmx.eu"
|
||||||
__license__ = "MIT"
|
__license__ = "MIT"
|
||||||
@@ -12,7 +12,7 @@ __description__ = "ollama-style CLI for MLX models on Apple Silicon"
|
|||||||
__url__ = "https://github.com/mzau/mlx-knife"
|
__url__ = "https://github.com/mzau/mlx-knife"
|
||||||
|
|
||||||
# Version tuple for programmatic access (major, minor, patch)
|
# Version tuple for programmatic access (major, minor, patch)
|
||||||
VERSION = (1, 0, 4)
|
VERSION = (1, 1, 0)
|
||||||
|
|
||||||
# Core functionality imports
|
# Core functionality imports
|
||||||
from .cache_utils import (
|
from .cache_utils import (
|
||||||
|
|||||||
+1
-1
@@ -41,7 +41,7 @@ def main():
|
|||||||
run_p.add_argument("prompt", nargs="?", default=None, help="Prompt text (if not provided, enters interactive mode)")
|
run_p.add_argument("prompt", nargs="?", default=None, help="Prompt text (if not provided, enters interactive mode)")
|
||||||
run_p.add_argument("--interactive", "-i", action="store_true", help="Force interactive dialog mode")
|
run_p.add_argument("--interactive", "-i", action="store_true", help="Force interactive dialog mode")
|
||||||
run_p.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature (default: 0.7)")
|
run_p.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature (default: 0.7)")
|
||||||
run_p.add_argument("--max-tokens", type=int, default=500, help="Maximum tokens to generate (default: 500)")
|
run_p.add_argument("--max-tokens", type=int, default=None, help="Maximum tokens to generate (default: model context length)")
|
||||||
run_p.add_argument("--top-p", type=float, default=0.9, help="Top-p sampling parameter (default: 0.9)")
|
run_p.add_argument("--top-p", type=float, default=0.9, help="Top-p sampling parameter (default: 0.9)")
|
||||||
run_p.add_argument("--repetition-penalty", type=float, default=1.1, help="Penalty for repeated tokens (default: 1.1)")
|
run_p.add_argument("--repetition-penalty", type=float, default=1.1, help="Penalty for repeated tokens (default: 1.1)")
|
||||||
run_p.add_argument("--no-stream", action="store_true", help="Disable streaming output")
|
run_p.add_argument("--no-stream", action="store_true", help="Disable streaming output")
|
||||||
|
|||||||
+93
-2
@@ -4,6 +4,8 @@ Enhanced MLX model runner with direct API integration.
|
|||||||
Provides ollama-like run experience with streaming and interactive chat.
|
Provides ollama-like run experience with streaming and interactive chat.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -15,6 +17,42 @@ from mlx_lm.generate import generate_step
|
|||||||
from mlx_lm.sample_utils import make_repetition_penalty, make_sampler
|
from mlx_lm.sample_utils import make_repetition_penalty, make_sampler
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_context_length(model_path: str) -> int:
|
||||||
|
"""Extract max_position_embeddings from model config.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to the MLX model directory
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Maximum context length for the model (defaults to 4096 if not found)
|
||||||
|
"""
|
||||||
|
config_path = os.path.join(model_path, "config.json")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(config_path) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
# Try various common config keys for context length
|
||||||
|
context_keys = [
|
||||||
|
"max_position_embeddings",
|
||||||
|
"n_positions",
|
||||||
|
"context_length",
|
||||||
|
"max_sequence_length",
|
||||||
|
"seq_len"
|
||||||
|
]
|
||||||
|
|
||||||
|
for key in context_keys:
|
||||||
|
if key in config:
|
||||||
|
return config[key]
|
||||||
|
|
||||||
|
# If no context length found, return reasonable default
|
||||||
|
return 4096
|
||||||
|
|
||||||
|
except (FileNotFoundError, json.JSONDecodeError, KeyError):
|
||||||
|
# Return default if config can't be read
|
||||||
|
return 4096
|
||||||
|
|
||||||
|
|
||||||
class MLXRunner:
|
class MLXRunner:
|
||||||
"""Direct MLX model runner with streaming and interactive capabilities."""
|
"""Direct MLX model runner with streaming and interactive capabilities."""
|
||||||
|
|
||||||
@@ -33,6 +71,7 @@ class MLXRunner:
|
|||||||
self._memory_baseline = None
|
self._memory_baseline = None
|
||||||
self._stop_tokens = None # Will be populated from tokenizer
|
self._stop_tokens = None # Will be populated from tokenizer
|
||||||
self._chat_stop_tokens = None # Chat-specific stop tokens
|
self._chat_stop_tokens = None # Chat-specific stop tokens
|
||||||
|
self._context_length = None # Will be populated from model config
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self._model_loaded = False
|
self._model_loaded = False
|
||||||
self._context_entered = False # Prevent nested context usage
|
self._context_entered = False # Prevent nested context usage
|
||||||
@@ -93,6 +132,13 @@ class MLXRunner:
|
|||||||
|
|
||||||
# Extract stop tokens from tokenizer
|
# Extract stop tokens from tokenizer
|
||||||
self._extract_stop_tokens()
|
self._extract_stop_tokens()
|
||||||
|
|
||||||
|
# Extract context length from model config
|
||||||
|
self._context_length = get_model_context_length(str(self.model_path))
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
print(f"Model context length: {self._context_length} tokens")
|
||||||
|
|
||||||
self._model_loaded = True
|
self._model_loaded = True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -181,6 +227,7 @@ class MLXRunner:
|
|||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
self._stop_tokens = None
|
self._stop_tokens = None
|
||||||
self._chat_stop_tokens = None
|
self._chat_stop_tokens = None
|
||||||
|
self._context_length = None
|
||||||
self._model_loaded = False
|
self._model_loaded = False
|
||||||
|
|
||||||
# Force garbage collection and clear MLX cache
|
# Force garbage collection and clear MLX cache
|
||||||
@@ -199,6 +246,38 @@ class MLXRunner:
|
|||||||
else:
|
else:
|
||||||
print(f"Cleanup complete (memory after: {memory_after:.1f}GB)")
|
print(f"Cleanup complete (memory after: {memory_after:.1f}GB)")
|
||||||
|
|
||||||
|
def get_effective_max_tokens(self, requested_tokens: Optional[int], interactive: bool = False) -> int:
|
||||||
|
"""Get effective max tokens based on model context and usage mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
requested_tokens: The requested max tokens (None if user didn't specify --max-tokens)
|
||||||
|
interactive: True if this is interactive mode (gets full context length)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Effective max tokens to use
|
||||||
|
"""
|
||||||
|
if not self._context_length:
|
||||||
|
# Fallback when context length is unknown
|
||||||
|
fallback = 4096 if interactive else 2048
|
||||||
|
if self.verbose:
|
||||||
|
if requested_tokens is None:
|
||||||
|
print(f"[WARNING] Model context length unknown, using fallback: {fallback} tokens")
|
||||||
|
else:
|
||||||
|
print(f"[WARNING] Model context length unknown, using user specified: {requested_tokens} tokens")
|
||||||
|
return requested_tokens if requested_tokens is not None else fallback
|
||||||
|
|
||||||
|
if interactive:
|
||||||
|
if requested_tokens is None:
|
||||||
|
# User didn't specify --max-tokens: use full model context
|
||||||
|
return self._context_length
|
||||||
|
else:
|
||||||
|
# User specified --max-tokens explicitly: respect their choice but cap at context
|
||||||
|
return min(requested_tokens, self._context_length)
|
||||||
|
else:
|
||||||
|
# Server/batch mode uses half context length for DoS protection
|
||||||
|
server_limit = self._context_length // 2
|
||||||
|
return min(requested_tokens or server_limit, server_limit)
|
||||||
|
|
||||||
def generate_streaming(
|
def generate_streaming(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@@ -209,6 +288,7 @@ class MLXRunner:
|
|||||||
repetition_context_size: int = 20,
|
repetition_context_size: int = 20,
|
||||||
use_chat_template: bool = True,
|
use_chat_template: bool = True,
|
||||||
use_chat_stop_tokens: bool = False,
|
use_chat_stop_tokens: bool = False,
|
||||||
|
interactive: bool = False,
|
||||||
) -> Iterator[str]:
|
) -> Iterator[str]:
|
||||||
"""Generate text with streaming output.
|
"""Generate text with streaming output.
|
||||||
|
|
||||||
@@ -221,6 +301,7 @@ class MLXRunner:
|
|||||||
repetition_context_size: Context size for repetition penalty
|
repetition_context_size: Context size for repetition penalty
|
||||||
use_chat_template: Apply tokenizer's chat template if available
|
use_chat_template: Apply tokenizer's chat template if available
|
||||||
use_chat_stop_tokens: Include chat turn markers as stop tokens (for interactive mode)
|
use_chat_stop_tokens: Include chat turn markers as stop tokens (for interactive mode)
|
||||||
|
interactive: True if this is interactive mode (affects token limits)
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Generated tokens as they are produced
|
Generated tokens as they are produced
|
||||||
@@ -228,6 +309,9 @@ class MLXRunner:
|
|||||||
if not self.model or not self.tokenizer:
|
if not self.model or not self.tokenizer:
|
||||||
raise RuntimeError("Model not loaded. Call load_model() first.")
|
raise RuntimeError("Model not loaded. Call load_model() first.")
|
||||||
|
|
||||||
|
# Apply context-aware token limits
|
||||||
|
effective_max_tokens = self.get_effective_max_tokens(max_tokens, interactive)
|
||||||
|
|
||||||
# Apply chat template if available and requested
|
# Apply chat template if available and requested
|
||||||
if use_chat_template and hasattr(self.tokenizer, 'chat_template') and self.tokenizer.chat_template:
|
if use_chat_template and hasattr(self.tokenizer, 'chat_template') and self.tokenizer.chat_template:
|
||||||
messages = [{"role": "user", "content": prompt}]
|
messages = [{"role": "user", "content": prompt}]
|
||||||
@@ -261,7 +345,7 @@ class MLXRunner:
|
|||||||
generator = generate_step(
|
generator = generate_step(
|
||||||
prompt=prompt_array,
|
prompt=prompt_array,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
max_tokens=max_tokens,
|
max_tokens=effective_max_tokens,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
logits_processors=logits_processors if logits_processors else None,
|
logits_processors=logits_processors if logits_processors else None,
|
||||||
)
|
)
|
||||||
@@ -367,6 +451,7 @@ class MLXRunner:
|
|||||||
repetition_penalty: float = 1.1,
|
repetition_penalty: float = 1.1,
|
||||||
repetition_context_size: int = 20,
|
repetition_context_size: int = 20,
|
||||||
use_chat_template: bool = True,
|
use_chat_template: bool = True,
|
||||||
|
interactive: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generate text in batch mode (non-streaming).
|
"""Generate text in batch mode (non-streaming).
|
||||||
|
|
||||||
@@ -377,6 +462,8 @@ class MLXRunner:
|
|||||||
top_p: Top-p sampling parameter
|
top_p: Top-p sampling parameter
|
||||||
repetition_penalty: Penalty for repeated tokens
|
repetition_penalty: Penalty for repeated tokens
|
||||||
repetition_context_size: Context size for repetition penalty
|
repetition_context_size: Context size for repetition penalty
|
||||||
|
use_chat_template: Apply tokenizer's chat template if available
|
||||||
|
interactive: True if this is interactive mode (affects token limits)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Generated text
|
Generated text
|
||||||
@@ -384,6 +471,9 @@ class MLXRunner:
|
|||||||
if not self.model or not self.tokenizer:
|
if not self.model or not self.tokenizer:
|
||||||
raise RuntimeError("Model not loaded. Call load_model() first.")
|
raise RuntimeError("Model not loaded. Call load_model() first.")
|
||||||
|
|
||||||
|
# Apply context-aware token limits
|
||||||
|
effective_max_tokens = self.get_effective_max_tokens(max_tokens, interactive)
|
||||||
|
|
||||||
# Apply chat template if available and requested
|
# Apply chat template if available and requested
|
||||||
if use_chat_template and hasattr(self.tokenizer, 'chat_template') and self.tokenizer.chat_template:
|
if use_chat_template and hasattr(self.tokenizer, 'chat_template') and self.tokenizer.chat_template:
|
||||||
messages = [{"role": "user", "content": prompt}]
|
messages = [{"role": "user", "content": prompt}]
|
||||||
@@ -418,7 +508,7 @@ class MLXRunner:
|
|||||||
generator = generate_step(
|
generator = generate_step(
|
||||||
prompt=prompt_array,
|
prompt=prompt_array,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
max_tokens=max_tokens,
|
max_tokens=effective_max_tokens,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
logits_processors=logits_processors if logits_processors else None,
|
logits_processors=logits_processors if logits_processors else None,
|
||||||
)
|
)
|
||||||
@@ -506,6 +596,7 @@ class MLXRunner:
|
|||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
use_chat_stop_tokens=True, # Enable chat stop tokens in interactive mode
|
use_chat_stop_tokens=True, # Enable chat stop tokens in interactive mode
|
||||||
|
interactive=True, # Enable full context length for interactive mode
|
||||||
):
|
):
|
||||||
print(token, end="", flush=True)
|
print(token, end="", flush=True)
|
||||||
response_tokens.append(token)
|
response_tokens.append(token)
|
||||||
|
|||||||
+18
-10
@@ -76,13 +76,9 @@ class ModelInfo(BaseModel):
|
|||||||
object: str = "model"
|
object: str = "model"
|
||||||
owned_by: str = "mlx-knife"
|
owned_by: str = "mlx-knife"
|
||||||
permission: List = []
|
permission: List = []
|
||||||
|
context_length: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
def get_effective_max_tokens(request_max_tokens: Optional[int]) -> int:
|
|
||||||
"""Get effective max_tokens value, using global default if not specified."""
|
|
||||||
global _default_max_tokens
|
|
||||||
return request_max_tokens if request_max_tokens is not None else _default_max_tokens
|
|
||||||
|
|
||||||
|
|
||||||
def get_or_load_model(model_spec: str, verbose: bool = False) -> MLXRunner:
|
def get_or_load_model(model_spec: str, verbose: bool = False) -> MLXRunner:
|
||||||
"""Get model from cache or load it if not cached."""
|
"""Get model from cache or load it if not cached."""
|
||||||
@@ -155,7 +151,7 @@ async def generate_completion_stream(
|
|||||||
token_count = 0
|
token_count = 0
|
||||||
for token in runner.generate_streaming(
|
for token in runner.generate_streaming(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
max_tokens=get_effective_max_tokens(request.max_tokens),
|
max_tokens=runner.get_effective_max_tokens(request.max_tokens or _default_max_tokens, interactive=False),
|
||||||
temperature=request.temperature,
|
temperature=request.temperature,
|
||||||
top_p=request.top_p,
|
top_p=request.top_p,
|
||||||
repetition_penalty=request.repetition_penalty,
|
repetition_penalty=request.repetition_penalty,
|
||||||
@@ -257,7 +253,7 @@ async def generate_chat_stream(
|
|||||||
try:
|
try:
|
||||||
for token in runner.generate_streaming(
|
for token in runner.generate_streaming(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
max_tokens=get_effective_max_tokens(request.max_tokens),
|
max_tokens=runner.get_effective_max_tokens(request.max_tokens or _default_max_tokens, interactive=False),
|
||||||
temperature=request.temperature,
|
temperature=request.temperature,
|
||||||
top_p=request.top_p,
|
top_p=request.top_p,
|
||||||
repetition_penalty=request.repetition_penalty,
|
repetition_penalty=request.repetition_penalty,
|
||||||
@@ -393,10 +389,22 @@ async def list_models():
|
|||||||
framework = detect_framework(model_dir, model_name)
|
framework = detect_framework(model_dir, model_name)
|
||||||
|
|
||||||
if framework == "MLX" and is_model_healthy(model_name):
|
if framework == "MLX" and is_model_healthy(model_name):
|
||||||
|
# Get model context length
|
||||||
|
context_length = None
|
||||||
|
try:
|
||||||
|
from .cache_utils import get_model_path
|
||||||
|
from .mlx_runner import get_model_context_length
|
||||||
|
model_path_tuple = get_model_path(model_name)
|
||||||
|
if model_path_tuple and model_path_tuple[0]:
|
||||||
|
context_length = get_model_context_length(str(model_path_tuple[0]))
|
||||||
|
except Exception:
|
||||||
|
pass # Fallback to None if context length cannot be determined
|
||||||
|
|
||||||
model_list.append(ModelInfo(
|
model_list.append(ModelInfo(
|
||||||
id=model_name,
|
id=model_name,
|
||||||
object="model",
|
object="model",
|
||||||
owned_by="mlx-knife"
|
owned_by="mlx-knife",
|
||||||
|
context_length=context_length
|
||||||
))
|
))
|
||||||
|
|
||||||
return {"object": "list", "data": model_list}
|
return {"object": "list", "data": model_list}
|
||||||
@@ -430,7 +438,7 @@ async def create_completion(request: CompletionRequest):
|
|||||||
|
|
||||||
generated_text = runner.generate_batch(
|
generated_text = runner.generate_batch(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
max_tokens=get_effective_max_tokens(request.max_tokens),
|
max_tokens=runner.get_effective_max_tokens(request.max_tokens or _default_max_tokens, interactive=False),
|
||||||
temperature=request.temperature,
|
temperature=request.temperature,
|
||||||
top_p=request.top_p,
|
top_p=request.top_p,
|
||||||
repetition_penalty=request.repetition_penalty,
|
repetition_penalty=request.repetition_penalty,
|
||||||
@@ -486,7 +494,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|||||||
|
|
||||||
generated_text = runner.generate_batch(
|
generated_text = runner.generate_batch(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
max_tokens=get_effective_max_tokens(request.max_tokens),
|
max_tokens=runner.get_effective_max_tokens(request.max_tokens or _default_max_tokens, interactive=False),
|
||||||
temperature=request.temperature,
|
temperature=request.temperature,
|
||||||
top_p=request.top_p,
|
top_p=request.top_p,
|
||||||
repetition_penalty=request.repetition_penalty,
|
repetition_penalty=request.repetition_penalty,
|
||||||
|
|||||||
+1
-1
@@ -13,7 +13,7 @@ authors = [
|
|||||||
{name = "The BROKE team", email = "broke@gmx.eu"},
|
{name = "The BROKE team", email = "broke@gmx.eu"},
|
||||||
]
|
]
|
||||||
classifiers = [
|
classifiers = [
|
||||||
"Development Status :: 5 - Production/Stable",
|
"Development Status :: 4 - Beta",
|
||||||
"Intended Audience :: Developers",
|
"Intended Audience :: Developers",
|
||||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||||
"Programming Language :: Python :: 3",
|
"Programming Language :: Python :: 3",
|
||||||
|
|||||||
+48
-10
@@ -107,16 +107,21 @@
|
|||||||
color: white;
|
color: white;
|
||||||
}
|
}
|
||||||
|
|
||||||
#clearButton:hover {
|
#clearButton:hover:not(:disabled) {
|
||||||
background: #5a6268;
|
background: #5a6268;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#clearButton:disabled {
|
||||||
|
background: #ccc;
|
||||||
|
cursor: not-allowed;
|
||||||
|
}
|
||||||
|
|
||||||
#sendButton {
|
#sendButton {
|
||||||
background: #007AFF;
|
background: #007AFF;
|
||||||
color: white;
|
color: white;
|
||||||
}
|
}
|
||||||
|
|
||||||
#sendButton:hover {
|
#sendButton:hover:not(:disabled) {
|
||||||
background: #0056b3;
|
background: #0056b3;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -280,6 +285,7 @@
|
|||||||
let currentModel = localStorage.getItem('mlxk_selected_model') || '';
|
let currentModel = localStorage.getItem('mlxk_selected_model') || '';
|
||||||
let isConnected = false;
|
let isConnected = false;
|
||||||
let conversationHistory = JSON.parse(localStorage.getItem('mlxk_chat_history') || '[]');
|
let conversationHistory = JSON.parse(localStorage.getItem('mlxk_chat_history') || '[]');
|
||||||
|
let isStreaming = false;
|
||||||
|
|
||||||
// DOM elements
|
// DOM elements
|
||||||
const statusEl = document.getElementById('status');
|
const statusEl = document.getElementById('status');
|
||||||
@@ -287,6 +293,7 @@
|
|||||||
const chatMessages = document.getElementById('chatMessages');
|
const chatMessages = document.getElementById('chatMessages');
|
||||||
const messageInput = document.getElementById('messageInput');
|
const messageInput = document.getElementById('messageInput');
|
||||||
const sendButton = document.getElementById('sendButton');
|
const sendButton = document.getElementById('sendButton');
|
||||||
|
const clearButton = document.getElementById('clearButton');
|
||||||
|
|
||||||
// Update status
|
// Update status
|
||||||
function updateStatus(connected, message = '') {
|
function updateStatus(connected, message = '') {
|
||||||
@@ -294,8 +301,9 @@
|
|||||||
statusEl.textContent = message || (connected ? 'Connected' : 'Disconnected');
|
statusEl.textContent = message || (connected ? 'Connected' : 'Disconnected');
|
||||||
statusEl.className = `status ${connected ? 'connected' : 'disconnected'}`;
|
statusEl.className = `status ${connected ? 'connected' : 'disconnected'}`;
|
||||||
|
|
||||||
messageInput.disabled = !connected || !currentModel;
|
messageInput.disabled = !connected || !currentModel || isStreaming;
|
||||||
sendButton.disabled = !connected || !currentModel;
|
sendButton.disabled = !connected || !currentModel || isStreaming;
|
||||||
|
clearButton.disabled = isStreaming || conversationHistory.length === 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check API health
|
// Check API health
|
||||||
@@ -321,6 +329,9 @@
|
|||||||
const data = await response.json();
|
const data = await response.json();
|
||||||
modelSelect.innerHTML = '<option value="">Select a model...</option>';
|
modelSelect.innerHTML = '<option value="">Select a model...</option>';
|
||||||
|
|
||||||
|
// Store models data for context length lookup
|
||||||
|
window.modelsData = data.data;
|
||||||
|
|
||||||
data.data.forEach(model => {
|
data.data.forEach(model => {
|
||||||
const option = document.createElement('option');
|
const option = document.createElement('option');
|
||||||
option.value = model.id;
|
option.value = model.id;
|
||||||
@@ -332,13 +343,13 @@
|
|||||||
if (currentModel && data.data.some(m => m.id === currentModel)) {
|
if (currentModel && data.data.some(m => m.id === currentModel)) {
|
||||||
// Restore previously selected model
|
// Restore previously selected model
|
||||||
modelSelect.value = currentModel;
|
modelSelect.value = currentModel;
|
||||||
updateStatus(true, `Ready with ${currentModel.replace('mlx-community/', '')}`);
|
updateStatus(true, getModelDisplayText(currentModel));
|
||||||
} else if (data.data.length > 0) {
|
} else if (data.data.length > 0) {
|
||||||
// Auto-select first model if no valid stored model
|
// Auto-select first model if no valid stored model
|
||||||
currentModel = data.data[0].id;
|
currentModel = data.data[0].id;
|
||||||
modelSelect.value = currentModel;
|
modelSelect.value = currentModel;
|
||||||
localStorage.setItem('mlxk_selected_model', currentModel);
|
localStorage.setItem('mlxk_selected_model', currentModel);
|
||||||
updateStatus(true, `Ready with ${currentModel.replace('mlx-community/', '')}`);
|
updateStatus(true, getModelDisplayText(currentModel));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
@@ -346,6 +357,23 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Update status with model info including token capacity
|
||||||
|
function getModelDisplayText(modelId) {
|
||||||
|
if (!modelId) return 'Select a model';
|
||||||
|
|
||||||
|
const modelName = modelId.replace('mlx-community/', '');
|
||||||
|
|
||||||
|
if (!window.modelsData) {
|
||||||
|
// Fallback before models data is loaded
|
||||||
|
return `Ready with ${modelName}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
const model = window.modelsData.find(m => m.id === modelId);
|
||||||
|
const contextLength = model?.context_length || 4096;
|
||||||
|
|
||||||
|
return `Ready with ${modelName} (${contextLength.toLocaleString()} tokens)`;
|
||||||
|
}
|
||||||
|
|
||||||
// Add message to chat
|
// Add message to chat
|
||||||
function addMessage(content, isUser = false) {
|
function addMessage(content, isUser = false) {
|
||||||
const messageDiv = document.createElement('div');
|
const messageDiv = document.createElement('div');
|
||||||
@@ -370,7 +398,8 @@
|
|||||||
localStorage.setItem('mlxk_chat_history', JSON.stringify(conversationHistory));
|
localStorage.setItem('mlxk_chat_history', JSON.stringify(conversationHistory));
|
||||||
addMessage(message, true);
|
addMessage(message, true);
|
||||||
messageInput.value = '';
|
messageInput.value = '';
|
||||||
sendButton.disabled = true;
|
isStreaming = true;
|
||||||
|
updateStatus(isConnected, `Generating response with ${currentModel.replace('mlx-community/', '')}...`);
|
||||||
|
|
||||||
// Add typing indicator
|
// Add typing indicator
|
||||||
const typingDiv = addMessage('Typing...', false);
|
const typingDiv = addMessage('Typing...', false);
|
||||||
@@ -442,9 +471,10 @@
|
|||||||
} catch (error) {
|
} catch (error) {
|
||||||
typingDiv.remove();
|
typingDiv.remove();
|
||||||
addMessage(`Error: ${error.message}`, false);
|
addMessage(`Error: ${error.message}`, false);
|
||||||
|
} finally {
|
||||||
|
isStreaming = false;
|
||||||
|
updateStatus(isConnected, currentModel ? getModelDisplayText(currentModel) : 'Select a model');
|
||||||
}
|
}
|
||||||
|
|
||||||
sendButton.disabled = false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clear conversation
|
// Clear conversation
|
||||||
@@ -453,6 +483,10 @@
|
|||||||
return; // Nothing to clear
|
return; // Nothing to clear
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (isStreaming) {
|
||||||
|
return; // Don't allow clearing during streaming
|
||||||
|
}
|
||||||
|
|
||||||
showModal(
|
showModal(
|
||||||
'Clear Chat',
|
'Clear Chat',
|
||||||
'Are you sure you want to clear the entire chat history?',
|
'Are you sure you want to clear the entire chat history?',
|
||||||
@@ -462,6 +496,7 @@
|
|||||||
conversationHistory = [];
|
conversationHistory = [];
|
||||||
localStorage.setItem('mlxk_chat_history', JSON.stringify(conversationHistory));
|
localStorage.setItem('mlxk_chat_history', JSON.stringify(conversationHistory));
|
||||||
chatMessages.innerHTML = '<div class="message assistant-message"><strong>MLX Assistant:</strong> Hi! I\'m ready to chat. Select a model and send me a message!</div>';
|
chatMessages.innerHTML = '<div class="message assistant-message"><strong>MLX Assistant:</strong> Hi! I\'m ready to chat. Select a model and send me a message!</div>';
|
||||||
|
updateStatus(isConnected, currentModel ? getModelDisplayText(currentModel) : 'Select a model');
|
||||||
},
|
},
|
||||||
() => {
|
() => {
|
||||||
// Cancel - do nothing
|
// Cancel - do nothing
|
||||||
@@ -512,11 +547,12 @@
|
|||||||
conversationHistory = [];
|
conversationHistory = [];
|
||||||
localStorage.setItem('mlxk_chat_history', JSON.stringify(conversationHistory));
|
localStorage.setItem('mlxk_chat_history', JSON.stringify(conversationHistory));
|
||||||
chatMessages.innerHTML = '<div class="message assistant-message"><strong>MLX Assistant:</strong> Hi! I\'m ready to chat with the new model!</div>';
|
chatMessages.innerHTML = '<div class="message assistant-message"><strong>MLX Assistant:</strong> Hi! I\'m ready to chat with the new model!</div>';
|
||||||
|
updateStatus(isConnected, currentModel ? getModelDisplayText(currentModel) : 'Select a model');
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
updateStatus(isConnected, currentModel ? `Ready with ${currentModel.replace('mlx-community/', '')}` : 'Select a model');
|
updateStatus(isConnected, currentModel ? getModelDisplayText(currentModel) : 'Select a model');
|
||||||
});
|
});
|
||||||
|
|
||||||
messageInput.addEventListener('keypress', (e) => {
|
messageInput.addEventListener('keypress', (e) => {
|
||||||
@@ -535,6 +571,8 @@
|
|||||||
addMessage(msg.content, msg.role === 'user');
|
addMessage(msg.content, msg.role === 'user');
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
// Update button states after restoring
|
||||||
|
updateStatus(isConnected, currentModel ? getModelDisplayText(currentModel) : 'Select a model');
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize
|
// Initialize
|
||||||
|
|||||||
@@ -0,0 +1,404 @@
|
|||||||
|
"""
|
||||||
|
Test for Issues #15 & #16: Dynamic Model-Aware Token Limits
|
||||||
|
|
||||||
|
Issue #15: Token-Limit vs Stop-Token Race Condition
|
||||||
|
- Models cut off by artificial token limits before natural stopping
|
||||||
|
- Solution: Context-aware token policies based on model capabilities
|
||||||
|
|
||||||
|
Issue #16: Interactive vs Server Token Limit Policies
|
||||||
|
- Interactive mode should allow unlimited tokens for natural completion
|
||||||
|
- Server mode needs DoS protection with reasonable limits
|
||||||
|
- Solution: Different token policies per usage context
|
||||||
|
|
||||||
|
This test is self-contained and manages its own MLX Knife server instance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
import signal
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
import psutil
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Realistic RAM requirements for 4-bit quantized models (in GB)
|
||||||
|
MODEL_RAM_REQUIREMENTS = {
|
||||||
|
"0.5B": 1, "1B": 2, "3B": 4, "4B": 5,
|
||||||
|
"7B": 8, "8x7B": 16, "24B": 20, "30B": 24,
|
||||||
|
"70B": 40, "480B": 180
|
||||||
|
}
|
||||||
|
|
||||||
|
SERVER_BASE_URL = "http://localhost:8001" # Different port to avoid conflicts
|
||||||
|
SERVER_PORT = 8001
|
||||||
|
|
||||||
|
|
||||||
|
def extract_model_size(model_name: str) -> str:
|
||||||
|
"""Extract model size from model name."""
|
||||||
|
# Match patterns like "30B", "8x7B", "480B", "0.5B", "3.2B", "Phi-3-mini" etc.
|
||||||
|
size_patterns = [
|
||||||
|
r'(\d+x\d+B)', # MoE models like "8x7B"
|
||||||
|
r'(\d+\.?\d*B)', # Standard like "30B", "0.5B", "3.2B"
|
||||||
|
r'(mini|small|medium|large)', # Qualitative sizes
|
||||||
|
]
|
||||||
|
|
||||||
|
for pattern in size_patterns:
|
||||||
|
match = re.search(pattern, model_name, re.IGNORECASE)
|
||||||
|
if match:
|
||||||
|
size = match.group(1).lower()
|
||||||
|
# Map qualitative sizes to quantitative
|
||||||
|
if size == 'mini':
|
||||||
|
return '3B' # Phi-3-mini is ~4B params
|
||||||
|
elif size == 'small':
|
||||||
|
return '1B'
|
||||||
|
elif size == 'medium':
|
||||||
|
return '7B'
|
||||||
|
elif size == 'large':
|
||||||
|
return '30B'
|
||||||
|
return size.upper()
|
||||||
|
|
||||||
|
return "3B" # Default fallback
|
||||||
|
|
||||||
|
|
||||||
|
def get_available_ram_gb() -> int:
|
||||||
|
"""Get available system RAM in GB."""
|
||||||
|
try:
|
||||||
|
return int(psutil.virtual_memory().available / (1024**3))
|
||||||
|
except Exception:
|
||||||
|
return 8 # Conservative fallback
|
||||||
|
|
||||||
|
|
||||||
|
def get_suitable_models(available_models: List[str]) -> List[str]:
|
||||||
|
"""Filter models based on available RAM."""
|
||||||
|
available_ram = get_available_ram_gb()
|
||||||
|
logger.info(f"Available RAM: {available_ram}GB")
|
||||||
|
|
||||||
|
suitable = []
|
||||||
|
for model in available_models:
|
||||||
|
size = extract_model_size(model)
|
||||||
|
required_ram = MODEL_RAM_REQUIREMENTS.get(size, 8)
|
||||||
|
|
||||||
|
if required_ram <= available_ram:
|
||||||
|
suitable.append(model)
|
||||||
|
logger.info(f"✓ {model} ({size}, {required_ram}GB) - Suitable")
|
||||||
|
else:
|
||||||
|
logger.info(f"✗ {model} ({size}, {required_ram}GB) - Too large")
|
||||||
|
|
||||||
|
return suitable
|
||||||
|
|
||||||
|
|
||||||
|
def get_cached_models() -> List[str]:
|
||||||
|
"""Get list of cached MLX models."""
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
["mlxk", "list", "--framework", "mlx"],
|
||||||
|
capture_output=True, text=True, timeout=10
|
||||||
|
)
|
||||||
|
if result.returncode != 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
models = []
|
||||||
|
for line in result.stdout.split('\n'):
|
||||||
|
line = line.strip()
|
||||||
|
if line and not line.startswith('MODEL') and not line.startswith('NAME'):
|
||||||
|
# Extract model name from table format
|
||||||
|
parts = line.split()
|
||||||
|
if len(parts) >= 1 and not parts[0] in ['MODEL', 'NAME']:
|
||||||
|
models.append(parts[0])
|
||||||
|
|
||||||
|
return models
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to get cached models: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def extract_context_length_from_model(model_name: str) -> int:
|
||||||
|
"""Extract context length from a real model's config."""
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
["mlxk", "show", model_name, "--config"],
|
||||||
|
capture_output=True, text=True, timeout=10
|
||||||
|
)
|
||||||
|
if result.returncode != 0:
|
||||||
|
return 4096
|
||||||
|
|
||||||
|
# Extract JSON from the output (it comes after "Config:")
|
||||||
|
config_text = result.stdout
|
||||||
|
|
||||||
|
# Find the JSON part after "Config:"
|
||||||
|
config_start = config_text.find("Config:")
|
||||||
|
if config_start == -1:
|
||||||
|
return 4096
|
||||||
|
|
||||||
|
json_text = config_text[config_start + 7:].strip() # Skip "Config:"
|
||||||
|
|
||||||
|
try:
|
||||||
|
config = json.loads(json_text)
|
||||||
|
context_keys = [
|
||||||
|
"max_position_embeddings",
|
||||||
|
"n_positions",
|
||||||
|
"context_length",
|
||||||
|
"max_sequence_length",
|
||||||
|
"seq_len"
|
||||||
|
]
|
||||||
|
|
||||||
|
for key in context_keys:
|
||||||
|
if key in config:
|
||||||
|
return config[key]
|
||||||
|
|
||||||
|
return 4096
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return 4096
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
return 4096
|
||||||
|
|
||||||
|
|
||||||
|
class MLXKnifeServer:
|
||||||
|
"""Manages MLX Knife server lifecycle for testing."""
|
||||||
|
|
||||||
|
def __init__(self, port: int = SERVER_PORT):
|
||||||
|
self.port = port
|
||||||
|
self.process = None
|
||||||
|
self.base_url = f"http://localhost:{port}"
|
||||||
|
|
||||||
|
def start(self) -> bool:
|
||||||
|
"""Start the MLX Knife server."""
|
||||||
|
try:
|
||||||
|
cmd = [
|
||||||
|
"mlxk", "server",
|
||||||
|
"--host", "127.0.0.1",
|
||||||
|
"--port", str(self.port),
|
||||||
|
"--max-tokens", "1000", # Conservative default for testing
|
||||||
|
"--log-level", "warning"
|
||||||
|
]
|
||||||
|
|
||||||
|
self.process = subprocess.Popen(
|
||||||
|
cmd,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
text=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for server to start
|
||||||
|
for attempt in range(30):
|
||||||
|
try:
|
||||||
|
response = requests.get(f"{self.base_url}/v1/models", timeout=2)
|
||||||
|
if response.status_code == 200:
|
||||||
|
logger.info(f"MLX Knife server started on port {self.port}")
|
||||||
|
return True
|
||||||
|
except requests.RequestException:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if self.process.poll() is not None:
|
||||||
|
logger.error("Server process died during startup")
|
||||||
|
return False
|
||||||
|
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
logger.error("Server failed to start within timeout")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to start server: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""Stop the MLX Knife server."""
|
||||||
|
if self.process:
|
||||||
|
try:
|
||||||
|
# Try graceful shutdown first
|
||||||
|
self.process.terminate()
|
||||||
|
try:
|
||||||
|
self.process.wait(timeout=10)
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
# Force kill if not responding
|
||||||
|
self.process.kill()
|
||||||
|
self.process.wait(timeout=5)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error stopping server: {e}")
|
||||||
|
finally:
|
||||||
|
self.process = None
|
||||||
|
|
||||||
|
def chat_completion(self, model: str, messages: List[Dict], max_tokens: int = None) -> Dict:
|
||||||
|
"""Send chat completion request."""
|
||||||
|
payload = {
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": 0.3,
|
||||||
|
"stream": False
|
||||||
|
}
|
||||||
|
if max_tokens:
|
||||||
|
payload["max_tokens"] = max_tokens
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
f"{self.base_url}/v1/chat/completions",
|
||||||
|
json=payload,
|
||||||
|
timeout=60
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def mlx_server():
|
||||||
|
"""Provide MLX Knife server for the test session."""
|
||||||
|
server = MLXKnifeServer()
|
||||||
|
|
||||||
|
if not server.start():
|
||||||
|
pytest.skip("Failed to start MLX Knife server")
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield server
|
||||||
|
finally:
|
||||||
|
server.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def available_models():
|
||||||
|
"""Get available models suitable for current system."""
|
||||||
|
all_models = get_cached_models()
|
||||||
|
if not all_models:
|
||||||
|
pytest.skip("No MLX models found in cache")
|
||||||
|
|
||||||
|
suitable = get_suitable_models(all_models)
|
||||||
|
if not suitable:
|
||||||
|
pytest.skip("No suitable models found for current RAM")
|
||||||
|
|
||||||
|
return suitable
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.server
|
||||||
|
class TestIssue15TokenLimitVsStopTokenRace:
|
||||||
|
"""Test Issue #15: Token-Limit vs Stop-Token Race Condition Resolution."""
|
||||||
|
|
||||||
|
def test_model_context_length_extraction(self, available_models):
|
||||||
|
"""Test that we can extract context length from real models."""
|
||||||
|
model = available_models[0]
|
||||||
|
context_length = extract_context_length_from_model(model)
|
||||||
|
|
||||||
|
assert context_length >= 512, f"Context length too small for {model}: {context_length}"
|
||||||
|
assert context_length <= 1048576, f"Context length unrealistic for {model}: {context_length}" # 1M tokens max
|
||||||
|
|
||||||
|
logger.info(f"Model {model} has context length: {context_length}")
|
||||||
|
|
||||||
|
def test_realistic_token_limits_prevent_race_condition(self, mlx_server, available_models):
|
||||||
|
"""Test that realistic token limits prevent race conditions."""
|
||||||
|
model = available_models[0]
|
||||||
|
context_length = extract_context_length_from_model(model)
|
||||||
|
|
||||||
|
# Request tokens close to but under the expected server limit (context/2)
|
||||||
|
server_limit = context_length // 2
|
||||||
|
test_tokens = min(server_limit - 100, 500) # Conservative test
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "Write a short story about a robot."}]
|
||||||
|
|
||||||
|
response = mlx_server.chat_completion(model, messages, max_tokens=test_tokens)
|
||||||
|
|
||||||
|
assert "choices" in response
|
||||||
|
assert len(response["choices"]) > 0
|
||||||
|
choice = response["choices"][0]
|
||||||
|
assert "message" in choice
|
||||||
|
assert "content" in choice["message"]
|
||||||
|
|
||||||
|
content = choice["message"]["content"]
|
||||||
|
assert len(content) > 0, "No content generated"
|
||||||
|
|
||||||
|
# The key test: model should generate reasonable content within limits
|
||||||
|
# without being cut off mid-sentence due to race conditions
|
||||||
|
logger.info(f"Generated {len(content)} characters with {test_tokens} token limit")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.server
|
||||||
|
class TestIssue16InteractiveVsServerTokenPolicies:
|
||||||
|
"""Test Issue #16: Interactive vs Server Token Limit Policies Resolution."""
|
||||||
|
|
||||||
|
def test_server_mode_uses_dos_protection_limits(self, mlx_server, available_models):
|
||||||
|
"""Test that server mode uses DoS protection (context/2)."""
|
||||||
|
model = available_models[0]
|
||||||
|
context_length = extract_context_length_from_model(model)
|
||||||
|
server_limit = context_length // 2
|
||||||
|
|
||||||
|
# Request more tokens than server limit should allow, but not too excessive for testing
|
||||||
|
excessive_tokens = min(server_limit + 200, 800) # Keep reasonable for testing
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "Write a brief summary of machine learning."}]
|
||||||
|
|
||||||
|
# This should work without errors - the server should internally
|
||||||
|
# limit tokens to the DoS protection limit
|
||||||
|
response = mlx_server.chat_completion(model, messages, max_tokens=excessive_tokens)
|
||||||
|
|
||||||
|
assert "choices" in response
|
||||||
|
assert len(response["choices"]) > 0
|
||||||
|
choice = response["choices"][0]
|
||||||
|
assert "message" in choice
|
||||||
|
assert "content" in choice["message"]
|
||||||
|
|
||||||
|
content = choice["message"]["content"]
|
||||||
|
assert len(content) > 0
|
||||||
|
|
||||||
|
# The response should be successful, proving the server handles
|
||||||
|
# excessive token requests gracefully
|
||||||
|
logger.info(f"Server handled excessive token request ({excessive_tokens}) gracefully")
|
||||||
|
logger.info(f"Model context: {context_length}, Server limit: {server_limit}, Generated content length: {len(content)}")
|
||||||
|
|
||||||
|
def test_server_honors_reasonable_token_requests(self, mlx_server, available_models):
|
||||||
|
"""Test that server honors reasonable token requests."""
|
||||||
|
model = available_models[0]
|
||||||
|
context_length = extract_context_length_from_model(model)
|
||||||
|
server_limit = context_length // 2
|
||||||
|
|
||||||
|
# Request reasonable number of tokens (well under limit)
|
||||||
|
reasonable_tokens = min(server_limit // 4, 200)
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "Say hello."}]
|
||||||
|
|
||||||
|
response = mlx_server.chat_completion(model, messages, max_tokens=reasonable_tokens)
|
||||||
|
|
||||||
|
assert "choices" in response
|
||||||
|
assert len(response["choices"]) > 0
|
||||||
|
choice = response["choices"][0]
|
||||||
|
assert "message" in choice
|
||||||
|
assert "content" in choice["message"]
|
||||||
|
|
||||||
|
content = choice["message"]["content"]
|
||||||
|
assert len(content) > 0
|
||||||
|
assert "hello" in content.lower() or "hi" in content.lower()
|
||||||
|
|
||||||
|
logger.info(f"Server honored reasonable token request ({reasonable_tokens})")
|
||||||
|
|
||||||
|
def test_model_capabilities_vs_hardcoded_limits(self, available_models):
|
||||||
|
"""Test that models with different context lengths get appropriate limits."""
|
||||||
|
if len(available_models) < 2:
|
||||||
|
pytest.skip("Need multiple models to compare context lengths")
|
||||||
|
|
||||||
|
model_contexts = []
|
||||||
|
for model in available_models[:3]: # Test up to 3 models
|
||||||
|
context_length = extract_context_length_from_model(model)
|
||||||
|
model_contexts.append((model, context_length))
|
||||||
|
|
||||||
|
# Verify that different models have different context lengths
|
||||||
|
# (or at least our system recognizes their individual capabilities)
|
||||||
|
contexts = [ctx for _, ctx in model_contexts]
|
||||||
|
|
||||||
|
# At minimum, verify context extraction worked
|
||||||
|
for model, context in model_contexts:
|
||||||
|
assert context >= 1024, f"Model {model} context too small: {context}"
|
||||||
|
logger.info(f"Model {model}: {context} tokens context")
|
||||||
|
|
||||||
|
# The key insight: No hardcoded 500/2000 token limits!
|
||||||
|
# Each model gets limits based on its actual capabilities
|
||||||
|
for model, context in model_contexts:
|
||||||
|
server_limit = context // 2
|
||||||
|
# Server limits should be much higher than old hardcoded limits
|
||||||
|
# for models with large context windows
|
||||||
|
if context >= 4096:
|
||||||
|
assert server_limit >= 2048, f"Model {model} should have server limit >= 2048, got {server_limit}"
|
||||||
@@ -334,4 +334,93 @@ class TestRunCommandErrorConditions:
|
|||||||
pytest.fail(f"Concurrent run command {i} hung")
|
pytest.fail(f"Concurrent run command {i} hung")
|
||||||
|
|
||||||
# Each should complete independently
|
# Each should complete independently
|
||||||
assert return_code is not None, f"Concurrent run {i} did not complete"
|
assert return_code is not None, f"Concurrent run {i} did not complete"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.timeout(60)
|
||||||
|
class TestRunCommandContextAwareLimits:
|
||||||
|
"""Test context-aware token limits in Issues #15 and #16 resolution."""
|
||||||
|
|
||||||
|
def test_context_length_extraction_from_real_model(self, mlx_knife_process, mock_model_cache):
|
||||||
|
"""Test that context length is correctly extracted from real model configs."""
|
||||||
|
# Create a mock model with realistic config.json
|
||||||
|
model_path = mock_model_cache("test-model", healthy=True)
|
||||||
|
|
||||||
|
# Add custom config.json with specific context length
|
||||||
|
config_content = {
|
||||||
|
"max_position_embeddings": 4096,
|
||||||
|
"hidden_size": 768,
|
||||||
|
"num_attention_heads": 12
|
||||||
|
}
|
||||||
|
import json
|
||||||
|
(model_path / "config.json").write_text(json.dumps(config_content))
|
||||||
|
|
||||||
|
# Test that the model context length is accessible
|
||||||
|
# This is an indirect test - we test that the run command uses model-aware limits
|
||||||
|
# by checking that it doesn't hang with realistic models
|
||||||
|
proc = mlx_knife_process([
|
||||||
|
"run", "test-model", "Test prompt",
|
||||||
|
"--max-tokens", "8000", # Request more than typical model context
|
||||||
|
"--verbose"
|
||||||
|
])
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Should complete within timeout (won't actually generate due to mock)
|
||||||
|
return_code = proc.wait(timeout=30)
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
proc.kill()
|
||||||
|
pytest.fail("Run command hung with high max-tokens")
|
||||||
|
|
||||||
|
# Should complete (may fail due to mock model, but shouldn't hang)
|
||||||
|
assert return_code is not None, "Run command did not complete with context-aware limits"
|
||||||
|
|
||||||
|
def test_server_vs_interactive_token_policies(self, mock_model_cache):
|
||||||
|
"""Test that server mode uses DoS protection while interactive mode uses full context."""
|
||||||
|
# This test validates the architectural decision:
|
||||||
|
# - Server mode: context_length / 2 (DoS protection)
|
||||||
|
# - Interactive mode: full context_length
|
||||||
|
|
||||||
|
from mlx_knife.mlx_runner import MLXRunner, get_model_context_length
|
||||||
|
import tempfile
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Create a temporary model directory with config
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
config_path = os.path.join(temp_dir, "config.json")
|
||||||
|
config = {"max_position_embeddings": 4096}
|
||||||
|
|
||||||
|
with open(config_path, 'w') as f:
|
||||||
|
json.dump(config, f)
|
||||||
|
|
||||||
|
# Test context length extraction
|
||||||
|
context_length = get_model_context_length(temp_dir)
|
||||||
|
assert context_length == 4096, "Context length extraction failed"
|
||||||
|
|
||||||
|
# Test MLXRunner effective token calculation
|
||||||
|
runner = MLXRunner(temp_dir, verbose=False)
|
||||||
|
runner._context_length = 4096
|
||||||
|
|
||||||
|
# Interactive mode should use full context
|
||||||
|
interactive_tokens = runner.get_effective_max_tokens(8000, interactive=True)
|
||||||
|
assert interactive_tokens == 4096, f"Interactive mode should use full context: {interactive_tokens}"
|
||||||
|
|
||||||
|
# Server mode should use half context (DoS protection)
|
||||||
|
server_tokens = runner.get_effective_max_tokens(8000, interactive=False)
|
||||||
|
assert server_tokens == 2048, f"Server mode should use half context: {server_tokens}"
|
||||||
|
|
||||||
|
# User requests smaller than limits should be honored
|
||||||
|
small_interactive = runner.get_effective_max_tokens(1000, interactive=True)
|
||||||
|
assert small_interactive == 1000, "Small requests should be honored in interactive mode"
|
||||||
|
|
||||||
|
small_server = runner.get_effective_max_tokens(1000, interactive=False)
|
||||||
|
assert small_server == 1000, "Small requests should be honored in server mode"
|
||||||
|
|
||||||
|
# Test None behavior (new CLI default=None logic)
|
||||||
|
# Interactive mode with None should use full context
|
||||||
|
none_interactive = runner.get_effective_max_tokens(None, interactive=True)
|
||||||
|
assert none_interactive == 4096, "None in interactive mode should use full context"
|
||||||
|
|
||||||
|
# Server mode with None should use server limit
|
||||||
|
none_server = runner.get_effective_max_tokens(None, interactive=False)
|
||||||
|
assert none_server == 2048, "None in server mode should use server limit (context/2)"
|
||||||
@@ -1,9 +1,12 @@
|
|||||||
"""
|
"""
|
||||||
Unit tests for MLXRunner memory management robustness.
|
Unit tests for MLXRunner memory management robustness and context length handling.
|
||||||
|
|
||||||
Tests context manager implementation, exception handling,
|
Tests context manager implementation, exception handling, cleanup guarantees,
|
||||||
and cleanup guarantees without requiring actual MLX models.
|
and model context length extraction without requiring actual MLX models.
|
||||||
"""
|
"""
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from unittest.mock import MagicMock, patch, PropertyMock
|
from unittest.mock import MagicMock, patch, PropertyMock
|
||||||
import gc
|
import gc
|
||||||
@@ -291,5 +294,258 @@ class TestMLXRunnerMemoryManagement(unittest.TestCase):
|
|||||||
self.assertFalse(runner._model_loaded)
|
self.assertFalse(runner._model_loaded)
|
||||||
|
|
||||||
|
|
||||||
|
class TestModelContextLength(unittest.TestCase):
|
||||||
|
"""Test model context length extraction functionality."""
|
||||||
|
|
||||||
|
def test_get_model_context_length_with_max_position_embeddings(self):
|
||||||
|
"""Test context length extraction from max_position_embeddings."""
|
||||||
|
from mlx_knife.mlx_runner import get_model_context_length
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
config_path = os.path.join(temp_dir, "config.json")
|
||||||
|
config = {
|
||||||
|
"max_position_embeddings": 4096,
|
||||||
|
"hidden_size": 768,
|
||||||
|
"num_attention_heads": 12
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(config_path, 'w') as f:
|
||||||
|
json.dump(config, f)
|
||||||
|
|
||||||
|
context_length = get_model_context_length(temp_dir)
|
||||||
|
self.assertEqual(context_length, 4096)
|
||||||
|
|
||||||
|
def test_get_model_context_length_with_n_positions(self):
|
||||||
|
"""Test context length extraction from n_positions (GPT-style)."""
|
||||||
|
from mlx_knife.mlx_runner import get_model_context_length
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
config_path = os.path.join(temp_dir, "config.json")
|
||||||
|
config = {
|
||||||
|
"n_positions": 2048,
|
||||||
|
"n_embd": 512,
|
||||||
|
"n_head": 8
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(config_path, 'w') as f:
|
||||||
|
json.dump(config, f)
|
||||||
|
|
||||||
|
context_length = get_model_context_length(temp_dir)
|
||||||
|
self.assertEqual(context_length, 2048)
|
||||||
|
|
||||||
|
def test_get_model_context_length_with_context_length(self):
|
||||||
|
"""Test context length extraction from context_length field."""
|
||||||
|
from mlx_knife.mlx_runner import get_model_context_length
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
config_path = os.path.join(temp_dir, "config.json")
|
||||||
|
config = {
|
||||||
|
"context_length": 8192,
|
||||||
|
"hidden_size": 1024
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(config_path, 'w') as f:
|
||||||
|
json.dump(config, f)
|
||||||
|
|
||||||
|
context_length = get_model_context_length(temp_dir)
|
||||||
|
self.assertEqual(context_length, 8192)
|
||||||
|
|
||||||
|
def test_get_model_context_length_with_max_sequence_length(self):
|
||||||
|
"""Test context length extraction from max_sequence_length."""
|
||||||
|
from mlx_knife.mlx_runner import get_model_context_length
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
config_path = os.path.join(temp_dir, "config.json")
|
||||||
|
config = {
|
||||||
|
"max_sequence_length": 32768,
|
||||||
|
"d_model": 2048
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(config_path, 'w') as f:
|
||||||
|
json.dump(config, f)
|
||||||
|
|
||||||
|
context_length = get_model_context_length(temp_dir)
|
||||||
|
self.assertEqual(context_length, 32768)
|
||||||
|
|
||||||
|
def test_get_model_context_length_with_seq_len(self):
|
||||||
|
"""Test context length extraction from seq_len field."""
|
||||||
|
from mlx_knife.mlx_runner import get_model_context_length
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
config_path = os.path.join(temp_dir, "config.json")
|
||||||
|
config = {
|
||||||
|
"seq_len": 16384,
|
||||||
|
"embedding_size": 1536
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(config_path, 'w') as f:
|
||||||
|
json.dump(config, f)
|
||||||
|
|
||||||
|
context_length = get_model_context_length(temp_dir)
|
||||||
|
self.assertEqual(context_length, 16384)
|
||||||
|
|
||||||
|
def test_get_model_context_length_priority_order(self):
|
||||||
|
"""Test that max_position_embeddings takes priority over other fields."""
|
||||||
|
from mlx_knife.mlx_runner import get_model_context_length
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
config_path = os.path.join(temp_dir, "config.json")
|
||||||
|
config = {
|
||||||
|
"max_position_embeddings": 4096, # Should be used (first in priority)
|
||||||
|
"n_positions": 2048,
|
||||||
|
"context_length": 8192,
|
||||||
|
"max_sequence_length": 16384,
|
||||||
|
"seq_len": 1024
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(config_path, 'w') as f:
|
||||||
|
json.dump(config, f)
|
||||||
|
|
||||||
|
context_length = get_model_context_length(temp_dir)
|
||||||
|
self.assertEqual(context_length, 4096)
|
||||||
|
|
||||||
|
def test_get_model_context_length_missing_config_file(self):
|
||||||
|
"""Test default context length when config.json is missing."""
|
||||||
|
from mlx_knife.mlx_runner import get_model_context_length
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
# No config.json file created
|
||||||
|
context_length = get_model_context_length(temp_dir)
|
||||||
|
self.assertEqual(context_length, 4096) # Default fallback
|
||||||
|
|
||||||
|
def test_get_model_context_length_invalid_json(self):
|
||||||
|
"""Test default context length when config.json is malformed."""
|
||||||
|
from mlx_knife.mlx_runner import get_model_context_length
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
config_path = os.path.join(temp_dir, "config.json")
|
||||||
|
|
||||||
|
# Write invalid JSON
|
||||||
|
with open(config_path, 'w') as f:
|
||||||
|
f.write("{ invalid json content")
|
||||||
|
|
||||||
|
context_length = get_model_context_length(temp_dir)
|
||||||
|
self.assertEqual(context_length, 4096) # Default fallback
|
||||||
|
|
||||||
|
def test_get_model_context_length_empty_config(self):
|
||||||
|
"""Test default context length when config.json has no context fields."""
|
||||||
|
from mlx_knife.mlx_runner import get_model_context_length
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
config_path = os.path.join(temp_dir, "config.json")
|
||||||
|
config = {
|
||||||
|
"hidden_size": 768,
|
||||||
|
"num_attention_heads": 12,
|
||||||
|
"model_type": "test_model"
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(config_path, 'w') as f:
|
||||||
|
json.dump(config, f)
|
||||||
|
|
||||||
|
context_length = get_model_context_length(temp_dir)
|
||||||
|
self.assertEqual(context_length, 4096) # Default fallback
|
||||||
|
|
||||||
|
|
||||||
|
class TestMLXRunnerContextAwareLimits(unittest.TestCase):
|
||||||
|
"""Test MLXRunner context-aware token limits."""
|
||||||
|
|
||||||
|
@patch('mlx_knife.mlx_runner.get_model_context_length')
|
||||||
|
def test_get_effective_max_tokens_interactive_mode(self, mock_get_context):
|
||||||
|
"""Test effective max tokens in interactive mode (uses full context)."""
|
||||||
|
from mlx_knife.mlx_runner import MLXRunner
|
||||||
|
|
||||||
|
mock_get_context.return_value = 4096
|
||||||
|
|
||||||
|
runner = MLXRunner("test_model", verbose=False)
|
||||||
|
runner._context_length = 4096
|
||||||
|
|
||||||
|
# Interactive mode: should use full context length
|
||||||
|
effective = runner.get_effective_max_tokens(8000, interactive=True)
|
||||||
|
self.assertEqual(effective, 4096) # Limited by model context
|
||||||
|
|
||||||
|
effective = runner.get_effective_max_tokens(2000, interactive=True)
|
||||||
|
self.assertEqual(effective, 2000) # User request is smaller
|
||||||
|
|
||||||
|
@patch('mlx_knife.mlx_runner.get_model_context_length')
|
||||||
|
def test_get_effective_max_tokens_server_mode(self, mock_get_context):
|
||||||
|
"""Test effective max tokens in server mode (uses half context for DoS protection)."""
|
||||||
|
from mlx_knife.mlx_runner import MLXRunner
|
||||||
|
|
||||||
|
mock_get_context.return_value = 4096
|
||||||
|
|
||||||
|
runner = MLXRunner("test_model", verbose=False)
|
||||||
|
runner._context_length = 4096
|
||||||
|
|
||||||
|
# Server mode: should use half context length
|
||||||
|
effective = runner.get_effective_max_tokens(8000, interactive=False)
|
||||||
|
self.assertEqual(effective, 2048) # Limited by server limit (4096 / 2)
|
||||||
|
|
||||||
|
effective = runner.get_effective_max_tokens(1000, interactive=False)
|
||||||
|
self.assertEqual(effective, 1000) # User request is smaller
|
||||||
|
|
||||||
|
@patch('mlx_knife.mlx_runner.get_model_context_length')
|
||||||
|
def test_get_effective_max_tokens_no_context_length(self, mock_get_context):
|
||||||
|
"""Test effective max tokens when context length is unknown."""
|
||||||
|
from mlx_knife.mlx_runner import MLXRunner
|
||||||
|
|
||||||
|
runner = MLXRunner("test_model", verbose=False)
|
||||||
|
runner._context_length = None # Context length unknown
|
||||||
|
|
||||||
|
# Should fallback to requested tokens
|
||||||
|
effective = runner.get_effective_max_tokens(1500, interactive=True)
|
||||||
|
self.assertEqual(effective, 1500)
|
||||||
|
|
||||||
|
effective = runner.get_effective_max_tokens(2500, interactive=False)
|
||||||
|
self.assertEqual(effective, 2500)
|
||||||
|
|
||||||
|
@patch('mlx_knife.mlx_runner.get_model_context_length')
|
||||||
|
def test_get_effective_max_tokens_none_interactive_mode(self, mock_get_context):
|
||||||
|
"""Test that None (no --max-tokens) uses full context in interactive mode."""
|
||||||
|
from mlx_knife.mlx_runner import MLXRunner
|
||||||
|
|
||||||
|
mock_get_context.return_value = 4096
|
||||||
|
|
||||||
|
runner = MLXRunner("test_model", verbose=False)
|
||||||
|
runner._context_length = 4096
|
||||||
|
|
||||||
|
# None (user didn't specify --max-tokens) should use full context
|
||||||
|
effective = runner.get_effective_max_tokens(None, interactive=True)
|
||||||
|
self.assertEqual(effective, 4096)
|
||||||
|
|
||||||
|
# Explicit values should still be respected
|
||||||
|
effective = runner.get_effective_max_tokens(500, interactive=True)
|
||||||
|
self.assertEqual(effective, 500) # Now 500 is treated as explicit user choice
|
||||||
|
|
||||||
|
@patch('mlx_knife.mlx_runner.get_model_context_length')
|
||||||
|
def test_get_effective_max_tokens_none_server_mode(self, mock_get_context):
|
||||||
|
"""Test that None uses server default in server mode."""
|
||||||
|
from mlx_knife.mlx_runner import MLXRunner
|
||||||
|
|
||||||
|
mock_get_context.return_value = 4096
|
||||||
|
|
||||||
|
runner = MLXRunner("test_model", verbose=False)
|
||||||
|
runner._context_length = 4096
|
||||||
|
|
||||||
|
# None in server mode should use server limit (context / 2)
|
||||||
|
effective = runner.get_effective_max_tokens(None, interactive=False)
|
||||||
|
self.assertEqual(effective, 2048) # 4096 / 2
|
||||||
|
|
||||||
|
@patch('mlx_knife.mlx_runner.get_model_context_length')
|
||||||
|
def test_get_effective_max_tokens_none_unknown_context(self, mock_get_context):
|
||||||
|
"""Test None behavior when context length is unknown."""
|
||||||
|
from mlx_knife.mlx_runner import MLXRunner
|
||||||
|
|
||||||
|
runner = MLXRunner("test_model", verbose=False)
|
||||||
|
runner._context_length = None
|
||||||
|
|
||||||
|
# Interactive mode: should use 4096 fallback when None
|
||||||
|
effective = runner.get_effective_max_tokens(None, interactive=True)
|
||||||
|
self.assertEqual(effective, 4096)
|
||||||
|
|
||||||
|
# Server mode: should use 2048 fallback when None
|
||||||
|
effective = runner.get_effective_max_tokens(None, interactive=False)
|
||||||
|
self.assertEqual(effective, 2048)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
Reference in New Issue
Block a user